Compare commits
5 Commits
f22f4a0ff5
...
9350e4544f
| Author | SHA1 | Date |
|---|---|---|
|
|
9350e4544f | |
|
|
fa4d01c23a | |
|
|
b0f6058299 | |
|
|
59d6b2152d | |
|
|
f37c751a3c |
|
|
@ -0,0 +1,5 @@
|
|||
# Azure OpenAI
|
||||
AZURE_OPENAI_API_KEY=
|
||||
AZURE_OPENAI_ENDPOINT=https://your-resource-name.openai.azure.com/
|
||||
AZURE_OPENAI_DEPLOYMENT_NAME=
|
||||
# OPENAI_API_VERSION=2024-10-21 # optional, required for non-v1 API
|
||||
|
|
@ -3,4 +3,7 @@ OPENAI_API_KEY=
|
|||
GOOGLE_API_KEY=
|
||||
ANTHROPIC_API_KEY=
|
||||
XAI_API_KEY=
|
||||
DEEPSEEK_API_KEY=
|
||||
DASHSCOPE_API_KEY=
|
||||
ZHIPU_API_KEY=
|
||||
OPENROUTER_API_KEY=
|
||||
|
|
|
|||
|
|
@ -140,10 +140,15 @@ export OPENAI_API_KEY=... # OpenAI (GPT)
|
|||
export GOOGLE_API_KEY=... # Google (Gemini)
|
||||
export ANTHROPIC_API_KEY=... # Anthropic (Claude)
|
||||
export XAI_API_KEY=... # xAI (Grok)
|
||||
export DEEPSEEK_API_KEY=... # DeepSeek
|
||||
export DASHSCOPE_API_KEY=... # Qwen (Alibaba DashScope)
|
||||
export ZHIPU_API_KEY=... # GLM (Zhipu)
|
||||
export OPENROUTER_API_KEY=... # OpenRouter
|
||||
export ALPHA_VANTAGE_API_KEY=... # Alpha Vantage
|
||||
```
|
||||
|
||||
For enterprise providers (e.g. Azure OpenAI, AWS Bedrock), copy `.env.enterprise.example` to `.env.enterprise` and fill in your credentials.
|
||||
|
||||
For local models, configure Ollama with `llm_provider: "ollama"` in your config.
|
||||
|
||||
Alternatively, copy `.env.example` to `.env` and fill in your keys:
|
||||
|
|
|
|||
43
cli/main.py
43
cli/main.py
|
|
@ -6,8 +6,9 @@ from functools import wraps
|
|||
from rich.console import Console
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load environment variables from .env file
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
load_dotenv(".env.enterprise", override=False)
|
||||
from rich.panel import Panel
|
||||
from rich.spinner import Spinner
|
||||
from rich.live import Live
|
||||
|
|
@ -79,7 +80,7 @@ class MessageBuffer:
|
|||
self.current_agent = None
|
||||
self.report_sections = {}
|
||||
self.selected_analysts = []
|
||||
self._last_message_id = None
|
||||
self._processed_message_ids = set()
|
||||
|
||||
def init_for_analysis(self, selected_analysts):
|
||||
"""Initialize agent status and report sections based on selected analysts.
|
||||
|
|
@ -114,7 +115,7 @@ class MessageBuffer:
|
|||
self.current_agent = None
|
||||
self.messages.clear()
|
||||
self.tool_calls.clear()
|
||||
self._last_message_id = None
|
||||
self._processed_message_ids.clear()
|
||||
|
||||
def get_completed_reports_count(self):
|
||||
"""Count reports that are finalized (their finalizing agent is completed).
|
||||
|
|
@ -1052,28 +1053,24 @@ def run_analysis():
|
|||
# Stream the analysis
|
||||
trace = []
|
||||
for chunk in graph.graph.stream(init_agent_state, **args):
|
||||
# Process messages if present (skip duplicates via message ID)
|
||||
if len(chunk["messages"]) > 0:
|
||||
last_message = chunk["messages"][-1]
|
||||
msg_id = getattr(last_message, "id", None)
|
||||
# Process all messages in chunk, deduplicating by message ID
|
||||
for message in chunk.get("messages", []):
|
||||
msg_id = getattr(message, "id", None)
|
||||
if msg_id is not None:
|
||||
if msg_id in message_buffer._processed_message_ids:
|
||||
continue
|
||||
message_buffer._processed_message_ids.add(msg_id)
|
||||
|
||||
if msg_id != message_buffer._last_message_id:
|
||||
message_buffer._last_message_id = msg_id
|
||||
msg_type, content = classify_message_type(message)
|
||||
if content and content.strip():
|
||||
message_buffer.add_message(msg_type, content)
|
||||
|
||||
# Add message to buffer
|
||||
msg_type, content = classify_message_type(last_message)
|
||||
if content and content.strip():
|
||||
message_buffer.add_message(msg_type, content)
|
||||
|
||||
# Handle tool calls
|
||||
if hasattr(last_message, "tool_calls") and last_message.tool_calls:
|
||||
for tool_call in last_message.tool_calls:
|
||||
if isinstance(tool_call, dict):
|
||||
message_buffer.add_tool_call(
|
||||
tool_call["name"], tool_call["args"]
|
||||
)
|
||||
else:
|
||||
message_buffer.add_tool_call(tool_call.name, tool_call.args)
|
||||
if hasattr(message, "tool_calls") and message.tool_calls:
|
||||
for tool_call in message.tool_calls:
|
||||
if isinstance(tool_call, dict):
|
||||
message_buffer.add_tool_call(tool_call["name"], tool_call["args"])
|
||||
else:
|
||||
message_buffer.add_tool_call(tool_call.name, tool_call.args)
|
||||
|
||||
# Update analyst statuses based on report state (runs on every chunk)
|
||||
update_analyst_statuses(message_buffer, chunk)
|
||||
|
|
|
|||
92
cli/utils.py
92
cli/utils.py
|
|
@ -174,17 +174,30 @@ def select_openrouter_model() -> str:
|
|||
return choice
|
||||
|
||||
|
||||
def select_shallow_thinking_agent(provider) -> str:
|
||||
"""Select shallow thinking llm engine using an interactive selection."""
|
||||
def _prompt_custom_model_id() -> str:
|
||||
"""Prompt user to type a custom model ID."""
|
||||
return questionary.text(
|
||||
"Enter model ID:",
|
||||
validate=lambda x: len(x.strip()) > 0 or "Please enter a model ID.",
|
||||
).ask().strip()
|
||||
|
||||
|
||||
def _select_model(provider: str, mode: str) -> str:
|
||||
"""Select a model for the given provider and mode (quick/deep)."""
|
||||
if provider.lower() == "openrouter":
|
||||
return select_openrouter_model()
|
||||
|
||||
if provider.lower() == "azure":
|
||||
return questionary.text(
|
||||
f"Enter Azure deployment name ({mode}-thinking):",
|
||||
validate=lambda x: len(x.strip()) > 0 or "Please enter a deployment name.",
|
||||
).ask().strip()
|
||||
|
||||
choice = questionary.select(
|
||||
"Select Your [Quick-Thinking LLM Engine]:",
|
||||
f"Select Your [{mode.title()}-Thinking LLM Engine]:",
|
||||
choices=[
|
||||
questionary.Choice(display, value=value)
|
||||
for display, value in get_model_options(provider, "quick")
|
||||
for display, value in get_model_options(provider, mode)
|
||||
],
|
||||
instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
|
||||
style=questionary.Style(
|
||||
|
|
@ -197,58 +210,45 @@ def select_shallow_thinking_agent(provider) -> str:
|
|||
).ask()
|
||||
|
||||
if choice is None:
|
||||
console.print(
|
||||
"\n[red]No shallow thinking llm engine selected. Exiting...[/red]"
|
||||
)
|
||||
console.print(f"\n[red]No {mode} thinking llm engine selected. Exiting...[/red]")
|
||||
exit(1)
|
||||
|
||||
if choice == "custom":
|
||||
return _prompt_custom_model_id()
|
||||
|
||||
return choice
|
||||
|
||||
|
||||
def select_shallow_thinking_agent(provider) -> str:
|
||||
"""Select shallow thinking llm engine using an interactive selection."""
|
||||
return _select_model(provider, "quick")
|
||||
|
||||
|
||||
def select_deep_thinking_agent(provider) -> str:
|
||||
"""Select deep thinking llm engine using an interactive selection."""
|
||||
|
||||
if provider.lower() == "openrouter":
|
||||
return select_openrouter_model()
|
||||
|
||||
choice = questionary.select(
|
||||
"Select Your [Deep-Thinking LLM Engine]:",
|
||||
choices=[
|
||||
questionary.Choice(display, value=value)
|
||||
for display, value in get_model_options(provider, "deep")
|
||||
],
|
||||
instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
|
||||
style=questionary.Style(
|
||||
[
|
||||
("selected", "fg:magenta noinherit"),
|
||||
("highlighted", "fg:magenta noinherit"),
|
||||
("pointer", "fg:magenta noinherit"),
|
||||
]
|
||||
),
|
||||
).ask()
|
||||
|
||||
if choice is None:
|
||||
console.print("\n[red]No deep thinking llm engine selected. Exiting...[/red]")
|
||||
exit(1)
|
||||
|
||||
return choice
|
||||
return _select_model(provider, "deep")
|
||||
|
||||
def select_llm_provider() -> tuple[str, str | None]:
|
||||
"""Select the LLM provider and its API endpoint."""
|
||||
BASE_URLS = [
|
||||
("OpenAI", "https://api.openai.com/v1"),
|
||||
("Google", None), # google-genai SDK manages its own endpoint
|
||||
("Anthropic", "https://api.anthropic.com/"),
|
||||
("xAI", "https://api.x.ai/v1"),
|
||||
("Openrouter", "https://openrouter.ai/api/v1"),
|
||||
("Ollama", "http://localhost:11434/v1"),
|
||||
# (display_name, provider_key, base_url)
|
||||
PROVIDERS = [
|
||||
("OpenAI", "openai", "https://api.openai.com/v1"),
|
||||
("Google", "google", None),
|
||||
("Anthropic", "anthropic", "https://api.anthropic.com/"),
|
||||
("xAI", "xai", "https://api.x.ai/v1"),
|
||||
("DeepSeek", "deepseek", "https://api.deepseek.com"),
|
||||
("Qwen", "qwen", "https://dashscope.aliyuncs.com/compatible-mode/v1"),
|
||||
("GLM", "glm", "https://open.bigmodel.cn/api/paas/v4/"),
|
||||
("OpenRouter", "openrouter", "https://openrouter.ai/api/v1"),
|
||||
("Azure OpenAI", "azure", None),
|
||||
("Ollama", "ollama", "http://localhost:11434/v1"),
|
||||
]
|
||||
|
||||
|
||||
choice = questionary.select(
|
||||
"Select your LLM Provider:",
|
||||
choices=[
|
||||
questionary.Choice(display, value=(display, value))
|
||||
for display, value in BASE_URLS
|
||||
questionary.Choice(display, value=(provider_key, url))
|
||||
for display, provider_key, url in PROVIDERS
|
||||
],
|
||||
instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
|
||||
style=questionary.Style(
|
||||
|
|
@ -261,13 +261,11 @@ def select_llm_provider() -> tuple[str, str | None]:
|
|||
).ask()
|
||||
|
||||
if choice is None:
|
||||
console.print("\n[red]no OpenAI backend selected. Exiting...[/red]")
|
||||
console.print("\n[red]No LLM provider selected. Exiting...[/red]")
|
||||
exit(1)
|
||||
|
||||
display_name, url = choice
|
||||
print(f"You selected: {display_name}\tURL: {url}")
|
||||
|
||||
return display_name, url
|
||||
provider, url = choice
|
||||
return provider, url
|
||||
|
||||
|
||||
def ask_openai_reasoning_effort() -> str:
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ services:
|
|||
env_file:
|
||||
- .env
|
||||
volumes:
|
||||
- ./results:/home/appuser/app/results
|
||||
- tradingagents_data:/home/appuser/.tradingagents
|
||||
tty: true
|
||||
stdin_open: true
|
||||
|
||||
|
|
@ -22,7 +22,7 @@ services:
|
|||
environment:
|
||||
- LLM_PROVIDER=ollama
|
||||
volumes:
|
||||
- ./results:/home/appuser/app/results
|
||||
- tradingagents_data:/home/appuser/.tradingagents
|
||||
depends_on:
|
||||
- ollama
|
||||
tty: true
|
||||
|
|
@ -31,4 +31,5 @@ services:
|
|||
- ollama
|
||||
|
||||
volumes:
|
||||
tradingagents_data:
|
||||
ollama_data:
|
||||
|
|
|
|||
|
|
@ -32,9 +32,42 @@ dependencies = [
|
|||
"yfinance>=0.2.63",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"pytest>=8.0.0",
|
||||
"pytest-cov>=4.1.0",
|
||||
"ruff>=0.4.0",
|
||||
"mypy>=1.9.0",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
tradingagents = "cli.main:app"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests"]
|
||||
python_files = ["test_*.py"]
|
||||
python_functions = ["test_*"]
|
||||
markers = [
|
||||
"unit: Unit tests (fast, isolated)",
|
||||
"integration: Integration tests (may require API keys)",
|
||||
"slow: Slow running tests",
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py310"
|
||||
line-length = 100
|
||||
select = ["E", "F", "I", "N", "W", "UP", "B", "C4", "DTZ", "ISC", "PIE", "PT", "RET", "SIM", "TCH", "ARG"]
|
||||
|
||||
[tool.ruff.per-file-ignores]
|
||||
"tests/*" = ["ARG"]
|
||||
|
||||
[tool.mypy]
|
||||
python_version = "3.10"
|
||||
strict = true
|
||||
warn_return_any = true
|
||||
warn_unused_configs = true
|
||||
ignore_missing_imports = true
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
include = ["tradingagents*", "cli*"]
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1 @@
|
|||
"""Tests for the TradingAgents framework."""
|
||||
|
|
@ -0,0 +1,128 @@
|
|||
"""Pytest configuration and shared fixtures for TradingAgents tests."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm():
|
||||
"""Mock LLM for testing without API calls."""
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_agent_state():
|
||||
"""Sample agent state for testing.
|
||||
|
||||
Returns:
|
||||
Dictionary with AgentState fields for use in tests.
|
||||
"""
|
||||
return {
|
||||
"company_of_interest": "AAPL",
|
||||
"trade_date": "2024-01-15",
|
||||
"messages": [],
|
||||
"sender": "",
|
||||
"market_report": "",
|
||||
"sentiment_report": "",
|
||||
"news_report": "",
|
||||
"fundamentals_report": "",
|
||||
"investment_debate_state": {
|
||||
"bull_history": "",
|
||||
"bear_history": "",
|
||||
"history": "",
|
||||
"current_response": "",
|
||||
"judge_decision": "",
|
||||
"count": 0,
|
||||
},
|
||||
"investment_plan": "",
|
||||
"trader_investment_plan": "",
|
||||
"risk_debate_state": {
|
||||
"aggressive_history": "",
|
||||
"conservative_history": "",
|
||||
"neutral_history": "",
|
||||
"history": "",
|
||||
"latest_speaker": "",
|
||||
"current_aggressive_response": "",
|
||||
"current_conservative_response": "",
|
||||
"current_neutral_response": "",
|
||||
"judge_decision": "",
|
||||
"count": 0,
|
||||
},
|
||||
"final_trade_decision": "",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_market_data():
|
||||
"""Sample market data for testing.
|
||||
|
||||
Returns:
|
||||
Dictionary with OHLCV market data for use in tests.
|
||||
"""
|
||||
return {
|
||||
"ticker": "AAPL",
|
||||
"date": "2024-01-15",
|
||||
"open": 185.0,
|
||||
"high": 187.5,
|
||||
"low": 184.2,
|
||||
"close": 186.5,
|
||||
"volume": 50000000,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_config():
|
||||
"""Sample configuration for testing.
|
||||
|
||||
Returns:
|
||||
Dictionary with default config values for use in tests.
|
||||
"""
|
||||
return {
|
||||
"project_dir": "/tmp/tradingagents",
|
||||
"results_dir": "/tmp/results",
|
||||
"data_cache_dir": "/tmp/data_cache",
|
||||
"llm_provider": "openai",
|
||||
"deep_think_llm": "gpt-4o",
|
||||
"quick_think_llm": "gpt-4o-mini",
|
||||
"backend_url": "https://api.openai.com/v1",
|
||||
"google_thinking_level": None,
|
||||
"openai_reasoning_effort": None,
|
||||
"max_debate_rounds": 1,
|
||||
"max_risk_discuss_rounds": 1,
|
||||
"max_recur_limit": 100,
|
||||
"data_vendors": {
|
||||
"core_stock_apis": "yfinance",
|
||||
"technical_indicators": "yfinance",
|
||||
"fundamental_data": "yfinance",
|
||||
"news_data": "yfinance",
|
||||
},
|
||||
"tool_vendors": {},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_situations():
|
||||
"""Sample financial situations for memory testing.
|
||||
|
||||
Returns:
|
||||
List of (situation, recommendation) tuples.
|
||||
"""
|
||||
return [
|
||||
(
|
||||
"High volatility in tech sector with increasing institutional selling",
|
||||
"Reduce exposure to high-growth tech stocks. Consider defensive positions.",
|
||||
),
|
||||
(
|
||||
"Strong earnings report beating expectations with raised guidance",
|
||||
"Consider buying on any pullbacks. Monitor for momentum continuation.",
|
||||
),
|
||||
(
|
||||
"Rising interest rates affecting growth stock valuations",
|
||||
"Review duration of fixed-income positions. Consider value stocks.",
|
||||
),
|
||||
(
|
||||
"Market showing signs of sector rotation with rising yields",
|
||||
"Rebalance portfolio to maintain target allocations.",
|
||||
),
|
||||
]
|
||||
|
|
@ -0,0 +1 @@
|
|||
"""Integration tests for TradingAgents framework."""
|
||||
|
|
@ -0,0 +1,169 @@
|
|||
"""Integration tests for TradingAgents graph workflow."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from tradingagents.default_config import DEFAULT_CONFIG
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestFullWorkflow:
|
||||
"""Integration tests for the full trading workflow."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config(self):
|
||||
"""Create a mock configuration for testing."""
|
||||
config = DEFAULT_CONFIG.copy()
|
||||
config["deep_think_llm"] = "gpt-4o-mini"
|
||||
config["quick_think_llm"] = "gpt-4o-mini"
|
||||
return config
|
||||
|
||||
@pytest.mark.skip(reason="Requires API keys")
|
||||
def test_propagate_returns_decision(self, mock_config):
|
||||
"""Integration test requiring live API keys."""
|
||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||
|
||||
ta = TradingAgentsGraph(debug=True, config=mock_config)
|
||||
state, decision = ta.propagate("AAPL", "2024-01-15")
|
||||
assert decision is not None
|
||||
assert "final_trade_decision" in state
|
||||
|
||||
@patch("tradingagents.graph.trading_graph.create_llm_client")
|
||||
def test_graph_initialization(self, mock_create_client, mock_config):
|
||||
"""Test graph initializes without errors."""
|
||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_create_client.return_value.get_llm.return_value = mock_llm
|
||||
|
||||
ta = TradingAgentsGraph(
|
||||
selected_analysts=["market"],
|
||||
debug=True,
|
||||
config=mock_config
|
||||
)
|
||||
assert ta.graph is not None
|
||||
|
||||
@patch("tradingagents.graph.trading_graph.create_llm_client")
|
||||
def test_graph_initialization_all_analysts(self, mock_create_client, mock_config):
|
||||
"""Test graph initializes with all analysts."""
|
||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_create_client.return_value.get_llm.return_value = mock_llm
|
||||
|
||||
ta = TradingAgentsGraph(
|
||||
selected_analysts=["market", "news", "fundamentals", "social"],
|
||||
debug=True,
|
||||
config=mock_config
|
||||
)
|
||||
assert ta.graph is not None
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestGraphSetup:
|
||||
"""Integration tests for graph setup."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config(self):
|
||||
"""Create a mock configuration for testing."""
|
||||
return DEFAULT_CONFIG.copy()
|
||||
|
||||
@patch("tradingagents.graph.trading_graph.create_llm_client")
|
||||
def test_setup_creates_nodes(self, mock_create_client, mock_config):
|
||||
"""Test that setup creates all required nodes."""
|
||||
from tradingagents.graph.conditional_logic import ConditionalLogic
|
||||
|
||||
mock_create_client.return_value.get_llm.return_value = MagicMock()
|
||||
|
||||
ConditionalLogic(
|
||||
max_debate_rounds=mock_config["max_debate_rounds"],
|
||||
max_risk_discuss_rounds=mock_config["max_risk_discuss_rounds"]
|
||||
)
|
||||
# GraphSetup should be instantiable
|
||||
# Actual node creation depends on internal implementation
|
||||
|
||||
def test_conditional_logic_instance(self, mock_config):
|
||||
"""Test that ConditionalLogic is instantiable."""
|
||||
from tradingagents.graph.conditional_logic import ConditionalLogic
|
||||
|
||||
logic = ConditionalLogic(
|
||||
max_debate_rounds=mock_config["max_debate_rounds"],
|
||||
max_risk_discuss_rounds=mock_config["max_risk_discuss_rounds"]
|
||||
)
|
||||
|
||||
assert logic.max_debate_rounds == mock_config["max_debate_rounds"]
|
||||
assert logic.max_risk_discuss_rounds == mock_config["max_risk_discuss_rounds"]
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestAgentInitialization:
|
||||
"""Integration tests for agent initialization."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm(self):
|
||||
"""Create a mock LLM for testing."""
|
||||
return MagicMock()
|
||||
|
||||
def test_market_analyst_creation(self, mock_llm):
|
||||
"""Test that market analyst can be created."""
|
||||
from tradingagents.agents.analysts.market_analyst import create_market_analyst
|
||||
|
||||
analyst = create_market_analyst(mock_llm)
|
||||
assert callable(analyst)
|
||||
|
||||
def test_news_analyst_creation(self, mock_llm):
|
||||
"""Test that news analyst can be created."""
|
||||
from tradingagents.agents.analysts.news_analyst import create_news_analyst
|
||||
|
||||
analyst = create_news_analyst(mock_llm)
|
||||
assert callable(analyst)
|
||||
|
||||
def test_fundamentals_analyst_creation(self, mock_llm):
|
||||
"""Test that fundamentals analyst can be created."""
|
||||
from tradingagents.agents.analysts.fundamentals_analyst import create_fundamentals_analyst
|
||||
|
||||
analyst = create_fundamentals_analyst(mock_llm)
|
||||
assert callable(analyst)
|
||||
|
||||
def test_bull_researcher_creation(self, mock_llm):
|
||||
"""Test that bull researcher can be created."""
|
||||
from tradingagents.agents.researchers.bull_researcher import create_bull_researcher
|
||||
from tradingagents.agents.utils.memory import FinancialSituationMemory
|
||||
|
||||
memory = FinancialSituationMemory("bull_memory")
|
||||
researcher = create_bull_researcher(mock_llm, memory)
|
||||
assert callable(researcher)
|
||||
|
||||
def test_bear_researcher_creation(self, mock_llm):
|
||||
"""Test that bear researcher can be created."""
|
||||
from tradingagents.agents.researchers.bear_researcher import create_bear_researcher
|
||||
from tradingagents.agents.utils.memory import FinancialSituationMemory
|
||||
|
||||
memory = FinancialSituationMemory("bear_memory")
|
||||
researcher = create_bear_researcher(mock_llm, memory)
|
||||
assert callable(researcher)
|
||||
|
||||
def test_trader_creation(self, mock_llm):
|
||||
"""Test that trader can be created."""
|
||||
from tradingagents.agents.trader.trader import create_trader
|
||||
from tradingagents.agents.utils.memory import FinancialSituationMemory
|
||||
|
||||
memory = FinancialSituationMemory("trader_memory")
|
||||
trader = create_trader(mock_llm, memory)
|
||||
assert callable(trader)
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestReflection:
|
||||
"""Integration tests for reflection system."""
|
||||
|
||||
def test_reflector_creation(self):
|
||||
"""Test that Reflector can be created."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from tradingagents.graph.reflection import Reflector
|
||||
|
||||
mock_llm = MagicMock()
|
||||
reflector = Reflector(mock_llm)
|
||||
assert reflector.quick_thinking_llm is not None
|
||||
|
|
@ -0,0 +1 @@
|
|||
"""Tests for agent modules."""
|
||||
|
|
@ -0,0 +1 @@
|
|||
"""Tests for dataflow modules."""
|
||||
|
|
@ -0,0 +1,198 @@
|
|||
"""Unit tests for data interface routing."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from tradingagents.dataflows.interface import (
|
||||
TOOLS_CATEGORIES,
|
||||
VENDOR_LIST,
|
||||
VENDOR_METHODS,
|
||||
get_category_for_method,
|
||||
get_vendor,
|
||||
route_to_vendor,
|
||||
)
|
||||
|
||||
|
||||
class TestToolsCategories:
|
||||
"""Tests for TOOLS_CATEGORIES structure."""
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_core_stock_apis_category_exists(self):
|
||||
"""Test that core_stock_apis category exists."""
|
||||
assert "core_stock_apis" in TOOLS_CATEGORIES
|
||||
assert "get_stock_data" in TOOLS_CATEGORIES["core_stock_apis"]["tools"]
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_technical_indicators_category_exists(self):
|
||||
"""Test that technical_indicators category exists."""
|
||||
assert "technical_indicators" in TOOLS_CATEGORIES
|
||||
assert "get_indicators" in TOOLS_CATEGORIES["technical_indicators"]["tools"]
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_fundamental_data_category_exists(self):
|
||||
"""Test that fundamental_data category exists."""
|
||||
assert "fundamental_data" in TOOLS_CATEGORIES
|
||||
expected_tools = [
|
||||
"get_fundamentals",
|
||||
"get_balance_sheet",
|
||||
"get_cashflow",
|
||||
"get_income_statement",
|
||||
]
|
||||
for tool in expected_tools:
|
||||
assert tool in TOOLS_CATEGORIES["fundamental_data"]["tools"]
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_news_data_category_exists(self):
|
||||
"""Test that news_data category exists."""
|
||||
assert "news_data" in TOOLS_CATEGORIES
|
||||
expected_tools = ["get_news", "get_global_news", "get_insider_transactions"]
|
||||
for tool in expected_tools:
|
||||
assert tool in TOOLS_CATEGORIES["news_data"]["tools"]
|
||||
|
||||
|
||||
class TestVendorList:
|
||||
"""Tests for VENDOR_LIST."""
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_yfinance_in_vendor_list(self):
|
||||
"""Test that yfinance is in vendor list."""
|
||||
assert "yfinance" in VENDOR_LIST
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_alpha_vantage_in_vendor_list(self):
|
||||
"""Test that alpha_vantage is in vendor list."""
|
||||
assert "alpha_vantage" in VENDOR_LIST
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_vendor_list_length(self):
|
||||
"""Test vendor list contains expected number of vendors."""
|
||||
assert len(VENDOR_LIST) == 2
|
||||
|
||||
|
||||
class TestGetCategoryForMethod:
|
||||
"""Tests for get_category_for_method function."""
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_category_for_stock_data(self):
|
||||
"""Test category for get_stock_data."""
|
||||
category = get_category_for_method("get_stock_data")
|
||||
assert category == "core_stock_apis"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_category_for_indicators(self):
|
||||
"""Test category for get_indicators."""
|
||||
category = get_category_for_method("get_indicators")
|
||||
assert category == "technical_indicators"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_category_for_fundamentals(self):
|
||||
"""Test category for get_fundamentals."""
|
||||
category = get_category_for_method("get_fundamentals")
|
||||
assert category == "fundamental_data"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_category_for_news(self):
|
||||
"""Test category for get_news."""
|
||||
category = get_category_for_method("get_news")
|
||||
assert category == "news_data"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_category_for_invalid_method_raises(self):
|
||||
"""Test that invalid method raises ValueError."""
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
get_category_for_method("invalid_method")
|
||||
|
||||
|
||||
class TestGetVendor:
|
||||
"""Tests for get_vendor function."""
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("tradingagents.dataflows.interface.get_config")
|
||||
def test_get_vendor_default(self, mock_get_config):
|
||||
"""Test getting default vendor for a category."""
|
||||
mock_get_config.return_value = {
|
||||
"data_vendors": {"core_stock_apis": "yfinance"},
|
||||
"tool_vendors": {},
|
||||
}
|
||||
|
||||
vendor = get_vendor("core_stock_apis")
|
||||
assert vendor == "yfinance"
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("tradingagents.dataflows.interface.get_config")
|
||||
def test_get_vendor_tool_level_override(self, mock_get_config):
|
||||
"""Test that tool-level vendor takes precedence."""
|
||||
mock_get_config.return_value = {
|
||||
"data_vendors": {"core_stock_apis": "yfinance"},
|
||||
"tool_vendors": {"get_stock_data": "alpha_vantage"},
|
||||
}
|
||||
|
||||
vendor = get_vendor("core_stock_apis", "get_stock_data")
|
||||
assert vendor == "alpha_vantage"
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("tradingagents.dataflows.interface.get_config")
|
||||
def test_get_vendor_missing_category_uses_default(self, mock_get_config):
|
||||
"""Test that missing category returns 'default'."""
|
||||
mock_get_config.return_value = {
|
||||
"data_vendors": {},
|
||||
"tool_vendors": {},
|
||||
}
|
||||
|
||||
vendor = get_vendor("unknown_category")
|
||||
assert vendor == "default"
|
||||
|
||||
|
||||
class TestVendorMethods:
|
||||
"""Tests for VENDOR_METHODS structure."""
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_stock_data_has_both_vendors(self):
|
||||
"""Test that get_stock_data has both vendors."""
|
||||
assert "yfinance" in VENDOR_METHODS["get_stock_data"]
|
||||
assert "alpha_vantage" in VENDOR_METHODS["get_stock_data"]
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_all_methods_have_vendors(self):
|
||||
"""Test that all methods have at least one vendor."""
|
||||
for method, vendors in VENDOR_METHODS.items():
|
||||
assert len(vendors) > 0, f"Method {method} has no vendors"
|
||||
|
||||
|
||||
class TestRouteToVendor:
|
||||
"""Tests for route_to_vendor function."""
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("tradingagents.dataflows.interface.get_config")
|
||||
def test_route_to_vendor_invalid_method_raises(self, mock_get_config):
|
||||
"""Test that routing invalid method raises ValueError."""
|
||||
mock_get_config.return_value = {"data_vendors": {}, "tool_vendors": {}}
|
||||
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
route_to_vendor("invalid_method", "AAPL")
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("tradingagents.dataflows.interface.get_config")
|
||||
@patch("tradingagents.dataflows.interface.VENDOR_METHODS")
|
||||
def test_route_to_vendor_fallback_on_rate_limit(self, mock_methods, mock_get_config):
|
||||
"""Test that vendor fallback works on rate limit errors."""
|
||||
mock_get_config.return_value = {
|
||||
"data_vendors": {"core_stock_apis": "alpha_vantage"},
|
||||
"tool_vendors": {},
|
||||
}
|
||||
|
||||
# This test would need proper mocking of the actual vendor functions
|
||||
# For now, we just verify the function signature exists
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("tradingagents.dataflows.interface.get_config")
|
||||
def test_route_to_vendor_no_available_vendor_raises(self, mock_get_config):
|
||||
"""Test that no available vendor raises RuntimeError."""
|
||||
mock_get_config.return_value = {
|
||||
"data_vendors": {"core_stock_apis": "nonexistent_vendor"},
|
||||
"tool_vendors": {},
|
||||
}
|
||||
|
||||
# This test would verify that if all vendors fail, RuntimeError is raised
|
||||
# Actual implementation depends on the real vendor functions
|
||||
|
|
@ -0,0 +1 @@
|
|||
"""Tests for graph modules."""
|
||||
|
|
@ -0,0 +1,240 @@
|
|||
"""Unit tests for conditional logic."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from tradingagents.graph.conditional_logic import ConditionalLogic
|
||||
|
||||
|
||||
class TestConditionalLogic:
|
||||
"""Tests for the ConditionalLogic class."""
|
||||
|
||||
@pytest.fixture
|
||||
def logic(self):
|
||||
"""Create a ConditionalLogic instance with default settings."""
|
||||
return ConditionalLogic(max_debate_rounds=1, max_risk_discuss_rounds=1)
|
||||
|
||||
@pytest.fixture
|
||||
def logic_extended(self):
|
||||
"""Create a ConditionalLogic instance with extended rounds."""
|
||||
return ConditionalLogic(max_debate_rounds=3, max_risk_discuss_rounds=2)
|
||||
|
||||
@pytest.fixture
|
||||
def state_with_tool_call(self):
|
||||
"""Create a state with a tool call in the last message."""
|
||||
msg = MagicMock()
|
||||
msg.tool_calls = [{"name": "get_stock_data"}]
|
||||
return {"messages": [msg]}
|
||||
|
||||
@pytest.fixture
|
||||
def state_without_tool_call(self):
|
||||
"""Create a state without tool calls."""
|
||||
msg = MagicMock()
|
||||
msg.tool_calls = []
|
||||
return {"messages": [msg]}
|
||||
|
||||
|
||||
class TestShouldContinueMarket(TestConditionalLogic):
|
||||
"""Tests for should_continue_market method."""
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_returns_tools_market_with_tool_call(self, logic, state_with_tool_call):
|
||||
"""Test that tool calls route to tools_market."""
|
||||
result = logic.should_continue_market(state_with_tool_call)
|
||||
assert result == "tools_market"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_returns_msg_clear_without_tool_call(self, logic, state_without_tool_call):
|
||||
"""Test that no tool calls route to Msg Clear Market."""
|
||||
result = logic.should_continue_market(state_without_tool_call)
|
||||
assert result == "Msg Clear Market"
|
||||
|
||||
|
||||
class TestShouldContinueSocial(TestConditionalLogic):
|
||||
"""Tests for should_continue_social method."""
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_returns_tools_social_with_tool_call(self, logic, state_with_tool_call):
|
||||
"""Test that tool calls route to tools_social."""
|
||||
result = logic.should_continue_social(state_with_tool_call)
|
||||
assert result == "tools_social"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_returns_msg_clear_without_tool_call(self, logic, state_without_tool_call):
|
||||
"""Test that no tool calls route to Msg Clear Social."""
|
||||
result = logic.should_continue_social(state_without_tool_call)
|
||||
assert result == "Msg Clear Social"
|
||||
|
||||
|
||||
class TestShouldContinueNews(TestConditionalLogic):
|
||||
"""Tests for should_continue_news method."""
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_returns_tools_news_with_tool_call(self, logic, state_with_tool_call):
|
||||
"""Test that tool calls route to tools_news."""
|
||||
result = logic.should_continue_news(state_with_tool_call)
|
||||
assert result == "tools_news"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_returns_msg_clear_without_tool_call(self, logic, state_without_tool_call):
|
||||
"""Test that no tool calls route to Msg Clear News."""
|
||||
result = logic.should_continue_news(state_without_tool_call)
|
||||
assert result == "Msg Clear News"
|
||||
|
||||
|
||||
class TestShouldContinueFundamentals(TestConditionalLogic):
|
||||
"""Tests for should_continue_fundamentals method."""
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_returns_tools_fundamentals_with_tool_call(self, logic, state_with_tool_call):
|
||||
"""Test that tool calls route to tools_fundamentals."""
|
||||
result = logic.should_continue_fundamentals(state_with_tool_call)
|
||||
assert result == "tools_fundamentals"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_returns_msg_clear_without_tool_call(self, logic, state_without_tool_call):
|
||||
"""Test that no tool calls route to Msg Clear Fundamentals."""
|
||||
result = logic.should_continue_fundamentals(state_without_tool_call)
|
||||
assert result == "Msg Clear Fundamentals"
|
||||
|
||||
|
||||
class TestShouldContinueDebate(TestConditionalLogic):
|
||||
"""Tests for should_continue_debate method."""
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_returns_research_manager_at_max_rounds(self, logic):
|
||||
"""Test that debate ends at max rounds."""
|
||||
state = {
|
||||
"investment_debate_state": {
|
||||
"count": 4, # 2 * max_debate_rounds = 2 * 1 = 2, but 4 > 2
|
||||
"current_response": "Bull Analyst: Buy signal",
|
||||
}
|
||||
}
|
||||
result = logic.should_continue_debate(state)
|
||||
assert result == "Research Manager"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_returns_bear_when_bull_speaks(self, logic):
|
||||
"""Test that Bull speaker routes to Bear."""
|
||||
state = {
|
||||
"investment_debate_state": {
|
||||
"count": 1,
|
||||
"current_response": "Bull Analyst: Strong buy opportunity",
|
||||
}
|
||||
}
|
||||
result = logic.should_continue_debate(state)
|
||||
assert result == "Bear Researcher"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_returns_bull_when_not_bull(self, logic):
|
||||
"""Test that Bear speaker routes to Bull."""
|
||||
state = {
|
||||
"investment_debate_state": {
|
||||
"count": 1,
|
||||
"current_response": "Bear Analyst: High risk warning",
|
||||
}
|
||||
}
|
||||
result = logic.should_continue_debate(state)
|
||||
assert result == "Bull Researcher"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_extended_debate_rounds(self, logic_extended):
|
||||
"""Test debate with extended rounds."""
|
||||
# With max_debate_rounds=3, max count = 2 * 3 = 6
|
||||
state = {
|
||||
"investment_debate_state": {
|
||||
"count": 5, # Still under 6
|
||||
"current_response": "Bull Analyst: Buy",
|
||||
}
|
||||
}
|
||||
result = logic_extended.should_continue_debate(state)
|
||||
assert result == "Bear Researcher"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_extended_debate_ends_at_max(self, logic_extended):
|
||||
"""Test extended debate ends at max rounds."""
|
||||
state = {
|
||||
"investment_debate_state": {
|
||||
"count": 6, # 2 * max_debate_rounds = 6
|
||||
"current_response": "Bull Analyst: Buy",
|
||||
}
|
||||
}
|
||||
result = logic_extended.should_continue_debate(state)
|
||||
assert result == "Research Manager"
|
||||
|
||||
|
||||
class TestShouldContinueRiskAnalysis(TestConditionalLogic):
|
||||
"""Tests for should_continue_risk_analysis method."""
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_returns_risk_judge_at_max_rounds(self, logic):
|
||||
"""Test that risk analysis ends at max rounds."""
|
||||
state = {
|
||||
"risk_debate_state": {
|
||||
"count": 6, # 3 * max_risk_discuss_rounds = 3 * 1 = 3, but 6 > 3
|
||||
"latest_speaker": "Aggressive Analyst",
|
||||
}
|
||||
}
|
||||
result = logic.should_continue_risk_analysis(state)
|
||||
assert result == "Risk Judge"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_returns_conservative_after_aggressive(self, logic):
|
||||
"""Test that Aggressive speaker routes to Conservative."""
|
||||
state = {
|
||||
"risk_debate_state": {
|
||||
"count": 1,
|
||||
"latest_speaker": "Aggressive Analyst: Go all in!",
|
||||
}
|
||||
}
|
||||
result = logic.should_continue_risk_analysis(state)
|
||||
assert result == "Conservative Analyst"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_returns_neutral_after_conservative(self, logic):
|
||||
"""Test that Conservative speaker routes to Neutral."""
|
||||
state = {
|
||||
"risk_debate_state": {
|
||||
"count": 1,
|
||||
"latest_speaker": "Conservative Analyst: Stay cautious",
|
||||
}
|
||||
}
|
||||
result = logic.should_continue_risk_analysis(state)
|
||||
assert result == "Neutral Analyst"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_returns_aggressive_after_neutral(self, logic):
|
||||
"""Test that Neutral speaker routes to Aggressive."""
|
||||
state = {
|
||||
"risk_debate_state": {
|
||||
"count": 1,
|
||||
"latest_speaker": "Neutral Analyst: Balanced view",
|
||||
}
|
||||
}
|
||||
result = logic.should_continue_risk_analysis(state)
|
||||
assert result == "Aggressive Analyst"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_extended_risk_rounds(self, logic_extended):
|
||||
"""Test risk analysis with extended rounds."""
|
||||
# With max_risk_discuss_rounds=2, max count = 3 * 2 = 6
|
||||
state = {
|
||||
"risk_debate_state": {
|
||||
"count": 5, # Still under 6
|
||||
"latest_speaker": "Aggressive Analyst",
|
||||
}
|
||||
}
|
||||
result = logic_extended.should_continue_risk_analysis(state)
|
||||
assert result == "Conservative Analyst"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_extended_risk_ends_at_max(self, logic_extended):
|
||||
"""Test extended risk analysis ends at max rounds."""
|
||||
state = {
|
||||
"risk_debate_state": {
|
||||
"count": 6, # 3 * max_risk_discuss_rounds = 6
|
||||
"latest_speaker": "Aggressive Analyst",
|
||||
}
|
||||
}
|
||||
result = logic_extended.should_continue_risk_analysis(state)
|
||||
assert result == "Risk Judge"
|
||||
|
|
@ -0,0 +1 @@
|
|||
"""Tests for LLM client modules."""
|
||||
|
|
@ -0,0 +1,72 @@
|
|||
"""Unit tests for Anthropic client."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from tradingagents.llm_clients.anthropic_client import AnthropicClient
|
||||
|
||||
|
||||
class TestAnthropicClient:
|
||||
"""Tests for the Anthropic client."""
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_init(self):
|
||||
"""Test client initialization."""
|
||||
client = AnthropicClient("claude-3-opus")
|
||||
assert client.model == "claude-3-opus"
|
||||
assert client.base_url is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_init_with_base_url(self):
|
||||
"""Test client initialization with base URL (accepted but may be ignored)."""
|
||||
client = AnthropicClient("claude-3-opus", base_url="https://custom.api.com")
|
||||
assert client.base_url == "https://custom.api.com"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_init_with_kwargs(self):
|
||||
"""Test client initialization with additional kwargs."""
|
||||
client = AnthropicClient("claude-3-opus", timeout=30, max_tokens=4096)
|
||||
assert client.kwargs.get("timeout") == 30
|
||||
assert client.kwargs.get("max_tokens") == 4096
|
||||
|
||||
|
||||
class TestAnthropicClientGetLLM:
|
||||
"""Tests for Anthropic client get_llm method."""
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch.dict("os.environ", {"ANTHROPIC_API_KEY": "test-key"})
|
||||
def test_get_llm_returns_chat_anthropic(self):
|
||||
"""Test that get_llm returns a ChatAnthropic instance."""
|
||||
client = AnthropicClient("claude-3-opus")
|
||||
llm = client.get_llm()
|
||||
assert llm.model == "claude-3-opus"
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch.dict("os.environ", {"ANTHROPIC_API_KEY": "test-key"})
|
||||
def test_get_llm_with_timeout(self):
|
||||
"""Test that timeout is passed to LLM kwargs."""
|
||||
client = AnthropicClient("claude-3-opus", timeout=60)
|
||||
# Verify timeout was passed to kwargs (ChatAnthropic may not expose it directly)
|
||||
assert "timeout" in client.kwargs
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch.dict("os.environ", {"ANTHROPIC_API_KEY": "test-key"})
|
||||
def test_get_llm_with_max_tokens(self):
|
||||
"""Test that max_tokens is passed to LLM."""
|
||||
client = AnthropicClient("claude-3-opus", max_tokens=2048)
|
||||
client.get_llm()
|
||||
# ChatAnthropic uses max_tokens_mixin or similar
|
||||
assert "max_tokens" in client.kwargs
|
||||
|
||||
|
||||
class TestAnthropicClientValidateModel:
|
||||
"""Tests for Anthropic client validate_model method."""
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_validate_model_returns_bool(self):
|
||||
"""Test that validate_model returns a boolean."""
|
||||
client = AnthropicClient("claude-3-opus")
|
||||
# This calls the validator function
|
||||
result = client.validate_model()
|
||||
assert isinstance(result, bool)
|
||||
|
|
@ -0,0 +1,82 @@
|
|||
"""Unit tests for LLM client factory."""
|
||||
|
||||
|
||||
import pytest
|
||||
|
||||
from tradingagents.llm_clients.anthropic_client import AnthropicClient
|
||||
from tradingagents.llm_clients.factory import create_llm_client
|
||||
from tradingagents.llm_clients.google_client import GoogleClient
|
||||
from tradingagents.llm_clients.openai_client import OpenAIClient
|
||||
|
||||
|
||||
class TestCreateLLMClient:
|
||||
"""Tests for the LLM client factory function."""
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_create_openai_client(self):
|
||||
"""Test creating an OpenAI client."""
|
||||
client = create_llm_client("openai", "gpt-4")
|
||||
assert isinstance(client, OpenAIClient)
|
||||
assert client.model == "gpt-4"
|
||||
assert client.provider == "openai"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_create_openai_client_case_insensitive(self):
|
||||
"""Test that provider names are case insensitive."""
|
||||
client = create_llm_client("OpenAI", "gpt-4o")
|
||||
assert isinstance(client, OpenAIClient)
|
||||
assert client.provider == "openai"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_create_anthropic_client(self):
|
||||
"""Test creating an Anthropic client."""
|
||||
client = create_llm_client("anthropic", "claude-3-opus")
|
||||
assert isinstance(client, AnthropicClient)
|
||||
assert client.model == "claude-3-opus"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_create_google_client(self):
|
||||
"""Test creating a Google client."""
|
||||
client = create_llm_client("google", "gemini-pro")
|
||||
assert isinstance(client, GoogleClient)
|
||||
assert client.model == "gemini-pro"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_create_xai_client(self):
|
||||
"""Test creating an xAI client (uses OpenAI-compatible API)."""
|
||||
client = create_llm_client("xai", "grok-beta")
|
||||
assert isinstance(client, OpenAIClient)
|
||||
assert client.provider == "xai"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_create_ollama_client(self):
|
||||
"""Test creating an Ollama client."""
|
||||
client = create_llm_client("ollama", "llama2")
|
||||
assert isinstance(client, OpenAIClient)
|
||||
assert client.provider == "ollama"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_create_openrouter_client(self):
|
||||
"""Test creating an OpenRouter client."""
|
||||
client = create_llm_client("openrouter", "gpt-4")
|
||||
assert isinstance(client, OpenAIClient)
|
||||
assert client.provider == "openrouter"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_unsupported_provider_raises(self):
|
||||
"""Test that unsupported provider raises ValueError."""
|
||||
with pytest.raises(ValueError, match="Unsupported LLM provider"):
|
||||
create_llm_client("unknown_provider", "model-name")
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_create_client_with_base_url(self):
|
||||
"""Test creating a client with custom base URL."""
|
||||
client = create_llm_client("openai", "gpt-4", base_url="https://custom.api.com/v1")
|
||||
assert client.base_url == "https://custom.api.com/v1"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_create_client_with_kwargs(self):
|
||||
"""Test creating a client with additional kwargs."""
|
||||
client = create_llm_client("openai", "gpt-4", timeout=30, max_retries=5)
|
||||
assert client.kwargs.get("timeout") == 30
|
||||
assert client.kwargs.get("max_retries") == 5
|
||||
|
|
@ -0,0 +1,100 @@
|
|||
"""Unit tests for Google client."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from tradingagents.llm_clients.google_client import GoogleClient
|
||||
|
||||
|
||||
class TestGoogleClient:
|
||||
"""Tests for the Google client."""
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_init(self):
|
||||
"""Test client initialization."""
|
||||
client = GoogleClient("gemini-pro")
|
||||
assert client.model == "gemini-pro"
|
||||
assert client.base_url is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_init_with_kwargs(self):
|
||||
"""Test client initialization with additional kwargs."""
|
||||
client = GoogleClient("gemini-pro", timeout=30)
|
||||
assert client.kwargs.get("timeout") == 30
|
||||
|
||||
|
||||
class TestGoogleClientGetLLM:
|
||||
"""Tests for Google client get_llm method."""
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch.dict("os.environ", {"GOOGLE_API_KEY": "test-key"})
|
||||
def test_get_llm_returns_chat_google(self):
|
||||
"""Test that get_llm returns a ChatGoogleGenerativeAI instance."""
|
||||
client = GoogleClient("gemini-pro")
|
||||
llm = client.get_llm()
|
||||
assert llm.model == "gemini-pro"
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch.dict("os.environ", {"GOOGLE_API_KEY": "test-key"})
|
||||
def test_get_llm_with_timeout(self):
|
||||
"""Test that timeout is passed to LLM."""
|
||||
client = GoogleClient("gemini-pro", timeout=60)
|
||||
llm = client.get_llm()
|
||||
assert llm.timeout == 60
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch.dict("os.environ", {"GOOGLE_API_KEY": "test-key"})
|
||||
def test_get_llm_gemini_3_pro_thinking_level(self):
|
||||
"""Test thinking level for Gemini 3 Pro models."""
|
||||
client = GoogleClient("gemini-3-pro", thinking_level="high")
|
||||
client.get_llm()
|
||||
# Gemini 3 Pro should get thinking_level directly
|
||||
assert "thinking_level" in client.kwargs
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch.dict("os.environ", {"GOOGLE_API_KEY": "test-key"})
|
||||
def test_get_llm_gemini_3_pro_minimal_to_low(self):
|
||||
"""Test that 'minimal' thinking level maps to 'low' for Gemini 3 Pro."""
|
||||
client = GoogleClient("gemini-3-pro", thinking_level="minimal")
|
||||
llm = client.get_llm()
|
||||
# Pro models don't support 'minimal', should be mapped to 'low'
|
||||
assert llm.thinking_level == "low"
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch.dict("os.environ", {"GOOGLE_API_KEY": "test-key"})
|
||||
def test_get_llm_gemini_3_flash_thinking_level(self):
|
||||
"""Test thinking level for Gemini 3 Flash models."""
|
||||
client = GoogleClient("gemini-3-flash", thinking_level="medium")
|
||||
llm = client.get_llm()
|
||||
# Gemini 3 Flash supports minimal, low, medium, high
|
||||
assert llm.thinking_level == "medium"
|
||||
|
||||
|
||||
class TestNormalizedChatGoogleGenerativeAI:
|
||||
"""Tests for the normalized Google Generative AI class."""
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_normalize_string_content(self):
|
||||
"""Test that string content is left unchanged."""
|
||||
# This is a static method test via the class
|
||||
# The _normalize_content method handles list content
|
||||
# Actual test would need a mock response
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_normalize_list_content(self):
|
||||
"""Test that list content is normalized to string."""
|
||||
# This tests the normalization logic for Gemini 3 responses
|
||||
# that return content as list of dicts
|
||||
# Actual test would need integration with the class
|
||||
|
||||
|
||||
class TestGoogleClientValidateModel:
|
||||
"""Tests for Google client validate_model method."""
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_validate_model_returns_bool(self):
|
||||
"""Test that validate_model returns a boolean."""
|
||||
client = GoogleClient("gemini-pro")
|
||||
result = client.validate_model()
|
||||
assert isinstance(result, bool)
|
||||
|
|
@ -0,0 +1,111 @@
|
|||
"""Unit tests for OpenAI client."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from tradingagents.llm_clients.openai_client import OpenAIClient, UnifiedChatOpenAI
|
||||
|
||||
|
||||
class TestOpenAIClient:
|
||||
"""Tests for the OpenAI client."""
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_init_with_provider(self):
|
||||
"""Test client initialization with provider."""
|
||||
client = OpenAIClient("gpt-4", provider="openai")
|
||||
assert client.model == "gpt-4"
|
||||
assert client.provider == "openai"
|
||||
assert client.base_url is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_init_with_base_url(self):
|
||||
"""Test client initialization with base URL."""
|
||||
client = OpenAIClient("gpt-4", base_url="https://custom.api.com/v1", provider="openai")
|
||||
assert client.base_url == "https://custom.api.com/v1"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_provider_lowercase(self):
|
||||
"""Test that provider is lowercased."""
|
||||
client = OpenAIClient("gpt-4", provider="OpenAI")
|
||||
assert client.provider == "openai"
|
||||
|
||||
|
||||
class TestUnifiedChatOpenAI:
|
||||
"""Tests for the UnifiedChatOpenAI class."""
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_is_reasoning_model_o1(self):
|
||||
"""Test reasoning model detection for o1 series."""
|
||||
assert UnifiedChatOpenAI._is_reasoning_model("o1-preview")
|
||||
assert UnifiedChatOpenAI._is_reasoning_model("o1-mini")
|
||||
assert UnifiedChatOpenAI._is_reasoning_model("O1-PRO")
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_is_reasoning_model_o3(self):
|
||||
"""Test reasoning model detection for o3 series."""
|
||||
assert UnifiedChatOpenAI._is_reasoning_model("o3-mini")
|
||||
assert UnifiedChatOpenAI._is_reasoning_model("O3-MINI")
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_is_reasoning_model_gpt5(self):
|
||||
"""Test reasoning model detection for GPT-5 series."""
|
||||
assert UnifiedChatOpenAI._is_reasoning_model("gpt-5")
|
||||
assert UnifiedChatOpenAI._is_reasoning_model("gpt-5.2")
|
||||
assert UnifiedChatOpenAI._is_reasoning_model("GPT-5-MINI")
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_is_not_reasoning_model(self):
|
||||
"""Test that standard models are not detected as reasoning models."""
|
||||
assert not UnifiedChatOpenAI._is_reasoning_model("gpt-4o")
|
||||
assert not UnifiedChatOpenAI._is_reasoning_model("gpt-4-turbo")
|
||||
assert not UnifiedChatOpenAI._is_reasoning_model("gpt-3.5-turbo")
|
||||
|
||||
|
||||
class TestOpenAIClientGetLLM:
|
||||
"""Tests for OpenAI client get_llm method."""
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch.dict("os.environ", {"OPENAI_API_KEY": "test-key"})
|
||||
def test_get_llm_openai(self):
|
||||
"""Test getting LLM for OpenAI provider."""
|
||||
client = OpenAIClient("gpt-4", provider="openai")
|
||||
llm = client.get_llm()
|
||||
assert llm.model == "gpt-4"
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch.dict("os.environ", {"XAI_API_KEY": "test-xai-key"})
|
||||
def test_get_llm_xai_uses_correct_url(self):
|
||||
"""Test that xAI client uses correct base URL."""
|
||||
client = OpenAIClient("grok-beta", provider="xai")
|
||||
# Verify xAI base_url is configured
|
||||
assert client.kwargs.get("base_url") is None # Not in kwargs, set in get_llm
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch.dict("os.environ", {"OPENROUTER_API_KEY": "test-or-key"})
|
||||
def test_get_llm_openrouter_uses_correct_url(self):
|
||||
"""Test that OpenRouter client uses correct base URL."""
|
||||
client = OpenAIClient("gpt-4", provider="openrouter")
|
||||
# Verify OpenRouter base_url is configured
|
||||
assert client.kwargs.get("base_url") is None # Not in kwargs, set in get_llm
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_llm_ollama_uses_correct_url(self):
|
||||
"""Test that Ollama client uses correct base URL."""
|
||||
client = OpenAIClient("llama2", provider="ollama")
|
||||
# Verify Ollama configuration
|
||||
assert client.provider == "ollama"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_llm_with_timeout(self):
|
||||
"""Test that timeout is passed to LLM kwargs."""
|
||||
client = OpenAIClient("gpt-4", provider="openai", timeout=60)
|
||||
# Verify timeout was passed to kwargs
|
||||
assert client.kwargs.get("timeout") == 60
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_llm_with_max_retries(self):
|
||||
"""Test that max_retries is passed to LLM."""
|
||||
client = OpenAIClient("gpt-4", provider="openai", max_retries=3)
|
||||
llm = client.get_llm()
|
||||
assert llm.max_retries == 3
|
||||
|
|
@ -0,0 +1 @@
|
|||
"""Tests for memory modules."""
|
||||
|
|
@ -0,0 +1,197 @@
|
|||
"""Unit tests for FinancialSituationMemory."""
|
||||
|
||||
import pytest
|
||||
|
||||
from tradingagents.agents.utils.memory import FinancialSituationMemory
|
||||
|
||||
|
||||
class TestFinancialSituationMemory:
|
||||
"""Tests for the FinancialSituationMemory class."""
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_init(self):
|
||||
"""Test memory initialization."""
|
||||
memory = FinancialSituationMemory("test_memory")
|
||||
assert memory.name == "test_memory"
|
||||
assert len(memory.documents) == 0
|
||||
assert len(memory.recommendations) == 0
|
||||
assert memory.bm25 is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_init_with_config(self):
|
||||
"""Test memory initialization with config (for API compatibility)."""
|
||||
memory = FinancialSituationMemory("test_memory", config={"some": "config"})
|
||||
assert memory.name == "test_memory"
|
||||
# Config is accepted but not used for BM25
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_situations_single(self):
|
||||
"""Test adding a single situation."""
|
||||
memory = FinancialSituationMemory("test_memory")
|
||||
memory.add_situations([("High volatility", "Reduce exposure")])
|
||||
|
||||
assert len(memory.documents) == 1
|
||||
assert len(memory.recommendations) == 1
|
||||
assert memory.documents[0] == "High volatility"
|
||||
assert memory.recommendations[0] == "Reduce exposure"
|
||||
assert memory.bm25 is not None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_situations_multiple(self):
|
||||
"""Test adding multiple situations."""
|
||||
memory = FinancialSituationMemory("test_memory")
|
||||
situations = [
|
||||
("High volatility in tech sector", "Reduce exposure"),
|
||||
("Strong earnings report", "Consider buying"),
|
||||
("Rising interest rates", "Review duration"),
|
||||
]
|
||||
memory.add_situations(situations)
|
||||
|
||||
assert len(memory.documents) == 3
|
||||
assert len(memory.recommendations) == 3
|
||||
assert memory.bm25 is not None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_situations_incremental(self):
|
||||
"""Test adding situations incrementally."""
|
||||
memory = FinancialSituationMemory("test_memory")
|
||||
memory.add_situations([("First situation", "First recommendation")])
|
||||
memory.add_situations([("Second situation", "Second recommendation")])
|
||||
|
||||
assert len(memory.documents) == 2
|
||||
assert memory.recommendations[0] == "First recommendation"
|
||||
assert memory.recommendations[1] == "Second recommendation"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_memories_returns_matches(self):
|
||||
"""Test that get_memories returns matching results."""
|
||||
memory = FinancialSituationMemory("test_memory")
|
||||
memory.add_situations([
|
||||
("High inflation affecting tech stocks", "Consider defensive positions"),
|
||||
("Strong dollar impacting exports", "Review international exposure"),
|
||||
])
|
||||
|
||||
results = memory.get_memories("inflation concerns in technology sector", n_matches=1)
|
||||
|
||||
assert len(results) == 1
|
||||
assert "similarity_score" in results[0]
|
||||
assert "matched_situation" in results[0]
|
||||
assert "recommendation" in results[0]
|
||||
assert results[0]["matched_situation"] == "High inflation affecting tech stocks"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_memories_multiple_matches(self):
|
||||
"""Test that get_memories returns multiple matches."""
|
||||
memory = FinancialSituationMemory("test_memory")
|
||||
memory.add_situations([
|
||||
("High inflation affecting tech stocks", "Consider defensive positions"),
|
||||
("Inflation concerns rising globally", "Review commodity exposure"),
|
||||
("Strong dollar impacting exports", "Review international exposure"),
|
||||
])
|
||||
|
||||
results = memory.get_memories("inflation worries", n_matches=2)
|
||||
|
||||
assert len(results) == 2
|
||||
# Both inflation-related situations should be in top results
|
||||
situations = [r["matched_situation"] for r in results]
|
||||
assert (
|
||||
"High inflation affecting tech stocks" in situations
|
||||
or "Inflation concerns rising globally" in situations
|
||||
)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_memories_empty_returns_empty(self):
|
||||
"""Test that get_memories on empty memory returns empty list."""
|
||||
memory = FinancialSituationMemory("test_memory")
|
||||
results = memory.get_memories("any query", n_matches=1)
|
||||
|
||||
assert results == []
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_memories_normalized_score(self):
|
||||
"""Test that similarity scores are computed correctly.
|
||||
|
||||
Note: BM25 scores can be negative for documents with low term frequency.
|
||||
The normalization divides by max_score but doesn't shift negative scores.
|
||||
"""
|
||||
memory = FinancialSituationMemory("test_memory")
|
||||
memory.add_situations([
|
||||
("High volatility tech sector", "Reduce exposure"),
|
||||
("Low volatility bonds", "Stable income"),
|
||||
])
|
||||
|
||||
results = memory.get_memories("volatility in tech", n_matches=2)
|
||||
|
||||
# Verify we get results with similarity_score field
|
||||
assert len(results) == 2
|
||||
for result in results:
|
||||
assert "similarity_score" in result
|
||||
# BM25 scores can theoretically be negative, verify it's a number
|
||||
assert isinstance(result["similarity_score"], float)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_clear(self):
|
||||
"""Test that clear empties the memory."""
|
||||
memory = FinancialSituationMemory("test_memory")
|
||||
memory.add_situations([("test", "test recommendation")])
|
||||
|
||||
assert len(memory.documents) == 1
|
||||
assert memory.bm25 is not None
|
||||
|
||||
memory.clear()
|
||||
|
||||
assert len(memory.documents) == 0
|
||||
assert len(memory.recommendations) == 0
|
||||
assert memory.bm25 is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_memories_after_clear(self):
|
||||
"""Test that get_memories works after clear and re-add."""
|
||||
memory = FinancialSituationMemory("test_memory")
|
||||
memory.add_situations([("First", "Rec1")])
|
||||
memory.clear()
|
||||
memory.add_situations([("Second", "Rec2")])
|
||||
|
||||
results = memory.get_memories("Second", n_matches=1)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0]["matched_situation"] == "Second"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_tokenize_lowercase(self):
|
||||
"""Test that tokenization lowercases text."""
|
||||
memory = FinancialSituationMemory("test")
|
||||
tokens = memory._tokenize("HELLO World")
|
||||
|
||||
assert all(token.islower() for token in tokens)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_tokenize_splits_on_punctuation(self):
|
||||
"""Test that tokenization splits on punctuation."""
|
||||
memory = FinancialSituationMemory("test")
|
||||
tokens = memory._tokenize("hello, world! test.")
|
||||
|
||||
assert tokens == ["hello", "world", "test"]
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_tokenize_handles_numbers(self):
|
||||
"""Test that tokenization handles numbers."""
|
||||
memory = FinancialSituationMemory("test")
|
||||
tokens = memory._tokenize("price 123.45 dollars")
|
||||
|
||||
assert "123" in tokens
|
||||
assert "45" in tokens
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_unicode_handling(self):
|
||||
"""Test that memory handles Unicode content."""
|
||||
memory = FinancialSituationMemory("test")
|
||||
memory.add_situations([
|
||||
("欧洲市场波动加剧", "考虑减少欧洲敞口"),
|
||||
("日本央行政策调整", "关注汇率变化"),
|
||||
])
|
||||
|
||||
results = memory.get_memories("欧洲市场", n_matches=1)
|
||||
|
||||
assert len(results) == 1
|
||||
assert "欧洲" in results[0]["matched_situation"]
|
||||
|
|
@ -78,7 +78,7 @@ class FinancialSituationMemory:
|
|||
|
||||
# Build results
|
||||
results = []
|
||||
max_score = max(scores) if max(scores) > 0 else 1 # Normalize scores
|
||||
max_score = float(scores.max()) if len(scores) > 0 and scores.max() > 0 else 1.0
|
||||
|
||||
for idx in top_indices:
|
||||
# Normalize score to 0-1 range for consistency
|
||||
|
|
|
|||
|
|
@ -0,0 +1,186 @@
|
|||
"""Configuration validation for the TradingAgents framework.
|
||||
|
||||
This module provides validation functions to ensure configuration
|
||||
settings are correct before runtime.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
# Valid LLM providers
|
||||
VALID_PROVIDERS = ["openai", "anthropic", "google", "xai", "ollama", "openrouter"]
|
||||
|
||||
# Valid data vendors
|
||||
VALID_DATA_VENDORS = ["yfinance", "alpha_vantage"]
|
||||
|
||||
# Required API key environment variables by provider
|
||||
PROVIDER_API_KEYS = {
|
||||
"openai": "OPENAI_API_KEY",
|
||||
"anthropic": "ANTHROPIC_API_KEY",
|
||||
"google": "GOOGLE_API_KEY",
|
||||
"xai": "XAI_API_KEY",
|
||||
"openrouter": "OPENROUTER_API_KEY",
|
||||
}
|
||||
|
||||
|
||||
def validate_config(config: dict[str, Any]) -> list[str]:
|
||||
"""Validate configuration dictionary.
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary to validate.
|
||||
|
||||
Returns:
|
||||
List of validation error messages (empty if valid).
|
||||
|
||||
Example:
|
||||
>>> errors = validate_config(config)
|
||||
>>> if errors:
|
||||
... print("Configuration errors:", errors)
|
||||
"""
|
||||
errors = []
|
||||
|
||||
# Validate LLM provider
|
||||
provider = config.get("llm_provider", "").lower()
|
||||
if provider not in VALID_PROVIDERS:
|
||||
errors.append(
|
||||
f"Invalid llm_provider: '{provider}'. Must be one of {VALID_PROVIDERS}"
|
||||
)
|
||||
|
||||
# Validate deep_think_llm
|
||||
if not config.get("deep_think_llm"):
|
||||
errors.append("deep_think_llm is required")
|
||||
|
||||
# Validate quick_think_llm
|
||||
if not config.get("quick_think_llm"):
|
||||
errors.append("quick_think_llm is required")
|
||||
|
||||
# Validate data vendors
|
||||
data_vendors = config.get("data_vendors", {})
|
||||
for category, vendor in data_vendors.items():
|
||||
if vendor not in VALID_DATA_VENDORS:
|
||||
errors.append(
|
||||
f"Invalid data vendor for {category}: '{vendor}'. "
|
||||
f"Must be one of {VALID_DATA_VENDORS}"
|
||||
)
|
||||
|
||||
# Validate tool vendors
|
||||
tool_vendors = config.get("tool_vendors", {})
|
||||
for tool, vendor in tool_vendors.items():
|
||||
if vendor not in VALID_DATA_VENDORS:
|
||||
errors.append(
|
||||
f"Invalid tool vendor for {tool}: '{vendor}'. "
|
||||
f"Must be one of {VALID_DATA_VENDORS}"
|
||||
)
|
||||
|
||||
# Validate numeric settings
|
||||
max_debate_rounds = config.get("max_debate_rounds", 1)
|
||||
if not isinstance(max_debate_rounds, int) or max_debate_rounds < 1:
|
||||
errors.append("max_debate_rounds must be a positive integer")
|
||||
|
||||
max_risk_discuss_rounds = config.get("max_risk_discuss_rounds", 1)
|
||||
if not isinstance(max_risk_discuss_rounds, int) or max_risk_discuss_rounds < 1:
|
||||
errors.append("max_risk_discuss_rounds must be a positive integer")
|
||||
|
||||
max_recur_limit = config.get("max_recur_limit", 100)
|
||||
if not isinstance(max_recur_limit, int) or max_recur_limit < 1:
|
||||
errors.append("max_recur_limit must be a positive integer")
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def validate_api_keys(config: dict[str, Any]) -> list[str]:
|
||||
"""Validate that required API keys are set for the configured provider.
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary containing llm_provider.
|
||||
|
||||
Returns:
|
||||
List of validation error messages (empty if valid).
|
||||
|
||||
Example:
|
||||
>>> errors = validate_api_keys(config)
|
||||
>>> if errors:
|
||||
... print("Missing API keys:", errors)
|
||||
"""
|
||||
errors = []
|
||||
|
||||
provider = config.get("llm_provider", "").lower()
|
||||
env_key = PROVIDER_API_KEYS.get(provider)
|
||||
|
||||
if env_key and not os.environ.get(env_key):
|
||||
errors.append(f"{env_key} not set for {provider} provider")
|
||||
|
||||
# Check for Alpha Vantage key if using alpha_vantage vendor
|
||||
data_vendors = config.get("data_vendors", {})
|
||||
tool_vendors = config.get("tool_vendors", {})
|
||||
|
||||
uses_alpha_vantage = (
|
||||
any(v == "alpha_vantage" for v in data_vendors.values()) or
|
||||
any(v == "alpha_vantage" for v in tool_vendors.values())
|
||||
)
|
||||
|
||||
if uses_alpha_vantage and not os.environ.get("ALPHA_VANTAGE_API_KEY"):
|
||||
errors.append("ALPHA_VANTAGE_API_KEY not set but alpha_vantage vendor is configured")
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def validate_config_full(config: dict[str, Any]) -> list[str]:
|
||||
"""Perform full configuration validation including API keys.
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary to validate.
|
||||
|
||||
Returns:
|
||||
List of all validation error messages (empty if valid).
|
||||
|
||||
Example:
|
||||
>>> errors = validate_config_full(config)
|
||||
>>> if errors:
|
||||
... for error in errors:
|
||||
... print(f"Error: {error}")
|
||||
... sys.exit(1)
|
||||
"""
|
||||
errors = validate_config(config)
|
||||
errors.extend(validate_api_keys(config))
|
||||
return errors
|
||||
|
||||
|
||||
def get_validation_report(config: dict[str, Any]) -> str:
|
||||
"""Get a human-readable validation report.
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary to validate.
|
||||
|
||||
Returns:
|
||||
Formatted string with validation results.
|
||||
|
||||
Example:
|
||||
>>> report = get_validation_report(config)
|
||||
>>> print(report)
|
||||
"""
|
||||
errors = validate_config_full(config)
|
||||
|
||||
lines = ["Configuration Validation Report", "=" * 40]
|
||||
|
||||
# Show configuration summary
|
||||
lines.append(f"\nLLM Provider: {config.get('llm_provider', 'not set')}")
|
||||
lines.append(f"Deep Think LLM: {config.get('deep_think_llm', 'not set')}")
|
||||
lines.append(f"Quick Think LLM: {config.get('quick_think_llm', 'not set')}")
|
||||
lines.append(f"Max Debate Rounds: {config.get('max_debate_rounds', 'not set')}")
|
||||
lines.append(f"Max Risk Discuss Rounds: {config.get('max_risk_discuss_rounds', 'not set')}")
|
||||
|
||||
data_vendors = config.get("data_vendors", {})
|
||||
if data_vendors:
|
||||
lines.append("\nData Vendors:")
|
||||
for category, vendor in data_vendors.items():
|
||||
lines.append(f" {category}: {vendor}")
|
||||
|
||||
if errors:
|
||||
lines.append("\nValidation Errors:")
|
||||
for error in errors:
|
||||
lines.append(f" - {error}")
|
||||
else:
|
||||
lines.append("\n✓ Configuration is valid")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
|
@ -1,12 +1,11 @@
|
|||
import os
|
||||
|
||||
_TRADINGAGENTS_HOME = os.path.join(os.path.expanduser("~"), ".tradingagents")
|
||||
|
||||
DEFAULT_CONFIG = {
|
||||
"project_dir": os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
|
||||
"results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", "./results"),
|
||||
"data_cache_dir": os.path.join(
|
||||
os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
|
||||
"dataflows/data_cache",
|
||||
),
|
||||
"results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", os.path.join(_TRADINGAGENTS_HOME, "logs")),
|
||||
"data_cache_dir": os.getenv("TRADINGAGENTS_CACHE_DIR", os.path.join(_TRADINGAGENTS_HOME, "cache")),
|
||||
# LLM settings
|
||||
"llm_provider": "openai",
|
||||
"deep_think_llm": "gpt-5.4",
|
||||
|
|
|
|||
|
|
@ -66,10 +66,8 @@ class TradingAgentsGraph:
|
|||
set_config(self.config)
|
||||
|
||||
# Create necessary directories
|
||||
os.makedirs(
|
||||
os.path.join(self.config["project_dir"], "dataflows/data_cache"),
|
||||
exist_ok=True,
|
||||
)
|
||||
os.makedirs(self.config["data_cache_dir"], exist_ok=True)
|
||||
os.makedirs(self.config["results_dir"], exist_ok=True)
|
||||
|
||||
# Initialize LLMs with provider-specific thinking configuration
|
||||
llm_kwargs = self._get_provider_kwargs()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,52 @@
|
|||
import os
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain_openai import AzureChatOpenAI
|
||||
|
||||
from .base_client import BaseLLMClient, normalize_content
|
||||
from .validators import validate_model
|
||||
|
||||
_PASSTHROUGH_KWARGS = (
|
||||
"timeout", "max_retries", "api_key", "reasoning_effort",
|
||||
"callbacks", "http_client", "http_async_client",
|
||||
)
|
||||
|
||||
|
||||
class NormalizedAzureChatOpenAI(AzureChatOpenAI):
|
||||
"""AzureChatOpenAI with normalized content output."""
|
||||
|
||||
def invoke(self, input, config=None, **kwargs):
|
||||
return normalize_content(super().invoke(input, config, **kwargs))
|
||||
|
||||
|
||||
class AzureOpenAIClient(BaseLLMClient):
|
||||
"""Client for Azure OpenAI deployments.
|
||||
|
||||
Requires environment variables:
|
||||
AZURE_OPENAI_API_KEY: API key
|
||||
AZURE_OPENAI_ENDPOINT: Endpoint URL (e.g. https://<resource>.openai.azure.com/)
|
||||
AZURE_OPENAI_DEPLOYMENT_NAME: Deployment name
|
||||
OPENAI_API_VERSION: API version (e.g. 2025-03-01-preview)
|
||||
"""
|
||||
|
||||
def __init__(self, model: str, base_url: Optional[str] = None, **kwargs):
|
||||
super().__init__(model, base_url, **kwargs)
|
||||
|
||||
def get_llm(self) -> Any:
|
||||
"""Return configured AzureChatOpenAI instance."""
|
||||
self.warn_if_unknown_model()
|
||||
|
||||
llm_kwargs = {
|
||||
"model": self.model,
|
||||
"azure_deployment": os.environ.get("AZURE_OPENAI_DEPLOYMENT_NAME", self.model),
|
||||
}
|
||||
|
||||
for key in _PASSTHROUGH_KWARGS:
|
||||
if key in self.kwargs:
|
||||
llm_kwargs[key] = self.kwargs[key]
|
||||
|
||||
return NormalizedAzureChatOpenAI(**llm_kwargs)
|
||||
|
||||
def validate_model(self) -> bool:
|
||||
"""Azure accepts any deployed model name."""
|
||||
return True
|
||||
|
|
@ -4,6 +4,12 @@ from .base_client import BaseLLMClient
|
|||
from .openai_client import OpenAIClient
|
||||
from .anthropic_client import AnthropicClient
|
||||
from .google_client import GoogleClient
|
||||
from .azure_client import AzureOpenAIClient
|
||||
|
||||
# Providers that use the OpenAI-compatible chat completions API
|
||||
_OPENAI_COMPATIBLE = (
|
||||
"openai", "xai", "deepseek", "qwen", "glm", "ollama", "openrouter",
|
||||
)
|
||||
|
||||
|
||||
def create_llm_client(
|
||||
|
|
@ -15,16 +21,10 @@ def create_llm_client(
|
|||
"""Create an LLM client for the specified provider.
|
||||
|
||||
Args:
|
||||
provider: LLM provider (openai, anthropic, google, xai, ollama, openrouter)
|
||||
provider: LLM provider name
|
||||
model: Model name/identifier
|
||||
base_url: Optional base URL for API endpoint
|
||||
**kwargs: Additional provider-specific arguments
|
||||
- http_client: Custom httpx.Client for SSL proxy or certificate customization
|
||||
- http_async_client: Custom httpx.AsyncClient for async operations
|
||||
- timeout: Request timeout in seconds
|
||||
- max_retries: Maximum retry attempts
|
||||
- api_key: API key for the provider
|
||||
- callbacks: LangChain callbacks
|
||||
|
||||
Returns:
|
||||
Configured BaseLLMClient instance
|
||||
|
|
@ -34,16 +34,16 @@ def create_llm_client(
|
|||
"""
|
||||
provider_lower = provider.lower()
|
||||
|
||||
if provider_lower in ("openai", "ollama", "openrouter"):
|
||||
if provider_lower in _OPENAI_COMPATIBLE:
|
||||
return OpenAIClient(model, base_url, provider=provider_lower, **kwargs)
|
||||
|
||||
if provider_lower == "xai":
|
||||
return OpenAIClient(model, base_url, provider="xai", **kwargs)
|
||||
|
||||
if provider_lower == "anthropic":
|
||||
return AnthropicClient(model, base_url, **kwargs)
|
||||
|
||||
if provider_lower == "google":
|
||||
return GoogleClient(model, base_url, **kwargs)
|
||||
|
||||
if provider_lower == "azure":
|
||||
return AzureOpenAIClient(model, base_url, **kwargs)
|
||||
|
||||
raise ValueError(f"Unsupported LLM provider: {provider}")
|
||||
|
|
|
|||
|
|
@ -63,8 +63,43 @@ MODEL_OPTIONS: ProviderModeOptions = {
|
|||
("Grok 4.1 Fast (Non-Reasoning) - Speed optimized, 2M ctx", "grok-4-1-fast-non-reasoning"),
|
||||
],
|
||||
},
|
||||
# OpenRouter models are fetched dynamically at CLI runtime.
|
||||
# No static entries needed; any model ID is accepted by the validator.
|
||||
"deepseek": {
|
||||
"quick": [
|
||||
("DeepSeek V3.2", "deepseek-chat"),
|
||||
("Custom model ID", "custom"),
|
||||
],
|
||||
"deep": [
|
||||
("DeepSeek V3.2 (thinking)", "deepseek-reasoner"),
|
||||
("DeepSeek V3.2", "deepseek-chat"),
|
||||
("Custom model ID", "custom"),
|
||||
],
|
||||
},
|
||||
"qwen": {
|
||||
"quick": [
|
||||
("Qwen 3.5 Flash", "qwen3.5-flash"),
|
||||
("Qwen Plus", "qwen-plus"),
|
||||
("Custom model ID", "custom"),
|
||||
],
|
||||
"deep": [
|
||||
("Qwen 3.6 Plus", "qwen3.6-plus"),
|
||||
("Qwen 3.5 Plus", "qwen3.5-plus"),
|
||||
("Qwen 3 Max", "qwen3-max"),
|
||||
("Custom model ID", "custom"),
|
||||
],
|
||||
},
|
||||
"glm": {
|
||||
"quick": [
|
||||
("GLM-4.7", "glm-4.7"),
|
||||
("GLM-5", "glm-5"),
|
||||
("Custom model ID", "custom"),
|
||||
],
|
||||
"deep": [
|
||||
("GLM-5.1", "glm-5.1"),
|
||||
("GLM-5", "glm-5"),
|
||||
("Custom model ID", "custom"),
|
||||
],
|
||||
},
|
||||
# OpenRouter: fetched dynamically. Azure: any deployed model name.
|
||||
"ollama": {
|
||||
"quick": [
|
||||
("Qwen3:latest (8B, local)", "qwen3:latest"),
|
||||
|
|
|
|||
|
|
@ -27,6 +27,9 @@ _PASSTHROUGH_KWARGS = (
|
|||
# Provider base URLs and API key env vars
|
||||
_PROVIDER_CONFIG = {
|
||||
"xai": ("https://api.x.ai/v1", "XAI_API_KEY"),
|
||||
"deepseek": ("https://api.deepseek.com", "DEEPSEEK_API_KEY"),
|
||||
"qwen": ("https://dashscope-intl.aliyuncs.com/compatible-mode/v1", "DASHSCOPE_API_KEY"),
|
||||
"glm": ("https://api.z.ai/api/paas/v4/", "ZHIPU_API_KEY"),
|
||||
"openrouter": ("https://openrouter.ai/api/v1", "OPENROUTER_API_KEY"),
|
||||
"ollama": ("http://localhost:11434/v1", None),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,132 @@
|
|||
"""Logging configuration for the TradingAgents framework.
|
||||
|
||||
This module provides structured logging setup for consistent log formatting
|
||||
across all framework components.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def setup_logging(
|
||||
level: str = "INFO",
|
||||
log_file: Path | None = None,
|
||||
name: str = "tradingagents",
|
||||
) -> logging.Logger:
|
||||
"""Configure logging for the trading agents framework.
|
||||
|
||||
Args:
|
||||
level: Log level (DEBUG, INFO, WARNING, ERROR, CRITICAL).
|
||||
log_file: Optional file path for log output.
|
||||
name: Logger name, defaults to 'tradingagents'.
|
||||
|
||||
Returns:
|
||||
Configured logger instance.
|
||||
|
||||
Example:
|
||||
>>> logger = setup_logging("DEBUG", Path("logs/trading.log"))
|
||||
>>> logger.info("Starting trading analysis")
|
||||
"""
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(getattr(logging, level.upper()))
|
||||
|
||||
# Remove existing handlers to avoid duplicates
|
||||
logger.handlers = []
|
||||
|
||||
formatter = logging.Formatter(
|
||||
"%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
|
||||
# Console handler
|
||||
console_handler = logging.StreamHandler(sys.stdout)
|
||||
console_handler.setFormatter(formatter)
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
# File handler (optional)
|
||||
if log_file:
|
||||
log_file = Path(log_file)
|
||||
log_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
file_handler = logging.FileHandler(log_file)
|
||||
file_handler.setFormatter(formatter)
|
||||
logger.addHandler(file_handler)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
def get_logger(name: str) -> logging.Logger:
|
||||
"""Get a logger with the given name.
|
||||
|
||||
Args:
|
||||
name: Logger name (typically __name__ of the module).
|
||||
|
||||
Returns:
|
||||
Logger instance.
|
||||
|
||||
Example:
|
||||
>>> logger = get_logger(__name__)
|
||||
>>> logger.info("Module loaded")
|
||||
"""
|
||||
return logging.getLogger(name)
|
||||
|
||||
|
||||
class TradingAgentsLogger:
|
||||
"""Context manager for logging trading operations.
|
||||
|
||||
Provides structured logging with timing information for operations.
|
||||
|
||||
Example:
|
||||
>>> with TradingAgentsLogger("market_analysis", "AAPL") as log:
|
||||
... # Perform analysis
|
||||
... log.info("Fetching market data")
|
||||
"""
|
||||
|
||||
def __init__(self, operation: str, symbol: str | None = None):
|
||||
"""Initialize the logger context.
|
||||
|
||||
Args:
|
||||
operation: Name of the operation being logged.
|
||||
symbol: Optional trading symbol being analyzed.
|
||||
"""
|
||||
self.operation = operation
|
||||
self.symbol = symbol
|
||||
self.logger = get_logger(f"tradingagents.{operation}")
|
||||
self.start_time: datetime | None = None
|
||||
|
||||
def __enter__(self) -> "TradingAgentsLogger":
|
||||
"""Enter the logging context."""
|
||||
self.start_time = datetime.now(timezone.utc)
|
||||
msg = f"Starting {self.operation}"
|
||||
if self.symbol:
|
||||
msg += f" for {self.symbol}"
|
||||
self.logger.info(msg)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Exit the logging context."""
|
||||
if self.start_time:
|
||||
duration = (datetime.now(timezone.utc) - self.start_time).total_seconds()
|
||||
if exc_type:
|
||||
self.logger.error(
|
||||
f"{self.operation} failed after {duration:.2f}s: {exc_val}"
|
||||
)
|
||||
else:
|
||||
self.logger.info(f"{self.operation} completed in {duration:.2f}s")
|
||||
|
||||
def info(self, message: str):
|
||||
"""Log an info message."""
|
||||
self.logger.info(message)
|
||||
|
||||
def warning(self, message: str):
|
||||
"""Log a warning message."""
|
||||
self.logger.warning(message)
|
||||
|
||||
def error(self, message: str):
|
||||
"""Log an error message."""
|
||||
self.logger.error(message)
|
||||
|
||||
def debug(self, message: str):
|
||||
"""Log a debug message."""
|
||||
self.logger.debug(message)
|
||||
|
|
@ -0,0 +1,95 @@
|
|||
"""Shared type definitions for the TradingAgents framework.
|
||||
|
||||
This module provides TypedDict classes and type aliases used across
|
||||
the framework for consistent type checking and documentation.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
||||
class MarketData(TypedDict):
|
||||
"""Market data structure for stock price information.
|
||||
|
||||
Attributes:
|
||||
ticker: Stock ticker symbol.
|
||||
date: Trading date string.
|
||||
open: Opening price.
|
||||
high: Highest price of the day.
|
||||
low: Lowest price of the day.
|
||||
close: Closing price.
|
||||
volume: Trading volume.
|
||||
"""
|
||||
|
||||
ticker: str
|
||||
date: str
|
||||
open: float
|
||||
high: float
|
||||
low: float
|
||||
close: float
|
||||
volume: int
|
||||
|
||||
|
||||
class AgentResponse(TypedDict):
|
||||
"""Response from an agent node execution.
|
||||
|
||||
Attributes:
|
||||
messages: List of messages to add to the conversation.
|
||||
report: Generated report content (if applicable).
|
||||
sender: Name of the sending agent.
|
||||
"""
|
||||
|
||||
messages: list[Any]
|
||||
report: str
|
||||
sender: str
|
||||
|
||||
|
||||
class ConfigDict(TypedDict, total=False):
|
||||
"""Configuration dictionary structure.
|
||||
|
||||
Attributes:
|
||||
project_dir: Project root directory.
|
||||
results_dir: Directory for storing results.
|
||||
data_cache_dir: Directory for caching data.
|
||||
llm_provider: LLM provider name.
|
||||
deep_think_llm: Model for complex reasoning.
|
||||
quick_think_llm: Model for fast responses.
|
||||
backend_url: API endpoint URL.
|
||||
google_thinking_level: Thinking level for Google models.
|
||||
openai_reasoning_effort: Reasoning effort for OpenAI models.
|
||||
max_debate_rounds: Maximum debate rounds between researchers.
|
||||
max_risk_discuss_rounds: Maximum risk discussion rounds.
|
||||
max_recur_limit: Maximum recursion limit for graph.
|
||||
data_vendors: Category-level vendor configuration.
|
||||
tool_vendors: Tool-level vendor configuration.
|
||||
"""
|
||||
|
||||
project_dir: str
|
||||
results_dir: str
|
||||
data_cache_dir: str
|
||||
llm_provider: str
|
||||
deep_think_llm: str
|
||||
quick_think_llm: str
|
||||
backend_url: str
|
||||
google_thinking_level: str | None
|
||||
openai_reasoning_effort: str | None
|
||||
max_debate_rounds: int
|
||||
max_risk_discuss_rounds: int
|
||||
max_recur_limit: int
|
||||
data_vendors: dict[str, str]
|
||||
tool_vendors: dict[str, str]
|
||||
|
||||
|
||||
class MemoryMatch(TypedDict):
|
||||
"""Result from memory similarity matching.
|
||||
|
||||
Attributes:
|
||||
matched_situation: The stored situation that matched.
|
||||
recommendation: The associated recommendation.
|
||||
similarity_score: BM25 similarity score (0-1 normalized).
|
||||
"""
|
||||
|
||||
matched_situation: str
|
||||
recommendation: str
|
||||
similarity_score: float
|
||||
Loading…
Reference in New Issue