TradingAgents/docs/guides/adding-llm-provider.md

10 KiB

Guide: Adding a New LLM Provider

This guide shows you how to add support for a new LLM provider to TradingAgents.

Overview

Adding a new LLM provider involves:

  1. Installing the provider's LangChain integration
  2. Adding initialization logic
  3. Configuring API keys
  4. Testing the integration
  5. Updating documentation

Step 1: Install LangChain Integration

Most providers have official LangChain integrations:

# Example: Adding Cohere
pip install langchain-cohere

# Example: Adding Mistral
pip install langchain-mistral

# Example: Adding HuggingFace
pip install langchain-huggingface

Add the dependency to requirements.txt:

langchain-cohere>=0.1.0

Step 2: Add Initialization Logic

Modify tradingagents/graph/trading_graph.py:

# Add import at top of file
from langchain_cohere import ChatCohere  # Example for Cohere

class TradingAgentsGraph:
    def __init__(self, selected_analysts=None, debug=False, config=None):
        # ... existing initialization ...

        # Add your provider to the initialization logic
        elif config["llm_provider"].lower() == "cohere":
            self.deep_thinking_llm = ChatCohere(
                model=config["deep_think_llm"],
                cohere_api_key=os.getenv("COHERE_API_KEY")
            )
            self.quick_thinking_llm = ChatCohere(
                model=config["quick_think_llm"],
                cohere_api_key=os.getenv("COHERE_API_KEY")
            )

        # ... rest of initialization ...

Step 3: Configure API Keys

Add Environment Variable

Update .env.example:

# LLM Provider API Keys
OPENAI_API_KEY=your_openai_key_here
ANTHROPIC_API_KEY=your_anthropic_key_here
COHERE_API_KEY=your_cohere_key_here  # NEW

Validate API Key

Add validation in initialization:

elif config["llm_provider"].lower() == "cohere":
    cohere_key = os.getenv("COHERE_API_KEY")
    if not cohere_key:
        raise ValueError(
            "COHERE_API_KEY environment variable is required when using cohere provider. "
            "Set it with: export COHERE_API_KEY=your_key_here"
        )

    self.deep_thinking_llm = ChatCohere(
        model=config["deep_think_llm"],
        cohere_api_key=cohere_key
    )

Step 4: Update Configuration

Add default configuration for your provider:

# In a configuration example or documentation

config = {
    "llm_provider": "cohere",
    "deep_think_llm": "command-r-plus",  # Cohere model for deep thinking
    "quick_think_llm": "command-r",      # Cohere model for quick tasks
    "backend_url": None  # If provider doesn't need custom endpoint
}

Step 5: Handle Provider-Specific Features

Custom Headers

Some providers require specific headers:

elif config["llm_provider"].lower() == "cohere":
    default_headers = {
        "X-Client-Name": "TradingAgents"
    }

    self.deep_thinking_llm = ChatCohere(
        model=config["deep_think_llm"],
        cohere_api_key=cohere_key,
        headers=default_headers
    )

Model Name Formats

Handle provider-specific model naming:

def _format_model_name(self, provider: str, model: str) -> str:
    """Format model name based on provider conventions."""
    if provider == "openrouter":
        # OpenRouter uses "provider/model" format
        return model if "/" in model else f"default/{model}"
    elif provider == "cohere":
        # Cohere uses simple model names
        return model.replace("cohere/", "")
    return model

Rate Limiting

Implement provider-specific rate limit handling:

from tradingagents.utils.exceptions import LLMRateLimitError

try:
    response = llm.invoke(messages)
except Exception as e:
    # Map provider-specific errors to unified exceptions
    if "rate_limit" in str(e).lower():
        raise LLMRateLimitError(
            provider="cohere",
            message=str(e),
            retry_after=60  # Default retry time
        )
    raise

Step 6: Add Error Handling

Create unified error handling for your provider:

# In tradingagents/utils/exceptions.py

class CohereLLMError(LLMError):
    """Cohere-specific LLM errors."""
    provider = "cohere"

def handle_cohere_error(error):
    """Convert Cohere errors to unified exceptions."""
    if "rate limit" in str(error).lower():
        return LLMRateLimitError(
            provider="cohere",
            message=str(error),
            retry_after=extract_retry_time(error)
        )
    return CohereLLMError(str(error))

Step 7: Test Integration

Create tests for your provider:

# tests/integration/test_cohere_provider.py

import pytest
import os
from tradingagents.graph.trading_graph import TradingAgentsGraph
from tradingagents.default_config import DEFAULT_CONFIG

@pytest.fixture
def cohere_config():
    """Configuration for Cohere provider."""
    config = DEFAULT_CONFIG.copy()
    config["llm_provider"] = "cohere"
    config["deep_think_llm"] = "command-r-plus"
    config["quick_think_llm"] = "command-r"
    return config

@pytest.fixture
def mock_env_cohere(monkeypatch):
    """Mock Cohere API key."""
    monkeypatch.setenv("COHERE_API_KEY", "test_key")

