10 KiB
Guide: Adding a New LLM Provider
This guide shows you how to add support for a new LLM provider to TradingAgents.
Overview
Adding a new LLM provider involves:
- Installing the provider's LangChain integration
- Adding initialization logic
- Configuring API keys
- Testing the integration
- Updating documentation
Step 1: Install LangChain Integration
Most providers have official LangChain integrations:
# Example: Adding Cohere
pip install langchain-cohere
# Example: Adding Mistral
pip install langchain-mistral
# Example: Adding HuggingFace
pip install langchain-huggingface
Add the dependency to requirements.txt:
langchain-cohere>=0.1.0
Step 2: Add Initialization Logic
Modify tradingagents/graph/trading_graph.py:
# Add import at top of file
from langchain_cohere import ChatCohere # Example for Cohere
class TradingAgentsGraph:
def __init__(self, selected_analysts=None, debug=False, config=None):
# ... existing initialization ...
# Add your provider to the initialization logic
elif config["llm_provider"].lower() == "cohere":
self.deep_thinking_llm = ChatCohere(
model=config["deep_think_llm"],
cohere_api_key=os.getenv("COHERE_API_KEY")
)
self.quick_thinking_llm = ChatCohere(
model=config["quick_think_llm"],
cohere_api_key=os.getenv("COHERE_API_KEY")
)
# ... rest of initialization ...
Step 3: Configure API Keys
Add Environment Variable
Update .env.example:
# LLM Provider API Keys
OPENAI_API_KEY=your_openai_key_here
ANTHROPIC_API_KEY=your_anthropic_key_here
COHERE_API_KEY=your_cohere_key_here # NEW
Validate API Key
Add validation in initialization:
elif config["llm_provider"].lower() == "cohere":
cohere_key = os.getenv("COHERE_API_KEY")
if not cohere_key:
raise ValueError(
"COHERE_API_KEY environment variable is required when using cohere provider. "
"Set it with: export COHERE_API_KEY=your_key_here"
)
self.deep_thinking_llm = ChatCohere(
model=config["deep_think_llm"],
cohere_api_key=cohere_key
)
Step 4: Update Configuration
Add default configuration for your provider:
# In a configuration example or documentation
config = {
"llm_provider": "cohere",
"deep_think_llm": "command-r-plus", # Cohere model for deep thinking
"quick_think_llm": "command-r", # Cohere model for quick tasks
"backend_url": None # If provider doesn't need custom endpoint
}
Step 5: Handle Provider-Specific Features
Custom Headers
Some providers require specific headers:
elif config["llm_provider"].lower() == "cohere":
default_headers = {
"X-Client-Name": "TradingAgents"
}
self.deep_thinking_llm = ChatCohere(
model=config["deep_think_llm"],
cohere_api_key=cohere_key,
headers=default_headers
)
Model Name Formats
Handle provider-specific model naming:
def _format_model_name(self, provider: str, model: str) -> str:
"""Format model name based on provider conventions."""
if provider == "openrouter":
# OpenRouter uses "provider/model" format
return model if "/" in model else f"default/{model}"
elif provider == "cohere":
# Cohere uses simple model names
return model.replace("cohere/", "")
return model
Rate Limiting
Implement provider-specific rate limit handling:
from tradingagents.utils.exceptions import LLMRateLimitError
try:
response = llm.invoke(messages)
except Exception as e:
# Map provider-specific errors to unified exceptions
if "rate_limit" in str(e).lower():
raise LLMRateLimitError(
provider="cohere",
message=str(e),
retry_after=60 # Default retry time
)
raise
Step 6: Add Error Handling
Create unified error handling for your provider:
# In tradingagents/utils/exceptions.py
class CohereLLMError(LLMError):
"""Cohere-specific LLM errors."""
provider = "cohere"
def handle_cohere_error(error):
"""Convert Cohere errors to unified exceptions."""
if "rate limit" in str(error).lower():
return LLMRateLimitError(
provider="cohere",
message=str(error),
retry_after=extract_retry_time(error)
)
return CohereLLMError(str(error))
Step 7: Test Integration
Create tests for your provider:
# tests/integration/test_cohere_provider.py
import pytest
import os
from tradingagents.graph.trading_graph import TradingAgentsGraph
from tradingagents.default_config import DEFAULT_CONFIG
@pytest.fixture
def cohere_config():
"""Configuration for Cohere provider."""
config = DEFAULT_CONFIG.copy()
config["llm_provider"] = "cohere"
config["deep_think_llm"] = "command-r-plus"
config["quick_think_llm"] = "command-r"
return config
@pytest.fixture
def mock_env_cohere(monkeypatch):
"""Mock Cohere API key."""
monkeypatch.setenv("COHERE_API_KEY", "test_key")
def test_cohere_initialization(cohere_config, mock_env_cohere):
"""Test Cohere provider can be initialized."""
ta = TradingAgentsGraph(config=cohere_config)
assert ta.deep_thinking_llm is not None
assert ta.quick_thinking_llm is not None
def test_cohere_missing_api_key(cohere_config):
"""Test error when API key is missing."""
# Don't set COHERE_API_KEY
with pytest.raises(ValueError, match="COHERE_API_KEY"):
TradingAgentsGraph(config=cohere_config)
@pytest.mark.integration
def test_cohere_analysis(cohere_config, mock_env_cohere):
"""Test full analysis with Cohere provider."""
ta = TradingAgentsGraph(
selected_analysts=["market"],
config=cohere_config
)
# This requires actual API key
if not os.getenv("COHERE_API_KEY"):
pytest.skip("COHERE_API_KEY not set")
state, decision = ta.propagate("NVDA", "2024-05-10")
assert decision["action"] in ["BUY", "SELL", "HOLD"]
assert 0.0 <= decision["confidence_score"] <= 1.0
Step 8: Update Documentation
Update Configuration Guide
Add provider details to docs/guides/configuration.md:
### Cohere Configuration
```python
config = {
"llm_provider": "cohere",
"deep_think_llm": "command-r-plus",
"quick_think_llm": "command-r"
}
Environment:
export COHERE_API_KEY=your_key_here
Models Available:
- command-r-plus: Advanced reasoning
- command-r: Fast, cost-effective
- command: Basic model
### Update LLM Integration Docs
Add to `docs/architecture/llm-integration.md`:
```markdown
### Cohere
- **Models**: Command-R-Plus, Command-R, Command
- **Strengths**: Fast inference, multilingual support
- **Use Case**: Cost-effective alternative with good performance
- **API Key**: `COHERE_API_KEY`
- **Endpoint**: Built-in
Step 9: Add Example Usage
Create example script examples/cohere_example.py:
from tradingagents.graph.trading_graph import TradingAgentsGraph
from tradingagents.default_config import DEFAULT_CONFIG
# Configure for Cohere
config = DEFAULT_CONFIG.copy()
config["llm_provider"] = "cohere"
config["deep_think_llm"] = "command-r-plus"
config["quick_think_llm"] = "command-r"
# Initialize
ta = TradingAgentsGraph(config=config)
# Run analysis
state, decision = ta.propagate("NVDA", "2024-05-10")
print(f"Decision: {decision['action']}")
print(f"Confidence: {decision['confidence_score']:.2%}")
Provider-Specific Considerations
OpenAI-Compatible APIs
For OpenAI-compatible APIs (e.g., local models):
elif config["llm_provider"].lower() == "custom_openai":
from langchain_openai import ChatOpenAI
self.deep_thinking_llm = ChatOpenAI(
model=config["deep_think_llm"],
base_url=config["backend_url"], # Custom endpoint
api_key=os.getenv("CUSTOM_API_KEY")
)
Azure OpenAI
elif config["llm_provider"].lower() == "azure":
from langchain_openai import AzureChatOpenAI
self.deep_thinking_llm = AzureChatOpenAI(
deployment_name=config["deep_think_llm"],
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
api_version="2024-02-15-preview"
)
HuggingFace
elif config["llm_provider"].lower() == "huggingface":
from langchain_huggingface import ChatHuggingFace
self.deep_thinking_llm = ChatHuggingFace(
model=config["deep_think_llm"],
huggingfacehub_api_token=os.getenv("HUGGINGFACE_API_KEY")
)
Best Practices
- Follow LangChain Patterns: Use official LangChain integrations when available
- Unified Error Handling: Map provider errors to TradingAgents exceptions
- Environment Variables: Always use environment variables for API keys
- Validation: Validate API keys before usage
- Testing: Write comprehensive tests for the integration
- Documentation: Update all relevant documentation
- Examples: Provide working examples
- Defaults: Set sensible default models
- Rate Limits: Implement retry logic
- Logging: Add debug logging for troubleshooting
Troubleshooting
Import Errors
Issue: ModuleNotFoundError: No module named 'langchain_cohere'
Solution: Install the provider package
pip install langchain-cohere
API Key Errors
Issue: ValueError: COHERE_API_KEY environment variable is required
Solution: Set the API key
export COHERE_API_KEY=your_key_here
Model Name Errors
Issue: Invalid model name: 'command-r-plus'
Solution: Check provider documentation for correct model names
Rate Limit Handling
Issue: Provider rate limits not being handled
Solution: Implement provider-specific error mapping
except ProviderError as e:
if "rate limit" in str(e).lower():
raise LLMRateLimitError(provider="cohere", ...)