def test_cohere_initialization(cohere_config, mock_env_cohere):
    """Test Cohere provider can be initialized."""
    ta = TradingAgentsGraph(config=cohere_config)

    assert ta.deep_thinking_llm is not None
    assert ta.quick_thinking_llm is not None

def test_cohere_missing_api_key(cohere_config):
    """Test error when API key is missing."""
    # Don't set COHERE_API_KEY
    with pytest.raises(ValueError, match="COHERE_API_KEY"):
        TradingAgentsGraph(config=cohere_config)

@pytest.mark.integration
def test_cohere_analysis(cohere_config, mock_env_cohere):
    """Test full analysis with Cohere provider."""
    ta = TradingAgentsGraph(
        selected_analysts=["market"],
        config=cohere_config
    )

    # This requires actual API key
    if not os.getenv("COHERE_API_KEY"):
        pytest.skip("COHERE_API_KEY not set")

    state, decision = ta.propagate("NVDA", "2024-05-10")

    assert decision["action"] in ["BUY", "SELL", "HOLD"]
    assert 0.0 <= decision["confidence_score"] <= 1.0

Step 8: Update Documentation

Update Configuration Guide

Add provider details to docs/guides/configuration.md:

### Cohere Configuration

```python
config = {
    "llm_provider": "cohere",
    "deep_think_llm": "command-r-plus",
    "quick_think_llm": "command-r"
}

Environment:

export COHERE_API_KEY=your_key_here

Models Available:

  • command-r-plus: Advanced reasoning
  • command-r: Fast, cost-effective
  • command: Basic model

### Update LLM Integration Docs

Add to `docs/architecture/llm-integration.md`:

```markdown
### Cohere
- **Models**: Command-R-Plus, Command-R, Command
- **Strengths**: Fast inference, multilingual support
- **Use Case**: Cost-effective alternative with good performance
- **API Key**: `COHERE_API_KEY`
- **Endpoint**: Built-in

Step 9: Add Example Usage

Create example script examples/cohere_example.py:

from tradingagents.graph.trading_graph import TradingAgentsGraph
from tradingagents.default_config import DEFAULT_CONFIG

# Configure for Cohere
config = DEFAULT_CONFIG.copy()
config["llm_provider"] = "cohere"
config["deep_think_llm"] = "command-r-plus"
config["quick_think_llm"] = "command-r"

# Initialize
ta = TradingAgentsGraph(config=config)

# Run analysis
state, decision = ta.propagate("NVDA", "2024-05-10")

print(f"Decision: {decision['action']}")
print(f"Confidence: {decision['confidence_score']:.2%}")

Provider-Specific Considerations

OpenAI-Compatible APIs

For OpenAI-compatible APIs (e.g., local models):

elif config["llm_provider"].lower() == "custom_openai":
    from langchain_openai import ChatOpenAI

    self.deep_thinking_llm = ChatOpenAI(
        model=config["deep_think_llm"],
        base_url=config["backend_url"],  # Custom endpoint
        api_key=os.getenv("CUSTOM_API_KEY")
    )

Azure OpenAI

elif config["llm_provider"].lower() == "azure":
    from langchain_openai import AzureChatOpenAI

    self.deep_thinking_llm = AzureChatOpenAI(
        deployment_name=config["deep_think_llm"],
        azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
        api_key=os.getenv("AZURE_OPENAI_API_KEY"),
        api_version="2024-02-15-preview"
    )

HuggingFace

elif config["llm_provider"].lower() == "huggingface":
    from langchain_huggingface import ChatHuggingFace

    self.deep_thinking_llm = ChatHuggingFace(
        model=config["deep_think_llm"],
        huggingfacehub_api_token=os.getenv("HUGGINGFACE_API_KEY")
    )

Best Practices

  1. Follow LangChain Patterns: Use official LangChain integrations when available
  2. Unified Error Handling: Map provider errors to TradingAgents exceptions
  3. Environment Variables: Always use environment variables for API keys
  4. Validation: Validate API keys before usage
  5. Testing: Write comprehensive tests for the integration
  6. Documentation: Update all relevant documentation
  7. Examples: Provide working examples
  8. Defaults: Set sensible default models
  9. Rate Limits: Implement retry logic
  10. Logging: Add debug logging for troubleshooting

Troubleshooting

Import Errors

Issue: ModuleNotFoundError: No module named 'langchain_cohere'

Solution: Install the provider package

pip install langchain-cohere

API Key Errors

Issue: ValueError: COHERE_API_KEY environment variable is required

Solution: Set the API key

export COHERE_API_KEY=your_key_here

Model Name Errors

Issue: Invalid model name: 'command-r-plus'

Solution: Check provider documentation for correct model names

Rate Limit Handling

Issue: Provider rate limits not being handled

Solution: Implement provider-specific error mapping

except ProviderError as e:
    if "rate limit" in str(e).lower():
        raise LLMRateLimitError(provider="cohere", ...)

See Also