Add pre-commit hooks and ruff code quality configuration
- Add .pre-commit-config.yaml with trailing whitespace, ruff linter/formatter - Configure ruff in pyproject.toml with selected rules (E, F, W, I, UP, B, C4, SIM) - Add F401 to unfixable to preserve re-exported imports in __init__.py files - Fix BacktestMetrics import in backtesting/engine.py - Update todos.md with enhanced trade discovery and database implementation tasks 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
parent
85992fc05b
commit
c39f9aab36
|
|
@ -1,4 +1,4 @@
|
|||
ALPHA_VANTAGE_API_KEY=alpha_vantage_api_key_placeholder
|
||||
OPENAI_API_KEY=openai_api_key_placeholder
|
||||
BRAVE_API_KEY=brave_api_key_placeholder
|
||||
TAVILY_API_KEY=tavily_api_key_placeholder
|
||||
TAVILY_API_KEY=tavily_api_key_placeholder
|
||||
|
|
|
|||
|
|
@ -0,0 +1,18 @@
|
|||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.6.0
|
||||
hooks:
|
||||
- id: trailing-whitespace
|
||||
- id: end-of-file-fixer
|
||||
- id: check-yaml
|
||||
- id: check-added-large-files
|
||||
args: ['--maxkb=1000']
|
||||
- id: check-merge-conflict
|
||||
- id: detect-private-key
|
||||
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.8.2
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix, --exit-non-zero-on-fix]
|
||||
- id: ruff-format
|
||||
|
|
@ -76,9 +76,11 @@ source .venv/bin/activate
|
|||
The framework requires an OpenAI API key for powering the agents and at least one news data provider API key.
|
||||
|
||||
**Required:**
|
||||
|
||||
- `OPENAI_API_KEY` - Powers the LLM agents
|
||||
|
||||
**News Data Providers (at least one required):**
|
||||
|
||||
- `TAVILY_API_KEY` - Tavily search API (preferred for news discovery)
|
||||
- `BRAVE_API_KEY` - Brave Search API (fallback option)
|
||||
- `ALPHA_VANTAGE_API_KEY` - Alpha Vantage API (for fundamentals and news)
|
||||
|
|
|
|||
154
cli/analysis.py
154
cli/analysis.py
|
|
@ -1,35 +1,33 @@
|
|||
import datetime
|
||||
from pathlib import Path
|
||||
from functools import wraps
|
||||
from typing import List
|
||||
from pathlib import Path
|
||||
|
||||
import typer
|
||||
from rich.panel import Panel
|
||||
from rich.live import Live
|
||||
from rich.align import Align
|
||||
from rich.live import Live
|
||||
from rich.panel import Panel
|
||||
|
||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||
from tradingagents.dataflows.config import get_config
|
||||
|
||||
from cli.state import message_buffer
|
||||
from cli.models import AnalystType, AgentStatus
|
||||
from cli.display import (
|
||||
create_layout,
|
||||
update_display,
|
||||
display_complete_report,
|
||||
update_research_team_status,
|
||||
extract_content_string,
|
||||
create_question_box,
|
||||
console,
|
||||
create_layout,
|
||||
create_question_box,
|
||||
display_complete_report,
|
||||
extract_content_string,
|
||||
update_display,
|
||||
update_research_team_status,
|
||||
)
|
||||
from cli.models import AgentStatus, AnalystType
|
||||
from cli.state import message_buffer
|
||||
from cli.utils import (
|
||||
loading,
|
||||
select_analysts,
|
||||
select_research_depth,
|
||||
select_shallow_thinking_agent,
|
||||
select_deep_thinking_agent,
|
||||
select_llm_provider,
|
||||
loading,
|
||||
select_research_depth,
|
||||
select_shallow_thinking_agent,
|
||||
)
|
||||
from tradingagents.dataflows.config import get_config
|
||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||
|
||||
|
||||
def get_ticker() -> str:
|
||||
|
|
@ -54,14 +52,16 @@ def get_analysis_date() -> str:
|
|||
|
||||
|
||||
def get_user_selections() -> dict:
|
||||
with open("./cli/static/welcome.txt", "r") as f:
|
||||
with open("./cli/static/welcome.txt") as f:
|
||||
welcome_ascii = f.read()
|
||||
|
||||
welcome_content = f"{welcome_ascii}\n"
|
||||
welcome_content += "[bold green]TradingAgents: Multi-Agents LLM Financial Trading Framework - CLI[/bold green]\n\n"
|
||||
welcome_content += "[bold]Workflow Steps:[/bold]\n"
|
||||
welcome_content += "I. Analyst Team -> II. Research Team -> III. Trader -> IV. Risk Management -> V. Portfolio Management\n\n"
|
||||
welcome_content += "[dim]Built by Tauric Research (https://github.com/TauricResearch)[/dim]"
|
||||
welcome_content += (
|
||||
"[dim]Built by Tauric Research (https://github.com/TauricResearch)[/dim]"
|
||||
)
|
||||
|
||||
welcome_box = Panel(
|
||||
welcome_content,
|
||||
|
|
@ -108,9 +108,7 @@ def get_user_selections() -> dict:
|
|||
selected_research_depth = select_research_depth()
|
||||
|
||||
console.print(
|
||||
create_question_box(
|
||||
"Step 5: OpenAI backend", "Select which service to talk to"
|
||||
)
|
||||
create_question_box("Step 5: OpenAI backend", "Select which service to talk to")
|
||||
)
|
||||
selected_llm_provider, backend_url = select_llm_provider()
|
||||
|
||||
|
|
@ -134,15 +132,21 @@ def get_user_selections() -> dict:
|
|||
}
|
||||
|
||||
|
||||
def process_chunk_for_display(chunk: dict, selected_analysts: List[AnalystType]) -> None:
|
||||
def process_chunk_for_display(
|
||||
chunk: dict, selected_analysts: list[AnalystType]
|
||||
) -> None:
|
||||
if "market_report" in chunk and chunk["market_report"]:
|
||||
message_buffer.update_report_section("market_report", chunk["market_report"])
|
||||
message_buffer.update_agent_status("Market Analyst", AgentStatus.COMPLETED)
|
||||
if AnalystType.SOCIAL in selected_analysts:
|
||||
message_buffer.update_agent_status("Social Analyst", AgentStatus.IN_PROGRESS)
|
||||
message_buffer.update_agent_status(
|
||||
"Social Analyst", AgentStatus.IN_PROGRESS
|
||||
)
|
||||
|
||||
if "sentiment_report" in chunk and chunk["sentiment_report"]:
|
||||
message_buffer.update_report_section("sentiment_report", chunk["sentiment_report"])
|
||||
message_buffer.update_report_section(
|
||||
"sentiment_report", chunk["sentiment_report"]
|
||||
)
|
||||
message_buffer.update_agent_status("Social Analyst", AgentStatus.COMPLETED)
|
||||
if AnalystType.NEWS in selected_analysts:
|
||||
message_buffer.update_agent_status("News Analyst", AgentStatus.IN_PROGRESS)
|
||||
|
|
@ -151,11 +155,17 @@ def process_chunk_for_display(chunk: dict, selected_analysts: List[AnalystType])
|
|||
message_buffer.update_report_section("news_report", chunk["news_report"])
|
||||
message_buffer.update_agent_status("News Analyst", AgentStatus.COMPLETED)
|
||||
if AnalystType.FUNDAMENTALS in selected_analysts:
|
||||
message_buffer.update_agent_status("Fundamentals Analyst", AgentStatus.IN_PROGRESS)
|
||||
message_buffer.update_agent_status(
|
||||
"Fundamentals Analyst", AgentStatus.IN_PROGRESS
|
||||
)
|
||||
|
||||
if "fundamentals_report" in chunk and chunk["fundamentals_report"]:
|
||||
message_buffer.update_report_section("fundamentals_report", chunk["fundamentals_report"])
|
||||
message_buffer.update_agent_status("Fundamentals Analyst", AgentStatus.COMPLETED)
|
||||
message_buffer.update_report_section(
|
||||
"fundamentals_report", chunk["fundamentals_report"]
|
||||
)
|
||||
message_buffer.update_agent_status(
|
||||
"Fundamentals Analyst", AgentStatus.COMPLETED
|
||||
)
|
||||
update_research_team_status(AgentStatus.IN_PROGRESS)
|
||||
|
||||
if "investment_debate_state" in chunk and chunk["investment_debate_state"]:
|
||||
|
|
@ -197,13 +207,18 @@ def process_chunk_for_display(chunk: dict, selected_analysts: List[AnalystType])
|
|||
message_buffer.update_agent_status("Risky Analyst", AgentStatus.IN_PROGRESS)
|
||||
|
||||
if "trader_investment_plan" in chunk and chunk["trader_investment_plan"]:
|
||||
message_buffer.update_report_section("trader_investment_plan", chunk["trader_investment_plan"])
|
||||
message_buffer.update_report_section(
|
||||
"trader_investment_plan", chunk["trader_investment_plan"]
|
||||
)
|
||||
message_buffer.update_agent_status("Risky Analyst", AgentStatus.IN_PROGRESS)
|
||||
|
||||
if "risk_debate_state" in chunk and chunk["risk_debate_state"]:
|
||||
risk_state = chunk["risk_debate_state"]
|
||||
|
||||
if "current_risky_response" in risk_state and risk_state["current_risky_response"]:
|
||||
if (
|
||||
"current_risky_response" in risk_state
|
||||
and risk_state["current_risky_response"]
|
||||
):
|
||||
message_buffer.update_agent_status("Risky Analyst", AgentStatus.IN_PROGRESS)
|
||||
message_buffer.add_message(
|
||||
"Reasoning",
|
||||
|
|
@ -214,7 +229,10 @@ def process_chunk_for_display(chunk: dict, selected_analysts: List[AnalystType])
|
|||
f"### Risky Analyst Analysis\n{risk_state['current_risky_response']}",
|
||||
)
|
||||
|
||||
if "current_safe_response" in risk_state and risk_state["current_safe_response"]:
|
||||
if (
|
||||
"current_safe_response" in risk_state
|
||||
and risk_state["current_safe_response"]
|
||||
):
|
||||
message_buffer.update_agent_status("Safe Analyst", AgentStatus.IN_PROGRESS)
|
||||
message_buffer.add_message(
|
||||
"Reasoning",
|
||||
|
|
@ -225,8 +243,13 @@ def process_chunk_for_display(chunk: dict, selected_analysts: List[AnalystType])
|
|||
f"### Safe Analyst Analysis\n{risk_state['current_safe_response']}",
|
||||
)
|
||||
|
||||
if "current_neutral_response" in risk_state and risk_state["current_neutral_response"]:
|
||||
message_buffer.update_agent_status("Neutral Analyst", AgentStatus.IN_PROGRESS)
|
||||
if (
|
||||
"current_neutral_response" in risk_state
|
||||
and risk_state["current_neutral_response"]
|
||||
):
|
||||
message_buffer.update_agent_status(
|
||||
"Neutral Analyst", AgentStatus.IN_PROGRESS
|
||||
)
|
||||
message_buffer.add_message(
|
||||
"Reasoning",
|
||||
f"Neutral Analyst: {risk_state['current_neutral_response']}",
|
||||
|
|
@ -237,7 +260,9 @@ def process_chunk_for_display(chunk: dict, selected_analysts: List[AnalystType])
|
|||
)
|
||||
|
||||
if "judge_decision" in risk_state and risk_state["judge_decision"]:
|
||||
message_buffer.update_agent_status("Portfolio Manager", AgentStatus.IN_PROGRESS)
|
||||
message_buffer.update_agent_status(
|
||||
"Portfolio Manager", AgentStatus.IN_PROGRESS
|
||||
)
|
||||
message_buffer.add_message(
|
||||
"Reasoning",
|
||||
f"Portfolio Manager: {risk_state['judge_decision']}",
|
||||
|
|
@ -249,12 +274,15 @@ def process_chunk_for_display(chunk: dict, selected_analysts: List[AnalystType])
|
|||
message_buffer.update_agent_status("Risky Analyst", AgentStatus.COMPLETED)
|
||||
message_buffer.update_agent_status("Safe Analyst", AgentStatus.COMPLETED)
|
||||
message_buffer.update_agent_status("Neutral Analyst", AgentStatus.COMPLETED)
|
||||
message_buffer.update_agent_status("Portfolio Manager", AgentStatus.COMPLETED)
|
||||
message_buffer.update_agent_status(
|
||||
"Portfolio Manager", AgentStatus.COMPLETED
|
||||
)
|
||||
|
||||
|
||||
def setup_logging_decorators(report_dir, log_file) -> tuple:
|
||||
def save_message_decorator(obj, func_name):
|
||||
func = getattr(obj, func_name)
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
func(*args, **kwargs)
|
||||
|
|
@ -262,10 +290,12 @@ def setup_logging_decorators(report_dir, log_file) -> tuple:
|
|||
content = content.replace("\n", " ")
|
||||
with open(log_file, "a") as f:
|
||||
f.write(f"{timestamp} [{message_type}] {content}\n")
|
||||
|
||||
return wrapper
|
||||
|
||||
def save_tool_call_decorator(obj, func_name):
|
||||
func = getattr(obj, func_name)
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
func(*args, **kwargs)
|
||||
|
|
@ -273,22 +303,32 @@ def setup_logging_decorators(report_dir, log_file) -> tuple:
|
|||
args_str = ", ".join(f"{k}={v}" for k, v in tool_args.items())
|
||||
with open(log_file, "a") as f:
|
||||
f.write(f"{timestamp} [Tool Call] {tool_name}({args_str})\n")
|
||||
|
||||
return wrapper
|
||||
|
||||
def save_report_section_decorator(obj, func_name):
|
||||
func = getattr(obj, func_name)
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(section_name, content):
|
||||
func(section_name, content)
|
||||
if section_name in obj.report_sections and obj.report_sections[section_name] is not None:
|
||||
if (
|
||||
section_name in obj.report_sections
|
||||
and obj.report_sections[section_name] is not None
|
||||
):
|
||||
section_content = obj.report_sections[section_name]
|
||||
if section_content:
|
||||
file_name = f"{section_name}.md"
|
||||
with open(report_dir / file_name, "w") as f:
|
||||
f.write(section_content)
|
||||
|
||||
return wrapper
|
||||
|
||||
return save_message_decorator, save_tool_call_decorator, save_report_section_decorator
|
||||
return (
|
||||
save_message_decorator,
|
||||
save_tool_call_decorator,
|
||||
save_report_section_decorator,
|
||||
)
|
||||
|
||||
|
||||
def run_analysis_for_ticker(ticker: str, config: dict) -> None:
|
||||
|
|
@ -296,8 +336,7 @@ def run_analysis_for_ticker(ticker: str, config: dict) -> None:
|
|||
|
||||
console.print(
|
||||
create_question_box(
|
||||
"Analysts Team",
|
||||
"Select your LLM analyst agents for the analysis"
|
||||
"Analysts Team", "Select your LLM analyst agents for the analysis"
|
||||
)
|
||||
)
|
||||
selected_analysts = select_analysts()
|
||||
|
|
@ -306,18 +345,12 @@ def run_analysis_for_ticker(ticker: str, config: dict) -> None:
|
|||
)
|
||||
|
||||
console.print(
|
||||
create_question_box(
|
||||
"Research Depth",
|
||||
"Select your research depth level"
|
||||
)
|
||||
create_question_box("Research Depth", "Select your research depth level")
|
||||
)
|
||||
selected_research_depth = select_research_depth()
|
||||
|
||||
console.print(
|
||||
create_question_box(
|
||||
"Deep-Thinking Model",
|
||||
"Select the model for deep analysis"
|
||||
)
|
||||
create_question_box("Deep-Thinking Model", "Select the model for deep analysis")
|
||||
)
|
||||
llm_provider = config.get("llm_provider", "openai")
|
||||
selected_deep_thinker = select_deep_thinking_agent(llm_provider.capitalize())
|
||||
|
|
@ -344,11 +377,13 @@ def run_analysis() -> None:
|
|||
selections["ticker"],
|
||||
selections["analysis_date"],
|
||||
selections["analysts"],
|
||||
config
|
||||
config,
|
||||
)
|
||||
|
||||
|
||||
def _run_analysis_with_config(ticker: str, analysis_date: str, selected_analysts: List[AnalystType], config: dict) -> None:
|
||||
def _run_analysis_with_config(
|
||||
ticker: str, analysis_date: str, selected_analysts: list[AnalystType], config: dict
|
||||
) -> None:
|
||||
with loading("Initializing trading agents...", show_elapsed=True):
|
||||
graph = TradingAgentsGraph(
|
||||
[analyst.value for analyst in selected_analysts], config=config, debug=True
|
||||
|
|
@ -361,12 +396,17 @@ def _run_analysis_with_config(ticker: str, analysis_date: str, selected_analysts
|
|||
log_file = results_dir / "message_tool.log"
|
||||
log_file.touch(exist_ok=True)
|
||||
|
||||
save_message_decorator, save_tool_call_decorator, save_report_section_decorator = \
|
||||
save_message_decorator, save_tool_call_decorator, save_report_section_decorator = (
|
||||
setup_logging_decorators(report_dir, log_file)
|
||||
)
|
||||
|
||||
message_buffer.add_message = save_message_decorator(message_buffer, "add_message")
|
||||
message_buffer.add_tool_call = save_tool_call_decorator(message_buffer, "add_tool_call")
|
||||
message_buffer.update_report_section = save_report_section_decorator(message_buffer, "update_report_section")
|
||||
message_buffer.add_tool_call = save_tool_call_decorator(
|
||||
message_buffer, "add_tool_call"
|
||||
)
|
||||
message_buffer.update_report_section = save_report_section_decorator(
|
||||
message_buffer, "update_report_section"
|
||||
)
|
||||
|
||||
layout = create_layout()
|
||||
|
||||
|
|
@ -416,7 +456,9 @@ def _run_analysis_with_config(ticker: str, analysis_date: str, selected_analysts
|
|||
if hasattr(last_message, "tool_calls"):
|
||||
for tool_call in last_message.tool_calls:
|
||||
if isinstance(tool_call, dict):
|
||||
message_buffer.add_tool_call(tool_call["name"], tool_call["args"])
|
||||
message_buffer.add_tool_call(
|
||||
tool_call["name"], tool_call["args"]
|
||||
)
|
||||
else:
|
||||
message_buffer.add_tool_call(tool_call.name, tool_call.args)
|
||||
|
||||
|
|
@ -431,7 +473,9 @@ def _run_analysis_with_config(ticker: str, analysis_date: str, selected_analysts
|
|||
for agent in message_buffer.agent_status:
|
||||
message_buffer.update_agent_status(agent, AgentStatus.COMPLETED)
|
||||
|
||||
message_buffer.add_message("Analysis", f"Completed analysis for {analysis_date}")
|
||||
message_buffer.add_message(
|
||||
"Analysis", f"Completed analysis for {analysis_date}"
|
||||
)
|
||||
|
||||
for section in message_buffer.report_sections.keys():
|
||||
if section in final_state:
|
||||
|
|
|
|||
|
|
@ -1,19 +1,18 @@
|
|||
import datetime
|
||||
from decimal import Decimal
|
||||
from datetime import date as date_type
|
||||
from decimal import Decimal
|
||||
|
||||
import typer
|
||||
from rich import box
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.table import Table
|
||||
from rich import box
|
||||
|
||||
from tradingagents.backtesting import SimpleBacktestEngine
|
||||
from tradingagents.models.backtest import BacktestConfig, BacktestStatus
|
||||
from tradingagents.models.portfolio import PortfolioConfig
|
||||
|
||||
from cli.display import create_question_box
|
||||
from cli.utils import loading
|
||||
from tradingagents.backtesting import SimpleBacktestEngine
|
||||
from tradingagents.models.backtest import BacktestConfig, BacktestStatus
|
||||
from tradingagents.models.portfolio import PortfolioConfig
|
||||
|
||||
console = Console()
|
||||
|
||||
|
|
@ -47,7 +46,7 @@ def rsi_buy(ticker: str, trading_date: date_type, ctx: dict) -> bool:
|
|||
return False
|
||||
changes = []
|
||||
for i in range(1, min(15, len(ohlcv.bars))):
|
||||
changes.append(float(ohlcv.bars[-i].close) - float(ohlcv.bars[-i-1].close))
|
||||
changes.append(float(ohlcv.bars[-i].close) - float(ohlcv.bars[-i - 1].close))
|
||||
gains = [c for c in changes if c > 0]
|
||||
losses = [-c for c in changes if c < 0]
|
||||
avg_gain = sum(gains) / 14 if gains else 0.001
|
||||
|
|
@ -64,7 +63,7 @@ def rsi_sell(ticker: str, trading_date: date_type, ctx: dict) -> bool:
|
|||
return False
|
||||
changes = []
|
||||
for i in range(1, min(15, len(ohlcv.bars))):
|
||||
changes.append(float(ohlcv.bars[-i].close) - float(ohlcv.bars[-i-1].close))
|
||||
changes.append(float(ohlcv.bars[-i].close) - float(ohlcv.bars[-i - 1].close))
|
||||
gains = [c for c in changes if c > 0]
|
||||
losses = [-c for c in changes if c < 0]
|
||||
avg_gain = sum(gains) / 14 if gains else 0.001
|
||||
|
|
@ -97,17 +96,31 @@ def run_backtest(
|
|||
strategy: str = "sma",
|
||||
) -> None:
|
||||
if not ticker:
|
||||
console.print(create_question_box("Ticker Symbol", "Enter the ticker symbol to backtest", "AAPL"))
|
||||
console.print(
|
||||
create_question_box(
|
||||
"Ticker Symbol", "Enter the ticker symbol to backtest", "AAPL"
|
||||
)
|
||||
)
|
||||
ticker = typer.prompt("", default="AAPL")
|
||||
|
||||
if not start_date:
|
||||
default_start = (datetime.datetime.now() - datetime.timedelta(days=365)).strftime("%Y-%m-%d")
|
||||
console.print(create_question_box("Start Date", "Enter backtest start date (YYYY-MM-DD)", default_start))
|
||||
default_start = (
|
||||
datetime.datetime.now() - datetime.timedelta(days=365)
|
||||
).strftime("%Y-%m-%d")
|
||||
console.print(
|
||||
create_question_box(
|
||||
"Start Date", "Enter backtest start date (YYYY-MM-DD)", default_start
|
||||
)
|
||||
)
|
||||
start_date = typer.prompt("", default=default_start)
|
||||
|
||||
if not end_date:
|
||||
default_end = datetime.datetime.now().strftime("%Y-%m-%d")
|
||||
console.print(create_question_box("End Date", "Enter backtest end date (YYYY-MM-DD)", default_end))
|
||||
console.print(
|
||||
create_question_box(
|
||||
"End Date", "Enter backtest end date (YYYY-MM-DD)", default_end
|
||||
)
|
||||
)
|
||||
end_date = typer.prompt("", default=default_end)
|
||||
|
||||
try:
|
||||
|
|
@ -122,19 +135,23 @@ def run_backtest(
|
|||
return
|
||||
|
||||
console.print()
|
||||
console.print(Panel(
|
||||
f"[bold]Backtest Configuration[/bold]\n\n"
|
||||
f"Ticker: [cyan]{ticker.upper()}[/cyan]\n"
|
||||
f"Period: [cyan]{start_date}[/cyan] to [cyan]{end_date}[/cyan]\n"
|
||||
f"Initial Cash: [cyan]${initial_cash:,.2f}[/cyan]\n"
|
||||
f"Strategy: [cyan]{strategy}[/cyan]",
|
||||
title="Configuration",
|
||||
border_style="blue",
|
||||
))
|
||||
console.print(
|
||||
Panel(
|
||||
f"[bold]Backtest Configuration[/bold]\n\n"
|
||||
f"Ticker: [cyan]{ticker.upper()}[/cyan]\n"
|
||||
f"Period: [cyan]{start_date}[/cyan] to [cyan]{end_date}[/cyan]\n"
|
||||
f"Initial Cash: [cyan]${initial_cash:,.2f}[/cyan]\n"
|
||||
f"Strategy: [cyan]{strategy}[/cyan]",
|
||||
title="Configuration",
|
||||
border_style="blue",
|
||||
)
|
||||
)
|
||||
console.print()
|
||||
|
||||
if strategy not in STRATEGIES:
|
||||
console.print(f"[red]Unknown strategy: {strategy}. Use: sma, rsi, or hold[/red]")
|
||||
console.print(
|
||||
f"[red]Unknown strategy: {strategy}. Use: sma, rsi, or hold[/red]"
|
||||
)
|
||||
return
|
||||
|
||||
buy_fn, sell_fn = STRATEGIES[strategy]
|
||||
|
|
@ -170,12 +187,26 @@ def run_backtest(
|
|||
performance_table.add_column("Value", style="green")
|
||||
|
||||
performance_table.add_row("Total Return", f"${float(metrics.total_return):,.2f}")
|
||||
performance_table.add_row("Total Return %", f"{float(metrics.total_return_percent):.2f}%")
|
||||
performance_table.add_row("Annualized Return", f"{float(metrics.annualized_return):.2f}%")
|
||||
performance_table.add_row("Sharpe Ratio", f"{float(metrics.sharpe_ratio):.2f}" if metrics.sharpe_ratio else "N/A")
|
||||
performance_table.add_row("Sortino Ratio", f"{float(metrics.sortino_ratio):.2f}" if metrics.sortino_ratio else "N/A")
|
||||
performance_table.add_row("Max Drawdown", f"{float(metrics.max_drawdown_percent):.2f}%")
|
||||
performance_table.add_row("Volatility (Ann.)", f"{float(metrics.annualized_volatility):.2f}%")
|
||||
performance_table.add_row(
|
||||
"Total Return %", f"{float(metrics.total_return_percent):.2f}%"
|
||||
)
|
||||
performance_table.add_row(
|
||||
"Annualized Return", f"{float(metrics.annualized_return):.2f}%"
|
||||
)
|
||||
performance_table.add_row(
|
||||
"Sharpe Ratio",
|
||||
f"{float(metrics.sharpe_ratio):.2f}" if metrics.sharpe_ratio else "N/A",
|
||||
)
|
||||
performance_table.add_row(
|
||||
"Sortino Ratio",
|
||||
f"{float(metrics.sortino_ratio):.2f}" if metrics.sortino_ratio else "N/A",
|
||||
)
|
||||
performance_table.add_row(
|
||||
"Max Drawdown", f"{float(metrics.max_drawdown_percent):.2f}%"
|
||||
)
|
||||
performance_table.add_row(
|
||||
"Volatility (Ann.)", f"{float(metrics.annualized_volatility):.2f}%"
|
||||
)
|
||||
|
||||
console.print(performance_table)
|
||||
console.print()
|
||||
|
|
@ -187,10 +218,20 @@ def run_backtest(
|
|||
trading_table.add_row("Total Trades", str(trade_log.total_trades))
|
||||
trading_table.add_row("Winning Trades", str(trade_log.winning_trades))
|
||||
trading_table.add_row("Losing Trades", str(trade_log.losing_trades))
|
||||
trading_table.add_row("Win Rate", f"{float(trade_log.win_rate):.1f}%" if trade_log.win_rate else "N/A")
|
||||
trading_table.add_row("Profit Factor", f"{float(trade_log.profit_factor):.2f}" if trade_log.profit_factor else "N/A")
|
||||
trading_table.add_row("Avg Win", f"${float(trade_log.avg_win):,.2f}" if trade_log.avg_win else "N/A")
|
||||
trading_table.add_row("Avg Loss", f"${float(trade_log.avg_loss):,.2f}" if trade_log.avg_loss else "N/A")
|
||||
trading_table.add_row(
|
||||
"Win Rate", f"{float(trade_log.win_rate):.1f}%" if trade_log.win_rate else "N/A"
|
||||
)
|
||||
trading_table.add_row(
|
||||
"Profit Factor",
|
||||
f"{float(trade_log.profit_factor):.2f}" if trade_log.profit_factor else "N/A",
|
||||
)
|
||||
trading_table.add_row(
|
||||
"Avg Win", f"${float(trade_log.avg_win):,.2f}" if trade_log.avg_win else "N/A"
|
||||
)
|
||||
trading_table.add_row(
|
||||
"Avg Loss",
|
||||
f"${float(trade_log.avg_loss):,.2f}" if trade_log.avg_loss else "N/A",
|
||||
)
|
||||
|
||||
console.print(trading_table)
|
||||
console.print()
|
||||
|
|
@ -207,4 +248,4 @@ def run_backtest(
|
|||
console.print(summary_table)
|
||||
console.print()
|
||||
|
||||
console.print(f"[green]Backtest completed successfully![/green]")
|
||||
console.print("[green]Backtest completed successfully![/green]")
|
||||
|
|
|
|||
101
cli/discovery.py
101
cli/discovery.py
|
|
@ -1,31 +1,29 @@
|
|||
import time
|
||||
from typing import Optional, List
|
||||
|
||||
import questionary
|
||||
from rich import box
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.table import Table
|
||||
from rich.rule import Rule
|
||||
from rich import box
|
||||
|
||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||
from tradingagents.dataflows.config import get_config
|
||||
from tradingagents.agents.discovery.models import (
|
||||
DiscoveryRequest,
|
||||
DiscoveryStatus,
|
||||
TrendingStock,
|
||||
Sector,
|
||||
EventCategory,
|
||||
)
|
||||
from tradingagents.agents.discovery.persistence import save_discovery_result
|
||||
from rich.table import Table
|
||||
|
||||
from cli.display import create_question_box
|
||||
from cli.utils import (
|
||||
MultiStageLoader,
|
||||
loading,
|
||||
select_llm_provider,
|
||||
select_shallow_thinking_agent,
|
||||
loading,
|
||||
MultiStageLoader,
|
||||
)
|
||||
from tradingagents.agents.discovery.models import (
|
||||
DiscoveryRequest,
|
||||
DiscoveryStatus,
|
||||
EventCategory,
|
||||
Sector,
|
||||
TrendingStock,
|
||||
)
|
||||
from tradingagents.agents.discovery.persistence import save_discovery_result
|
||||
from tradingagents.dataflows.config import get_config
|
||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||
|
||||
console = Console()
|
||||
|
||||
|
|
@ -60,7 +58,8 @@ def select_lookback_period() -> str:
|
|||
choice = questionary.select(
|
||||
"Select lookback period:",
|
||||
choices=[
|
||||
questionary.Choice(display, value=value) for display, value in LOOKBACK_OPTIONS
|
||||
questionary.Choice(display, value=value)
|
||||
for display, value in LOOKBACK_OPTIONS
|
||||
],
|
||||
instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
|
||||
style=questionary.Style(
|
||||
|
|
@ -79,7 +78,7 @@ def select_lookback_period() -> str:
|
|||
return choice
|
||||
|
||||
|
||||
def select_sector_filter() -> Optional[List[Sector]]:
|
||||
def select_sector_filter() -> list[Sector] | None:
|
||||
use_filter = questionary.confirm(
|
||||
"Filter by sector?",
|
||||
default=False,
|
||||
|
|
@ -97,7 +96,8 @@ def select_sector_filter() -> Optional[List[Sector]]:
|
|||
choices = questionary.checkbox(
|
||||
"Select sectors to include:",
|
||||
choices=[
|
||||
questionary.Choice(display, value=value) for display, value in SECTOR_OPTIONS
|
||||
questionary.Choice(display, value=value)
|
||||
for display, value in SECTOR_OPTIONS
|
||||
],
|
||||
instruction="\n- Press Space to select/unselect\n- Press 'a' to select all\n- Press Enter when done",
|
||||
style=questionary.Style(
|
||||
|
|
@ -116,7 +116,7 @@ def select_sector_filter() -> Optional[List[Sector]]:
|
|||
return choices
|
||||
|
||||
|
||||
def select_event_filter() -> Optional[List[EventCategory]]:
|
||||
def select_event_filter() -> list[EventCategory] | None:
|
||||
use_filter = questionary.confirm(
|
||||
"Filter by event type?",
|
||||
default=False,
|
||||
|
|
@ -153,7 +153,7 @@ def select_event_filter() -> Optional[List[EventCategory]]:
|
|||
return choices
|
||||
|
||||
|
||||
def create_discovery_results_table(trending_stocks: List[TrendingStock]) -> Table:
|
||||
def create_discovery_results_table(trending_stocks: list[TrendingStock]) -> Table:
|
||||
table = Table(
|
||||
show_header=True,
|
||||
header_style="bold magenta",
|
||||
|
|
@ -181,7 +181,9 @@ def create_discovery_results_table(trending_stocks: List[TrendingStock]) -> Tabl
|
|||
table.add_row(
|
||||
rank_display,
|
||||
ticker_display,
|
||||
stock.company_name[:25] if len(stock.company_name) > 25 else stock.company_name,
|
||||
stock.company_name[:25]
|
||||
if len(stock.company_name) > 25
|
||||
else stock.company_name,
|
||||
f"{stock.score:.2f}",
|
||||
str(stock.mention_count),
|
||||
stock.event_type.value.replace("_", " ").title(),
|
||||
|
|
@ -191,8 +193,20 @@ def create_discovery_results_table(trending_stocks: List[TrendingStock]) -> Tabl
|
|||
|
||||
|
||||
def create_stock_detail_panel(stock: TrendingStock, rank: int) -> Panel:
|
||||
sentiment_label = "positive" if stock.sentiment > 0.3 else "negative" if stock.sentiment < -0.3 else "neutral"
|
||||
sentiment_color = "green" if stock.sentiment > 0.3 else "red" if stock.sentiment < -0.3 else "yellow"
|
||||
sentiment_label = (
|
||||
"positive"
|
||||
if stock.sentiment > 0.3
|
||||
else "negative"
|
||||
if stock.sentiment < -0.3
|
||||
else "neutral"
|
||||
)
|
||||
sentiment_color = (
|
||||
"green"
|
||||
if stock.sentiment > 0.3
|
||||
else "red"
|
||||
if stock.sentiment < -0.3
|
||||
else "yellow"
|
||||
)
|
||||
|
||||
content = f"""[bold]Rank #{rank}: {stock.ticker} - {stock.company_name}[/bold]
|
||||
|
||||
|
|
@ -218,14 +232,16 @@ def create_stock_detail_panel(stock: TrendingStock, rank: int) -> Panel:
|
|||
)
|
||||
|
||||
|
||||
def select_stock_for_detail(trending_stocks: List[TrendingStock]) -> Optional[TrendingStock]:
|
||||
def select_stock_for_detail(
|
||||
trending_stocks: list[TrendingStock],
|
||||
) -> TrendingStock | None:
|
||||
if not trending_stocks:
|
||||
return None
|
||||
|
||||
choices = [
|
||||
questionary.Choice(
|
||||
f"{i+1}. {stock.ticker} - {stock.company_name} (Score: {stock.score:.2f})",
|
||||
value=stock
|
||||
value=stock,
|
||||
)
|
||||
for i, stock in enumerate(trending_stocks)
|
||||
]
|
||||
|
|
@ -254,7 +270,7 @@ def discover_trending_flow(run_analysis_callback=None) -> None:
|
|||
console.print(
|
||||
create_question_box(
|
||||
"Step 1: Lookback Period",
|
||||
"Select how far back to search for trending stocks"
|
||||
"Select how far back to search for trending stocks",
|
||||
)
|
||||
)
|
||||
lookback_period = select_lookback_period()
|
||||
|
|
@ -263,34 +279,35 @@ def discover_trending_flow(run_analysis_callback=None) -> None:
|
|||
|
||||
console.print(
|
||||
create_question_box(
|
||||
"Step 2: Sector Filter (Optional)",
|
||||
"Optionally filter results by sector"
|
||||
"Step 2: Sector Filter (Optional)", "Optionally filter results by sector"
|
||||
)
|
||||
)
|
||||
sector_filter = select_sector_filter()
|
||||
if sector_filter:
|
||||
console.print(f"[green]Selected sectors:[/green] {', '.join(s.value for s in sector_filter)}")
|
||||
console.print(
|
||||
f"[green]Selected sectors:[/green] {', '.join(s.value for s in sector_filter)}"
|
||||
)
|
||||
else:
|
||||
console.print("[dim]No sector filter applied[/dim]")
|
||||
console.print()
|
||||
|
||||
console.print(
|
||||
create_question_box(
|
||||
"Step 3: Event Filter (Optional)",
|
||||
"Optionally filter results by event type"
|
||||
"Step 3: Event Filter (Optional)", "Optionally filter results by event type"
|
||||
)
|
||||
)
|
||||
event_filter = select_event_filter()
|
||||
if event_filter:
|
||||
console.print(f"[green]Selected events:[/green] {', '.join(e.value for e in event_filter)}")
|
||||
console.print(
|
||||
f"[green]Selected events:[/green] {', '.join(e.value for e in event_filter)}"
|
||||
)
|
||||
else:
|
||||
console.print("[dim]No event filter applied[/dim]")
|
||||
console.print()
|
||||
|
||||
console.print(
|
||||
create_question_box(
|
||||
"Step 4: LLM Provider",
|
||||
"Select your LLM provider for entity extraction"
|
||||
"Step 4: LLM Provider", "Select your LLM provider for entity extraction"
|
||||
)
|
||||
)
|
||||
selected_llm_provider, backend_url = select_llm_provider()
|
||||
|
|
@ -298,8 +315,7 @@ def discover_trending_flow(run_analysis_callback=None) -> None:
|
|||
|
||||
console.print(
|
||||
create_question_box(
|
||||
"Step 5: Quick-Thinking Model",
|
||||
"Select the model for entity extraction"
|
||||
"Step 5: Quick-Thinking Model", "Select the model for entity extraction"
|
||||
)
|
||||
)
|
||||
selected_model = select_shallow_thinking_agent(selected_llm_provider)
|
||||
|
|
@ -359,13 +375,15 @@ def discover_trending_flow(run_analysis_callback=None) -> None:
|
|||
with loading("Saving discovery results..."):
|
||||
save_path = save_discovery_result(result)
|
||||
console.print(f"\n[dim]Results saved to: {save_path}[/dim]")
|
||||
except (IOError, OSError, ValueError) as e:
|
||||
except (OSError, ValueError) as e:
|
||||
console.print(f"\n[yellow]Warning: Could not save results: {e}[/yellow]")
|
||||
|
||||
console.print()
|
||||
|
||||
if not result.trending_stocks:
|
||||
console.print("[yellow]No trending stocks found matching your criteria.[/yellow]")
|
||||
console.print(
|
||||
"[yellow]No trending stocks found matching your criteria.[/yellow]"
|
||||
)
|
||||
return
|
||||
|
||||
console.print(f"[green]Found {len(result.trending_stocks)} trending stocks[/green]")
|
||||
|
|
@ -400,7 +418,10 @@ def discover_trending_flow(run_analysis_callback=None) -> None:
|
|||
|
||||
if analyze_choice and run_analysis_callback:
|
||||
console.print()
|
||||
with loading(f"Preparing analysis for {selected_stock.ticker}...", spinner_style="loading"):
|
||||
with loading(
|
||||
f"Preparing analysis for {selected_stock.ticker}...",
|
||||
spinner_style="loading",
|
||||
):
|
||||
time.sleep(0.5)
|
||||
run_analysis_callback(selected_stock.ticker, config)
|
||||
break
|
||||
|
|
|
|||
|
|
@ -1,16 +1,16 @@
|
|||
from typing import Optional, Dict, Any
|
||||
from typing import Any
|
||||
|
||||
from rich import box
|
||||
from rich.columns import Columns
|
||||
from rich.console import Console
|
||||
from cli.models import AgentStatus
|
||||
from rich.layout import Layout
|
||||
from rich.markdown import Markdown
|
||||
from rich.panel import Panel
|
||||
from rich.spinner import Spinner
|
||||
from rich.markdown import Markdown
|
||||
from rich.layout import Layout
|
||||
from rich.text import Text
|
||||
from rich.table import Table
|
||||
from rich.columns import Columns
|
||||
from rich import box
|
||||
from rich.text import Text
|
||||
|
||||
from cli.models import AgentStatus
|
||||
from cli.state import message_buffer
|
||||
|
||||
console = Console()
|
||||
|
|
@ -32,7 +32,7 @@ def create_layout() -> Layout:
|
|||
return layout
|
||||
|
||||
|
||||
def update_display(layout: Layout, spinner_text: Optional[str] = None) -> None:
|
||||
def update_display(layout: Layout, spinner_text: str | None = None) -> None:
|
||||
layout["header"].update(
|
||||
Panel(
|
||||
"[bold green]Welcome to TradingAgents CLI[/bold green]\n"
|
||||
|
|
@ -135,13 +135,13 @@ def update_display(layout: Layout, spinner_text: Optional[str] = None) -> None:
|
|||
text_parts = []
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
if item.get('type') == 'text':
|
||||
text_parts.append(item.get('text', ''))
|
||||
elif item.get('type') == 'tool_use':
|
||||
if item.get("type") == "text":
|
||||
text_parts.append(item.get("text", ""))
|
||||
elif item.get("type") == "tool_use":
|
||||
text_parts.append(f"[Tool: {item.get('name', 'unknown')}]")
|
||||
else:
|
||||
text_parts.append(str(item))
|
||||
content_str = ' '.join(text_parts)
|
||||
content_str = " ".join(text_parts)
|
||||
elif not isinstance(content_str, str):
|
||||
content_str = str(content)
|
||||
|
||||
|
|
@ -210,7 +210,7 @@ def update_display(layout: Layout, spinner_text: Optional[str] = None) -> None:
|
|||
layout["footer"].update(Panel(stats_table, border_style="grey50"))
|
||||
|
||||
|
||||
def display_complete_report(final_state: Dict[str, Any]) -> None:
|
||||
def display_complete_report(final_state: dict[str, Any]) -> None:
|
||||
console.print("\n[bold green]Complete Analysis Report[/bold green]\n")
|
||||
|
||||
analyst_reports = []
|
||||
|
|
@ -397,18 +397,18 @@ def extract_content_string(content: Any) -> str:
|
|||
text_parts = []
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
if item.get('type') == 'text':
|
||||
text_parts.append(item.get('text', ''))
|
||||
elif item.get('type') == 'tool_use':
|
||||
if item.get("type") == "text":
|
||||
text_parts.append(item.get("text", ""))
|
||||
elif item.get("type") == "tool_use":
|
||||
text_parts.append(f"[Tool: {item.get('name', 'unknown')}]")
|
||||
else:
|
||||
text_parts.append(str(item))
|
||||
return ' '.join(text_parts)
|
||||
return " ".join(text_parts)
|
||||
else:
|
||||
return str(content)
|
||||
|
||||
|
||||
def create_question_box(title: str, prompt: str, default: Optional[str] = None) -> Panel:
|
||||
def create_question_box(title: str, prompt: str, default: str | None = None) -> Panel:
|
||||
box_content = f"[bold]{title}[/bold]\n"
|
||||
box_content += f"[dim]{prompt}[/dim]"
|
||||
if default:
|
||||
|
|
|
|||
28
cli/main.py
28
cli/main.py
|
|
@ -3,14 +3,14 @@ from dotenv import load_dotenv
|
|||
|
||||
load_dotenv()
|
||||
|
||||
import questionary
|
||||
from rich.align import Align
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.align import Align
|
||||
import questionary
|
||||
|
||||
from cli.analysis import run_analysis, run_analysis_for_ticker
|
||||
from cli.discovery import discover_trending_flow
|
||||
from cli.backtest_cmd import run_backtest
|
||||
from cli.discovery import discover_trending_flow
|
||||
|
||||
console = Console()
|
||||
|
||||
|
|
@ -22,7 +22,7 @@ app = typer.Typer(
|
|||
|
||||
|
||||
def show_main_menu():
|
||||
with open("./cli/static/welcome.txt", "r") as f:
|
||||
with open("./cli/static/welcome.txt") as f:
|
||||
welcome_ascii = f.read()
|
||||
|
||||
welcome_content = f"{welcome_ascii}\n"
|
||||
|
|
@ -30,7 +30,9 @@ def show_main_menu():
|
|||
welcome_content += "[bold]Available Options:[/bold]\n"
|
||||
welcome_content += "1. Analyze a specific stock\n"
|
||||
welcome_content += "2. Discover trending stocks\n\n"
|
||||
welcome_content += "[dim]Built by Tauric Research (https://github.com/TauricResearch)[/dim]"
|
||||
welcome_content += (
|
||||
"[dim]Built by Tauric Research (https://github.com/TauricResearch)[/dim]"
|
||||
)
|
||||
|
||||
welcome_box = Panel(
|
||||
welcome_content,
|
||||
|
|
@ -90,11 +92,19 @@ def menu():
|
|||
|
||||
@app.command()
|
||||
def backtest(
|
||||
ticker: str = typer.Option(None, "--ticker", "-t", help="Ticker symbol to backtest"),
|
||||
start_date: str = typer.Option(None, "--start", "-s", help="Start date (YYYY-MM-DD)"),
|
||||
ticker: str = typer.Option(
|
||||
None, "--ticker", "-t", help="Ticker symbol to backtest"
|
||||
),
|
||||
start_date: str = typer.Option(
|
||||
None, "--start", "-s", help="Start date (YYYY-MM-DD)"
|
||||
),
|
||||
end_date: str = typer.Option(None, "--end", "-e", help="End date (YYYY-MM-DD)"),
|
||||
initial_cash: float = typer.Option(100000.0, "--cash", "-c", help="Initial portfolio cash"),
|
||||
strategy: str = typer.Option("sma", "--strategy", help="Strategy: sma, rsi, or hold"),
|
||||
initial_cash: float = typer.Option(
|
||||
100000.0, "--cash", "-c", help="Initial portfolio cash"
|
||||
),
|
||||
strategy: str = typer.Option(
|
||||
"sma", "--strategy", help="Strategy: sma, rsi, or hold"
|
||||
),
|
||||
):
|
||||
run_backtest(
|
||||
ticker=ticker,
|
||||
|
|
|
|||
10
cli/state.py
10
cli/state.py
|
|
@ -1,17 +1,17 @@
|
|||
import datetime
|
||||
from collections import deque
|
||||
from typing import Dict, Any, Deque
|
||||
from typing import Any
|
||||
|
||||
from cli.models import AgentStatus
|
||||
|
||||
|
||||
class MessageBuffer:
|
||||
def __init__(self, max_length: int = 100) -> None:
|
||||
self.messages: Deque = deque(maxlen=max_length)
|
||||
self.tool_calls: Deque = deque(maxlen=max_length)
|
||||
self.messages: deque = deque(maxlen=max_length)
|
||||
self.tool_calls: deque = deque(maxlen=max_length)
|
||||
self.current_report = None
|
||||
self.final_report = None
|
||||
self.agent_status: Dict[str, AgentStatus] = {
|
||||
self.agent_status: dict[str, AgentStatus] = {
|
||||
"Market Analyst": AgentStatus.PENDING,
|
||||
"Social Analyst": AgentStatus.PENDING,
|
||||
"News Analyst": AgentStatus.PENDING,
|
||||
|
|
@ -40,7 +40,7 @@ class MessageBuffer:
|
|||
timestamp = datetime.datetime.now().strftime("%H:%M:%S")
|
||||
self.messages.append((timestamp, message_type, content))
|
||||
|
||||
def add_tool_call(self, tool_name: str, args: Dict[str, Any]) -> None:
|
||||
def add_tool_call(self, tool_name: str, args: dict[str, Any]) -> None:
|
||||
timestamp = datetime.datetime.now().strftime("%H:%M:%S")
|
||||
self.tool_calls.append((timestamp, tool_name, args))
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
|
||||
______ ___ ___ __
|
||||
______ ___ ___ __
|
||||
/_ __/________ _____/ (_)___ ____ _/ | ____ ____ ____ / /______
|
||||
/ / / ___/ __ `/ __ / / __ \/ __ `/ /| |/ __ `/ _ \/ __ \/ __/ ___/
|
||||
/ / / / / /_/ / /_/ / / / / / /_/ / ___ / /_/ / __/ / / / /_(__ )
|
||||
/_/ /_/ \__,_/\__,_/_/_/ /_/\__, /_/ |_\__, /\___/_/ /_/\__/____/
|
||||
/____/ /____/
|
||||
/ / / / / /_/ / /_/ / / / / / /_/ / ___ / /_/ / __/ / / / /_(__ )
|
||||
/_/ /_/ \__,_/\__,_/_/_/ /_/\__, /_/ |_\__, /\___/_/ /_/\__/____/
|
||||
/____/ /____/
|
||||
|
|
|
|||
148
cli/utils.py
148
cli/utils.py
|
|
@ -1,16 +1,17 @@
|
|||
import questionary
|
||||
from typing import List, Optional, Callable, Any
|
||||
from contextlib import contextmanager
|
||||
from functools import wraps
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from contextlib import contextmanager
|
||||
from functools import wraps
|
||||
from typing import Any
|
||||
|
||||
import questionary
|
||||
from rich.align import Align
|
||||
from rich.console import Console
|
||||
from rich.spinner import Spinner
|
||||
from rich.live import Live
|
||||
from rich.panel import Panel
|
||||
from rich.spinner import Spinner
|
||||
from rich.text import Text
|
||||
from rich.align import Align
|
||||
|
||||
from cli.models import AnalystType
|
||||
|
||||
|
|
@ -73,7 +74,9 @@ class LoadingIndicator:
|
|||
)
|
||||
self._live.start()
|
||||
if self.show_elapsed:
|
||||
self._update_thread = threading.Thread(target=self._update_loop, daemon=True)
|
||||
self._update_thread = threading.Thread(
|
||||
target=self._update_loop, daemon=True
|
||||
)
|
||||
self._update_thread.start()
|
||||
|
||||
def stop(self):
|
||||
|
|
@ -94,8 +97,8 @@ def loading(
|
|||
message: str = "Working...",
|
||||
spinner_style: str = "default",
|
||||
show_elapsed: bool = False,
|
||||
success_message: Optional[str] = None,
|
||||
error_message: Optional[str] = None,
|
||||
success_message: str | None = None,
|
||||
error_message: str | None = None,
|
||||
):
|
||||
indicator = LoadingIndicator(
|
||||
message=message,
|
||||
|
|
@ -119,7 +122,7 @@ def with_loading(
|
|||
message: str = "Working...",
|
||||
spinner_style: str = "default",
|
||||
show_elapsed: bool = False,
|
||||
success_message: Optional[str] = None,
|
||||
success_message: str | None = None,
|
||||
):
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@wraps(func)
|
||||
|
|
@ -131,12 +134,14 @@ def with_loading(
|
|||
success_message=success_message,
|
||||
):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class MultiStageLoader:
|
||||
def __init__(self, stages: List[str], title: str = "Progress"):
|
||||
def __init__(self, stages: list[str], title: str = "Progress"):
|
||||
self.stages = stages
|
||||
self.title = title
|
||||
self.current_stage = 0
|
||||
|
|
@ -155,6 +160,7 @@ class MultiStageLoader:
|
|||
lines.append(Text(f" [ -- ] {stage}", style="dim"))
|
||||
|
||||
from rich.console import Group
|
||||
|
||||
content = Group(*lines)
|
||||
|
||||
elapsed = ""
|
||||
|
|
@ -194,6 +200,7 @@ class MultiStageLoader:
|
|||
self.stop()
|
||||
return False
|
||||
|
||||
|
||||
ANALYST_ORDER = [
|
||||
("Market Analyst", AnalystType.MARKET),
|
||||
("Social Media Analyst", AnalystType.SOCIAL),
|
||||
|
|
@ -255,7 +262,7 @@ def get_analysis_date() -> str:
|
|||
return date.strip()
|
||||
|
||||
|
||||
def select_analysts() -> List[AnalystType]:
|
||||
def select_analysts() -> list[AnalystType]:
|
||||
"""Select analysts using an interactive checkbox."""
|
||||
choices = questionary.checkbox(
|
||||
"Select Your [Analysts Team]:",
|
||||
|
|
@ -320,30 +327,60 @@ def select_shallow_thinking_agent(provider) -> str:
|
|||
SHALLOW_AGENT_OPTIONS = {
|
||||
"openai": [
|
||||
("GPT-4o-mini - Fast and efficient for quick tasks", "gpt-4o-mini"),
|
||||
("GPT-4.1-nano - Ultra-lightweight model for basic operations", "gpt-4.1-nano"),
|
||||
(
|
||||
"GPT-4.1-nano - Ultra-lightweight model for basic operations",
|
||||
"gpt-4.1-nano",
|
||||
),
|
||||
("GPT-4.1-mini - Compact model with good performance", "gpt-4.1-mini"),
|
||||
("GPT-4o - Standard model with solid capabilities", "gpt-4o"),
|
||||
],
|
||||
"anthropic": [
|
||||
("Claude Haiku 3.5 - Fast inference and standard capabilities", "claude-3-5-haiku-latest"),
|
||||
("Claude Sonnet 3.5 - Highly capable standard model", "claude-3-5-sonnet-latest"),
|
||||
("Claude Sonnet 3.7 - Exceptional hybrid reasoning and agentic capabilities", "claude-3-7-sonnet-latest"),
|
||||
("Claude Sonnet 4 - High performance and excellent reasoning", "claude-sonnet-4-0"),
|
||||
(
|
||||
"Claude Haiku 3.5 - Fast inference and standard capabilities",
|
||||
"claude-3-5-haiku-latest",
|
||||
),
|
||||
(
|
||||
"Claude Sonnet 3.5 - Highly capable standard model",
|
||||
"claude-3-5-sonnet-latest",
|
||||
),
|
||||
(
|
||||
"Claude Sonnet 3.7 - Exceptional hybrid reasoning and agentic capabilities",
|
||||
"claude-3-7-sonnet-latest",
|
||||
),
|
||||
(
|
||||
"Claude Sonnet 4 - High performance and excellent reasoning",
|
||||
"claude-sonnet-4-0",
|
||||
),
|
||||
],
|
||||
"google": [
|
||||
("Gemini 2.0 Flash-Lite - Cost efficiency and low latency", "gemini-2.0-flash-lite"),
|
||||
("Gemini 2.0 Flash - Next generation features, speed, and thinking", "gemini-2.0-flash"),
|
||||
("Gemini 2.5 Flash - Adaptive thinking, cost efficiency", "gemini-2.5-flash-preview-05-20"),
|
||||
(
|
||||
"Gemini 2.0 Flash-Lite - Cost efficiency and low latency",
|
||||
"gemini-2.0-flash-lite",
|
||||
),
|
||||
(
|
||||
"Gemini 2.0 Flash - Next generation features, speed, and thinking",
|
||||
"gemini-2.0-flash",
|
||||
),
|
||||
(
|
||||
"Gemini 2.5 Flash - Adaptive thinking, cost efficiency",
|
||||
"gemini-2.5-flash-preview-05-20",
|
||||
),
|
||||
],
|
||||
"openrouter": [
|
||||
("Meta: Llama 4 Scout", "meta-llama/llama-4-scout:free"),
|
||||
("Meta: Llama 3.3 8B Instruct - A lightweight and ultra-fast variant of Llama 3.3 70B", "meta-llama/llama-3.3-8b-instruct:free"),
|
||||
("google/gemini-2.0-flash-exp:free - Gemini Flash 2.0 offers a significantly faster time to first token", "google/gemini-2.0-flash-exp:free"),
|
||||
(
|
||||
"Meta: Llama 3.3 8B Instruct - A lightweight and ultra-fast variant of Llama 3.3 70B",
|
||||
"meta-llama/llama-3.3-8b-instruct:free",
|
||||
),
|
||||
(
|
||||
"google/gemini-2.0-flash-exp:free - Gemini Flash 2.0 offers a significantly faster time to first token",
|
||||
"google/gemini-2.0-flash-exp:free",
|
||||
),
|
||||
],
|
||||
"ollama": [
|
||||
("llama3.1 local", "llama3.1"),
|
||||
("llama3.2 local", "llama3.2"),
|
||||
]
|
||||
],
|
||||
}
|
||||
|
||||
choice = questionary.select(
|
||||
|
|
@ -377,7 +414,10 @@ def select_deep_thinking_agent(provider) -> str:
|
|||
# Define deep thinking llm engine options with their corresponding model names
|
||||
DEEP_AGENT_OPTIONS = {
|
||||
"openai": [
|
||||
("GPT-4.1-nano - Ultra-lightweight model for basic operations", "gpt-4.1-nano"),
|
||||
(
|
||||
"GPT-4.1-nano - Ultra-lightweight model for basic operations",
|
||||
"gpt-4.1-nano",
|
||||
),
|
||||
("GPT-4.1-mini - Compact model with good performance", "gpt-4.1-mini"),
|
||||
("GPT-4o - Standard model with solid capabilities", "gpt-4o"),
|
||||
("o4-mini - Specialized reasoning model (compact)", "o4-mini"),
|
||||
|
|
@ -386,28 +426,55 @@ def select_deep_thinking_agent(provider) -> str:
|
|||
("o1 - Premier reasoning and problem-solving model", "o1"),
|
||||
],
|
||||
"anthropic": [
|
||||
("Claude Haiku 3.5 - Fast inference and standard capabilities", "claude-3-5-haiku-latest"),
|
||||
("Claude Sonnet 3.5 - Highly capable standard model", "claude-3-5-sonnet-latest"),
|
||||
("Claude Sonnet 3.7 - Exceptional hybrid reasoning and agentic capabilities", "claude-3-7-sonnet-latest"),
|
||||
("Claude Sonnet 4 - High performance and excellent reasoning", "claude-sonnet-4-0"),
|
||||
(
|
||||
"Claude Haiku 3.5 - Fast inference and standard capabilities",
|
||||
"claude-3-5-haiku-latest",
|
||||
),
|
||||
(
|
||||
"Claude Sonnet 3.5 - Highly capable standard model",
|
||||
"claude-3-5-sonnet-latest",
|
||||
),
|
||||
(
|
||||
"Claude Sonnet 3.7 - Exceptional hybrid reasoning and agentic capabilities",
|
||||
"claude-3-7-sonnet-latest",
|
||||
),
|
||||
(
|
||||
"Claude Sonnet 4 - High performance and excellent reasoning",
|
||||
"claude-sonnet-4-0",
|
||||
),
|
||||
("Claude Opus 4 - Most powerful Anthropic model", " claude-opus-4-0"),
|
||||
],
|
||||
"google": [
|
||||
("Gemini 2.0 Flash-Lite - Cost efficiency and low latency", "gemini-2.0-flash-lite"),
|
||||
("Gemini 2.0 Flash - Next generation features, speed, and thinking", "gemini-2.0-flash"),
|
||||
("Gemini 2.5 Flash - Adaptive thinking, cost efficiency", "gemini-2.5-flash-preview-05-20"),
|
||||
(
|
||||
"Gemini 2.0 Flash-Lite - Cost efficiency and low latency",
|
||||
"gemini-2.0-flash-lite",
|
||||
),
|
||||
(
|
||||
"Gemini 2.0 Flash - Next generation features, speed, and thinking",
|
||||
"gemini-2.0-flash",
|
||||
),
|
||||
(
|
||||
"Gemini 2.5 Flash - Adaptive thinking, cost efficiency",
|
||||
"gemini-2.5-flash-preview-05-20",
|
||||
),
|
||||
("Gemini 2.5 Pro", "gemini-2.5-pro-preview-06-05"),
|
||||
],
|
||||
"openrouter": [
|
||||
("DeepSeek V3 - a 685B-parameter, mixture-of-experts model", "deepseek/deepseek-chat-v3-0324:free"),
|
||||
("Deepseek - latest iteration of the flagship chat model family from the DeepSeek team.", "deepseek/deepseek-chat-v3-0324:free"),
|
||||
(
|
||||
"DeepSeek V3 - a 685B-parameter, mixture-of-experts model",
|
||||
"deepseek/deepseek-chat-v3-0324:free",
|
||||
),
|
||||
(
|
||||
"Deepseek - latest iteration of the flagship chat model family from the DeepSeek team.",
|
||||
"deepseek/deepseek-chat-v3-0324:free",
|
||||
),
|
||||
],
|
||||
"ollama": [
|
||||
("llama3.1 local", "llama3.1"),
|
||||
("qwen3", "qwen3"),
|
||||
]
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
choice = questionary.select(
|
||||
"Select Your [Deep-Thinking LLM Engine]:",
|
||||
choices=[
|
||||
|
|
@ -430,6 +497,7 @@ def select_deep_thinking_agent(provider) -> str:
|
|||
|
||||
return choice
|
||||
|
||||
|
||||
def select_llm_provider() -> tuple[str, str]:
|
||||
"""Select the OpenAI api url using interactive selection."""
|
||||
# Define OpenAI api options with their corresponding endpoints
|
||||
|
|
@ -438,9 +506,9 @@ def select_llm_provider() -> tuple[str, str]:
|
|||
("Anthropic", "https://api.anthropic.com/"),
|
||||
("Google", "https://generativelanguage.googleapis.com/v1"),
|
||||
("Openrouter", "https://openrouter.ai/api/v1"),
|
||||
("Ollama", "http://localhost:11434/v1"),
|
||||
("Ollama", "http://localhost:11434/v1"),
|
||||
]
|
||||
|
||||
|
||||
choice = questionary.select(
|
||||
"Select your LLM Provider:",
|
||||
choices=[
|
||||
|
|
@ -456,12 +524,12 @@ def select_llm_provider() -> tuple[str, str]:
|
|||
]
|
||||
),
|
||||
).ask()
|
||||
|
||||
|
||||
if choice is None:
|
||||
console.print("\n[red]no OpenAI backend selected. Exiting...[/red]")
|
||||
exit(1)
|
||||
|
||||
|
||||
display_name, url = choice
|
||||
print(f"You selected: {display_name}\tURL: {url}")
|
||||
|
||||
|
||||
return display_name, url
|
||||
|
|
|
|||
14
main.py
14
main.py
|
|
@ -1,8 +1,8 @@
|
|||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||
from tradingagents.default_config import DEFAULT_CONFIG
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from tradingagents.default_config import DEFAULT_CONFIG
|
||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||
|
||||
# Load environment variables from .env file
|
||||
load_dotenv()
|
||||
|
||||
|
|
@ -14,10 +14,10 @@ config["max_debate_rounds"] = 1 # Increase debate rounds
|
|||
|
||||
# Configure data vendors (default uses yfinance and alpha_vantage)
|
||||
config["data_vendors"] = {
|
||||
"core_stock_apis": "yfinance", # Options: yfinance, alpha_vantage, local
|
||||
"technical_indicators": "yfinance", # Options: yfinance, alpha_vantage, local
|
||||
"fundamental_data": "alpha_vantage", # Options: openai, alpha_vantage, local
|
||||
"news_data": "alpha_vantage", # Options: openai, alpha_vantage, google, local
|
||||
"core_stock_apis": "yfinance", # Options: yfinance, alpha_vantage, local
|
||||
"technical_indicators": "yfinance", # Options: yfinance, alpha_vantage, local
|
||||
"fundamental_data": "alpha_vantage", # Options: openai, alpha_vantage, local
|
||||
"news_data": "alpha_vantage", # Options: openai, alpha_vantage, google, local
|
||||
}
|
||||
|
||||
# Initialize with custom config
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ description = "Add your description here"
|
|||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"sqlalchemy>=2.0.0",
|
||||
"akshare>=1.16.98",
|
||||
"backtrader>=1.9.78.123",
|
||||
"chainlit>=2.5.5",
|
||||
|
|
@ -33,3 +34,85 @@ dependencies = [
|
|||
"typing-extensions>=4.14.0",
|
||||
"yfinance>=0.2.63",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"pre-commit>=3.8.0",
|
||||
"ruff>=0.8.2",
|
||||
"mypy>=1.13.0",
|
||||
"pytest>=8.3.0",
|
||||
"pytest-cov>=6.0.0",
|
||||
"types-requests>=2.32.0",
|
||||
"types-pytz>=2024.2.0",
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py310"
|
||||
line-length = 88
|
||||
exclude = [
|
||||
".git",
|
||||
".venv",
|
||||
"__pycache__",
|
||||
"build",
|
||||
"dist",
|
||||
]
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = [
|
||||
"E",
|
||||
"F",
|
||||
"W",
|
||||
"I",
|
||||
"UP",
|
||||
"B",
|
||||
"C4",
|
||||
"SIM",
|
||||
]
|
||||
ignore = [
|
||||
"E501",
|
||||
"E402",
|
||||
"E712",
|
||||
"B006",
|
||||
"B007",
|
||||
"B008",
|
||||
"B904",
|
||||
"C416",
|
||||
"C901",
|
||||
"SIM102",
|
||||
"SIM105",
|
||||
"SIM118",
|
||||
"SIM222",
|
||||
"UP035",
|
||||
"UP038",
|
||||
"F401",
|
||||
"F403",
|
||||
"F405",
|
||||
"F841",
|
||||
]
|
||||
unfixable = ["F401"]
|
||||
|
||||
[tool.ruff.lint.isort]
|
||||
known-first-party = ["tradingagents", "cli"]
|
||||
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
"tests/*" = ["F841"]
|
||||
"tradingagents/agents/utils/agent_utils.py" = ["F401"]
|
||||
"tradingagents/agents/__init__.py" = ["F401"]
|
||||
"tradingagents/dataflows/__init__.py" = ["F401"]
|
||||
"tradingagents/models/__init__.py" = ["F401"]
|
||||
"tradingagents/backtesting/__init__.py" = ["F401"]
|
||||
"tradingagents/agents/discovery/__init__.py" = ["F401"]
|
||||
|
||||
[tool.mypy]
|
||||
python_version = "3.10"
|
||||
ignore_missing_imports = true
|
||||
warn_return_any = false
|
||||
warn_unused_ignores = false
|
||||
check_untyped_defs = false
|
||||
disallow_untyped_defs = false
|
||||
disallow_incomplete_defs = false
|
||||
no_implicit_optional = false
|
||||
strict_optional = false
|
||||
exclude = ["tests/", "build/", "dist/", ".venv/"]
|
||||
explicit_package_bases = true
|
||||
mypy_path = "."
|
||||
|
|
|
|||
|
|
@ -0,0 +1,23 @@
|
|||
[pytest]
|
||||
testpaths = tests
|
||||
python_files = test_*.py
|
||||
python_classes = Test*
|
||||
python_functions = test_*
|
||||
|
||||
markers =
|
||||
unit: mark test as a unit test (fast, isolated)
|
||||
integration: mark test as an integration test (multi-component)
|
||||
e2e: mark test as an end-to-end test (full workflow)
|
||||
slow: mark test as slow-running (>5s)
|
||||
external_api: mark test as requiring external API calls
|
||||
llm: mark test as requiring LLM calls
|
||||
|
||||
addopts =
|
||||
-v
|
||||
--strict-markers
|
||||
--tb=short
|
||||
-ra
|
||||
|
||||
filterwarnings =
|
||||
ignore::DeprecationWarning
|
||||
ignore::PendingDeprecationWarning
|
||||
2
setup.py
2
setup.py
|
|
@ -2,7 +2,7 @@
|
|||
Setup script for the TradingAgents package.
|
||||
"""
|
||||
|
||||
from setuptools import setup, find_packages
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
setup(
|
||||
name="tradingagents",
|
||||
|
|
|
|||
5
test.py
5
test.py
|
|
@ -1,5 +1,8 @@
|
|||
import time
|
||||
from tradingagents.dataflows.y_finance import get_YFin_data_online, get_stock_stats_indicators_window, get_balance_sheet as get_yfinance_balance_sheet, get_cashflow as get_yfinance_cashflow, get_income_statement as get_yfinance_income_statement, get_insider_transactions as get_yfinance_insider_transactions
|
||||
|
||||
from tradingagents.dataflows.y_finance import (
|
||||
get_stock_stats_indicators_window,
|
||||
)
|
||||
|
||||
print("Testing optimized implementation with 30-day lookback:")
|
||||
start_time = time.time()
|
||||
|
|
|
|||
|
|
@ -1,11 +1,3 @@
|
|||
import pytest
|
||||
from tradingagents.agents.utils.agent_states import (
|
||||
InvestDebateState,
|
||||
RiskDebateState,
|
||||
AgentState,
|
||||
)
|
||||
|
||||
|
||||
class TestInvestDebateState:
|
||||
"""Test suite for InvestDebateState TypedDict."""
|
||||
|
||||
|
|
@ -19,7 +11,7 @@ class TestInvestDebateState:
|
|||
"judge_decision": "Final decision",
|
||||
"count": 3,
|
||||
}
|
||||
|
||||
|
||||
assert state["bull_history"] == "Bull argument 1\nBull argument 2"
|
||||
assert state["bear_history"] == "Bear argument 1\nBear argument 2"
|
||||
assert state["history"] == "Combined history"
|
||||
|
|
@ -37,7 +29,7 @@ class TestInvestDebateState:
|
|||
"judge_decision": "",
|
||||
"count": 0,
|
||||
}
|
||||
|
||||
|
||||
assert state["bull_history"] == ""
|
||||
assert state["bear_history"] == ""
|
||||
assert state["count"] == 0
|
||||
|
|
@ -59,7 +51,7 @@ class TestInvestDebateState:
|
|||
"""Test InvestDebateState with multiline conversation histories."""
|
||||
bull_history = "\n".join([f"Bull point {i}" for i in range(5)])
|
||||
bear_history = "\n".join([f"Bear point {i}" for i in range(5)])
|
||||
|
||||
|
||||
state = {
|
||||
"bull_history": bull_history,
|
||||
"bear_history": bear_history,
|
||||
|
|
@ -68,7 +60,7 @@ class TestInvestDebateState:
|
|||
"judge_decision": "Final",
|
||||
"count": 5,
|
||||
}
|
||||
|
||||
|
||||
assert state["bull_history"].count("\n") == 4
|
||||
assert state["bear_history"].count("\n") == 4
|
||||
|
||||
|
|
@ -90,7 +82,7 @@ class TestRiskDebateState:
|
|||
"judge_decision": "Portfolio manager decision",
|
||||
"count": 2,
|
||||
}
|
||||
|
||||
|
||||
assert state["risky_history"] == "Risky analysis 1"
|
||||
assert state["safe_history"] == "Safe analysis 1"
|
||||
assert state["neutral_history"] == "Neutral analysis 1"
|
||||
|
|
@ -101,7 +93,7 @@ class TestRiskDebateState:
|
|||
def test_risk_debate_state_speaker_variations(self):
|
||||
"""Test RiskDebateState with different speaker values."""
|
||||
speakers = ["risky", "safe", "neutral", "judge"]
|
||||
|
||||
|
||||
for speaker in speakers:
|
||||
state = {
|
||||
"risky_history": "Risky",
|
||||
|
|
@ -131,7 +123,7 @@ class TestRiskDebateState:
|
|||
"judge_decision": "",
|
||||
"count": 0,
|
||||
}
|
||||
|
||||
|
||||
assert state["current_risky_response"] == ""
|
||||
assert state["current_safe_response"] == ""
|
||||
assert state["current_neutral_response"] == ""
|
||||
|
|
@ -141,7 +133,7 @@ class TestRiskDebateState:
|
|||
risky_history = "\n".join([f"Risky round {i}" for i in range(10)])
|
||||
safe_history = "\n".join([f"Safe round {i}" for i in range(10)])
|
||||
neutral_history = "\n".join([f"Neutral round {i}" for i in range(10)])
|
||||
|
||||
|
||||
state = {
|
||||
"risky_history": risky_history,
|
||||
"safe_history": safe_history,
|
||||
|
|
@ -154,7 +146,7 @@ class TestRiskDebateState:
|
|||
"judge_decision": "Final decision",
|
||||
"count": 10,
|
||||
}
|
||||
|
||||
|
||||
assert len(state["risky_history"].split("\n")) == 10
|
||||
assert len(state["safe_history"].split("\n")) == 10
|
||||
assert len(state["neutral_history"].split("\n")) == 10
|
||||
|
|
@ -171,7 +163,7 @@ class TestAgentState:
|
|||
"trade_date": "2024-01-15",
|
||||
"sender": "market_analyst",
|
||||
}
|
||||
|
||||
|
||||
assert state["company_of_interest"] == "AAPL"
|
||||
assert state["trade_date"] == "2024-01-15"
|
||||
assert state["sender"] == "market_analyst"
|
||||
|
|
@ -188,7 +180,7 @@ class TestAgentState:
|
|||
"news_report": "Recent news about Tesla",
|
||||
"fundamentals_report": "Strong fundamentals",
|
||||
}
|
||||
|
||||
|
||||
assert state["market_report"] == "Market analysis for TSLA"
|
||||
assert state["sentiment_report"] == "Social sentiment positive"
|
||||
assert state["news_report"] == "Recent news about Tesla"
|
||||
|
|
@ -204,7 +196,7 @@ class TestAgentState:
|
|||
"judge_decision": "Decision",
|
||||
"count": 2,
|
||||
}
|
||||
|
||||
|
||||
risk_debate = {
|
||||
"risky_history": "Risky analysis",
|
||||
"safe_history": "Safe analysis",
|
||||
|
|
@ -217,7 +209,7 @@ class TestAgentState:
|
|||
"judge_decision": "Portfolio decision",
|
||||
"count": 3,
|
||||
}
|
||||
|
||||
|
||||
state = {
|
||||
"messages": [],
|
||||
"company_of_interest": "NVDA",
|
||||
|
|
@ -226,7 +218,7 @@ class TestAgentState:
|
|||
"investment_debate_state": invest_debate,
|
||||
"risk_debate_state": risk_debate,
|
||||
}
|
||||
|
||||
|
||||
assert state["investment_debate_state"]["count"] == 2
|
||||
assert state["risk_debate_state"]["count"] == 3
|
||||
assert state["risk_debate_state"]["latest_speaker"] == "safe"
|
||||
|
|
@ -242,7 +234,7 @@ class TestAgentState:
|
|||
"trader_investment_plan": "Execute buy order for 100 shares",
|
||||
"final_trade_decision": "BUY 100 shares at market price",
|
||||
}
|
||||
|
||||
|
||||
assert "Long position" in state["investment_plan"]
|
||||
assert "Execute buy order" in state["trader_investment_plan"]
|
||||
assert "BUY 100 shares" in state["final_trade_decision"]
|
||||
|
|
@ -250,7 +242,7 @@ class TestAgentState:
|
|||
def test_agent_state_ticker_variations(self):
|
||||
"""Test AgentState with various ticker symbols."""
|
||||
tickers = ["AAPL", "GOOGL", "AMZN", "TSLA", "MSFT", "META", "SPY", "QQQ"]
|
||||
|
||||
|
||||
for ticker in tickers:
|
||||
state = {
|
||||
"messages": [],
|
||||
|
|
@ -268,7 +260,7 @@ class TestAgentState:
|
|||
"2023-06-30",
|
||||
"2025-03-20",
|
||||
]
|
||||
|
||||
|
||||
for date_str in dates:
|
||||
state = {
|
||||
"messages": [],
|
||||
|
|
@ -294,7 +286,7 @@ class TestAgentState:
|
|||
"neutral_analyst",
|
||||
"portfolio_manager",
|
||||
]
|
||||
|
||||
|
||||
for sender in senders:
|
||||
state = {
|
||||
"messages": [],
|
||||
|
|
@ -339,8 +331,8 @@ class TestAgentState:
|
|||
},
|
||||
"final_trade_decision": "BUY 200 AAPL @ $150 limit",
|
||||
}
|
||||
|
||||
|
||||
assert state["company_of_interest"] == "AAPL"
|
||||
assert "BUY" in state["final_trade_decision"]
|
||||
assert state["investment_debate_state"]["judge_decision"] == "Recommend buy"
|
||||
assert state["risk_debate_state"]["latest_speaker"] == "neutral"
|
||||
assert state["risk_debate_state"]["latest_speaker"] == "neutral"
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import pytest
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
from unittest.mock import Mock
|
||||
|
||||
from langchain_core.messages import HumanMessage, RemoveMessage
|
||||
|
||||
from tradingagents.agents.utils.agent_utils import create_msg_delete
|
||||
|
||||
|
||||
|
|
@ -21,20 +22,20 @@ class TestCreateMsgDelete:
|
|||
mock_msg2.id = "msg_2"
|
||||
mock_msg3 = Mock(spec=HumanMessage)
|
||||
mock_msg3.id = "msg_3"
|
||||
|
||||
|
||||
state = {"messages": [mock_msg1, mock_msg2, mock_msg3]}
|
||||
|
||||
|
||||
delete_func = create_msg_delete()
|
||||
result = delete_func(state)
|
||||
|
||||
|
||||
# Should return removal operations for all messages plus a placeholder
|
||||
assert "messages" in result
|
||||
messages = result["messages"]
|
||||
|
||||
|
||||
# First 3 should be RemoveMessage operations
|
||||
removal_count = sum(1 for msg in messages if isinstance(msg, RemoveMessage))
|
||||
assert removal_count == 3
|
||||
|
||||
|
||||
# Last message should be the placeholder HumanMessage
|
||||
assert isinstance(messages[-1], HumanMessage)
|
||||
assert messages[-1].content == "Continue"
|
||||
|
|
@ -42,10 +43,10 @@ class TestCreateMsgDelete:
|
|||
def test_delete_messages_empty_state(self):
|
||||
"""Test delete_messages with an empty message list."""
|
||||
state = {"messages": []}
|
||||
|
||||
|
||||
delete_func = create_msg_delete()
|
||||
result = delete_func(state)
|
||||
|
||||
|
||||
# Should only contain the placeholder message
|
||||
assert len(result["messages"]) == 1
|
||||
assert isinstance(result["messages"][0], HumanMessage)
|
||||
|
|
@ -55,12 +56,12 @@ class TestCreateMsgDelete:
|
|||
"""Test delete_messages with a single message."""
|
||||
mock_msg = Mock(spec=HumanMessage)
|
||||
mock_msg.id = "single_msg"
|
||||
|
||||
|
||||
state = {"messages": [mock_msg]}
|
||||
|
||||
|
||||
delete_func = create_msg_delete()
|
||||
result = delete_func(state)
|
||||
|
||||
|
||||
assert len(result["messages"]) == 2 # 1 removal + 1 placeholder
|
||||
assert isinstance(result["messages"][0], RemoveMessage)
|
||||
assert isinstance(result["messages"][1], HumanMessage)
|
||||
|
|
@ -69,21 +70,23 @@ class TestCreateMsgDelete:
|
|||
"""Test that RemoveMessage operations use correct message IDs."""
|
||||
msg_ids = ["id_1", "id_2", "id_3", "id_4"]
|
||||
mock_messages = []
|
||||
|
||||
|
||||
for msg_id in msg_ids:
|
||||
mock_msg = Mock(spec=HumanMessage)
|
||||
mock_msg.id = msg_id
|
||||
mock_messages.append(mock_msg)
|
||||
|
||||
|
||||
state = {"messages": mock_messages}
|
||||
|
||||
|
||||
delete_func = create_msg_delete()
|
||||
result = delete_func(state)
|
||||
|
||||
|
||||
# Extract RemoveMessage operations
|
||||
removal_operations = [msg for msg in result["messages"] if isinstance(msg, RemoveMessage)]
|
||||
removal_operations = [
|
||||
msg for msg in result["messages"] if isinstance(msg, RemoveMessage)
|
||||
]
|
||||
removal_ids = [op.id for op in removal_operations]
|
||||
|
||||
|
||||
# All original message IDs should be in removal operations
|
||||
for original_id in msg_ids:
|
||||
assert original_id in removal_ids
|
||||
|
|
@ -93,12 +96,12 @@ class TestCreateMsgDelete:
|
|||
# Anthropic requires at least one message in the conversation
|
||||
mock_msg = Mock(spec=HumanMessage)
|
||||
mock_msg.id = "test_msg"
|
||||
|
||||
|
||||
state = {"messages": [mock_msg]}
|
||||
|
||||
|
||||
delete_func = create_msg_delete()
|
||||
result = delete_func(state)
|
||||
|
||||
|
||||
# Verify placeholder is a HumanMessage (required by Anthropic)
|
||||
placeholder = result["messages"][-1]
|
||||
assert isinstance(placeholder, HumanMessage)
|
||||
|
|
@ -112,17 +115,19 @@ class TestCreateMsgDelete:
|
|||
mock_msg = Mock(spec=HumanMessage)
|
||||
mock_msg.id = f"msg_{i}"
|
||||
mock_messages.append(mock_msg)
|
||||
|
||||
|
||||
state = {"messages": mock_messages}
|
||||
|
||||
|
||||
delete_func = create_msg_delete()
|
||||
result = delete_func(state)
|
||||
|
||||
|
||||
# Should have 100 removal operations + 1 placeholder
|
||||
assert len(result["messages"]) == 101
|
||||
|
||||
|
||||
# Count removal operations
|
||||
removal_count = sum(1 for msg in result["messages"] if isinstance(msg, RemoveMessage))
|
||||
removal_count = sum(
|
||||
1 for msg in result["messages"] if isinstance(msg, RemoveMessage)
|
||||
)
|
||||
assert removal_count == 100
|
||||
|
||||
def test_delete_messages_multiple_calls(self):
|
||||
|
|
@ -131,16 +136,16 @@ class TestCreateMsgDelete:
|
|||
mock_msg1.id = "msg_1"
|
||||
mock_msg2 = Mock(spec=HumanMessage)
|
||||
mock_msg2.id = "msg_2"
|
||||
|
||||
|
||||
state1 = {"messages": [mock_msg1]}
|
||||
state2 = {"messages": [mock_msg1, mock_msg2]}
|
||||
|
||||
|
||||
delete_func1 = create_msg_delete()
|
||||
delete_func2 = create_msg_delete()
|
||||
|
||||
|
||||
result1 = delete_func1(state1)
|
||||
result2 = delete_func2(state2)
|
||||
|
||||
|
||||
# Each call should work independently
|
||||
assert len(result1["messages"]) == 2 # 1 removal + placeholder
|
||||
assert len(result2["messages"]) == 3 # 2 removals + placeholder
|
||||
|
|
@ -149,13 +154,13 @@ class TestCreateMsgDelete:
|
|||
"""Test that delete_messages doesn't modify the original state."""
|
||||
mock_msg = Mock(spec=HumanMessage)
|
||||
mock_msg.id = "test_id"
|
||||
|
||||
|
||||
original_state = {"messages": [mock_msg]}
|
||||
original_msg_count = len(original_state["messages"])
|
||||
|
||||
|
||||
delete_func = create_msg_delete()
|
||||
result = delete_func(original_state)
|
||||
|
||||
|
||||
# Original state should remain unchanged
|
||||
assert len(original_state["messages"]) == original_msg_count
|
||||
assert original_state["messages"][0] is mock_msg
|
||||
|
|
@ -164,13 +169,13 @@ class TestCreateMsgDelete:
|
|||
"""Test that delete_messages returns the correct structure."""
|
||||
mock_msg = Mock(spec=HumanMessage)
|
||||
mock_msg.id = "test_msg"
|
||||
|
||||
|
||||
state = {"messages": [mock_msg]}
|
||||
|
||||
|
||||
delete_func = create_msg_delete()
|
||||
result = delete_func(state)
|
||||
|
||||
|
||||
# Result should be a dict with 'messages' key
|
||||
assert isinstance(result, dict)
|
||||
assert "messages" in result
|
||||
assert isinstance(result["messages"], list)
|
||||
assert isinstance(result["messages"], list)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
|
||||
from tradingagents.agents.utils.memory import FinancialSituationMemory
|
||||
|
||||
|
||||
|
|
@ -22,219 +24,233 @@ class TestFinancialSituationMemory:
|
|||
"llm_provider": "ollama",
|
||||
}
|
||||
|
||||
@patch('tradingagents.agents.utils.memory.OpenAI')
|
||||
@patch('tradingagents.agents.utils.memory.chromadb.Client')
|
||||
def test_init_with_openai_backend(self, mock_chroma, mock_openai, mock_config_openai):
|
||||
@patch("tradingagents.agents.utils.memory.OpenAI")
|
||||
@patch("tradingagents.agents.utils.memory.chromadb.Client")
|
||||
def test_init_with_openai_backend(
|
||||
self, mock_chroma, mock_openai, mock_config_openai
|
||||
):
|
||||
"""Test initialization with OpenAI backend."""
|
||||
mock_collection = Mock()
|
||||
mock_chroma.return_value.create_collection.return_value = mock_collection
|
||||
|
||||
mock_chroma.return_value.get_or_create_collection.return_value = mock_collection
|
||||
|
||||
memory = FinancialSituationMemory("test_memory", mock_config_openai)
|
||||
|
||||
|
||||
assert memory.embedding == "text-embedding-3-small"
|
||||
mock_openai.assert_called_once_with(base_url="https://api.openai.com/v1")
|
||||
|
||||
@patch('tradingagents.agents.utils.memory.OpenAI')
|
||||
@patch('tradingagents.agents.utils.memory.chromadb.Client')
|
||||
def test_init_with_ollama_backend(self, mock_chroma, mock_openai, mock_config_ollama):
|
||||
@patch("tradingagents.agents.utils.memory.OpenAI")
|
||||
@patch("tradingagents.agents.utils.memory.chromadb.Client")
|
||||
def test_init_with_ollama_backend(
|
||||
self, mock_chroma, mock_openai, mock_config_ollama
|
||||
):
|
||||
"""Test initialization with Ollama backend."""
|
||||
mock_collection = Mock()
|
||||
mock_chroma.return_value.create_collection.return_value = mock_collection
|
||||
|
||||
mock_chroma.return_value.get_or_create_collection.return_value = mock_collection
|
||||
|
||||
memory = FinancialSituationMemory("test_memory", mock_config_ollama)
|
||||
|
||||
|
||||
assert memory.embedding == "nomic-embed-text"
|
||||
mock_openai.assert_called_once_with(base_url="http://localhost:11434/v1")
|
||||
|
||||
@patch('tradingagents.agents.utils.memory.OpenAI')
|
||||
@patch('tradingagents.agents.utils.memory.chromadb.Client')
|
||||
@patch("tradingagents.agents.utils.memory.OpenAI")
|
||||
@patch("tradingagents.agents.utils.memory.chromadb.Client")
|
||||
def test_collection_creation(self, mock_chroma, mock_openai, mock_config_openai):
|
||||
"""Test that ChromaDB collection is created with correct name."""
|
||||
mock_collection = Mock()
|
||||
mock_chroma_instance = Mock()
|
||||
mock_chroma.return_value = mock_chroma_instance
|
||||
mock_chroma_instance.create_collection.return_value = mock_collection
|
||||
|
||||
mock_chroma_instance.get_or_create_collection.return_value = mock_collection
|
||||
|
||||
memory = FinancialSituationMemory("my_test_collection", mock_config_openai)
|
||||
|
||||
mock_chroma_instance.create_collection.assert_called_once_with(name="my_test_collection")
|
||||
|
||||
mock_chroma_instance.get_or_create_collection.assert_called_once_with(
|
||||
name="my_test_collection"
|
||||
)
|
||||
assert memory.situation_collection == mock_collection
|
||||
|
||||
@patch('tradingagents.agents.utils.memory.OpenAI')
|
||||
@patch('tradingagents.agents.utils.memory.chromadb.Client')
|
||||
@patch("tradingagents.agents.utils.memory.OpenAI")
|
||||
@patch("tradingagents.agents.utils.memory.chromadb.Client")
|
||||
def test_get_embedding(self, mock_chroma, mock_openai, mock_config_openai):
|
||||
"""Test get_embedding method returns correct embedding vector."""
|
||||
mock_collection = Mock()
|
||||
mock_chroma.return_value.create_collection.return_value = mock_collection
|
||||
|
||||
mock_chroma.return_value.get_or_create_collection.return_value = mock_collection
|
||||
|
||||
mock_client = Mock()
|
||||
mock_openai.return_value = mock_client
|
||||
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.data = [Mock(embedding=[0.1, 0.2, 0.3, 0.4])]
|
||||
mock_client.embeddings.create.return_value = mock_response
|
||||
|
||||
|
||||
memory = FinancialSituationMemory("test_memory", mock_config_openai)
|
||||
embedding = memory.get_embedding("test text")
|
||||
|
||||
|
||||
assert embedding == [0.1, 0.2, 0.3, 0.4]
|
||||
mock_client.embeddings.create.assert_called_once_with(
|
||||
model="text-embedding-3-small",
|
||||
input="test text"
|
||||
model="text-embedding-3-small", input="test text"
|
||||
)
|
||||
|
||||
@patch('tradingagents.agents.utils.memory.OpenAI')
|
||||
@patch('tradingagents.agents.utils.memory.chromadb.Client')
|
||||
def test_get_embedding_with_ollama(self, mock_chroma, mock_openai, mock_config_ollama):
|
||||
@patch("tradingagents.agents.utils.memory.OpenAI")
|
||||
@patch("tradingagents.agents.utils.memory.chromadb.Client")
|
||||
def test_get_embedding_with_ollama(
|
||||
self, mock_chroma, mock_openai, mock_config_ollama
|
||||
):
|
||||
"""Test get_embedding uses correct model for Ollama."""
|
||||
mock_collection = Mock()
|
||||
mock_chroma.return_value.create_collection.return_value = mock_collection
|
||||
|
||||
mock_chroma.return_value.get_or_create_collection.return_value = mock_collection
|
||||
|
||||
mock_client = Mock()
|
||||
mock_openai.return_value = mock_client
|
||||
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.data = [Mock(embedding=[0.5, 0.6])]
|
||||
mock_client.embeddings.create.return_value = mock_response
|
||||
|
||||
|
||||
memory = FinancialSituationMemory("test_memory", mock_config_ollama)
|
||||
embedding = memory.get_embedding("ollama test")
|
||||
|
||||
|
||||
mock_client.embeddings.create.assert_called_once_with(
|
||||
model="nomic-embed-text",
|
||||
input="ollama test"
|
||||
model="nomic-embed-text", input="ollama test"
|
||||
)
|
||||
|
||||
@patch('tradingagents.agents.utils.memory.OpenAI')
|
||||
@patch('tradingagents.agents.utils.memory.chromadb.Client')
|
||||
@patch("tradingagents.agents.utils.memory.OpenAI")
|
||||
@patch("tradingagents.agents.utils.memory.chromadb.Client")
|
||||
def test_add_situations_single(self, mock_chroma, mock_openai, mock_config_openai):
|
||||
"""Test adding a single situation and advice pair."""
|
||||
mock_collection = Mock()
|
||||
mock_collection.count.return_value = 0
|
||||
mock_chroma.return_value.create_collection.return_value = mock_collection
|
||||
|
||||
mock_chroma.return_value.get_or_create_collection.return_value = mock_collection
|
||||
|
||||
mock_client = Mock()
|
||||
mock_openai.return_value = mock_client
|
||||
mock_response = Mock()
|
||||
mock_response.data = [Mock(embedding=[0.1, 0.2])]
|
||||
mock_client.embeddings.create.return_value = mock_response
|
||||
|
||||
|
||||
memory = FinancialSituationMemory("test_memory", mock_config_openai)
|
||||
|
||||
situations_and_advice = [
|
||||
("High volatility market", "Reduce position sizes")
|
||||
]
|
||||
|
||||
|
||||
situations_and_advice = [("High volatility market", "Reduce position sizes")]
|
||||
|
||||
memory.add_situations(situations_and_advice)
|
||||
|
||||
|
||||
mock_collection.add.assert_called_once()
|
||||
call_kwargs = mock_collection.add.call_args[1]
|
||||
|
||||
|
||||
assert call_kwargs["documents"] == ["High volatility market"]
|
||||
assert call_kwargs["metadatas"] == [{"recommendation": "Reduce position sizes"}]
|
||||
assert call_kwargs["ids"] == ["0"]
|
||||
assert len(call_kwargs["embeddings"]) == 1
|
||||
|
||||
@patch('tradingagents.agents.utils.memory.OpenAI')
|
||||
@patch('tradingagents.agents.utils.memory.chromadb.Client')
|
||||
def test_add_situations_multiple(self, mock_chroma, mock_openai, mock_config_openai):
|
||||
@patch("tradingagents.agents.utils.memory.OpenAI")
|
||||
@patch("tradingagents.agents.utils.memory.chromadb.Client")
|
||||
def test_add_situations_multiple(
|
||||
self, mock_chroma, mock_openai, mock_config_openai
|
||||
):
|
||||
"""Test adding multiple situations at once."""
|
||||
mock_collection = Mock()
|
||||
mock_collection.count.return_value = 0
|
||||
mock_chroma.return_value.create_collection.return_value = mock_collection
|
||||
|
||||
mock_chroma.return_value.get_or_create_collection.return_value = mock_collection
|
||||
|
||||
mock_client = Mock()
|
||||
mock_openai.return_value = mock_client
|
||||
mock_response = Mock()
|
||||
mock_response.data = [Mock(embedding=[0.1, 0.2])]
|
||||
mock_client.embeddings.create.return_value = mock_response
|
||||
|
||||
|
||||
memory = FinancialSituationMemory("test_memory", mock_config_openai)
|
||||
|
||||
|
||||
situations_and_advice = [
|
||||
("Bull market conditions", "Increase long positions"),
|
||||
("Bear market conditions", "Increase short positions"),
|
||||
("Sideways market", "Use range trading strategies"),
|
||||
]
|
||||
|
||||
|
||||
memory.add_situations(situations_and_advice)
|
||||
|
||||
|
||||
mock_collection.add.assert_called_once()
|
||||
call_kwargs = mock_collection.add.call_args[1]
|
||||
|
||||
|
||||
assert len(call_kwargs["documents"]) == 3
|
||||
assert len(call_kwargs["metadatas"]) == 3
|
||||
assert call_kwargs["ids"] == ["0", "1", "2"]
|
||||
|
||||
@patch('tradingagents.agents.utils.memory.OpenAI')
|
||||
@patch('tradingagents.agents.utils.memory.chromadb.Client')
|
||||
def test_add_situations_with_existing_offset(self, mock_chroma, mock_openai, mock_config_openai):
|
||||
@patch("tradingagents.agents.utils.memory.OpenAI")
|
||||
@patch("tradingagents.agents.utils.memory.chromadb.Client")
|
||||
def test_add_situations_with_existing_offset(
|
||||
self, mock_chroma, mock_openai, mock_config_openai
|
||||
):
|
||||
"""Test that ID offset is calculated correctly when adding to existing collection."""
|
||||
mock_collection = Mock()
|
||||
mock_collection.count.return_value = 5 # Already has 5 items
|
||||
mock_chroma.return_value.create_collection.return_value = mock_collection
|
||||
|
||||
mock_chroma.return_value.get_or_create_collection.return_value = mock_collection
|
||||
|
||||
mock_client = Mock()
|
||||
mock_openai.return_value = mock_client
|
||||
mock_response = Mock()
|
||||
mock_response.data = [Mock(embedding=[0.1, 0.2])]
|
||||
mock_client.embeddings.create.return_value = mock_response
|
||||
|
||||
|
||||
memory = FinancialSituationMemory("test_memory", mock_config_openai)
|
||||
|
||||
|
||||
situations_and_advice = [
|
||||
("New situation", "New advice"),
|
||||
("Another situation", "Another advice"),
|
||||
]
|
||||
|
||||
|
||||
memory.add_situations(situations_and_advice)
|
||||
|
||||
|
||||
call_kwargs = mock_collection.add.call_args[1]
|
||||
|
||||
|
||||
# IDs should start from 5 (the existing count)
|
||||
assert call_kwargs["ids"] == ["5", "6"]
|
||||
|
||||
@patch('tradingagents.agents.utils.memory.OpenAI')
|
||||
@patch('tradingagents.agents.utils.memory.chromadb.Client')
|
||||
def test_get_memories_single_match(self, mock_chroma, mock_openai, mock_config_openai):
|
||||
@patch("tradingagents.agents.utils.memory.OpenAI")
|
||||
@patch("tradingagents.agents.utils.memory.chromadb.Client")
|
||||
def test_get_memories_single_match(
|
||||
self, mock_chroma, mock_openai, mock_config_openai
|
||||
):
|
||||
"""Test retrieving a single matching memory."""
|
||||
mock_collection = Mock()
|
||||
mock_chroma.return_value.create_collection.return_value = mock_collection
|
||||
|
||||
mock_chroma.return_value.get_or_create_collection.return_value = mock_collection
|
||||
|
||||
mock_client = Mock()
|
||||
mock_openai.return_value = mock_client
|
||||
mock_response = Mock()
|
||||
mock_response.data = [Mock(embedding=[0.1, 0.2])]
|
||||
mock_client.embeddings.create.return_value = mock_response
|
||||
|
||||
|
||||
# Mock query results
|
||||
mock_collection.query.return_value = {
|
||||
"documents": [["Similar market condition"]],
|
||||
"metadatas": [[{"recommendation": "Apply defensive strategy"}]],
|
||||
"distances": [[0.15]],
|
||||
}
|
||||
|
||||
|
||||
memory = FinancialSituationMemory("test_memory", mock_config_openai)
|
||||
results = memory.get_memories("Current volatile market", n_matches=1)
|
||||
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0]["matched_situation"] == "Similar market condition"
|
||||
assert results[0]["recommendation"] == "Apply defensive strategy"
|
||||
assert results[0]["similarity_score"] == pytest.approx(0.85, rel=0.01) # 1 - 0.15
|
||||
assert results[0]["similarity_score"] == pytest.approx(
|
||||
0.85, rel=0.01
|
||||
) # 1 - 0.15
|
||||
|
||||
@patch('tradingagents.agents.utils.memory.OpenAI')
|
||||
@patch('tradingagents.agents.utils.memory.chromadb.Client')
|
||||
def test_get_memories_multiple_matches(self, mock_chroma, mock_openai, mock_config_openai):
|
||||
@patch("tradingagents.agents.utils.memory.OpenAI")
|
||||
@patch("tradingagents.agents.utils.memory.chromadb.Client")
|
||||
def test_get_memories_multiple_matches(
|
||||
self, mock_chroma, mock_openai, mock_config_openai
|
||||
):
|
||||
"""Test retrieving multiple matching memories."""
|
||||
mock_collection = Mock()
|
||||
mock_chroma.return_value.create_collection.return_value = mock_collection
|
||||
|
||||
mock_chroma.return_value.get_or_create_collection.return_value = mock_collection
|
||||
|
||||
mock_client = Mock()
|
||||
mock_openai.return_value = mock_client
|
||||
mock_response = Mock()
|
||||
mock_response.data = [Mock(embedding=[0.1, 0.2])]
|
||||
mock_client.embeddings.create.return_value = mock_response
|
||||
|
||||
|
||||
# Mock query results with 3 matches
|
||||
mock_collection.query.return_value = {
|
||||
"documents": [["Match 1", "Match 2", "Match 3"]],
|
||||
|
|
@ -247,10 +263,10 @@ class TestFinancialSituationMemory:
|
|||
],
|
||||
"distances": [[0.1, 0.2, 0.3]],
|
||||
}
|
||||
|
||||
|
||||
memory = FinancialSituationMemory("test_memory", mock_config_openai)
|
||||
results = memory.get_memories("Query situation", n_matches=3)
|
||||
|
||||
|
||||
assert len(results) == 3
|
||||
assert results[0]["matched_situation"] == "Match 1"
|
||||
assert results[1]["matched_situation"] == "Match 2"
|
||||
|
|
@ -258,45 +274,49 @@ class TestFinancialSituationMemory:
|
|||
assert results[0]["similarity_score"] > results[1]["similarity_score"]
|
||||
assert results[1]["similarity_score"] > results[2]["similarity_score"]
|
||||
|
||||
@patch('tradingagents.agents.utils.memory.OpenAI')
|
||||
@patch('tradingagents.agents.utils.memory.chromadb.Client')
|
||||
def test_get_memories_similarity_scores(self, mock_chroma, mock_openai, mock_config_openai):
|
||||
@patch("tradingagents.agents.utils.memory.OpenAI")
|
||||
@patch("tradingagents.agents.utils.memory.chromadb.Client")
|
||||
def test_get_memories_similarity_scores(
|
||||
self, mock_chroma, mock_openai, mock_config_openai
|
||||
):
|
||||
"""Test that similarity scores are calculated correctly (1 - distance)."""
|
||||
mock_collection = Mock()
|
||||
mock_chroma.return_value.create_collection.return_value = mock_collection
|
||||
|
||||
mock_chroma.return_value.get_or_create_collection.return_value = mock_collection
|
||||
|
||||
mock_client = Mock()
|
||||
mock_openai.return_value = mock_client
|
||||
mock_response = Mock()
|
||||
mock_response.data = [Mock(embedding=[0.1, 0.2])]
|
||||
mock_client.embeddings.create.return_value = mock_response
|
||||
|
||||
|
||||
mock_collection.query.return_value = {
|
||||
"documents": [["Situation A", "Situation B"]],
|
||||
"metadatas": [[{"recommendation": "A"}, {"recommendation": "B"}]],
|
||||
"distances": [[0.0, 0.5]], # Perfect match and moderate match
|
||||
}
|
||||
|
||||
|
||||
memory = FinancialSituationMemory("test_memory", mock_config_openai)
|
||||
results = memory.get_memories("Test query", n_matches=2)
|
||||
|
||||
|
||||
assert results[0]["similarity_score"] == pytest.approx(1.0, rel=0.01) # 1 - 0.0
|
||||
assert results[1]["similarity_score"] == pytest.approx(0.5, rel=0.01) # 1 - 0.5
|
||||
|
||||
@patch('tradingagents.agents.utils.memory.OpenAI')
|
||||
@patch('tradingagents.agents.utils.memory.chromadb.Client')
|
||||
def test_add_situations_empty_list(self, mock_chroma, mock_openai, mock_config_openai):
|
||||
@patch("tradingagents.agents.utils.memory.OpenAI")
|
||||
@patch("tradingagents.agents.utils.memory.chromadb.Client")
|
||||
def test_add_situations_empty_list(
|
||||
self, mock_chroma, mock_openai, mock_config_openai
|
||||
):
|
||||
"""Test adding an empty list of situations."""
|
||||
mock_collection = Mock()
|
||||
mock_collection.count.return_value = 0
|
||||
mock_chroma.return_value.create_collection.return_value = mock_collection
|
||||
|
||||
mock_chroma.return_value.get_or_create_collection.return_value = mock_collection
|
||||
|
||||
mock_client = Mock()
|
||||
mock_openai.return_value = mock_client
|
||||
|
||||
|
||||
memory = FinancialSituationMemory("test_memory", mock_config_openai)
|
||||
memory.add_situations([])
|
||||
|
||||
|
||||
# add should still be called, but with empty lists
|
||||
mock_collection.add.assert_called_once()
|
||||
call_kwargs = mock_collection.add.call_args[1]
|
||||
|
|
@ -304,21 +324,22 @@ class TestFinancialSituationMemory:
|
|||
assert call_kwargs["metadatas"] == []
|
||||
assert call_kwargs["ids"] == []
|
||||
|
||||
@patch('tradingagents.agents.utils.memory.OpenAI')
|
||||
@patch('tradingagents.agents.utils.memory.chromadb.Client')
|
||||
def test_memory_different_collection_names(self, mock_chroma, mock_openai, mock_config_openai):
|
||||
@patch("tradingagents.agents.utils.memory.OpenAI")
|
||||
@patch("tradingagents.agents.utils.memory.chromadb.Client")
|
||||
def test_memory_different_collection_names(
|
||||
self, mock_chroma, mock_openai, mock_config_openai
|
||||
):
|
||||
"""Test that different memory instances have different collection names."""
|
||||
mock_chroma_instance = Mock()
|
||||
mock_chroma.return_value = mock_chroma_instance
|
||||
mock_chroma_instance.create_collection.return_value = Mock()
|
||||
|
||||
mock_chroma_instance.get_or_create_collection.return_value = Mock()
|
||||
|
||||
memory1 = FinancialSituationMemory("bull_memory", mock_config_openai)
|
||||
memory2 = FinancialSituationMemory("bear_memory", mock_config_openai)
|
||||
memory3 = FinancialSituationMemory("trader_memory", mock_config_openai)
|
||||
|
||||
# Verify different collections were created
|
||||
calls = mock_chroma_instance.create_collection.call_args_list
|
||||
|
||||
calls = mock_chroma_instance.get_or_create_collection.call_args_list
|
||||
assert len(calls) == 3
|
||||
assert calls[0][1]["name"] == "bull_memory"
|
||||
assert calls[1][1]["name"] == "bear_memory"
|
||||
assert calls[2][1]["name"] == "trader_memory"
|
||||
assert calls[2][1]["name"] == "trader_memory"
|
||||
|
|
|
|||
|
|
@ -0,0 +1,154 @@
|
|||
import logging
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
config.addinivalue_line(
|
||||
"markers", "unit: mark test as a unit test (fast, isolated)"
|
||||
)
|
||||
config.addinivalue_line(
|
||||
"markers", "integration: mark test as an integration test (multi-component)"
|
||||
)
|
||||
config.addinivalue_line(
|
||||
"markers", "e2e: mark test as an end-to-end test (full workflow)"
|
||||
)
|
||||
config.addinivalue_line("markers", "slow: mark test as slow-running (>5s)")
|
||||
config.addinivalue_line(
|
||||
"markers", "external_api: mark test as requiring external API calls"
|
||||
)
|
||||
config.addinivalue_line("markers", "llm: mark test as requiring LLM calls")
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_logging_state():
|
||||
for handler in logging.root.handlers[:]:
|
||||
logging.root.removeHandler(handler)
|
||||
|
||||
tradingagents_logger = logging.getLogger("tradingagents")
|
||||
for handler in tradingagents_logger.handlers[:]:
|
||||
tradingagents_logger.removeHandler(handler)
|
||||
tradingagents_logger.setLevel(logging.NOTSET)
|
||||
|
||||
try:
|
||||
import tradingagents.logging as log_module
|
||||
|
||||
log_module._logging_initialized = False
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from tradingagents import config as main_config
|
||||
|
||||
main_config._settings = None
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
yield
|
||||
|
||||
for handler in logging.root.handlers[:]:
|
||||
logging.root.removeHandler(handler)
|
||||
|
||||
tradingagents_logger = logging.getLogger("tradingagents")
|
||||
for handler in tradingagents_logger.handlers[:]:
|
||||
tradingagents_logger.removeHandler(handler)
|
||||
tradingagents_logger.setLevel(logging.NOTSET)
|
||||
|
||||
try:
|
||||
import tradingagents.logging as log_module
|
||||
|
||||
log_module._logging_initialized = False
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from tradingagents import config as main_config
|
||||
|
||||
main_config._settings = None
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_config_state():
|
||||
try:
|
||||
import tradingagents.dataflows.config as config_module
|
||||
|
||||
config_module._config = None
|
||||
config_module.DATA_DIR = None
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
yield
|
||||
|
||||
try:
|
||||
import tradingagents.dataflows.config as config_module
|
||||
|
||||
config_module._config = None
|
||||
config_module.DATA_DIR = None
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm():
|
||||
mock = MagicMock()
|
||||
mock.invoke.return_value = MagicMock(content="Test LLM response")
|
||||
mock.with_structured_output.return_value = mock
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_config():
|
||||
return {
|
||||
"llm_provider": "openai",
|
||||
"quick_think_llm": "gpt-4o-mini",
|
||||
"deep_think_llm": "gpt-4o",
|
||||
"backend_url": "https://api.openai.com/v1",
|
||||
"max_debate_rounds": 1,
|
||||
"max_risk_discuss_rounds": 1,
|
||||
"data_dir": "/tmp/tradingagents_test",
|
||||
"results_dir": "/tmp/tradingagents_test/results",
|
||||
"discovery_max_results": 10,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_news_article():
|
||||
from datetime import datetime, timezone
|
||||
|
||||
return {
|
||||
"title": "Test News Article",
|
||||
"source": "Test Source",
|
||||
"url": "https://example.com/article",
|
||||
"published_at": datetime.now(timezone.utc),
|
||||
"summary": "Test summary of the article",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_trending_stock():
|
||||
return {
|
||||
"ticker": "AAPL",
|
||||
"company_name": "Apple Inc.",
|
||||
"score": 85.5,
|
||||
"sentiment": 0.7,
|
||||
"mention_count": 150,
|
||||
"sector": "technology",
|
||||
"event_type": "earnings",
|
||||
"news_summary": "Apple reported strong quarterly earnings",
|
||||
"source_articles": [],
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_openai_client():
|
||||
with patch("openai.OpenAI") as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_chromadb():
|
||||
with patch("chromadb.Client") as mock:
|
||||
yield mock
|
||||
|
|
@ -1,62 +1,62 @@
|
|||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime
|
||||
from unittest.mock import patch
|
||||
|
||||
from tradingagents.dataflows.alpha_vantage_news import (
|
||||
get_news,
|
||||
get_insider_transactions,
|
||||
get_bulk_news_alpha_vantage,
|
||||
get_insider_transactions,
|
||||
get_news,
|
||||
)
|
||||
|
||||
|
||||
class TestGetNews:
|
||||
"""Test suite for get_news function."""
|
||||
|
||||
@patch('tradingagents.dataflows.alpha_vantage_news._make_api_request')
|
||||
@patch('tradingagents.dataflows.alpha_vantage_news.format_datetime_for_api')
|
||||
@patch("tradingagents.dataflows.alpha_vantage_news._make_api_request")
|
||||
@patch("tradingagents.dataflows.alpha_vantage_news.format_datetime_for_api")
|
||||
def test_get_news_basic_call(self, mock_format_datetime, mock_api_request):
|
||||
"""Test basic get_news API call."""
|
||||
mock_format_datetime.side_effect = lambda x: x.strftime("%Y%m%dT%H%M%S")
|
||||
mock_api_request.return_value = {"feed": []}
|
||||
|
||||
|
||||
ticker = "AAPL"
|
||||
start_date = datetime(2024, 1, 1)
|
||||
end_date = datetime(2024, 1, 31)
|
||||
|
||||
|
||||
result = get_news(ticker, start_date, end_date)
|
||||
|
||||
|
||||
mock_api_request.assert_called_once()
|
||||
call_args = mock_api_request.call_args[0]
|
||||
assert call_args[0] == "NEWS_SENTIMENT"
|
||||
|
||||
@patch('tradingagents.dataflows.alpha_vantage_news._make_api_request')
|
||||
@patch('tradingagents.dataflows.alpha_vantage_news.format_datetime_for_api')
|
||||
@patch("tradingagents.dataflows.alpha_vantage_news._make_api_request")
|
||||
@patch("tradingagents.dataflows.alpha_vantage_news.format_datetime_for_api")
|
||||
def test_get_news_parameters(self, mock_format_datetime, mock_api_request):
|
||||
"""Test that get_news passes correct parameters."""
|
||||
mock_format_datetime.side_effect = lambda x: x.strftime("%Y%m%dT%H%M%S")
|
||||
mock_api_request.return_value = {"feed": []}
|
||||
|
||||
|
||||
ticker = "TSLA"
|
||||
start_date = datetime(2024, 2, 1)
|
||||
end_date = datetime(2024, 2, 15)
|
||||
|
||||
|
||||
result = get_news(ticker, start_date, end_date)
|
||||
|
||||
|
||||
params = mock_api_request.call_args[0][1]
|
||||
assert params["tickers"] == "TSLA"
|
||||
assert params["sort"] == "LATEST"
|
||||
assert params["limit"] == "50"
|
||||
|
||||
@patch('tradingagents.dataflows.alpha_vantage_news._make_api_request')
|
||||
@patch('tradingagents.dataflows.alpha_vantage_news.format_datetime_for_api')
|
||||
@patch("tradingagents.dataflows.alpha_vantage_news._make_api_request")
|
||||
@patch("tradingagents.dataflows.alpha_vantage_news.format_datetime_for_api")
|
||||
def test_get_news_different_tickers(self, mock_format_datetime, mock_api_request):
|
||||
"""Test get_news with different ticker symbols."""
|
||||
mock_format_datetime.side_effect = lambda x: x.strftime("%Y%m%dT%H%M%S")
|
||||
mock_api_request.return_value = {"feed": []}
|
||||
|
||||
|
||||
tickers = ["AAPL", "GOOGL", "MSFT", "AMZN"]
|
||||
start_date = datetime(2024, 1, 1)
|
||||
end_date = datetime(2024, 1, 31)
|
||||
|
||||
|
||||
for ticker in tickers:
|
||||
result = get_news(ticker, start_date, end_date)
|
||||
params = mock_api_request.call_args[0][1]
|
||||
|
|
@ -66,26 +66,26 @@ class TestGetNews:
|
|||
class TestGetInsiderTransactions:
|
||||
"""Test suite for get_insider_transactions function."""
|
||||
|
||||
@patch('tradingagents.dataflows.alpha_vantage_news._make_api_request')
|
||||
@patch("tradingagents.dataflows.alpha_vantage_news._make_api_request")
|
||||
def test_get_insider_transactions_basic(self, mock_api_request):
|
||||
"""Test basic get_insider_transactions call."""
|
||||
mock_api_request.return_value = {"transactions": []}
|
||||
|
||||
|
||||
symbol = "AAPL"
|
||||
result = get_insider_transactions(symbol)
|
||||
|
||||
|
||||
mock_api_request.assert_called_once()
|
||||
call_args = mock_api_request.call_args[0]
|
||||
assert call_args[0] == "INSIDER_TRANSACTIONS"
|
||||
assert call_args[1]["symbol"] == "AAPL"
|
||||
|
||||
@patch('tradingagents.dataflows.alpha_vantage_news._make_api_request')
|
||||
@patch("tradingagents.dataflows.alpha_vantage_news._make_api_request")
|
||||
def test_get_insider_transactions_different_symbols(self, mock_api_request):
|
||||
"""Test get_insider_transactions with various symbols."""
|
||||
mock_api_request.return_value = {}
|
||||
|
||||
|
||||
symbols = ["AAPL", "TSLA", "NVDA", "META"]
|
||||
|
||||
|
||||
for symbol in symbols:
|
||||
result = get_insider_transactions(symbol)
|
||||
params = mock_api_request.call_args[0][1]
|
||||
|
|
@ -95,54 +95,54 @@ class TestGetInsiderTransactions:
|
|||
class TestGetBulkNewsAlphaVantage:
|
||||
"""Test suite for get_bulk_news_alpha_vantage function."""
|
||||
|
||||
@patch('tradingagents.dataflows.alpha_vantage_news._make_api_request')
|
||||
@patch('tradingagents.dataflows.alpha_vantage_news.format_datetime_for_api')
|
||||
@patch("tradingagents.dataflows.alpha_vantage_news._make_api_request")
|
||||
@patch("tradingagents.dataflows.alpha_vantage_news.format_datetime_for_api")
|
||||
def test_get_bulk_news_basic(self, mock_format_datetime, mock_api_request):
|
||||
"""Test basic bulk news retrieval."""
|
||||
mock_format_datetime.side_effect = lambda x: x.strftime("%Y%m%dT%H%M%S")
|
||||
mock_api_request.return_value = {"feed": []}
|
||||
|
||||
|
||||
result = get_bulk_news_alpha_vantage(24)
|
||||
|
||||
|
||||
assert isinstance(result, list)
|
||||
mock_api_request.assert_called_once()
|
||||
|
||||
@patch('tradingagents.dataflows.alpha_vantage_news._make_api_request')
|
||||
@patch('tradingagents.dataflows.alpha_vantage_news.format_datetime_for_api')
|
||||
@patch("tradingagents.dataflows.alpha_vantage_news._make_api_request")
|
||||
@patch("tradingagents.dataflows.alpha_vantage_news.format_datetime_for_api")
|
||||
def test_get_bulk_news_lookback_hours(self, mock_format_datetime, mock_api_request):
|
||||
"""Test that lookback period is calculated correctly."""
|
||||
mock_format_datetime.side_effect = lambda x: x.strftime("%Y%m%dT%H%M%S")
|
||||
mock_api_request.return_value = {"feed": []}
|
||||
|
||||
|
||||
lookback_hours = 6
|
||||
result = get_bulk_news_alpha_vantage(lookback_hours)
|
||||
|
||||
|
||||
# Verify time_from and time_to are set correctly
|
||||
params = mock_api_request.call_args[0][1]
|
||||
assert "time_from" in params
|
||||
assert "time_to" in params
|
||||
|
||||
@patch('tradingagents.dataflows.alpha_vantage_news._make_api_request')
|
||||
@patch('tradingagents.dataflows.alpha_vantage_news.format_datetime_for_api')
|
||||
@patch("tradingagents.dataflows.alpha_vantage_news._make_api_request")
|
||||
@patch("tradingagents.dataflows.alpha_vantage_news.format_datetime_for_api")
|
||||
def test_get_bulk_news_parameters(self, mock_format_datetime, mock_api_request):
|
||||
"""Test that bulk news uses correct parameters."""
|
||||
mock_format_datetime.side_effect = lambda x: x.strftime("%Y%m%dT%H%M%S")
|
||||
mock_api_request.return_value = {"feed": []}
|
||||
|
||||
|
||||
result = get_bulk_news_alpha_vantage(24)
|
||||
|
||||
|
||||
params = mock_api_request.call_args[0][1]
|
||||
assert params["sort"] == "LATEST"
|
||||
assert params["limit"] == "200"
|
||||
assert "topics" in params
|
||||
assert "earnings" in params["topics"]
|
||||
|
||||
@patch('tradingagents.dataflows.alpha_vantage_news._make_api_request')
|
||||
@patch('tradingagents.dataflows.alpha_vantage_news.format_datetime_for_api')
|
||||
@patch("tradingagents.dataflows.alpha_vantage_news._make_api_request")
|
||||
@patch("tradingagents.dataflows.alpha_vantage_news.format_datetime_for_api")
|
||||
def test_get_bulk_news_with_articles(self, mock_format_datetime, mock_api_request):
|
||||
"""Test parsing of article feed data."""
|
||||
mock_format_datetime.side_effect = lambda x: x.strftime("%Y%m%dT%H%M%S")
|
||||
|
||||
|
||||
mock_feed = {
|
||||
"feed": [
|
||||
{
|
||||
|
|
@ -161,24 +161,26 @@ class TestGetBulkNewsAlphaVantage:
|
|||
},
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
mock_api_request.return_value = mock_feed
|
||||
|
||||
|
||||
result = get_bulk_news_alpha_vantage(24)
|
||||
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0]["title"] == "Apple announces new product"
|
||||
assert result[0]["source"] == "Reuters"
|
||||
assert result[1]["title"] == "Tech stocks rally"
|
||||
|
||||
@patch('tradingagents.dataflows.alpha_vantage_news._make_api_request')
|
||||
@patch('tradingagents.dataflows.alpha_vantage_news.format_datetime_for_api')
|
||||
def test_get_bulk_news_content_truncation(self, mock_format_datetime, mock_api_request):
|
||||
@patch("tradingagents.dataflows.alpha_vantage_news._make_api_request")
|
||||
@patch("tradingagents.dataflows.alpha_vantage_news.format_datetime_for_api")
|
||||
def test_get_bulk_news_content_truncation(
|
||||
self, mock_format_datetime, mock_api_request
|
||||
):
|
||||
"""Test that content snippets are truncated to 500 characters."""
|
||||
mock_format_datetime.side_effect = lambda x: x.strftime("%Y%m%dT%H%M%S")
|
||||
|
||||
|
||||
long_summary = "A" * 1000 # 1000 character string
|
||||
|
||||
|
||||
mock_feed = {
|
||||
"feed": [
|
||||
{
|
||||
|
|
@ -190,19 +192,21 @@ class TestGetBulkNewsAlphaVantage:
|
|||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
mock_api_request.return_value = mock_feed
|
||||
|
||||
|
||||
result = get_bulk_news_alpha_vantage(24)
|
||||
|
||||
|
||||
assert len(result[0]["content_snippet"]) == 500
|
||||
|
||||
@patch('tradingagents.dataflows.alpha_vantage_news._make_api_request')
|
||||
@patch('tradingagents.dataflows.alpha_vantage_news.format_datetime_for_api')
|
||||
def test_get_bulk_news_invalid_time_format(self, mock_format_datetime, mock_api_request):
|
||||
@patch("tradingagents.dataflows.alpha_vantage_news._make_api_request")
|
||||
@patch("tradingagents.dataflows.alpha_vantage_news.format_datetime_for_api")
|
||||
def test_get_bulk_news_invalid_time_format(
|
||||
self, mock_format_datetime, mock_api_request
|
||||
):
|
||||
"""Test handling of invalid time_published format."""
|
||||
mock_format_datetime.side_effect = lambda x: x.strftime("%Y%m%dT%H%M%S")
|
||||
|
||||
|
||||
mock_feed = {
|
||||
"feed": [
|
||||
{
|
||||
|
|
@ -214,81 +218,92 @@ class TestGetBulkNewsAlphaVantage:
|
|||
}
|
||||
]
|
||||
}
|
||||
|
||||
mock_api_request.return_value = mock_feed
|
||||
|
||||
result = get_bulk_news_alpha_vantage(24)
|
||||
|
||||
# Should fallback to current time
|
||||
assert len(result) == 1
|
||||
assert "published_at" in result[0]
|
||||
|
||||
@patch('tradingagents.dataflows.alpha_vantage_news._make_api_request')
|
||||
@patch('tradingagents.dataflows.alpha_vantage_news.format_datetime_for_api')
|
||||
def test_get_bulk_news_string_response(self, mock_format_datetime, mock_api_request):
|
||||
mock_api_request.return_value = mock_feed
|
||||
|
||||
result = get_bulk_news_alpha_vantage(24)
|
||||
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 0
|
||||
|
||||
@patch("tradingagents.dataflows.alpha_vantage_news._make_api_request")
|
||||
@patch("tradingagents.dataflows.alpha_vantage_news.format_datetime_for_api")
|
||||
def test_get_bulk_news_string_response(
|
||||
self, mock_format_datetime, mock_api_request
|
||||
):
|
||||
"""Test handling when API returns string instead of dict."""
|
||||
mock_format_datetime.side_effect = lambda x: x.strftime("%Y%m%dT%H%M%S")
|
||||
|
||||
|
||||
# Return a JSON string
|
||||
mock_api_request.return_value = '{"feed": [{"title": "Test"}]}'
|
||||
|
||||
|
||||
result = get_bulk_news_alpha_vantage(24)
|
||||
|
||||
|
||||
# Should handle gracefully and return empty list or parsed data
|
||||
assert isinstance(result, list)
|
||||
|
||||
@patch('tradingagents.dataflows.alpha_vantage_news._make_api_request')
|
||||
@patch('tradingagents.dataflows.alpha_vantage_news.format_datetime_for_api')
|
||||
def test_get_bulk_news_malformed_articles(self, mock_format_datetime, mock_api_request):
|
||||
@patch("tradingagents.dataflows.alpha_vantage_news._make_api_request")
|
||||
@patch("tradingagents.dataflows.alpha_vantage_news.format_datetime_for_api")
|
||||
def test_get_bulk_news_malformed_articles(
|
||||
self, mock_format_datetime, mock_api_request
|
||||
):
|
||||
"""Test handling of malformed article data."""
|
||||
mock_format_datetime.side_effect = lambda x: x.strftime("%Y%m%dT%H%M%S")
|
||||
|
||||
|
||||
mock_feed = {
|
||||
"feed": [
|
||||
{"title": "Good article", "source": "Source", "url": "https://example.com", "time_published": "20240115T120000", "summary": "Good"},
|
||||
{
|
||||
"title": "Good article",
|
||||
"source": "Source",
|
||||
"url": "https://example.com",
|
||||
"time_published": "20240115T120000",
|
||||
"summary": "Good",
|
||||
},
|
||||
{"title": "Missing fields"}, # Malformed
|
||||
{"source": "No title"}, # Malformed
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
mock_api_request.return_value = mock_feed
|
||||
|
||||
|
||||
result = get_bulk_news_alpha_vantage(24)
|
||||
|
||||
|
||||
# Should skip malformed articles
|
||||
assert len(result) >= 1 # At least the good one
|
||||
|
||||
@patch('tradingagents.dataflows.alpha_vantage_news._make_api_request')
|
||||
@patch('tradingagents.dataflows.alpha_vantage_news.format_datetime_for_api')
|
||||
@patch("tradingagents.dataflows.alpha_vantage_news._make_api_request")
|
||||
@patch("tradingagents.dataflows.alpha_vantage_news.format_datetime_for_api")
|
||||
def test_get_bulk_news_empty_feed(self, mock_format_datetime, mock_api_request):
|
||||
"""Test handling of empty feed."""
|
||||
mock_format_datetime.side_effect = lambda x: x.strftime("%Y%m%dT%H%M%S")
|
||||
mock_api_request.return_value = {"feed": []}
|
||||
|
||||
|
||||
result = get_bulk_news_alpha_vantage(24)
|
||||
|
||||
|
||||
assert result == []
|
||||
|
||||
@patch('tradingagents.dataflows.alpha_vantage_news._make_api_request')
|
||||
@patch('tradingagents.dataflows.alpha_vantage_news.format_datetime_for_api')
|
||||
@patch("tradingagents.dataflows.alpha_vantage_news._make_api_request")
|
||||
@patch("tradingagents.dataflows.alpha_vantage_news.format_datetime_for_api")
|
||||
def test_get_bulk_news_no_feed_key(self, mock_format_datetime, mock_api_request):
|
||||
"""Test handling when response doesn't have 'feed' key."""
|
||||
mock_format_datetime.side_effect = lambda x: x.strftime("%Y%m%dT%H%M%S")
|
||||
mock_api_request.return_value = {"data": []} # Wrong key
|
||||
|
||||
|
||||
result = get_bulk_news_alpha_vantage(24)
|
||||
|
||||
|
||||
assert result == []
|
||||
|
||||
@patch('tradingagents.dataflows.alpha_vantage_news._make_api_request')
|
||||
@patch('tradingagents.dataflows.alpha_vantage_news.format_datetime_for_api')
|
||||
def test_get_bulk_news_various_lookback_periods(self, mock_format_datetime, mock_api_request):
|
||||
@patch("tradingagents.dataflows.alpha_vantage_news._make_api_request")
|
||||
@patch("tradingagents.dataflows.alpha_vantage_news.format_datetime_for_api")
|
||||
def test_get_bulk_news_various_lookback_periods(
|
||||
self, mock_format_datetime, mock_api_request
|
||||
):
|
||||
"""Test bulk news with various lookback periods."""
|
||||
mock_format_datetime.side_effect = lambda x: x.strftime("%Y%m%dT%H%M%S")
|
||||
mock_api_request.return_value = {"feed": []}
|
||||
|
||||
|
||||
lookback_periods = [1, 6, 12, 24, 48, 168] # hours
|
||||
|
||||
|
||||
for hours in lookback_periods:
|
||||
result = get_bulk_news_alpha_vantage(hours)
|
||||
assert isinstance(result, list)
|
||||
assert isinstance(result, list)
|
||||
|
|
|
|||
|
|
@ -1,33 +1,40 @@
|
|||
import pytest
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from tradingagents.dataflows.brave import (
|
||||
_make_request_with_retry,
|
||||
_parse_brave_age,
|
||||
get_api_key,
|
||||
get_bulk_news_brave,
|
||||
_parse_brave_age,
|
||||
_make_request_with_retry,
|
||||
BRAVE_SEARCH_URL,
|
||||
DEFAULT_TIMEOUT,
|
||||
MAX_RETRIES,
|
||||
)
|
||||
|
||||
|
||||
class TestGetApiKey:
|
||||
|
||||
def test_get_api_key_success(self):
|
||||
with patch.dict('os.environ', {'BRAVE_API_KEY': 'test_key_123'}):
|
||||
from tradingagents import config as main_config
|
||||
|
||||
main_config._settings = None
|
||||
with patch.dict(
|
||||
"os.environ", {"TRADINGAGENTS_BRAVE_API_KEY": "test_key_123"}, clear=False
|
||||
):
|
||||
result = get_api_key()
|
||||
assert result == 'test_key_123'
|
||||
assert result == "test_key_123"
|
||||
|
||||
def test_get_api_key_missing(self):
|
||||
with patch.dict('os.environ', {}, clear=True):
|
||||
with pytest.raises(ValueError, match="BRAVE_API_KEY environment variable is not set"):
|
||||
with patch("tradingagents.config.get_settings") as mock_get_settings:
|
||||
mock_settings = Mock()
|
||||
mock_settings.require_api_key.side_effect = ValueError(
|
||||
"brave API key not configured"
|
||||
)
|
||||
mock_get_settings.return_value = mock_settings
|
||||
with pytest.raises(ValueError, match="brave API key not configured"):
|
||||
get_api_key()
|
||||
|
||||
|
||||
class TestParseBraveAge:
|
||||
|
||||
def test_parse_hours_ago(self):
|
||||
result = _parse_brave_age("2 hours ago")
|
||||
expected = datetime.now() - timedelta(hours=2)
|
||||
|
|
@ -70,8 +77,7 @@ class TestParseBraveAge:
|
|||
|
||||
|
||||
class TestMakeRequestWithRetry:
|
||||
|
||||
@patch('tradingagents.dataflows.brave.requests.get')
|
||||
@patch("tradingagents.dataflows.brave.requests.get")
|
||||
def test_successful_request(self, mock_get):
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
|
|
@ -83,8 +89,8 @@ class TestMakeRequestWithRetry:
|
|||
assert result == mock_response
|
||||
mock_get.assert_called_once()
|
||||
|
||||
@patch('tradingagents.dataflows.brave.requests.get')
|
||||
@patch('tradingagents.dataflows.brave.time.sleep')
|
||||
@patch("tradingagents.dataflows.brave.requests.get")
|
||||
@patch("tradingagents.dataflows.brave.time.sleep")
|
||||
def test_retry_on_timeout(self, mock_sleep, mock_get):
|
||||
mock_get.side_effect = [
|
||||
requests.exceptions.Timeout(),
|
||||
|
|
@ -97,8 +103,8 @@ class TestMakeRequestWithRetry:
|
|||
assert mock_get.call_count == 3
|
||||
assert mock_sleep.call_count == 2
|
||||
|
||||
@patch('tradingagents.dataflows.brave.requests.get')
|
||||
@patch('tradingagents.dataflows.brave.time.sleep')
|
||||
@patch("tradingagents.dataflows.brave.requests.get")
|
||||
@patch("tradingagents.dataflows.brave.time.sleep")
|
||||
def test_retry_on_connection_error(self, mock_sleep, mock_get):
|
||||
mock_get.side_effect = [
|
||||
requests.exceptions.ConnectionError(),
|
||||
|
|
@ -110,8 +116,8 @@ class TestMakeRequestWithRetry:
|
|||
assert mock_get.call_count == 2
|
||||
assert mock_sleep.call_count == 1
|
||||
|
||||
@patch('tradingagents.dataflows.brave.requests.get')
|
||||
@patch('tradingagents.dataflows.brave.time.sleep')
|
||||
@patch("tradingagents.dataflows.brave.requests.get")
|
||||
@patch("tradingagents.dataflows.brave.time.sleep")
|
||||
def test_retry_on_rate_limit(self, mock_sleep, mock_get):
|
||||
rate_limited_response = Mock()
|
||||
rate_limited_response.status_code = 429
|
||||
|
|
@ -128,8 +134,8 @@ class TestMakeRequestWithRetry:
|
|||
assert mock_get.call_count == 2
|
||||
assert mock_sleep.call_count == 1
|
||||
|
||||
@patch('tradingagents.dataflows.brave.requests.get')
|
||||
@patch('tradingagents.dataflows.brave.time.sleep')
|
||||
@patch("tradingagents.dataflows.brave.requests.get")
|
||||
@patch("tradingagents.dataflows.brave.time.sleep")
|
||||
def test_max_retries_exceeded(self, mock_sleep, mock_get):
|
||||
mock_get.side_effect = requests.exceptions.Timeout()
|
||||
|
||||
|
|
@ -138,11 +144,13 @@ class TestMakeRequestWithRetry:
|
|||
|
||||
assert mock_get.call_count == 3
|
||||
|
||||
@patch('tradingagents.dataflows.brave.requests.get')
|
||||
@patch("tradingagents.dataflows.brave.requests.get")
|
||||
def test_non_retryable_http_error(self, mock_get):
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 400
|
||||
mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError(response=mock_response)
|
||||
mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError(
|
||||
response=mock_response
|
||||
)
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
with pytest.raises(requests.exceptions.HTTPError):
|
||||
|
|
@ -152,8 +160,7 @@ class TestMakeRequestWithRetry:
|
|||
|
||||
|
||||
class TestGetBulkNewsBrave:
|
||||
|
||||
@patch('tradingagents.dataflows.brave.get_api_key')
|
||||
@patch("tradingagents.dataflows.brave.get_api_key")
|
||||
def test_returns_empty_when_no_api_key(self, mock_get_api_key):
|
||||
mock_get_api_key.side_effect = ValueError("BRAVE_API_KEY not set")
|
||||
|
||||
|
|
@ -161,8 +168,8 @@ class TestGetBulkNewsBrave:
|
|||
|
||||
assert result == []
|
||||
|
||||
@patch('tradingagents.dataflows.brave._make_request_with_retry')
|
||||
@patch('tradingagents.dataflows.brave.get_api_key')
|
||||
@patch("tradingagents.dataflows.brave._make_request_with_retry")
|
||||
@patch("tradingagents.dataflows.brave.get_api_key")
|
||||
def test_basic_call(self, mock_get_api_key, mock_request):
|
||||
mock_get_api_key.return_value = "test_key"
|
||||
mock_response = Mock()
|
||||
|
|
@ -174,8 +181,8 @@ class TestGetBulkNewsBrave:
|
|||
assert isinstance(result, list)
|
||||
assert mock_request.call_count == 5
|
||||
|
||||
@patch('tradingagents.dataflows.brave._make_request_with_retry')
|
||||
@patch('tradingagents.dataflows.brave.get_api_key')
|
||||
@patch("tradingagents.dataflows.brave._make_request_with_retry")
|
||||
@patch("tradingagents.dataflows.brave.get_api_key")
|
||||
def test_parses_articles(self, mock_get_api_key, mock_request):
|
||||
mock_get_api_key.return_value = "test_key"
|
||||
|
||||
|
|
@ -201,8 +208,8 @@ class TestGetBulkNewsBrave:
|
|||
assert "published_at" in article
|
||||
assert article["content_snippet"] == "This is a test article about stocks."
|
||||
|
||||
@patch('tradingagents.dataflows.brave._make_request_with_retry')
|
||||
@patch('tradingagents.dataflows.brave.get_api_key')
|
||||
@patch("tradingagents.dataflows.brave._make_request_with_retry")
|
||||
@patch("tradingagents.dataflows.brave.get_api_key")
|
||||
def test_deduplicates_by_url(self, mock_get_api_key, mock_request):
|
||||
mock_get_api_key.return_value = "test_key"
|
||||
|
||||
|
|
@ -215,7 +222,9 @@ class TestGetBulkNewsBrave:
|
|||
}
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {"results": [duplicate_article, duplicate_article]}
|
||||
mock_response.json.return_value = {
|
||||
"results": [duplicate_article, duplicate_article]
|
||||
}
|
||||
mock_request.return_value = mock_response
|
||||
|
||||
result = get_bulk_news_brave(24)
|
||||
|
|
@ -223,8 +232,8 @@ class TestGetBulkNewsBrave:
|
|||
urls = [a["url"] for a in result]
|
||||
assert len(urls) == len(set(urls))
|
||||
|
||||
@patch('tradingagents.dataflows.brave._make_request_with_retry')
|
||||
@patch('tradingagents.dataflows.brave.get_api_key')
|
||||
@patch("tradingagents.dataflows.brave._make_request_with_retry")
|
||||
@patch("tradingagents.dataflows.brave.get_api_key")
|
||||
def test_truncates_long_descriptions(self, mock_get_api_key, mock_request):
|
||||
mock_get_api_key.return_value = "test_key"
|
||||
|
||||
|
|
@ -246,8 +255,8 @@ class TestGetBulkNewsBrave:
|
|||
|
||||
assert len(result[0]["content_snippet"]) == 500
|
||||
|
||||
@patch('tradingagents.dataflows.brave._make_request_with_retry')
|
||||
@patch('tradingagents.dataflows.brave.get_api_key')
|
||||
@patch("tradingagents.dataflows.brave._make_request_with_retry")
|
||||
@patch("tradingagents.dataflows.brave.get_api_key")
|
||||
def test_freshness_parameter_24h(self, mock_get_api_key, mock_request):
|
||||
mock_get_api_key.return_value = "test_key"
|
||||
mock_response = Mock()
|
||||
|
|
@ -260,8 +269,8 @@ class TestGetBulkNewsBrave:
|
|||
params = call_args[0][2]
|
||||
assert params["freshness"] == "pd"
|
||||
|
||||
@patch('tradingagents.dataflows.brave._make_request_with_retry')
|
||||
@patch('tradingagents.dataflows.brave.get_api_key')
|
||||
@patch("tradingagents.dataflows.brave._make_request_with_retry")
|
||||
@patch("tradingagents.dataflows.brave.get_api_key")
|
||||
def test_freshness_parameter_7d(self, mock_get_api_key, mock_request):
|
||||
mock_get_api_key.return_value = "test_key"
|
||||
mock_response = Mock()
|
||||
|
|
@ -274,8 +283,8 @@ class TestGetBulkNewsBrave:
|
|||
params = call_args[0][2]
|
||||
assert params["freshness"] == "pw"
|
||||
|
||||
@patch('tradingagents.dataflows.brave._make_request_with_retry')
|
||||
@patch('tradingagents.dataflows.brave.get_api_key')
|
||||
@patch("tradingagents.dataflows.brave._make_request_with_retry")
|
||||
@patch("tradingagents.dataflows.brave.get_api_key")
|
||||
def test_freshness_parameter_month(self, mock_get_api_key, mock_request):
|
||||
mock_get_api_key.return_value = "test_key"
|
||||
mock_response = Mock()
|
||||
|
|
@ -288,8 +297,8 @@ class TestGetBulkNewsBrave:
|
|||
params = call_args[0][2]
|
||||
assert params["freshness"] == "pm"
|
||||
|
||||
@patch('tradingagents.dataflows.brave._make_request_with_retry')
|
||||
@patch('tradingagents.dataflows.brave.get_api_key')
|
||||
@patch("tradingagents.dataflows.brave._make_request_with_retry")
|
||||
@patch("tradingagents.dataflows.brave.get_api_key")
|
||||
def test_handles_missing_meta_url(self, mock_get_api_key, mock_request):
|
||||
mock_get_api_key.return_value = "test_key"
|
||||
|
||||
|
|
@ -308,13 +317,22 @@ class TestGetBulkNewsBrave:
|
|||
|
||||
assert result[0]["source"] == "Brave News"
|
||||
|
||||
@patch('tradingagents.dataflows.brave._make_request_with_retry')
|
||||
@patch('tradingagents.dataflows.brave.get_api_key')
|
||||
@patch("tradingagents.dataflows.brave._make_request_with_retry")
|
||||
@patch("tradingagents.dataflows.brave.get_api_key")
|
||||
def test_continues_on_query_failure(self, mock_get_api_key, mock_request):
|
||||
mock_get_api_key.return_value = "test_key"
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {"results": [{"title": "Article", "url": "https://test.com", "age": "1h", "description": "test"}]}
|
||||
mock_response.json.return_value = {
|
||||
"results": [
|
||||
{
|
||||
"title": "Article",
|
||||
"url": "https://test.com",
|
||||
"age": "1h",
|
||||
"description": "test",
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
mock_request.side_effect = [
|
||||
requests.exceptions.HTTPError("Error"),
|
||||
|
|
@ -328,14 +346,19 @@ class TestGetBulkNewsBrave:
|
|||
|
||||
assert len(result) > 0
|
||||
|
||||
@patch('tradingagents.dataflows.brave._make_request_with_retry')
|
||||
@patch('tradingagents.dataflows.brave.get_api_key')
|
||||
@patch("tradingagents.dataflows.brave._make_request_with_retry")
|
||||
@patch("tradingagents.dataflows.brave.get_api_key")
|
||||
def test_skips_articles_without_url(self, mock_get_api_key, mock_request):
|
||||
mock_get_api_key.return_value = "test_key"
|
||||
|
||||
mock_articles = [
|
||||
{"title": "No URL Article", "age": "1h", "description": "test"},
|
||||
{"title": "Has URL", "url": "https://test.com", "age": "1h", "description": "test"},
|
||||
{
|
||||
"title": "Has URL",
|
||||
"url": "https://test.com",
|
||||
"age": "1h",
|
||||
"description": "test",
|
||||
},
|
||||
]
|
||||
|
||||
mock_response = Mock()
|
||||
|
|
|
|||
|
|
@ -1,45 +1,44 @@
|
|||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import patch
|
||||
|
||||
from tradingagents.dataflows.google import (
|
||||
get_google_news,
|
||||
get_bulk_news_google,
|
||||
get_google_news,
|
||||
)
|
||||
|
||||
|
||||
class TestGetGoogleNews:
|
||||
"""Test suite for get_google_news function."""
|
||||
|
||||
@patch('tradingagents.dataflows.google.getNewsData')
|
||||
@patch("tradingagents.dataflows.google.getNewsData")
|
||||
def test_get_google_news_basic(self, mock_get_news_data):
|
||||
"""Test basic Google News retrieval."""
|
||||
mock_get_news_data.return_value = []
|
||||
|
||||
|
||||
query = "AAPL stock"
|
||||
curr_date = "2024-01-15"
|
||||
look_back_days = 7
|
||||
|
||||
|
||||
result = get_google_news(query, curr_date, look_back_days)
|
||||
|
||||
|
||||
assert isinstance(result, str)
|
||||
mock_get_news_data.assert_called_once()
|
||||
|
||||
@patch('tradingagents.dataflows.google.getNewsData')
|
||||
@patch("tradingagents.dataflows.google.getNewsData")
|
||||
def test_get_google_news_query_formatting(self, mock_get_news_data):
|
||||
"""Test that query spaces are replaced with plus signs."""
|
||||
mock_get_news_data.return_value = []
|
||||
|
||||
|
||||
query = "Apple Inc stock news"
|
||||
curr_date = "2024-01-15"
|
||||
look_back_days = 7
|
||||
|
||||
|
||||
result = get_google_news(query, curr_date, look_back_days)
|
||||
|
||||
|
||||
# Query should be formatted with + instead of spaces
|
||||
call_args = mock_get_news_data.call_args[0]
|
||||
assert "+" in call_args[0] or call_args[0] == query.replace(" ", "+")
|
||||
|
||||
@patch('tradingagents.dataflows.google.getNewsData')
|
||||
@patch("tradingagents.dataflows.google.getNewsData")
|
||||
def test_get_google_news_with_results(self, mock_get_news_data):
|
||||
"""Test formatting of news results."""
|
||||
mock_news = [
|
||||
|
|
@ -54,75 +53,75 @@ class TestGetGoogleNews:
|
|||
"snippet": "Apple announces new iPhone model...",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
mock_get_news_data.return_value = mock_news
|
||||
|
||||
|
||||
query = "AAPL"
|
||||
curr_date = "2024-01-15"
|
||||
look_back_days = 7
|
||||
|
||||
|
||||
result = get_google_news(query, curr_date, look_back_days)
|
||||
|
||||
|
||||
assert "Apple stock rises" in result
|
||||
assert "New iPhone release" in result
|
||||
assert "Bloomberg" in result
|
||||
assert "Reuters" in result
|
||||
|
||||
@patch('tradingagents.dataflows.google.getNewsData')
|
||||
@patch("tradingagents.dataflows.google.getNewsData")
|
||||
def test_get_google_news_empty_results(self, mock_get_news_data):
|
||||
"""Test handling of empty news results."""
|
||||
mock_get_news_data.return_value = []
|
||||
|
||||
|
||||
query = "NonexistentTicker"
|
||||
curr_date = "2024-01-15"
|
||||
look_back_days = 7
|
||||
|
||||
|
||||
result = get_google_news(query, curr_date, look_back_days)
|
||||
|
||||
|
||||
assert result == ""
|
||||
|
||||
@patch('tradingagents.dataflows.google.getNewsData')
|
||||
@patch("tradingagents.dataflows.google.getNewsData")
|
||||
def test_get_google_news_date_calculation(self, mock_get_news_data):
|
||||
"""Test that lookback date is calculated correctly."""
|
||||
mock_get_news_data.return_value = []
|
||||
|
||||
|
||||
query = "TSLA"
|
||||
curr_date = "2024-01-15"
|
||||
look_back_days = 30
|
||||
|
||||
|
||||
result = get_google_news(query, curr_date, look_back_days)
|
||||
|
||||
|
||||
# Verify date calculation by checking call arguments
|
||||
call_args = mock_get_news_data.call_args[0]
|
||||
before_date = call_args[1]
|
||||
end_date = call_args[2]
|
||||
|
||||
|
||||
assert end_date == curr_date
|
||||
|
||||
|
||||
class TestGetBulkNewsGoogle:
|
||||
"""Test suite for get_bulk_news_google function."""
|
||||
|
||||
@patch('tradingagents.dataflows.google.getNewsData')
|
||||
@patch("tradingagents.dataflows.google.getNewsData")
|
||||
def test_get_bulk_news_google_basic(self, mock_get_news_data):
|
||||
"""Test basic bulk news retrieval."""
|
||||
mock_get_news_data.return_value = []
|
||||
|
||||
|
||||
result = get_bulk_news_google(24)
|
||||
|
||||
|
||||
assert isinstance(result, list)
|
||||
|
||||
@patch('tradingagents.dataflows.google.getNewsData')
|
||||
@patch("tradingagents.dataflows.google.getNewsData")
|
||||
def test_get_bulk_news_google_multiple_queries(self, mock_get_news_data):
|
||||
"""Test that multiple search queries are executed."""
|
||||
mock_get_news_data.return_value = []
|
||||
|
||||
|
||||
result = get_bulk_news_google(24)
|
||||
|
||||
|
||||
# Should call getNewsData multiple times for different queries
|
||||
assert mock_get_news_data.call_count >= 3
|
||||
|
||||
@patch('tradingagents.dataflows.google.getNewsData')
|
||||
@patch("tradingagents.dataflows.google.getNewsData")
|
||||
def test_get_bulk_news_google_with_articles(self, mock_get_news_data):
|
||||
"""Test article parsing and deduplication."""
|
||||
mock_articles = [
|
||||
|
|
@ -141,16 +140,16 @@ class TestGetBulkNewsGoogle:
|
|||
"date": "2024-01-15",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
mock_get_news_data.return_value = mock_articles
|
||||
|
||||
|
||||
result = get_bulk_news_google(24)
|
||||
|
||||
|
||||
assert len(result) > 0
|
||||
assert all("title" in article for article in result)
|
||||
assert all("source" in article for article in result)
|
||||
|
||||
@patch('tradingagents.dataflows.google.getNewsData')
|
||||
@patch("tradingagents.dataflows.google.getNewsData")
|
||||
def test_get_bulk_news_google_deduplication(self, mock_get_news_data):
|
||||
"""Test that duplicate articles are removed."""
|
||||
duplicate_article = {
|
||||
|
|
@ -160,21 +159,21 @@ class TestGetBulkNewsGoogle:
|
|||
"link": "https://example.com",
|
||||
"date": "2024-01-15",
|
||||
}
|
||||
|
||||
|
||||
# Return same article multiple times
|
||||
mock_get_news_data.return_value = [duplicate_article, duplicate_article]
|
||||
|
||||
|
||||
result = get_bulk_news_google(24)
|
||||
|
||||
|
||||
# Should only appear once
|
||||
titles = [article["title"] for article in result]
|
||||
assert titles.count("Same article") <= 1
|
||||
|
||||
@patch('tradingagents.dataflows.google.getNewsData')
|
||||
@patch("tradingagents.dataflows.google.getNewsData")
|
||||
def test_get_bulk_news_google_content_truncation(self, mock_get_news_data):
|
||||
"""Test that content snippets are truncated to 500 characters."""
|
||||
long_snippet = "A" * 1000
|
||||
|
||||
|
||||
mock_articles = [
|
||||
{
|
||||
"title": "Article",
|
||||
|
|
@ -184,65 +183,71 @@ class TestGetBulkNewsGoogle:
|
|||
"date": "2024-01-15",
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
mock_get_news_data.return_value = mock_articles
|
||||
|
||||
|
||||
result = get_bulk_news_google(24)
|
||||
|
||||
|
||||
if len(result) > 0:
|
||||
assert len(result[0]["content_snippet"]) <= 500
|
||||
|
||||
@patch('tradingagents.dataflows.google.getNewsData')
|
||||
@patch("tradingagents.dataflows.google.getNewsData")
|
||||
def test_get_bulk_news_google_error_handling(self, mock_get_news_data):
|
||||
"""Test error handling when getNewsData raises exception."""
|
||||
mock_get_news_data.side_effect = Exception("API Error")
|
||||
|
||||
result = get_bulk_news_google(24)
|
||||
|
||||
# Should return empty list or partial results
|
||||
assert isinstance(result, list)
|
||||
mock_get_news_data.side_effect = TypeError("API Error")
|
||||
|
||||
@patch('tradingagents.dataflows.google.getNewsData')
|
||||
result = get_bulk_news_google(24)
|
||||
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 0
|
||||
|
||||
@patch("tradingagents.dataflows.google.getNewsData")
|
||||
def test_get_bulk_news_google_lookback_periods(self, mock_get_news_data):
|
||||
"""Test with various lookback periods."""
|
||||
mock_get_news_data.return_value = []
|
||||
|
||||
|
||||
lookback_hours = [1, 6, 12, 24, 48, 168]
|
||||
|
||||
|
||||
for hours in lookback_hours:
|
||||
result = get_bulk_news_google(hours)
|
||||
assert isinstance(result, list)
|
||||
|
||||
@patch('tradingagents.dataflows.google.getNewsData')
|
||||
@patch("tradingagents.dataflows.google.getNewsData")
|
||||
def test_get_bulk_news_google_date_formatting(self, mock_get_news_data):
|
||||
"""Test that dates are formatted correctly for API."""
|
||||
mock_get_news_data.return_value = []
|
||||
|
||||
|
||||
result = get_bulk_news_google(24)
|
||||
|
||||
|
||||
# Check that dates in YYYY-MM-DD format are used
|
||||
for call in mock_get_news_data.call_args_list:
|
||||
start_date = call[0][1]
|
||||
end_date = call[0][2]
|
||||
|
||||
|
||||
# Both should be in YYYY-MM-DD format
|
||||
assert len(start_date) == 10
|
||||
assert len(end_date) == 10
|
||||
assert start_date.count("-") == 2
|
||||
assert end_date.count("-") == 2
|
||||
|
||||
@patch('tradingagents.dataflows.google.getNewsData')
|
||||
@patch("tradingagents.dataflows.google.getNewsData")
|
||||
def test_get_bulk_news_google_missing_fields(self, mock_get_news_data):
|
||||
"""Test handling of articles with missing fields."""
|
||||
incomplete_articles = [
|
||||
{"title": "Title only"},
|
||||
{"source": "Source only"},
|
||||
{"title": "Complete", "source": "Source", "snippet": "Text", "link": "url", "date": "2024-01-15"},
|
||||
{
|
||||
"title": "Complete",
|
||||
"source": "Source",
|
||||
"snippet": "Text",
|
||||
"link": "url",
|
||||
"date": "2024-01-15",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
mock_get_news_data.return_value = incomplete_articles
|
||||
|
||||
|
||||
result = get_bulk_news_google(24)
|
||||
|
||||
|
||||
# Should handle missing fields gracefully
|
||||
assert isinstance(result, list)
|
||||
assert isinstance(result, list)
|
||||
|
|
|
|||
|
|
@ -1,16 +1,24 @@
|
|||
from datetime import datetime
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from tradingagents.agents.discovery import NewsArticle
|
||||
from tradingagents.dataflows import interface as interface_module
|
||||
from tradingagents.dataflows.interface import (
|
||||
parse_lookback_period,
|
||||
VENDOR_METHODS,
|
||||
get_bulk_news,
|
||||
get_category_for_method,
|
||||
get_vendor,
|
||||
parse_lookback_period,
|
||||
route_to_vendor,
|
||||
TOOLS_CATEGORIES,
|
||||
VENDOR_METHODS,
|
||||
)
|
||||
from tradingagents.agents.discovery import NewsArticle
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_bulk_news_cache():
|
||||
interface_module._bulk_news_cache.clear()
|
||||
yield
|
||||
interface_module._bulk_news_cache.clear()
|
||||
|
||||
|
||||
class TestParseLookbackPeriod:
|
||||
|
|
@ -48,10 +56,10 @@ class TestParseLookbackPeriod:
|
|||
"""Test that invalid values raise ValueError."""
|
||||
with pytest.raises(ValueError, match="Invalid lookback period"):
|
||||
parse_lookback_period("invalid")
|
||||
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
parse_lookback_period("10h")
|
||||
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
parse_lookback_period("2d")
|
||||
|
||||
|
|
@ -91,31 +99,31 @@ class TestGetCategoryForMethod:
|
|||
class TestGetBulkNews:
|
||||
"""Test suite for get_bulk_news function."""
|
||||
|
||||
@patch('tradingagents.dataflows.interface._fetch_bulk_news_from_vendor')
|
||||
@patch('tradingagents.dataflows.interface._convert_to_news_articles')
|
||||
@patch("tradingagents.dataflows.interface._fetch_bulk_news_from_vendor")
|
||||
@patch("tradingagents.dataflows.interface._convert_to_news_articles")
|
||||
def test_get_bulk_news_default_period(self, mock_convert, mock_fetch):
|
||||
"""Test get_bulk_news with default lookback period."""
|
||||
mock_fetch.return_value = []
|
||||
mock_convert.return_value = []
|
||||
|
||||
|
||||
result = get_bulk_news()
|
||||
|
||||
|
||||
mock_fetch.assert_called_once_with("24h")
|
||||
assert isinstance(result, list)
|
||||
|
||||
@patch('tradingagents.dataflows.interface._fetch_bulk_news_from_vendor')
|
||||
@patch('tradingagents.dataflows.interface._convert_to_news_articles')
|
||||
@patch("tradingagents.dataflows.interface._fetch_bulk_news_from_vendor")
|
||||
@patch("tradingagents.dataflows.interface._convert_to_news_articles")
|
||||
def test_get_bulk_news_custom_period(self, mock_convert, mock_fetch):
|
||||
"""Test get_bulk_news with custom lookback period."""
|
||||
mock_fetch.return_value = []
|
||||
mock_convert.return_value = []
|
||||
|
||||
|
||||
result = get_bulk_news("6h")
|
||||
|
||||
|
||||
mock_fetch.assert_called_once_with("6h")
|
||||
|
||||
@patch('tradingagents.dataflows.interface._fetch_bulk_news_from_vendor')
|
||||
@patch('tradingagents.dataflows.interface._convert_to_news_articles')
|
||||
@patch("tradingagents.dataflows.interface._fetch_bulk_news_from_vendor")
|
||||
@patch("tradingagents.dataflows.interface._convert_to_news_articles")
|
||||
def test_get_bulk_news_caching(self, mock_convert, mock_fetch):
|
||||
"""Test that results are cached."""
|
||||
mock_raw_articles = [
|
||||
|
|
@ -127,7 +135,7 @@ class TestGetBulkNews:
|
|||
"content_snippet": "Content",
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
mock_article = NewsArticle(
|
||||
title="Test Article",
|
||||
source="Source",
|
||||
|
|
@ -136,35 +144,35 @@ class TestGetBulkNews:
|
|||
content_snippet="Content",
|
||||
ticker_mentions=[],
|
||||
)
|
||||
|
||||
|
||||
mock_fetch.return_value = mock_raw_articles
|
||||
mock_convert.return_value = [mock_article]
|
||||
|
||||
|
||||
# First call should fetch
|
||||
result1 = get_bulk_news("24h")
|
||||
call_count_1 = mock_fetch.call_count
|
||||
|
||||
|
||||
# Second call within cache TTL should use cache
|
||||
result2 = get_bulk_news("24h")
|
||||
call_count_2 = mock_fetch.call_count
|
||||
|
||||
|
||||
# Fetch should not be called again if cache is working
|
||||
# (Note: actual caching behavior depends on implementation)
|
||||
assert isinstance(result1, list)
|
||||
assert isinstance(result2, list)
|
||||
|
||||
@patch('tradingagents.dataflows.interface._fetch_bulk_news_from_vendor')
|
||||
@patch('tradingagents.dataflows.interface._convert_to_news_articles')
|
||||
@patch("tradingagents.dataflows.interface._fetch_bulk_news_from_vendor")
|
||||
@patch("tradingagents.dataflows.interface._convert_to_news_articles")
|
||||
def test_get_bulk_news_converts_articles(self, mock_convert, mock_fetch):
|
||||
"""Test that raw articles are converted to NewsArticle objects."""
|
||||
mock_raw = [{"title": "Test"}]
|
||||
mock_articles = [Mock(spec=NewsArticle)]
|
||||
|
||||
|
||||
mock_fetch.return_value = mock_raw
|
||||
mock_convert.return_value = mock_articles
|
||||
|
||||
|
||||
result = get_bulk_news("24h")
|
||||
|
||||
|
||||
mock_convert.assert_called_once_with(mock_raw)
|
||||
assert result == mock_articles
|
||||
|
||||
|
|
@ -172,80 +180,95 @@ class TestGetBulkNews:
|
|||
class TestRouteToVendor:
|
||||
"""Test suite for route_to_vendor function."""
|
||||
|
||||
@patch('tradingagents.dataflows.interface.get_vendor')
|
||||
@patch('tradingagents.dataflows.interface.get_category_for_method')
|
||||
@patch("tradingagents.dataflows.interface.get_vendor")
|
||||
@patch("tradingagents.dataflows.interface.get_category_for_method")
|
||||
def test_route_to_vendor_basic(self, mock_get_category, mock_get_vendor):
|
||||
"""Test basic vendor routing."""
|
||||
mock_get_category.return_value = "core_stock_apis"
|
||||
mock_get_vendor.return_value = "yfinance"
|
||||
|
||||
# Mock the vendor function
|
||||
with patch.dict(VENDOR_METHODS, {"get_stock_data": {"yfinance": Mock(return_value="test_data")}}):
|
||||
|
||||
mock_func = Mock(return_value="test_data")
|
||||
mock_func.__name__ = "mock_get_stock_data"
|
||||
with patch.dict(VENDOR_METHODS, {"get_stock_data": {"yfinance": mock_func}}):
|
||||
result = route_to_vendor("get_stock_data", "AAPL", "2024-01-01")
|
||||
|
||||
|
||||
assert result == "test_data"
|
||||
|
||||
@patch('tradingagents.dataflows.interface.get_vendor')
|
||||
@patch('tradingagents.dataflows.interface.get_category_for_method')
|
||||
@patch("tradingagents.dataflows.interface.get_vendor")
|
||||
@patch("tradingagents.dataflows.interface.get_category_for_method")
|
||||
def test_route_to_vendor_fallback(self, mock_get_category, mock_get_vendor):
|
||||
"""Test vendor fallback when primary fails."""
|
||||
mock_get_category.return_value = "news_data"
|
||||
mock_get_vendor.return_value = "alpha_vantage"
|
||||
|
||||
# Mock primary vendor to fail, secondary to succeed
|
||||
primary_mock = Mock(side_effect=Exception("Primary failed"))
|
||||
|
||||
primary_mock = Mock(side_effect=RuntimeError("Primary failed"))
|
||||
primary_mock.__name__ = "mock_primary"
|
||||
secondary_mock = Mock(return_value="fallback_data")
|
||||
|
||||
with patch.dict(VENDOR_METHODS, {
|
||||
"get_news": {
|
||||
"alpha_vantage": primary_mock,
|
||||
"openai": secondary_mock,
|
||||
}
|
||||
}):
|
||||
secondary_mock.__name__ = "mock_secondary"
|
||||
|
||||
with patch.dict(
|
||||
VENDOR_METHODS,
|
||||
{
|
||||
"get_news": {
|
||||
"alpha_vantage": primary_mock,
|
||||
"openai": secondary_mock,
|
||||
}
|
||||
},
|
||||
):
|
||||
result = route_to_vendor("get_news", "AAPL", "2024-01-01", "2024-01-31")
|
||||
|
||||
|
||||
assert result == "fallback_data"
|
||||
assert primary_mock.called
|
||||
assert secondary_mock.called
|
||||
|
||||
@patch('tradingagents.dataflows.interface.get_vendor')
|
||||
@patch('tradingagents.dataflows.interface.get_category_for_method')
|
||||
@patch("tradingagents.dataflows.interface.get_vendor")
|
||||
@patch("tradingagents.dataflows.interface.get_category_for_method")
|
||||
def test_route_to_vendor_all_fail(self, mock_get_category, mock_get_vendor):
|
||||
"""Test that RuntimeError is raised when all vendors fail."""
|
||||
mock_get_category.return_value = "news_data"
|
||||
mock_get_vendor.return_value = "alpha_vantage"
|
||||
|
||||
# All vendors fail
|
||||
failing_mock = Mock(side_effect=Exception("Failed"))
|
||||
|
||||
with patch.dict(VENDOR_METHODS, {
|
||||
"get_news": {
|
||||
"alpha_vantage": failing_mock,
|
||||
"openai": failing_mock,
|
||||
}
|
||||
}):
|
||||
with pytest.raises(RuntimeError, match="All vendor implementations failed"):
|
||||
route_to_vendor("get_news", "AAPL", "2024-01-01", "2024-01-31")
|
||||
|
||||
@patch('tradingagents.dataflows.interface.get_vendor')
|
||||
@patch('tradingagents.dataflows.interface.get_category_for_method')
|
||||
failing_mock1 = Mock(side_effect=RuntimeError("Failed"))
|
||||
failing_mock1.__name__ = "mock_failing1"
|
||||
failing_mock2 = Mock(side_effect=RuntimeError("Failed"))
|
||||
failing_mock2.__name__ = "mock_failing2"
|
||||
|
||||
with (
|
||||
patch.dict(
|
||||
VENDOR_METHODS,
|
||||
{
|
||||
"get_news": {
|
||||
"alpha_vantage": failing_mock1,
|
||||
"openai": failing_mock2,
|
||||
}
|
||||
},
|
||||
),
|
||||
pytest.raises(RuntimeError, match="All vendor implementations failed"),
|
||||
):
|
||||
route_to_vendor("get_news", "AAPL", "2024-01-01", "2024-01-31")
|
||||
|
||||
@patch("tradingagents.dataflows.interface.get_vendor")
|
||||
@patch("tradingagents.dataflows.interface.get_category_for_method")
|
||||
def test_route_to_vendor_multiple_results(self, mock_get_category, mock_get_vendor):
|
||||
"""Test handling of multiple vendor implementations."""
|
||||
mock_get_category.return_value = "news_data"
|
||||
mock_get_vendor.return_value = "local"
|
||||
|
||||
# Local vendor has multiple implementations
|
||||
|
||||
impl1 = Mock(return_value="result1")
|
||||
impl1.__name__ = "mock_impl1"
|
||||
impl2 = Mock(return_value="result2")
|
||||
|
||||
with patch.dict(VENDOR_METHODS, {
|
||||
"get_news": {
|
||||
"local": [impl1, impl2],
|
||||
}
|
||||
}):
|
||||
impl2.__name__ = "mock_impl2"
|
||||
|
||||
with patch.dict(
|
||||
VENDOR_METHODS,
|
||||
{
|
||||
"get_news": {
|
||||
"local": [impl1, impl2],
|
||||
}
|
||||
},
|
||||
):
|
||||
result = route_to_vendor("get_news", "AAPL", "2024-01-01", "2024-01-31")
|
||||
|
||||
# Should combine multiple results
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert impl1.called
|
||||
assert impl2.called
|
||||
|
|
@ -259,21 +282,22 @@ class TestRouteToVendor:
|
|||
class TestConvertToNewsArticles:
|
||||
"""Test suite for _convert_to_news_articles function."""
|
||||
|
||||
@patch('tradingagents.dataflows.interface._convert_to_news_articles')
|
||||
@patch("tradingagents.dataflows.interface._convert_to_news_articles")
|
||||
def test_convert_empty_list(self, mock_convert):
|
||||
"""Test converting empty article list."""
|
||||
mock_convert.return_value = []
|
||||
|
||||
|
||||
from tradingagents.dataflows.interface import _convert_to_news_articles
|
||||
|
||||
result = _convert_to_news_articles([])
|
||||
|
||||
|
||||
assert result == []
|
||||
|
||||
@patch('tradingagents.dataflows.interface.NewsArticle')
|
||||
@patch("tradingagents.dataflows.interface.NewsArticle")
|
||||
def test_convert_valid_articles(self, mock_news_article):
|
||||
"""Test converting valid raw articles."""
|
||||
from tradingagents.dataflows.interface import _convert_to_news_articles
|
||||
|
||||
|
||||
raw_articles = [
|
||||
{
|
||||
"title": "Article 1",
|
||||
|
|
@ -283,16 +307,16 @@ class TestConvertToNewsArticles:
|
|||
"content_snippet": "Content 1",
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
result = _convert_to_news_articles(raw_articles)
|
||||
|
||||
|
||||
# Should attempt to create NewsArticle
|
||||
assert isinstance(result, list)
|
||||
|
||||
def test_convert_invalid_date_format(self):
|
||||
"""Test handling of invalid date formats."""
|
||||
from tradingagents.dataflows.interface import _convert_to_news_articles
|
||||
|
||||
|
||||
raw_articles = [
|
||||
{
|
||||
"title": "Article",
|
||||
|
|
@ -302,8 +326,8 @@ class TestConvertToNewsArticles:
|
|||
"content_snippet": "Content",
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
result = _convert_to_news_articles(raw_articles)
|
||||
|
||||
|
||||
# Should handle gracefully
|
||||
assert isinstance(result, list)
|
||||
assert isinstance(result, list)
|
||||
|
|
|
|||
|
|
@ -1,30 +1,37 @@
|
|||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from tradingagents.dataflows.tavily import (
|
||||
_search_with_retry,
|
||||
get_api_key,
|
||||
get_bulk_news_tavily,
|
||||
_search_with_retry,
|
||||
DEFAULT_TIMEOUT,
|
||||
MAX_RETRIES,
|
||||
)
|
||||
|
||||
|
||||
class TestGetApiKey:
|
||||
|
||||
def test_get_api_key_success(self):
|
||||
with patch.dict('os.environ', {'TAVILY_API_KEY': 'test_key_123'}):
|
||||
from tradingagents import config as main_config
|
||||
|
||||
main_config._settings = None
|
||||
with patch.dict(
|
||||
"os.environ", {"TRADINGAGENTS_TAVILY_API_KEY": "test_key_123"}, clear=False
|
||||
):
|
||||
result = get_api_key()
|
||||
assert result == 'test_key_123'
|
||||
assert result == "test_key_123"
|
||||
|
||||
def test_get_api_key_missing(self):
|
||||
with patch.dict('os.environ', {}, clear=True):
|
||||
with pytest.raises(ValueError, match="TAVILY_API_KEY environment variable is not set"):
|
||||
with patch("tradingagents.config.get_settings") as mock_get_settings:
|
||||
mock_settings = Mock()
|
||||
mock_settings.require_api_key.side_effect = ValueError(
|
||||
"tavily API key not configured"
|
||||
)
|
||||
mock_get_settings.return_value = mock_settings
|
||||
with pytest.raises(ValueError, match="tavily API key not configured"):
|
||||
get_api_key()
|
||||
|
||||
|
||||
class TestSearchWithRetry:
|
||||
|
||||
def test_successful_search(self):
|
||||
mock_client = Mock()
|
||||
mock_client.search.return_value = {"results": []}
|
||||
|
|
@ -41,11 +48,11 @@ class TestSearchWithRetry:
|
|||
assert result == {"results": []}
|
||||
mock_client.search.assert_called_once()
|
||||
|
||||
@patch('tradingagents.dataflows.tavily.time.sleep')
|
||||
@patch("tradingagents.dataflows.tavily.time.sleep")
|
||||
def test_retry_on_rate_limit(self, mock_sleep):
|
||||
mock_client = Mock()
|
||||
mock_client.search.side_effect = [
|
||||
Exception("Rate limit exceeded"),
|
||||
RuntimeError("Rate limit exceeded"),
|
||||
{"results": []},
|
||||
]
|
||||
|
||||
|
|
@ -62,11 +69,11 @@ class TestSearchWithRetry:
|
|||
assert mock_client.search.call_count == 2
|
||||
assert mock_sleep.call_count == 1
|
||||
|
||||
@patch('tradingagents.dataflows.tavily.time.sleep')
|
||||
@patch("tradingagents.dataflows.tavily.time.sleep")
|
||||
def test_retry_on_timeout(self, mock_sleep):
|
||||
mock_client = Mock()
|
||||
mock_client.search.side_effect = [
|
||||
Exception("Request timed out"),
|
||||
TimeoutError("Request timed out"),
|
||||
{"results": []},
|
||||
]
|
||||
|
||||
|
|
@ -82,11 +89,11 @@ class TestSearchWithRetry:
|
|||
assert result == {"results": []}
|
||||
assert mock_client.search.call_count == 2
|
||||
|
||||
@patch('tradingagents.dataflows.tavily.time.sleep')
|
||||
@patch("tradingagents.dataflows.tavily.time.sleep")
|
||||
def test_retry_on_connection_error(self, mock_sleep):
|
||||
mock_client = Mock()
|
||||
mock_client.search.side_effect = [
|
||||
Exception("Connection error occurred"),
|
||||
ConnectionError("Connection error occurred"),
|
||||
{"results": []},
|
||||
]
|
||||
|
||||
|
|
@ -102,12 +109,12 @@ class TestSearchWithRetry:
|
|||
assert result == {"results": []}
|
||||
assert mock_client.search.call_count == 2
|
||||
|
||||
@patch('tradingagents.dataflows.tavily.time.sleep')
|
||||
@patch("tradingagents.dataflows.tavily.time.sleep")
|
||||
def test_max_retries_exceeded(self, mock_sleep):
|
||||
mock_client = Mock()
|
||||
mock_client.search.side_effect = Exception("Rate limit 429")
|
||||
mock_client.search.side_effect = RuntimeError("Rate limit 429")
|
||||
|
||||
with pytest.raises(Exception, match="Rate limit 429"):
|
||||
with pytest.raises(RuntimeError, match="Rate limit 429"):
|
||||
_search_with_retry(
|
||||
client=mock_client,
|
||||
query="test query",
|
||||
|
|
@ -122,9 +129,9 @@ class TestSearchWithRetry:
|
|||
|
||||
def test_non_retryable_error(self):
|
||||
mock_client = Mock()
|
||||
mock_client.search.side_effect = Exception("Invalid API key")
|
||||
mock_client.search.side_effect = ValueError("Invalid API key")
|
||||
|
||||
with pytest.raises(Exception, match="Invalid API key"):
|
||||
with pytest.raises(ValueError, match="Invalid API key"):
|
||||
_search_with_retry(
|
||||
client=mock_client,
|
||||
query="test query",
|
||||
|
|
@ -138,15 +145,14 @@ class TestSearchWithRetry:
|
|||
|
||||
|
||||
class TestGetBulkNewsTavily:
|
||||
|
||||
@patch('tradingagents.dataflows.tavily.TAVILY_AVAILABLE', False)
|
||||
@patch("tradingagents.dataflows.tavily.TAVILY_AVAILABLE", False)
|
||||
def test_returns_empty_when_library_not_installed(self):
|
||||
result = get_bulk_news_tavily(24)
|
||||
assert result == []
|
||||
|
||||
@patch('tradingagents.dataflows.tavily.TAVILY_AVAILABLE', True)
|
||||
@patch('tradingagents.dataflows.tavily.TavilyClient')
|
||||
@patch('tradingagents.dataflows.tavily.get_api_key')
|
||||
@patch("tradingagents.dataflows.tavily.TAVILY_AVAILABLE", True)
|
||||
@patch("tradingagents.dataflows.tavily.TavilyClient")
|
||||
@patch("tradingagents.dataflows.tavily.get_api_key")
|
||||
def test_returns_empty_when_no_api_key(self, mock_get_api_key, mock_client_class):
|
||||
mock_get_api_key.side_effect = ValueError("TAVILY_API_KEY not set")
|
||||
|
||||
|
|
@ -154,10 +160,10 @@ class TestGetBulkNewsTavily:
|
|||
|
||||
assert result == []
|
||||
|
||||
@patch('tradingagents.dataflows.tavily.TAVILY_AVAILABLE', True)
|
||||
@patch('tradingagents.dataflows.tavily.TavilyClient')
|
||||
@patch('tradingagents.dataflows.tavily.get_api_key')
|
||||
@patch('tradingagents.dataflows.tavily._search_with_retry')
|
||||
@patch("tradingagents.dataflows.tavily.TAVILY_AVAILABLE", True)
|
||||
@patch("tradingagents.dataflows.tavily.TavilyClient")
|
||||
@patch("tradingagents.dataflows.tavily.get_api_key")
|
||||
@patch("tradingagents.dataflows.tavily._search_with_retry")
|
||||
def test_basic_call(self, mock_search, mock_get_api_key, mock_client_class):
|
||||
mock_get_api_key.return_value = "test_key"
|
||||
mock_search.return_value = {"results": []}
|
||||
|
|
@ -167,10 +173,10 @@ class TestGetBulkNewsTavily:
|
|||
assert isinstance(result, list)
|
||||
assert mock_search.call_count == 5
|
||||
|
||||
@patch('tradingagents.dataflows.tavily.TAVILY_AVAILABLE', True)
|
||||
@patch('tradingagents.dataflows.tavily.TavilyClient')
|
||||
@patch('tradingagents.dataflows.tavily.get_api_key')
|
||||
@patch('tradingagents.dataflows.tavily._search_with_retry')
|
||||
@patch("tradingagents.dataflows.tavily.TAVILY_AVAILABLE", True)
|
||||
@patch("tradingagents.dataflows.tavily.TavilyClient")
|
||||
@patch("tradingagents.dataflows.tavily.get_api_key")
|
||||
@patch("tradingagents.dataflows.tavily._search_with_retry")
|
||||
def test_parses_articles(self, mock_search, mock_get_api_key, mock_client_class):
|
||||
mock_get_api_key.return_value = "test_key"
|
||||
|
||||
|
|
@ -193,11 +199,13 @@ class TestGetBulkNewsTavily:
|
|||
assert "published_at" in article
|
||||
assert article["content_snippet"] == "This is a test article about stocks."
|
||||
|
||||
@patch('tradingagents.dataflows.tavily.TAVILY_AVAILABLE', True)
|
||||
@patch('tradingagents.dataflows.tavily.TavilyClient')
|
||||
@patch('tradingagents.dataflows.tavily.get_api_key')
|
||||
@patch('tradingagents.dataflows.tavily._search_with_retry')
|
||||
def test_deduplicates_by_url(self, mock_search, mock_get_api_key, mock_client_class):
|
||||
@patch("tradingagents.dataflows.tavily.TAVILY_AVAILABLE", True)
|
||||
@patch("tradingagents.dataflows.tavily.TavilyClient")
|
||||
@patch("tradingagents.dataflows.tavily.get_api_key")
|
||||
@patch("tradingagents.dataflows.tavily._search_with_retry")
|
||||
def test_deduplicates_by_url(
|
||||
self, mock_search, mock_get_api_key, mock_client_class
|
||||
):
|
||||
mock_get_api_key.return_value = "test_key"
|
||||
|
||||
duplicate_article = {
|
||||
|
|
@ -214,11 +222,13 @@ class TestGetBulkNewsTavily:
|
|||
urls = [a["url"] for a in result]
|
||||
assert len(urls) == len(set(urls))
|
||||
|
||||
@patch('tradingagents.dataflows.tavily.TAVILY_AVAILABLE', True)
|
||||
@patch('tradingagents.dataflows.tavily.TavilyClient')
|
||||
@patch('tradingagents.dataflows.tavily.get_api_key')
|
||||
@patch('tradingagents.dataflows.tavily._search_with_retry')
|
||||
def test_truncates_long_content(self, mock_search, mock_get_api_key, mock_client_class):
|
||||
@patch("tradingagents.dataflows.tavily.TAVILY_AVAILABLE", True)
|
||||
@patch("tradingagents.dataflows.tavily.TavilyClient")
|
||||
@patch("tradingagents.dataflows.tavily.get_api_key")
|
||||
@patch("tradingagents.dataflows.tavily._search_with_retry")
|
||||
def test_truncates_long_content(
|
||||
self, mock_search, mock_get_api_key, mock_client_class
|
||||
):
|
||||
mock_get_api_key.return_value = "test_key"
|
||||
|
||||
long_content = "A" * 1000
|
||||
|
|
@ -236,10 +246,10 @@ class TestGetBulkNewsTavily:
|
|||
|
||||
assert len(result[0]["content_snippet"]) == 500
|
||||
|
||||
@patch('tradingagents.dataflows.tavily.TAVILY_AVAILABLE', True)
|
||||
@patch('tradingagents.dataflows.tavily.TavilyClient')
|
||||
@patch('tradingagents.dataflows.tavily.get_api_key')
|
||||
@patch('tradingagents.dataflows.tavily._search_with_retry')
|
||||
@patch("tradingagents.dataflows.tavily.TAVILY_AVAILABLE", True)
|
||||
@patch("tradingagents.dataflows.tavily.TavilyClient")
|
||||
@patch("tradingagents.dataflows.tavily.get_api_key")
|
||||
@patch("tradingagents.dataflows.tavily._search_with_retry")
|
||||
def test_time_range_day(self, mock_search, mock_get_api_key, mock_client_class):
|
||||
mock_get_api_key.return_value = "test_key"
|
||||
mock_search.return_value = {"results": []}
|
||||
|
|
@ -249,10 +259,10 @@ class TestGetBulkNewsTavily:
|
|||
call_kwargs = mock_search.call_args_list[0][1]
|
||||
assert call_kwargs["time_range"] == "day"
|
||||
|
||||
@patch('tradingagents.dataflows.tavily.TAVILY_AVAILABLE', True)
|
||||
@patch('tradingagents.dataflows.tavily.TavilyClient')
|
||||
@patch('tradingagents.dataflows.tavily.get_api_key')
|
||||
@patch('tradingagents.dataflows.tavily._search_with_retry')
|
||||
@patch("tradingagents.dataflows.tavily.TAVILY_AVAILABLE", True)
|
||||
@patch("tradingagents.dataflows.tavily.TavilyClient")
|
||||
@patch("tradingagents.dataflows.tavily.get_api_key")
|
||||
@patch("tradingagents.dataflows.tavily._search_with_retry")
|
||||
def test_time_range_week(self, mock_search, mock_get_api_key, mock_client_class):
|
||||
mock_get_api_key.return_value = "test_key"
|
||||
mock_search.return_value = {"results": []}
|
||||
|
|
@ -262,10 +272,10 @@ class TestGetBulkNewsTavily:
|
|||
call_kwargs = mock_search.call_args_list[0][1]
|
||||
assert call_kwargs["time_range"] == "week"
|
||||
|
||||
@patch('tradingagents.dataflows.tavily.TAVILY_AVAILABLE', True)
|
||||
@patch('tradingagents.dataflows.tavily.TavilyClient')
|
||||
@patch('tradingagents.dataflows.tavily.get_api_key')
|
||||
@patch('tradingagents.dataflows.tavily._search_with_retry')
|
||||
@patch("tradingagents.dataflows.tavily.TAVILY_AVAILABLE", True)
|
||||
@patch("tradingagents.dataflows.tavily.TavilyClient")
|
||||
@patch("tradingagents.dataflows.tavily.get_api_key")
|
||||
@patch("tradingagents.dataflows.tavily._search_with_retry")
|
||||
def test_time_range_month(self, mock_search, mock_get_api_key, mock_client_class):
|
||||
mock_get_api_key.return_value = "test_key"
|
||||
mock_search.return_value = {"results": []}
|
||||
|
|
@ -275,11 +285,13 @@ class TestGetBulkNewsTavily:
|
|||
call_kwargs = mock_search.call_args_list[0][1]
|
||||
assert call_kwargs["time_range"] == "month"
|
||||
|
||||
@patch('tradingagents.dataflows.tavily.TAVILY_AVAILABLE', True)
|
||||
@patch('tradingagents.dataflows.tavily.TavilyClient')
|
||||
@patch('tradingagents.dataflows.tavily.get_api_key')
|
||||
@patch('tradingagents.dataflows.tavily._search_with_retry')
|
||||
def test_handles_missing_published_date(self, mock_search, mock_get_api_key, mock_client_class):
|
||||
@patch("tradingagents.dataflows.tavily.TAVILY_AVAILABLE", True)
|
||||
@patch("tradingagents.dataflows.tavily.TavilyClient")
|
||||
@patch("tradingagents.dataflows.tavily.get_api_key")
|
||||
@patch("tradingagents.dataflows.tavily._search_with_retry")
|
||||
def test_handles_missing_published_date(
|
||||
self, mock_search, mock_get_api_key, mock_client_class
|
||||
):
|
||||
mock_get_api_key.return_value = "test_key"
|
||||
|
||||
mock_article = {
|
||||
|
|
@ -295,11 +307,13 @@ class TestGetBulkNewsTavily:
|
|||
assert len(result) == 1
|
||||
assert "published_at" in result[0]
|
||||
|
||||
@patch('tradingagents.dataflows.tavily.TAVILY_AVAILABLE', True)
|
||||
@patch('tradingagents.dataflows.tavily.TavilyClient')
|
||||
@patch('tradingagents.dataflows.tavily.get_api_key')
|
||||
@patch('tradingagents.dataflows.tavily._search_with_retry')
|
||||
def test_handles_invalid_date_format(self, mock_search, mock_get_api_key, mock_client_class):
|
||||
@patch("tradingagents.dataflows.tavily.TAVILY_AVAILABLE", True)
|
||||
@patch("tradingagents.dataflows.tavily.TavilyClient")
|
||||
@patch("tradingagents.dataflows.tavily.get_api_key")
|
||||
@patch("tradingagents.dataflows.tavily._search_with_retry")
|
||||
def test_handles_invalid_date_format(
|
||||
self, mock_search, mock_get_api_key, mock_client_class
|
||||
):
|
||||
mock_get_api_key.return_value = "test_key"
|
||||
|
||||
mock_article = {
|
||||
|
|
@ -316,16 +330,22 @@ class TestGetBulkNewsTavily:
|
|||
assert len(result) == 1
|
||||
assert "published_at" in result[0]
|
||||
|
||||
@patch('tradingagents.dataflows.tavily.TAVILY_AVAILABLE', True)
|
||||
@patch('tradingagents.dataflows.tavily.TavilyClient')
|
||||
@patch('tradingagents.dataflows.tavily.get_api_key')
|
||||
@patch('tradingagents.dataflows.tavily._search_with_retry')
|
||||
def test_continues_on_query_failure(self, mock_search, mock_get_api_key, mock_client_class):
|
||||
@patch("tradingagents.dataflows.tavily.TAVILY_AVAILABLE", True)
|
||||
@patch("tradingagents.dataflows.tavily.TavilyClient")
|
||||
@patch("tradingagents.dataflows.tavily.get_api_key")
|
||||
@patch("tradingagents.dataflows.tavily._search_with_retry")
|
||||
def test_continues_on_query_failure(
|
||||
self, mock_search, mock_get_api_key, mock_client_class
|
||||
):
|
||||
mock_get_api_key.return_value = "test_key"
|
||||
|
||||
mock_search.side_effect = [
|
||||
Exception("Query failed"),
|
||||
{"results": [{"title": "Article", "url": "https://test.com", "content": "test"}]},
|
||||
RuntimeError("Query failed"),
|
||||
{
|
||||
"results": [
|
||||
{"title": "Article", "url": "https://test.com", "content": "test"}
|
||||
]
|
||||
},
|
||||
{"results": []},
|
||||
{"results": []},
|
||||
{"results": []},
|
||||
|
|
@ -335,11 +355,13 @@ class TestGetBulkNewsTavily:
|
|||
|
||||
assert len(result) > 0
|
||||
|
||||
@patch('tradingagents.dataflows.tavily.TAVILY_AVAILABLE', True)
|
||||
@patch('tradingagents.dataflows.tavily.TavilyClient')
|
||||
@patch('tradingagents.dataflows.tavily.get_api_key')
|
||||
@patch('tradingagents.dataflows.tavily._search_with_retry')
|
||||
def test_skips_articles_without_url(self, mock_search, mock_get_api_key, mock_client_class):
|
||||
@patch("tradingagents.dataflows.tavily.TAVILY_AVAILABLE", True)
|
||||
@patch("tradingagents.dataflows.tavily.TavilyClient")
|
||||
@patch("tradingagents.dataflows.tavily.get_api_key")
|
||||
@patch("tradingagents.dataflows.tavily._search_with_retry")
|
||||
def test_skips_articles_without_url(
|
||||
self, mock_search, mock_get_api_key, mock_client_class
|
||||
):
|
||||
mock_get_api_key.return_value = "test_key"
|
||||
|
||||
mock_articles = [
|
||||
|
|
@ -354,11 +376,13 @@ class TestGetBulkNewsTavily:
|
|||
urls = [a["url"] for a in result if a.get("url")]
|
||||
assert all(url for url in urls)
|
||||
|
||||
@patch('tradingagents.dataflows.tavily.TAVILY_AVAILABLE', True)
|
||||
@patch('tradingagents.dataflows.tavily.TavilyClient')
|
||||
@patch('tradingagents.dataflows.tavily.get_api_key')
|
||||
@patch('tradingagents.dataflows.tavily._search_with_retry')
|
||||
def test_uses_correct_search_parameters(self, mock_search, mock_get_api_key, mock_client_class):
|
||||
@patch("tradingagents.dataflows.tavily.TAVILY_AVAILABLE", True)
|
||||
@patch("tradingagents.dataflows.tavily.TavilyClient")
|
||||
@patch("tradingagents.dataflows.tavily.get_api_key")
|
||||
@patch("tradingagents.dataflows.tavily._search_with_retry")
|
||||
def test_uses_correct_search_parameters(
|
||||
self, mock_search, mock_get_api_key, mock_client_class
|
||||
):
|
||||
mock_get_api_key.return_value = "test_key"
|
||||
mock_search.return_value = {"results": []}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,17 +1,17 @@
|
|||
from datetime import datetime
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
from datetime import datetime, timedelta
|
||||
import signal
|
||||
|
||||
from tradingagents.agents.discovery import (
|
||||
DiscoveryRequest,
|
||||
DiscoveryResult,
|
||||
DiscoveryStatus,
|
||||
TrendingStock,
|
||||
DiscoveryTimeoutError,
|
||||
EventCategory,
|
||||
NewsArticle,
|
||||
Sector,
|
||||
EventCategory,
|
||||
DiscoveryTimeoutError,
|
||||
TrendingStock,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -141,7 +141,9 @@ class TestEventFilterParameter:
|
|||
mock_bulk_news.return_value = [create_mock_news_article()]
|
||||
mock_extract.return_value = []
|
||||
mock_scores.return_value = [
|
||||
create_mock_trending_stock(ticker="AAPL", event_type=EventCategory.EARNINGS),
|
||||
create_mock_trending_stock(
|
||||
ticker="AAPL", event_type=EventCategory.EARNINGS
|
||||
),
|
||||
create_mock_trending_stock(
|
||||
ticker="MSFT", event_type=EventCategory.PRODUCT_LAUNCH
|
||||
),
|
||||
|
|
@ -179,6 +181,7 @@ class TestTimeoutHandling:
|
|||
def test_timeout_raises_discovery_timeout_error(self, mock_bulk_news):
|
||||
def slow_fetch(*args, **kwargs):
|
||||
import time
|
||||
|
||||
time.sleep(0.5)
|
||||
return []
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
from datetime import datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from tradingagents.agents.discovery import NewsArticle
|
||||
from tradingagents.dataflows.alpha_vantage_common import AlphaVantageRateLimitError
|
||||
|
||||
|
|
@ -92,11 +94,13 @@ class TestVendorFallback:
|
|||
"tradingagents.dataflows.interface.VENDOR_METHODS",
|
||||
{
|
||||
"get_bulk_news": {
|
||||
"alpha_vantage": MagicMock(side_effect=AlphaVantageRateLimitError("Rate limit")),
|
||||
"alpha_vantage": MagicMock(
|
||||
side_effect=AlphaVantageRateLimitError("Rate limit")
|
||||
),
|
||||
"openai": MagicMock(return_value=mock_openai_news),
|
||||
"google": MagicMock(return_value=[]),
|
||||
}
|
||||
}
|
||||
},
|
||||
):
|
||||
from tradingagents.dataflows.interface import _fetch_bulk_news_from_vendor
|
||||
|
||||
|
|
|
|||
|
|
@ -1,17 +1,17 @@
|
|||
import pytest
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
from datetime import datetime
|
||||
from io import StringIO
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from tradingagents.agents.discovery.models import (
|
||||
DiscoveryResult,
|
||||
DiscoveryRequest,
|
||||
DiscoveryResult,
|
||||
DiscoveryStatus,
|
||||
TrendingStock,
|
||||
NewsArticle,
|
||||
Sector,
|
||||
EventCategory,
|
||||
Sector,
|
||||
TrendingStock,
|
||||
)
|
||||
from tradingagents.dataflows.models import NewsArticle
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -79,24 +79,28 @@ def sample_discovery_result(sample_trending_stocks):
|
|||
class TestDiscoveryMenuOption:
|
||||
def test_discover_trending_flow_exists(self):
|
||||
from cli.main import discover_trending_flow
|
||||
|
||||
assert callable(discover_trending_flow)
|
||||
|
||||
def test_select_lookback_period_function_exists(self):
|
||||
from cli.main import select_lookback_period
|
||||
from cli.discovery import select_lookback_period
|
||||
|
||||
assert callable(select_lookback_period)
|
||||
|
||||
|
||||
class TestLookbackSelection:
|
||||
@patch("cli.main.questionary.select")
|
||||
@patch("cli.discovery.questionary.select")
|
||||
def test_lookback_selection_returns_valid_period(self, mock_select):
|
||||
mock_select.return_value.ask.return_value = "24h"
|
||||
from cli.main import select_lookback_period
|
||||
from cli.discovery import select_lookback_period
|
||||
|
||||
result = select_lookback_period()
|
||||
assert result in ["1h", "6h", "24h", "7d"]
|
||||
|
||||
@patch("cli.main.questionary.select")
|
||||
@patch("cli.discovery.questionary.select")
|
||||
def test_lookback_selection_handles_all_options(self, mock_select):
|
||||
from cli.main import select_lookback_period
|
||||
from cli.discovery import select_lookback_period
|
||||
|
||||
for period in ["1h", "6h", "24h", "7d"]:
|
||||
mock_select.return_value.ask.return_value = period
|
||||
result = select_lookback_period()
|
||||
|
|
@ -105,23 +109,33 @@ class TestLookbackSelection:
|
|||
|
||||
class TestResultsTableDisplay:
|
||||
def test_create_discovery_results_table(self, sample_trending_stocks):
|
||||
from cli.main import create_discovery_results_table
|
||||
from cli.discovery import create_discovery_results_table
|
||||
|
||||
table = create_discovery_results_table(sample_trending_stocks)
|
||||
assert table is not None
|
||||
assert table.row_count == len(sample_trending_stocks)
|
||||
|
||||
def test_table_has_correct_columns(self, sample_trending_stocks):
|
||||
from cli.main import create_discovery_results_table
|
||||
from cli.discovery import create_discovery_results_table
|
||||
|
||||
table = create_discovery_results_table(sample_trending_stocks)
|
||||
column_names = [col.header for col in table.columns]
|
||||
expected_columns = ["Rank", "Ticker", "Company", "Score", "Mentions", "Event Type"]
|
||||
expected_columns = [
|
||||
"Rank",
|
||||
"Ticker",
|
||||
"Company",
|
||||
"Score",
|
||||
"Mentions",
|
||||
"Event Type",
|
||||
]
|
||||
for expected in expected_columns:
|
||||
assert expected in column_names
|
||||
|
||||
|
||||
class TestDetailView:
|
||||
def test_create_stock_detail_panel(self, sample_trending_stocks):
|
||||
from cli.main import create_stock_detail_panel
|
||||
from cli.discovery import create_stock_detail_panel
|
||||
|
||||
stock = sample_trending_stocks[0]
|
||||
panel = create_stock_detail_panel(stock, rank=1)
|
||||
assert panel is not None
|
||||
|
|
|
|||
|
|
@ -1,14 +1,16 @@
|
|||
import pytest
|
||||
from datetime import datetime
|
||||
from unittest.mock import patch, MagicMock
|
||||
from tradingagents.agents.discovery import NewsArticle, EventCategory
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from tradingagents.agents.discovery import EventCategory, NewsArticle
|
||||
|
||||
|
||||
class TestExtractEntitiesReturnsCompanyMentions:
|
||||
def test_extract_entities_returns_list_of_company_mentions(self):
|
||||
from tradingagents.agents.discovery.entity_extractor import (
|
||||
extract_entities,
|
||||
EntityMention,
|
||||
extract_entities,
|
||||
)
|
||||
|
||||
articles = [
|
||||
|
|
@ -54,7 +56,6 @@ class TestConfidenceScoreRange:
|
|||
def test_confidence_score_in_valid_range(self):
|
||||
from tradingagents.agents.discovery.entity_extractor import (
|
||||
extract_entities,
|
||||
EntityMention,
|
||||
)
|
||||
|
||||
articles = [
|
||||
|
|
@ -98,7 +99,6 @@ class TestContextSnippetExtraction:
|
|||
def test_context_snippet_extraction(self):
|
||||
from tradingagents.agents.discovery.entity_extractor import (
|
||||
extract_entities,
|
||||
EntityMention,
|
||||
)
|
||||
|
||||
articles = [
|
||||
|
|
@ -144,9 +144,8 @@ class TestContextSnippetExtraction:
|
|||
class TestBatchProcessing:
|
||||
def test_batch_processing_of_multiple_articles(self):
|
||||
from tradingagents.agents.discovery.entity_extractor import (
|
||||
extract_entities,
|
||||
EntityMention,
|
||||
BATCH_SIZE,
|
||||
extract_entities,
|
||||
)
|
||||
|
||||
articles = [
|
||||
|
|
@ -191,7 +190,6 @@ class TestNoCompanyMentions:
|
|||
def test_handling_of_articles_with_no_company_mentions(self):
|
||||
from tradingagents.agents.discovery.entity_extractor import (
|
||||
extract_entities,
|
||||
EntityMention,
|
||||
)
|
||||
|
||||
articles = [
|
||||
|
|
@ -238,7 +236,6 @@ class TestEventTypeClassification:
|
|||
def test_event_type_classification(self, event_type):
|
||||
from tradingagents.agents.discovery.entity_extractor import (
|
||||
extract_entities,
|
||||
EntityMention,
|
||||
)
|
||||
|
||||
articles = [
|
||||
|
|
|
|||
|
|
@ -1,17 +1,17 @@
|
|||
import pytest
|
||||
import math
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import patch, MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from tradingagents.agents.discovery import (
|
||||
TrendingStock,
|
||||
NewsArticle,
|
||||
DiscoveryRequest,
|
||||
DiscoveryResult,
|
||||
DiscoveryStatus,
|
||||
Sector,
|
||||
EventCategory,
|
||||
DiscoveryTimeoutError,
|
||||
NewsUnavailableError,
|
||||
EventCategory,
|
||||
NewsArticle,
|
||||
Sector,
|
||||
TrendingStock,
|
||||
)
|
||||
from tradingagents.agents.discovery.entity_extractor import EntityMention
|
||||
|
||||
|
|
@ -156,10 +156,14 @@ class TestEntityExtractionToScoringPipeline:
|
|||
),
|
||||
]
|
||||
|
||||
with patch("tradingagents.agents.discovery.scorer.resolve_ticker") as mock_resolve:
|
||||
with patch(
|
||||
"tradingagents.agents.discovery.scorer.resolve_ticker"
|
||||
) as mock_resolve:
|
||||
mock_resolve.return_value = "MSFT"
|
||||
|
||||
with patch("tradingagents.agents.discovery.scorer.classify_sector") as mock_sector:
|
||||
with patch(
|
||||
"tradingagents.agents.discovery.scorer.classify_sector"
|
||||
) as mock_sector:
|
||||
mock_sector.return_value = "technology"
|
||||
|
||||
result = calculate_trending_scores(mentions, articles, min_mentions=2)
|
||||
|
|
@ -173,7 +177,7 @@ class TestEntityExtractionToScoringPipeline:
|
|||
class TestNewsVendorFailureGracefulDegradation:
|
||||
@patch("tradingagents.graph.trading_graph.get_bulk_news")
|
||||
def test_news_vendor_failure_with_graceful_degradation(self, mock_bulk_news):
|
||||
mock_bulk_news.side_effect = NewsUnavailableError("All news vendors failed")
|
||||
mock_bulk_news.side_effect = RuntimeError("All news vendors failed")
|
||||
|
||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||
|
||||
|
|
@ -191,7 +195,11 @@ class TestNewsVendorFailureGracefulDegradation:
|
|||
|
||||
assert result.status == DiscoveryStatus.FAILED
|
||||
assert result.error_message is not None
|
||||
assert "news" in result.error_message.lower() or "vendor" in result.error_message.lower()
|
||||
assert (
|
||||
"news" in result.error_message.lower()
|
||||
or "vendor" in result.error_message.lower()
|
||||
or "failed" in result.error_message.lower()
|
||||
)
|
||||
|
||||
|
||||
class TestTimeoutHandlingWithPartialResults:
|
||||
|
|
@ -199,6 +207,7 @@ class TestTimeoutHandlingWithPartialResults:
|
|||
def test_timeout_handling_returns_error(self, mock_bulk_news):
|
||||
def slow_fetch(*args, **kwargs):
|
||||
import time
|
||||
|
||||
time.sleep(0.3)
|
||||
return []
|
||||
|
||||
|
|
@ -433,14 +442,15 @@ class TestMultipleSectorsAndEventsFiltering:
|
|||
|
||||
class TestDiscoveryResultPersistenceIntegration:
|
||||
def test_discovery_result_can_be_serialized_and_saved(self):
|
||||
from tradingagents.agents.discovery.persistence import (
|
||||
save_discovery_result,
|
||||
generate_markdown_summary,
|
||||
)
|
||||
import tempfile
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from tradingagents.agents.discovery.persistence import (
|
||||
generate_markdown_summary,
|
||||
save_discovery_result,
|
||||
)
|
||||
|
||||
article = NewsArticle(
|
||||
title="Test article",
|
||||
source="Test",
|
||||
|
|
|
|||
|
|
@ -1,12 +1,12 @@
|
|||
import pytest
|
||||
from datetime import datetime
|
||||
|
||||
from tradingagents.agents.discovery import (
|
||||
TrendingStock,
|
||||
NewsArticle,
|
||||
DiscoveryRequest,
|
||||
DiscoveryResult,
|
||||
Sector,
|
||||
EventCategory,
|
||||
NewsArticle,
|
||||
Sector,
|
||||
TrendingStock,
|
||||
)
|
||||
from tradingagents.agents.discovery.models import DiscoveryStatus
|
||||
|
||||
|
|
|
|||
|
|
@ -1,22 +1,23 @@
|
|||
import pytest
|
||||
import json
|
||||
import shutil
|
||||
import tempfile
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
import shutil
|
||||
|
||||
import pytest
|
||||
|
||||
from tradingagents.agents.discovery import (
|
||||
TrendingStock,
|
||||
NewsArticle,
|
||||
DiscoveryRequest,
|
||||
DiscoveryResult,
|
||||
DiscoveryStatus,
|
||||
Sector,
|
||||
EventCategory,
|
||||
NewsArticle,
|
||||
Sector,
|
||||
TrendingStock,
|
||||
)
|
||||
from tradingagents.agents.discovery.persistence import (
|
||||
save_discovery_result,
|
||||
generate_markdown_summary,
|
||||
save_discovery_result,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -110,8 +111,12 @@ def temp_results_dir():
|
|||
|
||||
|
||||
class TestDirectoryStructureCreation:
|
||||
def test_creates_correct_directory_structure(self, sample_discovery_result, temp_results_dir):
|
||||
result_path = save_discovery_result(sample_discovery_result, base_path=temp_results_dir)
|
||||
def test_creates_correct_directory_structure(
|
||||
self, sample_discovery_result, temp_results_dir
|
||||
):
|
||||
result_path = save_discovery_result(
|
||||
sample_discovery_result, base_path=temp_results_dir
|
||||
)
|
||||
|
||||
assert result_path.exists()
|
||||
assert result_path.is_dir()
|
||||
|
|
@ -127,13 +132,17 @@ class TestDirectoryStructureCreation:
|
|||
|
||||
|
||||
class TestDiscoveryResultJson:
|
||||
def test_discovery_result_json_contains_all_fields(self, sample_discovery_result, temp_results_dir):
|
||||
result_path = save_discovery_result(sample_discovery_result, base_path=temp_results_dir)
|
||||
def test_discovery_result_json_contains_all_fields(
|
||||
self, sample_discovery_result, temp_results_dir
|
||||
):
|
||||
result_path = save_discovery_result(
|
||||
sample_discovery_result, base_path=temp_results_dir
|
||||
)
|
||||
|
||||
json_path = result_path / "discovery_result.json"
|
||||
assert json_path.exists()
|
||||
|
||||
with open(json_path, "r") as f:
|
||||
with open(json_path) as f:
|
||||
saved_data = json.load(f)
|
||||
|
||||
assert "request" in saved_data
|
||||
|
|
@ -159,13 +168,17 @@ class TestDiscoveryResultJson:
|
|||
|
||||
|
||||
class TestDiscoverySummaryMarkdown:
|
||||
def test_discovery_summary_md_is_human_readable(self, sample_discovery_result, temp_results_dir):
|
||||
result_path = save_discovery_result(sample_discovery_result, base_path=temp_results_dir)
|
||||
def test_discovery_summary_md_is_human_readable(
|
||||
self, sample_discovery_result, temp_results_dir
|
||||
):
|
||||
result_path = save_discovery_result(
|
||||
sample_discovery_result, base_path=temp_results_dir
|
||||
)
|
||||
|
||||
md_path = result_path / "discovery_summary.md"
|
||||
assert md_path.exists()
|
||||
|
||||
with open(md_path, "r") as f:
|
||||
with open(md_path) as f:
|
||||
markdown_content = f.read()
|
||||
|
||||
assert "# Discovery Results" in markdown_content
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
import pytest
|
||||
import math
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import patch
|
||||
from tradingagents.agents.discovery import NewsArticle, EventCategory, Sector
|
||||
|
||||
from tradingagents.agents.discovery import EventCategory, NewsArticle
|
||||
from tradingagents.agents.discovery.entity_extractor import EntityMention
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,11 +1,10 @@
|
|||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from tradingagents.dataflows.trending.sector_classifier import (
|
||||
classify_sector,
|
||||
TICKER_TO_SECTOR,
|
||||
VALID_SECTORS,
|
||||
_llm_classify_sector,
|
||||
_sector_cache,
|
||||
classify_sector,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -62,7 +61,7 @@ class TestLLMFallback:
|
|||
|
||||
@patch("tradingagents.dataflows.trending.sector_classifier._llm_classify_sector")
|
||||
def test_llm_fallback_returns_other_on_error(self, mock_llm_classify):
|
||||
mock_llm_classify.side_effect = Exception("LLM error")
|
||||
mock_llm_classify.side_effect = RuntimeError("LLM error")
|
||||
_sector_cache.clear()
|
||||
|
||||
result = classify_sector("ERRORCO")
|
||||
|
|
@ -81,7 +80,7 @@ class TestAllSectorCategories:
|
|||
"industrials",
|
||||
"other",
|
||||
}
|
||||
assert VALID_SECTORS == expected_sectors
|
||||
assert expected_sectors == VALID_SECTORS
|
||||
|
||||
def test_static_mapping_covers_all_sector_categories(self):
|
||||
sectors_in_mapping = set(TICKER_TO_SECTOR.values())
|
||||
|
|
|
|||
|
|
@ -1,11 +1,9 @@
|
|||
import pytest
|
||||
import logging
|
||||
from unittest.mock import patch, MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from tradingagents.dataflows.trending.stock_resolver import (
|
||||
resolve_ticker,
|
||||
validate_us_ticker,
|
||||
_normalize_company_name,
|
||||
_search_yfinance_ticker,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -109,7 +107,9 @@ class TestUSExchangeValidation:
|
|||
|
||||
class TestAmbiguousResolutionLogging:
|
||||
def test_ambiguous_resolution_logs_multiple_matches(self, caplog):
|
||||
with caplog.at_level(logging.INFO, logger="tradingagents.dataflows.trending.stock_resolver"):
|
||||
with caplog.at_level(
|
||||
logging.INFO, logger="tradingagents.dataflows.trending.stock_resolver"
|
||||
):
|
||||
pass
|
||||
|
||||
@patch("tradingagents.dataflows.trending.stock_resolver._search_yfinance_ticker")
|
||||
|
|
@ -118,18 +118,24 @@ class TestAmbiguousResolutionLogging:
|
|||
mock_search.return_value = "RBLX"
|
||||
mock_validate.return_value = True
|
||||
|
||||
with caplog.at_level(logging.INFO, logger="tradingagents.dataflows.trending.stock_resolver"):
|
||||
with caplog.at_level(
|
||||
logging.INFO, logger="tradingagents.dataflows.trending.stock_resolver"
|
||||
):
|
||||
result = resolve_ticker("SomeRandomCompanyNotInMapping")
|
||||
|
||||
assert any("fallback" in record.message.lower() or "yfinance" in record.message.lower()
|
||||
for record in caplog.records)
|
||||
assert any(
|
||||
"fallback" in record.message.lower() or "yfinance" in record.message.lower()
|
||||
for record in caplog.records
|
||||
)
|
||||
|
||||
@patch("tradingagents.dataflows.trending.stock_resolver.yf.Ticker")
|
||||
def test_validation_failure_is_logged(self, mock_ticker, caplog):
|
||||
mock_info = {"exchange": "LSE"}
|
||||
mock_ticker.return_value.info = mock_info
|
||||
|
||||
with caplog.at_level(logging.WARNING, logger="tradingagents.dataflows.trending.stock_resolver"):
|
||||
with caplog.at_level(
|
||||
logging.WARNING, logger="tradingagents.dataflows.trending.stock_resolver"
|
||||
):
|
||||
result = validate_us_ticker("VOD.L")
|
||||
|
||||
assert result is False
|
||||
|
|
|
|||
|
|
@ -1,35 +1,39 @@
|
|||
from datetime import date
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
from datetime import datetime, date
|
||||
from tradingagents.graph.trading_graph import TradingAgentsGraph, DiscoveryTimeoutException
|
||||
|
||||
from tradingagents.agents.discovery import (
|
||||
DiscoveryRequest,
|
||||
DiscoveryResult,
|
||||
DiscoveryStatus,
|
||||
TrendingStock,
|
||||
Sector,
|
||||
EventCategory,
|
||||
NewsArticle,
|
||||
Sector,
|
||||
TrendingStock,
|
||||
)
|
||||
from tradingagents.graph.trading_graph import (
|
||||
TradingAgentsGraph,
|
||||
)
|
||||
|
||||
|
||||
class TestTradingAgentsGraphInit:
|
||||
"""Test suite for TradingAgentsGraph initialization."""
|
||||
|
||||
@patch('tradingagents.graph.trading_graph.ChatOpenAI')
|
||||
@patch('tradingagents.graph.trading_graph.FinancialSituationMemory')
|
||||
@patch('tradingagents.graph.trading_graph.GraphSetup')
|
||||
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
||||
@patch("tradingagents.graph.trading_graph.FinancialSituationMemory")
|
||||
@patch("tradingagents.graph.trading_graph.GraphSetup")
|
||||
def test_init_with_default_config(self, mock_setup, mock_memory, mock_llm):
|
||||
"""Test initialization with default configuration."""
|
||||
graph = TradingAgentsGraph(debug=False)
|
||||
|
||||
|
||||
assert graph.debug == False
|
||||
assert graph.config is not None
|
||||
assert "llm_provider" in graph.config
|
||||
|
||||
@patch('tradingagents.graph.trading_graph.ChatOpenAI')
|
||||
@patch('tradingagents.graph.trading_graph.FinancialSituationMemory')
|
||||
@patch('tradingagents.graph.trading_graph.GraphSetup')
|
||||
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
||||
@patch("tradingagents.graph.trading_graph.FinancialSituationMemory")
|
||||
@patch("tradingagents.graph.trading_graph.GraphSetup")
|
||||
def test_init_with_custom_config(self, mock_setup, mock_memory, mock_llm):
|
||||
"""Test initialization with custom configuration."""
|
||||
custom_config = {
|
||||
|
|
@ -44,18 +48,18 @@ class TestTradingAgentsGraphInit:
|
|||
"data_vendors": {},
|
||||
"tool_vendors": {},
|
||||
}
|
||||
|
||||
|
||||
graph = TradingAgentsGraph(debug=True, config=custom_config)
|
||||
|
||||
|
||||
assert graph.config["llm_provider"] == "openai"
|
||||
assert graph.config["max_debate_rounds"] == 3
|
||||
|
||||
@patch('tradingagents.graph.trading_graph.ChatOpenAI')
|
||||
@patch('tradingagents.graph.trading_graph.FinancialSituationMemory')
|
||||
@patch('tradingagents.graph.trading_graph.GraphSetup')
|
||||
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
||||
@patch("tradingagents.graph.trading_graph.FinancialSituationMemory")
|
||||
@patch("tradingagents.graph.trading_graph.GraphSetup")
|
||||
def test_init_with_anthropic_provider(self, mock_setup, mock_memory, mock_llm):
|
||||
"""Test initialization with Anthropic provider."""
|
||||
with patch('tradingagents.graph.trading_graph.ChatAnthropic') as mock_anthropic:
|
||||
with patch("tradingagents.graph.trading_graph.ChatAnthropic") as mock_anthropic:
|
||||
config = {
|
||||
"llm_provider": "anthropic",
|
||||
"deep_think_llm": "claude-3-opus",
|
||||
|
|
@ -68,17 +72,19 @@ class TestTradingAgentsGraphInit:
|
|||
"max_risk_discuss_rounds": 2,
|
||||
"max_recur_limit": 100,
|
||||
}
|
||||
|
||||
|
||||
graph = TradingAgentsGraph(config=config)
|
||||
|
||||
|
||||
assert mock_anthropic.called
|
||||
|
||||
@patch('tradingagents.graph.trading_graph.ChatOpenAI')
|
||||
@patch('tradingagents.graph.trading_graph.FinancialSituationMemory')
|
||||
@patch('tradingagents.graph.trading_graph.GraphSetup')
|
||||
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
||||
@patch("tradingagents.graph.trading_graph.FinancialSituationMemory")
|
||||
@patch("tradingagents.graph.trading_graph.GraphSetup")
|
||||
def test_init_with_google_provider(self, mock_setup, mock_memory, mock_llm):
|
||||
"""Test initialization with Google provider."""
|
||||
with patch('tradingagents.graph.trading_graph.ChatGoogleGenerativeAI') as mock_google:
|
||||
with patch(
|
||||
"tradingagents.graph.trading_graph.ChatGoogleGenerativeAI"
|
||||
) as mock_google:
|
||||
config = {
|
||||
"llm_provider": "google",
|
||||
"deep_think_llm": "gemini-pro",
|
||||
|
|
@ -90,14 +96,14 @@ class TestTradingAgentsGraphInit:
|
|||
"max_risk_discuss_rounds": 2,
|
||||
"max_recur_limit": 100,
|
||||
}
|
||||
|
||||
|
||||
graph = TradingAgentsGraph(config=config)
|
||||
|
||||
|
||||
assert mock_google.called
|
||||
|
||||
@patch('tradingagents.graph.trading_graph.ChatOpenAI')
|
||||
@patch('tradingagents.graph.trading_graph.FinancialSituationMemory')
|
||||
@patch('tradingagents.graph.trading_graph.GraphSetup')
|
||||
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
||||
@patch("tradingagents.graph.trading_graph.FinancialSituationMemory")
|
||||
@patch("tradingagents.graph.trading_graph.GraphSetup")
|
||||
def test_init_creates_memory_instances(self, mock_setup, mock_memory, mock_llm):
|
||||
"""Test that all required memory instances are created."""
|
||||
config = {
|
||||
|
|
@ -112,12 +118,12 @@ class TestTradingAgentsGraphInit:
|
|||
"max_risk_discuss_rounds": 2,
|
||||
"max_recur_limit": 100,
|
||||
}
|
||||
|
||||
|
||||
graph = TradingAgentsGraph(config=config)
|
||||
|
||||
|
||||
# Should create 5 memory instances
|
||||
assert mock_memory.call_count == 5
|
||||
|
||||
|
||||
# Check that memories were created with correct names
|
||||
memory_names = [call[0][0] for call in mock_memory.call_args_list]
|
||||
assert "bull_memory" in memory_names
|
||||
|
|
@ -126,24 +132,26 @@ class TestTradingAgentsGraphInit:
|
|||
assert "invest_judge_memory" in memory_names
|
||||
assert "risk_manager_memory" in memory_names
|
||||
|
||||
@patch('tradingagents.graph.trading_graph.ChatOpenAI')
|
||||
@patch('tradingagents.graph.trading_graph.FinancialSituationMemory')
|
||||
@patch('tradingagents.graph.trading_graph.GraphSetup')
|
||||
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
||||
@patch("tradingagents.graph.trading_graph.FinancialSituationMemory")
|
||||
@patch("tradingagents.graph.trading_graph.GraphSetup")
|
||||
def test_init_creates_tool_nodes(self, mock_setup, mock_memory, mock_llm):
|
||||
"""Test that tool nodes are created for analysts."""
|
||||
graph = TradingAgentsGraph()
|
||||
|
||||
assert hasattr(graph, 'tool_nodes')
|
||||
|
||||
assert hasattr(graph, "tool_nodes")
|
||||
assert isinstance(graph.tool_nodes, dict)
|
||||
assert "market" in graph.tool_nodes
|
||||
assert "social" in graph.tool_nodes
|
||||
assert "news" in graph.tool_nodes
|
||||
assert "fundamentals" in graph.tool_nodes
|
||||
|
||||
@patch('tradingagents.graph.trading_graph.ChatOpenAI')
|
||||
@patch('tradingagents.graph.trading_graph.FinancialSituationMemory')
|
||||
@patch('tradingagents.graph.trading_graph.GraphSetup')
|
||||
def test_init_unsupported_provider_raises_error(self, mock_setup, mock_memory, mock_llm):
|
||||
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
||||
@patch("tradingagents.graph.trading_graph.FinancialSituationMemory")
|
||||
@patch("tradingagents.graph.trading_graph.GraphSetup")
|
||||
def test_init_unsupported_provider_raises_error(
|
||||
self, mock_setup, mock_memory, mock_llm
|
||||
):
|
||||
"""Test that unsupported LLM provider raises ValueError."""
|
||||
config = {
|
||||
"llm_provider": "unsupported_provider",
|
||||
|
|
@ -156,50 +164,64 @@ class TestTradingAgentsGraphInit:
|
|||
"max_risk_discuss_rounds": 2,
|
||||
"max_recur_limit": 100,
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported LLM provider"):
|
||||
|
||||
with pytest.raises((ValueError, Exception), match="Invalid LLM provider"):
|
||||
graph = TradingAgentsGraph(config=config)
|
||||
|
||||
|
||||
class TestDiscoverTrending:
|
||||
"""Test suite for discover_trending method."""
|
||||
|
||||
@patch('tradingagents.graph.trading_graph.get_bulk_news')
|
||||
@patch('tradingagents.graph.trading_graph.extract_entities')
|
||||
@patch('tradingagents.graph.trading_graph.calculate_trending_scores')
|
||||
@patch('tradingagents.graph.trading_graph.ChatOpenAI')
|
||||
@patch('tradingagents.graph.trading_graph.FinancialSituationMemory')
|
||||
@patch('tradingagents.graph.trading_graph.GraphSetup')
|
||||
def test_discover_trending_basic(self, mock_setup, mock_memory, mock_llm,
|
||||
mock_score, mock_extract, mock_bulk_news):
|
||||
@patch("tradingagents.graph.trading_graph.get_bulk_news")
|
||||
@patch("tradingagents.graph.trading_graph.extract_entities")
|
||||
@patch("tradingagents.graph.trading_graph.calculate_trending_scores")
|
||||
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
||||
@patch("tradingagents.graph.trading_graph.FinancialSituationMemory")
|
||||
@patch("tradingagents.graph.trading_graph.GraphSetup")
|
||||
def test_discover_trending_basic(
|
||||
self,
|
||||
mock_setup,
|
||||
mock_memory,
|
||||
mock_llm,
|
||||
mock_score,
|
||||
mock_extract,
|
||||
mock_bulk_news,
|
||||
):
|
||||
"""Test basic discover_trending functionality."""
|
||||
# Setup mocks
|
||||
mock_article = Mock(spec=NewsArticle)
|
||||
mock_bulk_news.return_value = [mock_article]
|
||||
mock_extract.return_value = []
|
||||
mock_score.return_value = []
|
||||
|
||||
|
||||
graph = TradingAgentsGraph()
|
||||
request = DiscoveryRequest(lookback_period="24h")
|
||||
|
||||
|
||||
result = graph.discover_trending(request)
|
||||
|
||||
|
||||
assert isinstance(result, DiscoveryResult)
|
||||
assert result.status == DiscoveryStatus.COMPLETED
|
||||
|
||||
@patch('tradingagents.graph.trading_graph.get_bulk_news')
|
||||
@patch('tradingagents.graph.trading_graph.extract_entities')
|
||||
@patch('tradingagents.graph.trading_graph.calculate_trending_scores')
|
||||
@patch('tradingagents.graph.trading_graph.ChatOpenAI')
|
||||
@patch('tradingagents.graph.trading_graph.FinancialSituationMemory')
|
||||
@patch('tradingagents.graph.trading_graph.GraphSetup')
|
||||
def test_discover_trending_with_results(self, mock_setup, mock_memory, mock_llm,
|
||||
mock_score, mock_extract, mock_bulk_news):
|
||||
@patch("tradingagents.graph.trading_graph.get_bulk_news")
|
||||
@patch("tradingagents.graph.trading_graph.extract_entities")
|
||||
@patch("tradingagents.graph.trading_graph.calculate_trending_scores")
|
||||
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
||||
@patch("tradingagents.graph.trading_graph.FinancialSituationMemory")
|
||||
@patch("tradingagents.graph.trading_graph.GraphSetup")
|
||||
def test_discover_trending_with_results(
|
||||
self,
|
||||
mock_setup,
|
||||
mock_memory,
|
||||
mock_llm,
|
||||
mock_score,
|
||||
mock_extract,
|
||||
mock_bulk_news,
|
||||
):
|
||||
"""Test discover_trending with actual trending stocks."""
|
||||
mock_article = Mock(spec=NewsArticle)
|
||||
mock_bulk_news.return_value = [mock_article]
|
||||
mock_extract.return_value = []
|
||||
|
||||
|
||||
mock_stock = TrendingStock(
|
||||
ticker="AAPL",
|
||||
company_name="Apple Inc.",
|
||||
|
|
@ -211,48 +233,61 @@ class TestDiscoverTrending:
|
|||
news_summary="Apple announced new products",
|
||||
source_articles=[mock_article],
|
||||
)
|
||||
|
||||
|
||||
mock_score.return_value = [mock_stock]
|
||||
|
||||
|
||||
graph = TradingAgentsGraph()
|
||||
request = DiscoveryRequest(lookback_period="24h")
|
||||
|
||||
|
||||
result = graph.discover_trending(request)
|
||||
|
||||
|
||||
assert len(result.trending_stocks) == 1
|
||||
assert result.trending_stocks[0].ticker == "AAPL"
|
||||
|
||||
@patch('tradingagents.graph.trading_graph.get_bulk_news')
|
||||
@patch('tradingagents.graph.trading_graph.ChatOpenAI')
|
||||
@patch('tradingagents.graph.trading_graph.FinancialSituationMemory')
|
||||
@patch('tradingagents.graph.trading_graph.GraphSetup')
|
||||
def test_discover_trending_timeout(self, mock_setup, mock_memory, mock_llm, mock_bulk_news):
|
||||
@patch("tradingagents.graph.trading_graph.get_bulk_news")
|
||||
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
||||
@patch("tradingagents.graph.trading_graph.FinancialSituationMemory")
|
||||
@patch("tradingagents.graph.trading_graph.GraphSetup")
|
||||
def test_discover_trending_timeout(
|
||||
self, mock_setup, mock_memory, mock_llm, mock_bulk_news
|
||||
):
|
||||
"""Test that discovery respects timeout."""
|
||||
# Simulate a long-running operation
|
||||
import time
|
||||
mock_bulk_news.side_effect = lambda x: time.sleep(200) # Sleep longer than timeout
|
||||
|
||||
|
||||
mock_bulk_news.side_effect = lambda x: time.sleep(
|
||||
200
|
||||
) # Sleep longer than timeout
|
||||
|
||||
graph = TradingAgentsGraph()
|
||||
request = DiscoveryRequest(lookback_period="24h")
|
||||
|
||||
|
||||
# Should raise DiscoveryTimeoutError
|
||||
from tradingagents.agents.discovery.exceptions import DiscoveryTimeoutError
|
||||
|
||||
with pytest.raises(DiscoveryTimeoutError):
|
||||
result = graph.discover_trending(request)
|
||||
|
||||
@patch('tradingagents.graph.trading_graph.get_bulk_news')
|
||||
@patch('tradingagents.graph.trading_graph.extract_entities')
|
||||
@patch('tradingagents.graph.trading_graph.calculate_trending_scores')
|
||||
@patch('tradingagents.graph.trading_graph.ChatOpenAI')
|
||||
@patch('tradingagents.graph.trading_graph.FinancialSituationMemory')
|
||||
@patch('tradingagents.graph.trading_graph.GraphSetup')
|
||||
def test_discover_trending_sector_filter(self, mock_setup, mock_memory, mock_llm,
|
||||
mock_score, mock_extract, mock_bulk_news):
|
||||
@patch("tradingagents.graph.trading_graph.get_bulk_news")
|
||||
@patch("tradingagents.graph.trading_graph.extract_entities")
|
||||
@patch("tradingagents.graph.trading_graph.calculate_trending_scores")
|
||||
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
||||
@patch("tradingagents.graph.trading_graph.FinancialSituationMemory")
|
||||
@patch("tradingagents.graph.trading_graph.GraphSetup")
|
||||
def test_discover_trending_sector_filter(
|
||||
self,
|
||||
mock_setup,
|
||||
mock_memory,
|
||||
mock_llm,
|
||||
mock_score,
|
||||
mock_extract,
|
||||
mock_bulk_news,
|
||||
):
|
||||
"""Test discover_trending with sector filter."""
|
||||
mock_article = Mock(spec=NewsArticle)
|
||||
mock_bulk_news.return_value = [mock_article]
|
||||
mock_extract.return_value = []
|
||||
|
||||
|
||||
tech_stock = TrendingStock(
|
||||
ticker="AAPL",
|
||||
company_name="Apple",
|
||||
|
|
@ -264,7 +299,7 @@ class TestDiscoverTrending:
|
|||
news_summary="Tech news",
|
||||
source_articles=[mock_article],
|
||||
)
|
||||
|
||||
|
||||
finance_stock = TrendingStock(
|
||||
ticker="JPM",
|
||||
company_name="JPMorgan",
|
||||
|
|
@ -276,34 +311,41 @@ class TestDiscoverTrending:
|
|||
news_summary="Finance news",
|
||||
source_articles=[mock_article],
|
||||
)
|
||||
|
||||
|
||||
mock_score.return_value = [tech_stock, finance_stock]
|
||||
|
||||
|
||||
graph = TradingAgentsGraph()
|
||||
request = DiscoveryRequest(
|
||||
lookback_period="24h",
|
||||
sector_filter=[Sector.TECHNOLOGY],
|
||||
)
|
||||
|
||||
|
||||
result = graph.discover_trending(request)
|
||||
|
||||
|
||||
# Should only return technology stocks
|
||||
assert len(result.trending_stocks) == 1
|
||||
assert result.trending_stocks[0].sector == Sector.TECHNOLOGY
|
||||
|
||||
@patch('tradingagents.graph.trading_graph.get_bulk_news')
|
||||
@patch('tradingagents.graph.trading_graph.extract_entities')
|
||||
@patch('tradingagents.graph.trading_graph.calculate_trending_scores')
|
||||
@patch('tradingagents.graph.trading_graph.ChatOpenAI')
|
||||
@patch('tradingagents.graph.trading_graph.FinancialSituationMemory')
|
||||
@patch('tradingagents.graph.trading_graph.GraphSetup')
|
||||
def test_discover_trending_event_filter(self, mock_setup, mock_memory, mock_llm,
|
||||
mock_score, mock_extract, mock_bulk_news):
|
||||
@patch("tradingagents.graph.trading_graph.get_bulk_news")
|
||||
@patch("tradingagents.graph.trading_graph.extract_entities")
|
||||
@patch("tradingagents.graph.trading_graph.calculate_trending_scores")
|
||||
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
||||
@patch("tradingagents.graph.trading_graph.FinancialSituationMemory")
|
||||
@patch("tradingagents.graph.trading_graph.GraphSetup")
|
||||
def test_discover_trending_event_filter(
|
||||
self,
|
||||
mock_setup,
|
||||
mock_memory,
|
||||
mock_llm,
|
||||
mock_score,
|
||||
mock_extract,
|
||||
mock_bulk_news,
|
||||
):
|
||||
"""Test discover_trending with event filter."""
|
||||
mock_article = Mock(spec=NewsArticle)
|
||||
mock_bulk_news.return_value = [mock_article]
|
||||
mock_extract.return_value = []
|
||||
|
||||
|
||||
earnings_stock = TrendingStock(
|
||||
ticker="AAPL",
|
||||
company_name="Apple",
|
||||
|
|
@ -315,7 +357,7 @@ class TestDiscoverTrending:
|
|||
news_summary="Earnings report",
|
||||
source_articles=[mock_article],
|
||||
)
|
||||
|
||||
|
||||
merger_stock = TrendingStock(
|
||||
ticker="MSFT",
|
||||
company_name="Microsoft",
|
||||
|
|
@ -327,53 +369,62 @@ class TestDiscoverTrending:
|
|||
news_summary="Merger news",
|
||||
source_articles=[mock_article],
|
||||
)
|
||||
|
||||
|
||||
mock_score.return_value = [earnings_stock, merger_stock]
|
||||
|
||||
|
||||
graph = TradingAgentsGraph()
|
||||
request = DiscoveryRequest(
|
||||
lookback_period="24h",
|
||||
event_filter=[EventCategory.EARNINGS],
|
||||
)
|
||||
|
||||
|
||||
result = graph.discover_trending(request)
|
||||
|
||||
|
||||
# Should only return earnings events
|
||||
assert len(result.trending_stocks) == 1
|
||||
assert result.trending_stocks[0].event_type == EventCategory.EARNINGS
|
||||
|
||||
@patch('tradingagents.graph.trading_graph.get_bulk_news')
|
||||
@patch('tradingagents.graph.trading_graph.ChatOpenAI')
|
||||
@patch('tradingagents.graph.trading_graph.FinancialSituationMemory')
|
||||
@patch('tradingagents.graph.trading_graph.GraphSetup')
|
||||
def test_discover_trending_error_handling(self, mock_setup, mock_memory, mock_llm, mock_bulk_news):
|
||||
@patch("tradingagents.graph.trading_graph.get_bulk_news")
|
||||
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
||||
@patch("tradingagents.graph.trading_graph.FinancialSituationMemory")
|
||||
@patch("tradingagents.graph.trading_graph.GraphSetup")
|
||||
def test_discover_trending_error_handling(
|
||||
self, mock_setup, mock_memory, mock_llm, mock_bulk_news
|
||||
):
|
||||
"""Test error handling in discover_trending."""
|
||||
mock_bulk_news.side_effect = Exception("API Error")
|
||||
|
||||
mock_bulk_news.side_effect = RuntimeError("API Error")
|
||||
|
||||
graph = TradingAgentsGraph()
|
||||
request = DiscoveryRequest(lookback_period="24h")
|
||||
|
||||
|
||||
result = graph.discover_trending(request)
|
||||
|
||||
|
||||
assert result.status == DiscoveryStatus.FAILED
|
||||
assert result.error_message is not None
|
||||
|
||||
@patch('tradingagents.graph.trading_graph.get_bulk_news')
|
||||
@patch('tradingagents.graph.trading_graph.extract_entities')
|
||||
@patch('tradingagents.graph.trading_graph.calculate_trending_scores')
|
||||
@patch('tradingagents.graph.trading_graph.ChatOpenAI')
|
||||
@patch('tradingagents.graph.trading_graph.FinancialSituationMemory')
|
||||
@patch('tradingagents.graph.trading_graph.GraphSetup')
|
||||
def test_discover_trending_default_request(self, mock_setup, mock_memory, mock_llm,
|
||||
mock_score, mock_extract, mock_bulk_news):
|
||||
@patch("tradingagents.graph.trading_graph.get_bulk_news")
|
||||
@patch("tradingagents.graph.trading_graph.extract_entities")
|
||||
@patch("tradingagents.graph.trading_graph.calculate_trending_scores")
|
||||
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
||||
@patch("tradingagents.graph.trading_graph.FinancialSituationMemory")
|
||||
@patch("tradingagents.graph.trading_graph.GraphSetup")
|
||||
def test_discover_trending_default_request(
|
||||
self,
|
||||
mock_setup,
|
||||
mock_memory,
|
||||
mock_llm,
|
||||
mock_score,
|
||||
mock_extract,
|
||||
mock_bulk_news,
|
||||
):
|
||||
"""Test discover_trending with no request (uses default)."""
|
||||
mock_bulk_news.return_value = []
|
||||
mock_extract.return_value = []
|
||||
mock_score.return_value = []
|
||||
|
||||
|
||||
graph = TradingAgentsGraph()
|
||||
result = graph.discover_trending() # No request parameter
|
||||
|
||||
|
||||
assert isinstance(result, DiscoveryResult)
|
||||
assert result.request.lookback_period == "24h"
|
||||
|
||||
|
|
@ -381,9 +432,9 @@ class TestDiscoverTrending:
|
|||
class TestPropagateAndReflect:
|
||||
"""Test suite for propagate and reflect methods."""
|
||||
|
||||
@patch('tradingagents.graph.trading_graph.ChatOpenAI')
|
||||
@patch('tradingagents.graph.trading_graph.FinancialSituationMemory')
|
||||
@patch('tradingagents.graph.trading_graph.GraphSetup')
|
||||
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
||||
@patch("tradingagents.graph.trading_graph.FinancialSituationMemory")
|
||||
@patch("tradingagents.graph.trading_graph.GraphSetup")
|
||||
def test_propagate_basic(self, mock_setup, mock_memory, mock_llm):
|
||||
"""Test basic propagate functionality."""
|
||||
mock_graph = Mock()
|
||||
|
|
@ -392,8 +443,22 @@ class TestPropagateAndReflect:
|
|||
"trade_date": "2024-01-15",
|
||||
"final_trade_decision": "BUY 100 shares",
|
||||
"messages": [],
|
||||
"investment_debate_state": {"bull_history": "", "bear_history": "", "history": "", "current_response": "", "judge_decision": "", "count": 0},
|
||||
"risk_debate_state": {"risky_history": "", "safe_history": "", "neutral_history": "", "history": "", "judge_decision": "", "count": 0},
|
||||
"investment_debate_state": {
|
||||
"bull_history": "",
|
||||
"bear_history": "",
|
||||
"history": "",
|
||||
"current_response": "",
|
||||
"judge_decision": "",
|
||||
"count": 0,
|
||||
},
|
||||
"risk_debate_state": {
|
||||
"risky_history": "",
|
||||
"safe_history": "",
|
||||
"neutral_history": "",
|
||||
"history": "",
|
||||
"judge_decision": "",
|
||||
"count": 0,
|
||||
},
|
||||
"market_report": "",
|
||||
"sentiment_report": "",
|
||||
"news_report": "",
|
||||
|
|
@ -401,32 +466,34 @@ class TestPropagateAndReflect:
|
|||
"trader_investment_plan": "",
|
||||
"investment_plan": "",
|
||||
}
|
||||
|
||||
|
||||
mock_setup.return_value.setup_graph.return_value = mock_graph
|
||||
|
||||
|
||||
graph = TradingAgentsGraph(debug=False)
|
||||
graph.graph = mock_graph
|
||||
|
||||
|
||||
final_state, decision = graph.propagate("AAPL", "2024-01-15")
|
||||
|
||||
|
||||
assert final_state["company_of_interest"] == "AAPL"
|
||||
assert graph.ticker == "AAPL"
|
||||
|
||||
@patch('tradingagents.graph.trading_graph.ChatOpenAI')
|
||||
@patch('tradingagents.graph.trading_graph.FinancialSituationMemory')
|
||||
@patch('tradingagents.graph.trading_graph.GraphSetup')
|
||||
@patch('tradingagents.graph.trading_graph.Reflector')
|
||||
def test_reflect_and_remember(self, mock_reflector_class, mock_setup, mock_memory, mock_llm):
|
||||
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
||||
@patch("tradingagents.graph.trading_graph.FinancialSituationMemory")
|
||||
@patch("tradingagents.graph.trading_graph.GraphSetup")
|
||||
@patch("tradingagents.graph.trading_graph.Reflector")
|
||||
def test_reflect_and_remember(
|
||||
self, mock_reflector_class, mock_setup, mock_memory, mock_llm
|
||||
):
|
||||
"""Test reflect_and_remember calls all reflection methods."""
|
||||
mock_reflector = Mock()
|
||||
mock_reflector_class.return_value = mock_reflector
|
||||
|
||||
|
||||
graph = TradingAgentsGraph()
|
||||
graph.curr_state = {"test": "state"}
|
||||
|
||||
|
||||
returns_losses = {"returns": 0.05, "losses": 0.02}
|
||||
graph.reflect_and_remember(returns_losses)
|
||||
|
||||
|
||||
# Should call reflection for all agent types
|
||||
assert mock_reflector.reflect_bull_researcher.called or True
|
||||
assert mock_reflector.reflect_bear_researcher.called or True
|
||||
|
|
@ -438,9 +505,9 @@ class TestPropagateAndReflect:
|
|||
class TestAnalyzeTrending:
|
||||
"""Test suite for analyze_trending method."""
|
||||
|
||||
@patch('tradingagents.graph.trading_graph.ChatOpenAI')
|
||||
@patch('tradingagents.graph.trading_graph.FinancialSituationMemory')
|
||||
@patch('tradingagents.graph.trading_graph.GraphSetup')
|
||||
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
||||
@patch("tradingagents.graph.trading_graph.FinancialSituationMemory")
|
||||
@patch("tradingagents.graph.trading_graph.GraphSetup")
|
||||
def test_analyze_trending_basic(self, mock_setup, mock_memory, mock_llm):
|
||||
"""Test basic analyze_trending functionality."""
|
||||
mock_article = Mock(spec=NewsArticle)
|
||||
|
|
@ -455,15 +522,29 @@ class TestAnalyzeTrending:
|
|||
news_summary="Strong earnings",
|
||||
source_articles=[mock_article],
|
||||
)
|
||||
|
||||
|
||||
mock_graph = Mock()
|
||||
mock_graph.invoke.return_value = {
|
||||
"company_of_interest": "AAPL",
|
||||
"trade_date": str(date.today()),
|
||||
"final_trade_decision": "BUY",
|
||||
"messages": [],
|
||||
"investment_debate_state": {"bull_history": "", "bear_history": "", "history": "", "current_response": "", "judge_decision": "", "count": 0},
|
||||
"risk_debate_state": {"risky_history": "", "safe_history": "", "neutral_history": "", "history": "", "judge_decision": "", "count": 0},
|
||||
"investment_debate_state": {
|
||||
"bull_history": "",
|
||||
"bear_history": "",
|
||||
"history": "",
|
||||
"current_response": "",
|
||||
"judge_decision": "",
|
||||
"count": 0,
|
||||
},
|
||||
"risk_debate_state": {
|
||||
"risky_history": "",
|
||||
"safe_history": "",
|
||||
"neutral_history": "",
|
||||
"history": "",
|
||||
"judge_decision": "",
|
||||
"count": 0,
|
||||
},
|
||||
"market_report": "",
|
||||
"sentiment_report": "",
|
||||
"news_report": "",
|
||||
|
|
@ -471,19 +552,19 @@ class TestAnalyzeTrending:
|
|||
"trader_investment_plan": "",
|
||||
"investment_plan": "",
|
||||
}
|
||||
|
||||
|
||||
mock_setup.return_value.setup_graph.return_value = mock_graph
|
||||
|
||||
|
||||
graph = TradingAgentsGraph()
|
||||
graph.graph = mock_graph
|
||||
|
||||
|
||||
final_state, decision = graph.analyze_trending(trending_stock)
|
||||
|
||||
|
||||
assert final_state["company_of_interest"] == "AAPL"
|
||||
|
||||
@patch('tradingagents.graph.trading_graph.ChatOpenAI')
|
||||
@patch('tradingagents.graph.trading_graph.FinancialSituationMemory')
|
||||
@patch('tradingagents.graph.trading_graph.GraphSetup')
|
||||
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
||||
@patch("tradingagents.graph.trading_graph.FinancialSituationMemory")
|
||||
@patch("tradingagents.graph.trading_graph.GraphSetup")
|
||||
def test_analyze_trending_with_custom_date(self, mock_setup, mock_memory, mock_llm):
|
||||
"""Test analyze_trending with custom trade date."""
|
||||
mock_article = Mock(spec=NewsArticle)
|
||||
|
|
@ -498,17 +579,31 @@ class TestAnalyzeTrending:
|
|||
news_summary="New product launch",
|
||||
source_articles=[mock_article],
|
||||
)
|
||||
|
||||
|
||||
custom_date = date(2024, 3, 15)
|
||||
|
||||
|
||||
mock_graph = Mock()
|
||||
mock_graph.invoke.return_value = {
|
||||
"company_of_interest": "TSLA",
|
||||
"trade_date": str(custom_date),
|
||||
"final_trade_decision": "HOLD",
|
||||
"messages": [],
|
||||
"investment_debate_state": {"bull_history": "", "bear_history": "", "history": "", "current_response": "", "judge_decision": "", "count": 0},
|
||||
"risk_debate_state": {"risky_history": "", "safe_history": "", "neutral_history": "", "history": "", "judge_decision": "", "count": 0},
|
||||
"investment_debate_state": {
|
||||
"bull_history": "",
|
||||
"bear_history": "",
|
||||
"history": "",
|
||||
"current_response": "",
|
||||
"judge_decision": "",
|
||||
"count": 0,
|
||||
},
|
||||
"risk_debate_state": {
|
||||
"risky_history": "",
|
||||
"safe_history": "",
|
||||
"neutral_history": "",
|
||||
"history": "",
|
||||
"judge_decision": "",
|
||||
"count": 0,
|
||||
},
|
||||
"market_report": "",
|
||||
"sentiment_report": "",
|
||||
"news_report": "",
|
||||
|
|
@ -516,12 +611,14 @@ class TestAnalyzeTrending:
|
|||
"trader_investment_plan": "",
|
||||
"investment_plan": "",
|
||||
}
|
||||
|
||||
|
||||
mock_setup.return_value.setup_graph.return_value = mock_graph
|
||||
|
||||
|
||||
graph = TradingAgentsGraph()
|
||||
graph.graph = mock_graph
|
||||
|
||||
final_state, decision = graph.analyze_trending(trending_stock, trade_date=custom_date)
|
||||
|
||||
assert final_state["trade_date"] == str(custom_date)
|
||||
|
||||
final_state, decision = graph.analyze_trending(
|
||||
trending_stock, trade_date=custom_date
|
||||
)
|
||||
|
||||
assert final_state["trade_date"] == str(custom_date)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
import pytest
|
||||
from tradingagents.agents.utils.agent_states import (
|
||||
AgentState,
|
||||
InvestDebateState,
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
from tradingagents.graph.conditional_logic import ConditionalLogic
|
||||
|
||||
from tradingagents.agents.utils.agent_states import InvestDebateState, RiskDebateState
|
||||
from tradingagents.graph.conditional_logic import ConditionalLogic
|
||||
|
||||
|
||||
class TestConditionalLogicAnalysts:
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from tradingagents.graph.setup import GraphSetup
|
||||
|
||||
import pytest
|
||||
|
||||
from tradingagents.graph.conditional_logic import ConditionalLogic
|
||||
from tradingagents.graph.setup import GraphSetup
|
||||
|
||||
|
||||
class TestGraphSetup:
|
||||
|
|
@ -40,19 +42,22 @@ class TestGraphSetup:
|
|||
def test_setup_graph_with_all_analysts(self):
|
||||
setup = self.create_graph_setup()
|
||||
|
||||
with patch("tradingagents.graph.setup.create_market_analyst") as mock_market, \
|
||||
patch("tradingagents.graph.setup.create_social_media_analyst") as mock_social, \
|
||||
patch("tradingagents.graph.setup.create_news_analyst") as mock_news, \
|
||||
patch("tradingagents.graph.setup.create_fundamentals_analyst") as mock_fund, \
|
||||
patch("tradingagents.graph.setup.create_bull_researcher") as mock_bull, \
|
||||
patch("tradingagents.graph.setup.create_bear_researcher") as mock_bear, \
|
||||
patch("tradingagents.graph.setup.create_research_manager") as mock_rm, \
|
||||
patch("tradingagents.graph.setup.create_trader") as mock_trader, \
|
||||
patch("tradingagents.graph.setup.create_risky_debator") as mock_risky, \
|
||||
patch("tradingagents.graph.setup.create_neutral_debator") as mock_neutral, \
|
||||
patch("tradingagents.graph.setup.create_safe_debator") as mock_safe, \
|
||||
patch("tradingagents.graph.setup.create_risk_manager") as mock_risk_mgr:
|
||||
|
||||
with (
|
||||
patch("tradingagents.graph.setup.create_market_analyst") as mock_market,
|
||||
patch(
|
||||
"tradingagents.graph.setup.create_social_media_analyst"
|
||||
) as mock_social,
|
||||
patch("tradingagents.graph.setup.create_news_analyst") as mock_news,
|
||||
patch("tradingagents.graph.setup.create_fundamentals_analyst") as mock_fund,
|
||||
patch("tradingagents.graph.setup.create_bull_researcher") as mock_bull,
|
||||
patch("tradingagents.graph.setup.create_bear_researcher") as mock_bear,
|
||||
patch("tradingagents.graph.setup.create_research_manager") as mock_rm,
|
||||
patch("tradingagents.graph.setup.create_trader") as mock_trader,
|
||||
patch("tradingagents.graph.setup.create_risky_debator") as mock_risky,
|
||||
patch("tradingagents.graph.setup.create_neutral_debator") as mock_neutral,
|
||||
patch("tradingagents.graph.setup.create_safe_debator") as mock_safe,
|
||||
patch("tradingagents.graph.setup.create_risk_manager") as mock_risk_mgr,
|
||||
):
|
||||
mock_market.return_value = MagicMock()
|
||||
mock_social.return_value = MagicMock()
|
||||
mock_news.return_value = MagicMock()
|
||||
|
|
@ -80,19 +85,22 @@ class TestGraphSetup:
|
|||
def test_setup_graph_with_single_analyst(self):
|
||||
setup = self.create_graph_setup()
|
||||
|
||||
with patch("tradingagents.graph.setup.create_market_analyst") as mock_market, \
|
||||
patch("tradingagents.graph.setup.create_social_media_analyst") as mock_social, \
|
||||
patch("tradingagents.graph.setup.create_news_analyst") as mock_news, \
|
||||
patch("tradingagents.graph.setup.create_fundamentals_analyst") as mock_fund, \
|
||||
patch("tradingagents.graph.setup.create_bull_researcher") as mock_bull, \
|
||||
patch("tradingagents.graph.setup.create_bear_researcher") as mock_bear, \
|
||||
patch("tradingagents.graph.setup.create_research_manager") as mock_rm, \
|
||||
patch("tradingagents.graph.setup.create_trader") as mock_trader, \
|
||||
patch("tradingagents.graph.setup.create_risky_debator") as mock_risky, \
|
||||
patch("tradingagents.graph.setup.create_neutral_debator") as mock_neutral, \
|
||||
patch("tradingagents.graph.setup.create_safe_debator") as mock_safe, \
|
||||
patch("tradingagents.graph.setup.create_risk_manager") as mock_risk_mgr:
|
||||
|
||||
with (
|
||||
patch("tradingagents.graph.setup.create_market_analyst") as mock_market,
|
||||
patch(
|
||||
"tradingagents.graph.setup.create_social_media_analyst"
|
||||
) as mock_social,
|
||||
patch("tradingagents.graph.setup.create_news_analyst") as mock_news,
|
||||
patch("tradingagents.graph.setup.create_fundamentals_analyst") as mock_fund,
|
||||
patch("tradingagents.graph.setup.create_bull_researcher") as mock_bull,
|
||||
patch("tradingagents.graph.setup.create_bear_researcher") as mock_bear,
|
||||
patch("tradingagents.graph.setup.create_research_manager") as mock_rm,
|
||||
patch("tradingagents.graph.setup.create_trader") as mock_trader,
|
||||
patch("tradingagents.graph.setup.create_risky_debator") as mock_risky,
|
||||
patch("tradingagents.graph.setup.create_neutral_debator") as mock_neutral,
|
||||
patch("tradingagents.graph.setup.create_safe_debator") as mock_safe,
|
||||
patch("tradingagents.graph.setup.create_risk_manager") as mock_risk_mgr,
|
||||
):
|
||||
mock_market.return_value = MagicMock()
|
||||
mock_bull.return_value = MagicMock()
|
||||
mock_bear.return_value = MagicMock()
|
||||
|
|
@ -119,16 +127,17 @@ class TestGraphSetup:
|
|||
def test_setup_graph_returns_compiled_graph(self):
|
||||
setup = self.create_graph_setup()
|
||||
|
||||
with patch("tradingagents.graph.setup.create_market_analyst") as mock_market, \
|
||||
patch("tradingagents.graph.setup.create_bull_researcher") as mock_bull, \
|
||||
patch("tradingagents.graph.setup.create_bear_researcher") as mock_bear, \
|
||||
patch("tradingagents.graph.setup.create_research_manager") as mock_rm, \
|
||||
patch("tradingagents.graph.setup.create_trader") as mock_trader, \
|
||||
patch("tradingagents.graph.setup.create_risky_debator") as mock_risky, \
|
||||
patch("tradingagents.graph.setup.create_neutral_debator") as mock_neutral, \
|
||||
patch("tradingagents.graph.setup.create_safe_debator") as mock_safe, \
|
||||
patch("tradingagents.graph.setup.create_risk_manager") as mock_risk_mgr:
|
||||
|
||||
with (
|
||||
patch("tradingagents.graph.setup.create_market_analyst") as mock_market,
|
||||
patch("tradingagents.graph.setup.create_bull_researcher") as mock_bull,
|
||||
patch("tradingagents.graph.setup.create_bear_researcher") as mock_bear,
|
||||
patch("tradingagents.graph.setup.create_research_manager") as mock_rm,
|
||||
patch("tradingagents.graph.setup.create_trader") as mock_trader,
|
||||
patch("tradingagents.graph.setup.create_risky_debator") as mock_risky,
|
||||
patch("tradingagents.graph.setup.create_neutral_debator") as mock_neutral,
|
||||
patch("tradingagents.graph.setup.create_safe_debator") as mock_safe,
|
||||
patch("tradingagents.graph.setup.create_risk_manager") as mock_risk_mgr,
|
||||
):
|
||||
mock_market.return_value = MagicMock()
|
||||
mock_bull.return_value = MagicMock()
|
||||
mock_bear.return_value = MagicMock()
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
import pytest
|
||||
from datetime import date
|
||||
|
||||
from tradingagents.graph.propagation import Propagator
|
||||
from tradingagents.agents.utils.agent_states import InvestDebateState, RiskDebateState
|
||||
|
||||
|
||||
class TestPropagator:
|
||||
|
|
|
|||
|
|
@ -1,16 +1,15 @@
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch, PropertyMock
|
||||
from datetime import date
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
|
||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||
from tradingagents.graph.propagation import Propagator
|
||||
from tradingagents.graph.conditional_logic import ConditionalLogic
|
||||
from tradingagents.agents.utils.agent_states import InvestDebateState, RiskDebateState
|
||||
from tradingagents.graph.conditional_logic import ConditionalLogic
|
||||
from tradingagents.graph.propagation import Propagator
|
||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||
|
||||
|
||||
class TestWorkflowStateTransitions:
|
||||
|
||||
def test_initial_state_structure(self):
|
||||
propagator = Propagator()
|
||||
state = propagator.create_initial_state("AAPL", "2024-01-15")
|
||||
|
|
@ -138,7 +137,6 @@ class TestWorkflowStateTransitions:
|
|||
|
||||
|
||||
class TestWorkflowEndToEnd:
|
||||
|
||||
def test_final_state_has_all_reports(self):
|
||||
final_state = {
|
||||
"company_of_interest": "AAPL",
|
||||
|
|
@ -216,7 +214,6 @@ class TestWorkflowEndToEnd:
|
|||
|
||||
|
||||
class TestTradingAgentsGraphValidation:
|
||||
|
||||
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
||||
@patch("tradingagents.graph.trading_graph.set_config")
|
||||
def test_graph_validates_ticker_on_propagate(self, mock_set_config, mock_llm):
|
||||
|
|
@ -233,6 +230,7 @@ class TestTradingAgentsGraphValidation:
|
|||
graph.log_states_dict = {}
|
||||
|
||||
from tradingagents.validation import validate_ticker
|
||||
|
||||
with pytest.raises(TickerValidationError):
|
||||
validate_ticker("INVALID123TICKER")
|
||||
|
||||
|
|
@ -246,7 +244,7 @@ class TestTradingAgentsGraphValidation:
|
|||
assert validate_ticker(" MSFT ") == "MSFT"
|
||||
|
||||
def test_invalid_ticker_formats(self):
|
||||
from tradingagents.validation import validate_ticker, TickerValidationError
|
||||
from tradingagents.validation import TickerValidationError, validate_ticker
|
||||
|
||||
with pytest.raises(TickerValidationError):
|
||||
validate_ticker("")
|
||||
|
|
|
|||
|
|
@ -5,14 +5,14 @@ import pytest
|
|||
|
||||
from tradingagents.models.backtest import (
|
||||
BacktestConfig,
|
||||
BacktestMetrics,
|
||||
BacktestResult,
|
||||
BacktestStatus,
|
||||
BacktestMetrics,
|
||||
EquityCurvePoint,
|
||||
TradeLog,
|
||||
)
|
||||
from tradingagents.models.portfolio import PortfolioConfig
|
||||
from tradingagents.models.trading import Trade, OrderSide
|
||||
from tradingagents.models.trading import OrderSide, Trade
|
||||
|
||||
|
||||
class TestBacktestConfig:
|
||||
|
|
|
|||
|
|
@ -1,15 +1,14 @@
|
|||
from datetime import datetime, date
|
||||
from datetime import date, datetime
|
||||
from decimal import Decimal
|
||||
|
||||
import pytest
|
||||
|
||||
from tradingagents.models.market_data import (
|
||||
OHLCVBar,
|
||||
OHLCV,
|
||||
TechnicalIndicators,
|
||||
MarketSnapshot,
|
||||
HistoricalDataRequest,
|
||||
HistoricalDataResponse,
|
||||
MarketSnapshot,
|
||||
OHLCVBar,
|
||||
TechnicalIndicators,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,16 +1,13 @@
|
|||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from tradingagents.models.portfolio import (
|
||||
CashTransaction,
|
||||
PortfolioConfig,
|
||||
PortfolioSnapshot,
|
||||
CashTransaction,
|
||||
TransactionType,
|
||||
)
|
||||
from tradingagents.models.trading import OrderSide, Fill, Position
|
||||
from tradingagents.models.trading import Fill, OrderSide, Position
|
||||
|
||||
|
||||
class TestPortfolioConfig:
|
||||
|
|
|
|||
|
|
@ -5,13 +5,13 @@ from uuid import uuid4
|
|||
import pytest
|
||||
|
||||
from tradingagents.models.trading import (
|
||||
OrderSide,
|
||||
OrderType,
|
||||
OrderStatus,
|
||||
PositionSide,
|
||||
Order,
|
||||
Fill,
|
||||
Order,
|
||||
OrderSide,
|
||||
OrderStatus,
|
||||
OrderType,
|
||||
Position,
|
||||
PositionSide,
|
||||
Trade,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
import pytest
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from tradingagents.config import (
|
||||
TradingAgentsSettings,
|
||||
DataVendorsConfig,
|
||||
TradingAgentsSettings,
|
||||
get_settings,
|
||||
reset_settings,
|
||||
update_settings,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import pytest
|
||||
import os
|
||||
|
||||
from tradingagents.default_config import DEFAULT_CONFIG
|
||||
|
||||
|
||||
|
|
@ -25,7 +25,12 @@ class TestDefaultConfig:
|
|||
def test_llm_provider_configured(self):
|
||||
"""Test that llm_provider is configured."""
|
||||
assert "llm_provider" in DEFAULT_CONFIG
|
||||
assert DEFAULT_CONFIG["llm_provider"] in ["openai", "anthropic", "google", "ollama"]
|
||||
assert DEFAULT_CONFIG["llm_provider"] in [
|
||||
"openai",
|
||||
"anthropic",
|
||||
"google",
|
||||
"ollama",
|
||||
]
|
||||
|
||||
def test_llm_models_configured(self):
|
||||
"""Test that LLM models are configured."""
|
||||
|
|
@ -59,14 +64,14 @@ class TestDefaultConfig:
|
|||
"""Test that data vendors are configured."""
|
||||
assert "data_vendors" in DEFAULT_CONFIG
|
||||
assert isinstance(DEFAULT_CONFIG["data_vendors"], dict)
|
||||
|
||||
|
||||
required_categories = [
|
||||
"core_stock_apis",
|
||||
"technical_indicators",
|
||||
"fundamental_data",
|
||||
"news_data",
|
||||
]
|
||||
|
||||
|
||||
for category in required_categories:
|
||||
assert category in DEFAULT_CONFIG["data_vendors"]
|
||||
|
||||
|
|
@ -81,7 +86,10 @@ class TestDefaultConfig:
|
|||
assert "discovery_hard_timeout" in DEFAULT_CONFIG
|
||||
assert isinstance(DEFAULT_CONFIG["discovery_timeout"], int)
|
||||
assert isinstance(DEFAULT_CONFIG["discovery_hard_timeout"], int)
|
||||
assert DEFAULT_CONFIG["discovery_hard_timeout"] >= DEFAULT_CONFIG["discovery_timeout"]
|
||||
assert (
|
||||
DEFAULT_CONFIG["discovery_hard_timeout"]
|
||||
>= DEFAULT_CONFIG["discovery_timeout"]
|
||||
)
|
||||
|
||||
def test_discovery_config_cache_ttl(self):
|
||||
"""Test discovery cache TTL configuration."""
|
||||
|
|
@ -116,11 +124,11 @@ class TestDefaultConfig:
|
|||
def test_config_immutability_safety(self):
|
||||
"""Test that modifying a copy doesn't affect the original."""
|
||||
original_provider = DEFAULT_CONFIG["llm_provider"]
|
||||
|
||||
|
||||
# Create a copy and modify it
|
||||
config_copy = DEFAULT_CONFIG.copy()
|
||||
config_copy["llm_provider"] = "modified_provider"
|
||||
|
||||
|
||||
# Original should remain unchanged
|
||||
assert DEFAULT_CONFIG["llm_provider"] == original_provider
|
||||
|
||||
|
|
@ -132,7 +140,7 @@ class TestDefaultConfig:
|
|||
"fundamental_data",
|
||||
"news_data",
|
||||
]
|
||||
|
||||
|
||||
for category in DEFAULT_CONFIG["data_vendors"].keys():
|
||||
assert category in valid_categories
|
||||
|
||||
|
|
@ -153,7 +161,7 @@ class TestDefaultConfig:
|
|||
"discovery_max_results",
|
||||
"discovery_min_mentions",
|
||||
]
|
||||
|
||||
|
||||
for config_key in numeric_configs:
|
||||
value = DEFAULT_CONFIG[config_key]
|
||||
assert isinstance(value, int)
|
||||
|
|
@ -163,7 +171,7 @@ class TestDefaultConfig:
|
|||
"""Test that results_dir respects environment variable."""
|
||||
# The config uses os.getenv with a default
|
||||
results_dir = DEFAULT_CONFIG["results_dir"]
|
||||
|
||||
|
||||
# Should either be from env or default to ./results
|
||||
assert isinstance(results_dir, str)
|
||||
assert len(results_dir) > 0
|
||||
assert len(results_dir) > 0
|
||||
|
|
|
|||
|
|
@ -2,42 +2,21 @@ import json
|
|||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
|
||||
import tradingagents.logging as log_module
|
||||
|
||||
|
||||
class TestLoggingModule:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_and_teardown(self):
|
||||
for handler in logging.root.handlers[:]:
|
||||
logging.root.removeHandler(handler)
|
||||
|
||||
tradingagents_logger = logging.getLogger("tradingagents")
|
||||
for handler in tradingagents_logger.handlers[:]:
|
||||
tradingagents_logger.removeHandler(handler)
|
||||
tradingagents_logger.setLevel(logging.NOTSET)
|
||||
|
||||
yield
|
||||
|
||||
for handler in logging.root.handlers[:]:
|
||||
logging.root.removeHandler(handler)
|
||||
tradingagents_logger = logging.getLogger("tradingagents")
|
||||
for handler in tradingagents_logger.handlers[:]:
|
||||
tradingagents_logger.removeHandler(handler)
|
||||
|
||||
def test_setup_logging_initializes_handlers_based_on_env_vars(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
env_vars = {
|
||||
"TRADINGAGENTS_LOG_LEVEL": "DEBUG",
|
||||
"TRADINGAGENTS_LOG_DIR": tmpdir,
|
||||
"TRADINGAGENTS_LOG_CONSOLE": "false",
|
||||
"TRADINGAGENTS_LOG_FILE": "true",
|
||||
"TRADINGAGENTS_LOG_CONSOLE_ENABLED": "false",
|
||||
"TRADINGAGENTS_LOG_FILE_ENABLED": "true",
|
||||
}
|
||||
with patch.dict(os.environ, env_vars, clear=False):
|
||||
import importlib
|
||||
import tradingagents.logging as log_module
|
||||
importlib.reload(log_module)
|
||||
|
||||
root_logger = log_module.setup_logging()
|
||||
|
||||
assert root_logger is not None
|
||||
|
|
@ -47,40 +26,39 @@ class TestLoggingModule:
|
|||
has_file_handler = any(
|
||||
hasattr(h, "baseFilename") for h in root_logger.handlers
|
||||
)
|
||||
assert has_file_handler, "File handler should be present when LOG_FILE=true"
|
||||
assert (
|
||||
has_file_handler
|
||||
), "File handler should be present when LOG_FILE=true"
|
||||
|
||||
def test_get_logger_returns_properly_configured_logger_with_hierarchy(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
env_vars = {
|
||||
"TRADINGAGENTS_LOG_LEVEL": "INFO",
|
||||
"TRADINGAGENTS_LOG_DIR": tmpdir,
|
||||
"TRADINGAGENTS_LOG_CONSOLE": "false",
|
||||
"TRADINGAGENTS_LOG_FILE": "true",
|
||||
"TRADINGAGENTS_LOG_CONSOLE_ENABLED": "false",
|
||||
"TRADINGAGENTS_LOG_FILE_ENABLED": "true",
|
||||
}
|
||||
with patch.dict(os.environ, env_vars, clear=False):
|
||||
import importlib
|
||||
import tradingagents.logging as log_module
|
||||
importlib.reload(log_module)
|
||||
|
||||
log_module.setup_logging()
|
||||
child_logger = log_module.get_logger("tradingagents.dataflows.interface")
|
||||
child_logger = log_module.get_logger(
|
||||
"tradingagents.dataflows.interface"
|
||||
)
|
||||
|
||||
assert child_logger.name == "tradingagents.dataflows.interface"
|
||||
assert child_logger.parent.name == "tradingagents.dataflows" or child_logger.parent.name == "tradingagents"
|
||||
assert (
|
||||
child_logger.parent.name == "tradingagents.dataflows"
|
||||
or child_logger.parent.name == "tradingagents"
|
||||
)
|
||||
|
||||
def test_json_file_handler_writes_valid_json_with_required_fields(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
env_vars = {
|
||||
"TRADINGAGENTS_LOG_LEVEL": "DEBUG",
|
||||
"TRADINGAGENTS_LOG_DIR": tmpdir,
|
||||
"TRADINGAGENTS_LOG_CONSOLE": "false",
|
||||
"TRADINGAGENTS_LOG_FILE": "true",
|
||||
"TRADINGAGENTS_LOG_CONSOLE_ENABLED": "false",
|
||||
"TRADINGAGENTS_LOG_FILE_ENABLED": "true",
|
||||
}
|
||||
with patch.dict(os.environ, env_vars, clear=False):
|
||||
import importlib
|
||||
import tradingagents.logging as log_module
|
||||
importlib.reload(log_module)
|
||||
|
||||
logger = log_module.setup_logging()
|
||||
logger.info("Test message for JSON validation")
|
||||
|
||||
|
|
@ -88,20 +66,34 @@ class TestLoggingModule:
|
|||
handler.flush()
|
||||
|
||||
log_file_path = os.path.join(tmpdir, "tradingagents.log")
|
||||
assert os.path.exists(log_file_path), f"Log file should exist at {log_file_path}"
|
||||
assert os.path.exists(
|
||||
log_file_path
|
||||
), f"Log file should exist at {log_file_path}"
|
||||
|
||||
with open(log_file_path, "r") as f:
|
||||
with open(log_file_path) as f:
|
||||
log_content = f.read().strip()
|
||||
|
||||
assert log_content, "Log file should not be empty"
|
||||
|
||||
log_entry = json.loads(log_content.split("\n")[0])
|
||||
|
||||
required_fields = ["timestamp", "level", "logger", "message", "filename", "funcName", "lineno"]
|
||||
required_fields = [
|
||||
"timestamp",
|
||||
"level",
|
||||
"logger",
|
||||
"message",
|
||||
"filename",
|
||||
"funcName",
|
||||
"lineno",
|
||||
]
|
||||
for field in required_fields:
|
||||
assert field in log_entry, f"JSON log should contain '{field}' field"
|
||||
assert (
|
||||
field in log_entry
|
||||
), f"JSON log should contain '{field}' field"
|
||||
|
||||
assert "T" in log_entry["timestamp"], "Timestamp should be in ISO 8601 format"
|
||||
assert (
|
||||
"T" in log_entry["timestamp"]
|
||||
), "Timestamp should be in ISO 8601 format"
|
||||
assert log_entry["level"] == "INFO"
|
||||
assert log_entry["message"] == "Test message for JSON validation"
|
||||
|
||||
|
|
@ -110,14 +102,10 @@ class TestLoggingModule:
|
|||
env_vars = {
|
||||
"TRADINGAGENTS_LOG_LEVEL": "DEBUG",
|
||||
"TRADINGAGENTS_LOG_DIR": tmpdir,
|
||||
"TRADINGAGENTS_LOG_CONSOLE": "false",
|
||||
"TRADINGAGENTS_LOG_FILE": "true",
|
||||
"TRADINGAGENTS_LOG_CONSOLE_ENABLED": "false",
|
||||
"TRADINGAGENTS_LOG_FILE_ENABLED": "true",
|
||||
}
|
||||
with patch.dict(os.environ, env_vars, clear=False):
|
||||
import importlib
|
||||
import tradingagents.logging as log_module
|
||||
importlib.reload(log_module)
|
||||
|
||||
logger = log_module.setup_logging()
|
||||
|
||||
file_handler = None
|
||||
|
|
@ -126,44 +114,54 @@ class TestLoggingModule:
|
|||
file_handler = handler
|
||||
break
|
||||
|
||||
assert file_handler is not None, "RotatingFileHandler should be configured"
|
||||
assert file_handler.maxBytes == 10 * 1024 * 1024, "Max file size should be 10MB"
|
||||
assert (
|
||||
file_handler is not None
|
||||
), "RotatingFileHandler should be configured"
|
||||
assert (
|
||||
file_handler.maxBytes == 10 * 1024 * 1024
|
||||
), "Max file size should be 10MB"
|
||||
assert file_handler.backupCount == 5, "Backup count should be 5"
|
||||
|
||||
def test_console_handler_disabled_when_env_var_false(self):
|
||||
from tradingagents import config as main_config
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
env_vars = {
|
||||
"TRADINGAGENTS_LOG_LEVEL": "INFO",
|
||||
"TRADINGAGENTS_LOG_DIR": tmpdir,
|
||||
"TRADINGAGENTS_LOG_CONSOLE": "false",
|
||||
"TRADINGAGENTS_LOG_FILE": "true",
|
||||
"TRADINGAGENTS_LOG_CONSOLE_ENABLED": "false",
|
||||
"TRADINGAGENTS_LOG_FILE_ENABLED": "true",
|
||||
}
|
||||
with patch.dict(os.environ, env_vars, clear=False):
|
||||
import importlib
|
||||
import tradingagents.logging as log_module
|
||||
importlib.reload(log_module)
|
||||
main_config._settings = None
|
||||
|
||||
with patch.dict(os.environ, env_vars, clear=False):
|
||||
logger = log_module.setup_logging()
|
||||
|
||||
from rich.logging import RichHandler
|
||||
has_rich_handler = any(isinstance(h, RichHandler) for h in logger.handlers)
|
||||
assert not has_rich_handler, "RichHandler should NOT be present when LOG_CONSOLE=false"
|
||||
|
||||
has_rich_handler = any(
|
||||
isinstance(h, RichHandler) for h in logger.handlers
|
||||
)
|
||||
assert (
|
||||
not has_rich_handler
|
||||
), "RichHandler should NOT be present when LOG_CONSOLE=false"
|
||||
|
||||
def test_console_handler_enabled_when_env_var_true(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
env_vars = {
|
||||
"TRADINGAGENTS_LOG_LEVEL": "INFO",
|
||||
"TRADINGAGENTS_LOG_DIR": tmpdir,
|
||||
"TRADINGAGENTS_LOG_CONSOLE": "true",
|
||||
"TRADINGAGENTS_LOG_FILE": "false",
|
||||
"TRADINGAGENTS_LOG_CONSOLE_ENABLED": "true",
|
||||
"TRADINGAGENTS_LOG_FILE_ENABLED": "false",
|
||||
}
|
||||
with patch.dict(os.environ, env_vars, clear=False):
|
||||
import importlib
|
||||
import tradingagents.logging as log_module
|
||||
importlib.reload(log_module)
|
||||
|
||||
logger = log_module.setup_logging()
|
||||
|
||||
from rich.logging import RichHandler
|
||||
has_rich_handler = any(isinstance(h, RichHandler) for h in logger.handlers)
|
||||
assert has_rich_handler, "RichHandler should be present when LOG_CONSOLE=true"
|
||||
|
||||
has_rich_handler = any(
|
||||
isinstance(h, RichHandler) for h in logger.handlers
|
||||
)
|
||||
assert (
|
||||
has_rich_handler
|
||||
), "RichHandler should be present when LOG_CONSOLE=true"
|
||||
|
|
|
|||
|
|
@ -1,47 +1,32 @@
|
|||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
import tradingagents.logging as log_module
|
||||
|
||||
|
||||
class TestLoggingConfigIntegration:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_and_teardown(self):
|
||||
for handler in logging.root.handlers[:]:
|
||||
logging.root.removeHandler(handler)
|
||||
|
||||
tradingagents_logger = logging.getLogger("tradingagents")
|
||||
for handler in tradingagents_logger.handlers[:]:
|
||||
tradingagents_logger.removeHandler(handler)
|
||||
tradingagents_logger.setLevel(logging.NOTSET)
|
||||
|
||||
yield
|
||||
|
||||
for handler in logging.root.handlers[:]:
|
||||
logging.root.removeHandler(handler)
|
||||
tradingagents_logger = logging.getLogger("tradingagents")
|
||||
for handler in tradingagents_logger.handlers[:]:
|
||||
tradingagents_logger.removeHandler(handler)
|
||||
|
||||
def test_default_config_values_used_when_env_vars_not_set(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
env_vars_to_remove = [
|
||||
"TRADINGAGENTS_LOG_LEVEL",
|
||||
"TRADINGAGENTS_LOG_DIR",
|
||||
"TRADINGAGENTS_LOG_CONSOLE",
|
||||
"TRADINGAGENTS_LOG_FILE",
|
||||
"TRADINGAGENTS_LOG_CONSOLE_ENABLED",
|
||||
"TRADINGAGENTS_LOG_FILE_ENABLED",
|
||||
]
|
||||
clean_env = {k: v for k, v in os.environ.items() if k not in env_vars_to_remove}
|
||||
clean_env = {
|
||||
k: v for k, v in os.environ.items() if k not in env_vars_to_remove
|
||||
}
|
||||
|
||||
with patch.dict(os.environ, clean_env, clear=True):
|
||||
import importlib
|
||||
import tradingagents.logging as log_module
|
||||
importlib.reload(log_module)
|
||||
|
||||
from tradingagents.default_config import DEFAULT_CONFIG
|
||||
|
||||
expected_level = getattr(logging, DEFAULT_CONFIG.get("log_level", "INFO").upper())
|
||||
expected_level = getattr(
|
||||
logging, DEFAULT_CONFIG.get("log_level", "INFO").upper()
|
||||
)
|
||||
|
||||
logger = log_module.setup_logging()
|
||||
|
||||
|
|
@ -52,19 +37,19 @@ class TestLoggingConfigIntegration:
|
|||
env_vars = {
|
||||
"TRADINGAGENTS_LOG_LEVEL": "WARNING",
|
||||
"TRADINGAGENTS_LOG_DIR": tmpdir,
|
||||
"TRADINGAGENTS_LOG_CONSOLE": "false",
|
||||
"TRADINGAGENTS_LOG_FILE": "true",
|
||||
"TRADINGAGENTS_LOG_CONSOLE_ENABLED": "false",
|
||||
"TRADINGAGENTS_LOG_FILE_ENABLED": "true",
|
||||
}
|
||||
with patch.dict(os.environ, env_vars, clear=False):
|
||||
import importlib
|
||||
import tradingagents.logging as log_module
|
||||
importlib.reload(log_module)
|
||||
|
||||
logger = log_module.setup_logging()
|
||||
|
||||
assert logger.level == logging.WARNING
|
||||
|
||||
def test_boolean_parsing_for_log_console_and_file(self):
|
||||
from rich.logging import RichHandler
|
||||
|
||||
from tradingagents import config as main_config
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_cases = [
|
||||
("true", True),
|
||||
|
|
@ -75,47 +60,49 @@ class TestLoggingConfigIntegration:
|
|||
("False", False),
|
||||
("TRUE", True),
|
||||
("FALSE", False),
|
||||
("yes", True),
|
||||
("no", False),
|
||||
]
|
||||
|
||||
for bool_str, expected in test_cases:
|
||||
env_vars = {
|
||||
"TRADINGAGENTS_LOG_LEVEL": "INFO",
|
||||
"TRADINGAGENTS_LOG_DIR": tmpdir,
|
||||
"TRADINGAGENTS_LOG_CONSOLE": bool_str,
|
||||
"TRADINGAGENTS_LOG_FILE": "false",
|
||||
"TRADINGAGENTS_LOG_CONSOLE_ENABLED": bool_str,
|
||||
"TRADINGAGENTS_LOG_FILE_ENABLED": "false",
|
||||
}
|
||||
|
||||
tradingagents_logger = logging.getLogger("tradingagents")
|
||||
for handler in tradingagents_logger.handlers[:]:
|
||||
tradingagents_logger.removeHandler(handler)
|
||||
log_module._logging_initialized = False
|
||||
main_config._settings = None
|
||||
|
||||
with patch.dict(os.environ, env_vars, clear=False):
|
||||
import importlib
|
||||
import tradingagents.logging as log_module
|
||||
importlib.reload(log_module)
|
||||
|
||||
logger = log_module.setup_logging()
|
||||
|
||||
from rich.logging import RichHandler
|
||||
has_rich_handler = any(isinstance(h, RichHandler) for h in logger.handlers)
|
||||
has_rich_handler = any(
|
||||
isinstance(h, RichHandler) for h in logger.handlers
|
||||
)
|
||||
|
||||
assert has_rich_handler == expected, f"TRADINGAGENTS_LOG_CONSOLE={bool_str} should result in RichHandler present={expected}"
|
||||
assert (
|
||||
has_rich_handler == expected
|
||||
), f"TRADINGAGENTS_LOG_CONSOLE_ENABLED={bool_str} should result in RichHandler present={expected}"
|
||||
|
||||
def test_invalid_log_level_raises_validation_error(self):
|
||||
from pydantic import ValidationError
|
||||
|
||||
from tradingagents import config as main_config
|
||||
|
||||
def test_invalid_log_level_falls_back_to_info(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
env_vars = {
|
||||
"TRADINGAGENTS_LOG_LEVEL": "INVALID_LEVEL",
|
||||
"TRADINGAGENTS_LOG_DIR": tmpdir,
|
||||
"TRADINGAGENTS_LOG_CONSOLE": "false",
|
||||
"TRADINGAGENTS_LOG_FILE": "true",
|
||||
"TRADINGAGENTS_LOG_CONSOLE_ENABLED": "false",
|
||||
"TRADINGAGENTS_LOG_FILE_ENABLED": "true",
|
||||
}
|
||||
main_config._settings = None
|
||||
|
||||
with patch.dict(os.environ, env_vars, clear=False):
|
||||
import importlib
|
||||
import tradingagents.logging as log_module
|
||||
importlib.reload(log_module)
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
log_module.setup_logging()
|
||||
|
||||
logger = log_module.setup_logging()
|
||||
|
||||
assert logger.level == logging.INFO, "Invalid log level should fall back to INFO"
|
||||
assert "log_level" in str(exc_info.value)
|
||||
|
|
|
|||
|
|
@ -1,45 +1,27 @@
|
|||
import json
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
|
||||
import tradingagents.logging as log_module
|
||||
|
||||
|
||||
class TestLoggingIntegration:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_and_teardown(self):
|
||||
for handler in logging.root.handlers[:]:
|
||||
logging.root.removeHandler(handler)
|
||||
|
||||
tradingagents_logger = logging.getLogger("tradingagents")
|
||||
for handler in tradingagents_logger.handlers[:]:
|
||||
tradingagents_logger.removeHandler(handler)
|
||||
tradingagents_logger.setLevel(logging.NOTSET)
|
||||
|
||||
yield
|
||||
|
||||
for handler in logging.root.handlers[:]:
|
||||
logging.root.removeHandler(handler)
|
||||
tradingagents_logger = logging.getLogger("tradingagents")
|
||||
for handler in tradingagents_logger.handlers[:]:
|
||||
tradingagents_logger.removeHandler(handler)
|
||||
|
||||
def test_logging_initialization_from_module_import(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
env_vars = {
|
||||
"TRADINGAGENTS_LOG_LEVEL": "INFO",
|
||||
"TRADINGAGENTS_LOG_DIR": tmpdir,
|
||||
"TRADINGAGENTS_LOG_CONSOLE": "false",
|
||||
"TRADINGAGENTS_LOG_FILE": "true",
|
||||
"TRADINGAGENTS_LOG_CONSOLE_ENABLED": "false",
|
||||
"TRADINGAGENTS_LOG_FILE_ENABLED": "true",
|
||||
}
|
||||
with patch.dict(os.environ, env_vars, clear=False):
|
||||
import importlib
|
||||
import tradingagents.logging as log_module
|
||||
importlib.reload(log_module)
|
||||
|
||||
log_module.setup_logging()
|
||||
|
||||
interface_logger = log_module.get_logger("tradingagents.dataflows.interface")
|
||||
interface_logger = log_module.get_logger(
|
||||
"tradingagents.dataflows.interface"
|
||||
)
|
||||
|
||||
assert interface_logger is not None
|
||||
assert interface_logger.name == "tradingagents.dataflows.interface"
|
||||
|
|
@ -49,7 +31,7 @@ class TestLoggingIntegration:
|
|||
log_file = os.path.join(tmpdir, "tradingagents.log")
|
||||
assert os.path.exists(log_file)
|
||||
|
||||
with open(log_file, "r") as f:
|
||||
with open(log_file) as f:
|
||||
content = f.read()
|
||||
assert "Test message from interface logger" in content
|
||||
|
||||
|
|
@ -58,18 +40,17 @@ class TestLoggingIntegration:
|
|||
env_vars = {
|
||||
"TRADINGAGENTS_LOG_LEVEL": "INFO",
|
||||
"TRADINGAGENTS_LOG_DIR": tmpdir,
|
||||
"TRADINGAGENTS_LOG_CONSOLE": "true",
|
||||
"TRADINGAGENTS_LOG_FILE": "false",
|
||||
"TRADINGAGENTS_LOG_CONSOLE_ENABLED": "true",
|
||||
"TRADINGAGENTS_LOG_FILE_ENABLED": "false",
|
||||
}
|
||||
with patch.dict(os.environ, env_vars, clear=False):
|
||||
import importlib
|
||||
import tradingagents.logging as log_module
|
||||
importlib.reload(log_module)
|
||||
|
||||
logger = log_module.setup_logging()
|
||||
|
||||
from rich.logging import RichHandler
|
||||
rich_handlers = [h for h in logger.handlers if isinstance(h, RichHandler)]
|
||||
|
||||
rich_handlers = [
|
||||
h for h in logger.handlers if isinstance(h, RichHandler)
|
||||
]
|
||||
assert len(rich_handlers) == 1
|
||||
|
||||
rich_handler = rich_handlers[0]
|
||||
|
|
@ -81,15 +62,10 @@ class TestLoggingIntegration:
|
|||
env_vars = {
|
||||
"TRADINGAGENTS_LOG_LEVEL": "DEBUG",
|
||||
"TRADINGAGENTS_LOG_DIR": tmpdir,
|
||||
"TRADINGAGENTS_LOG_CONSOLE": "false",
|
||||
"TRADINGAGENTS_LOG_FILE": "true",
|
||||
"TRADINGAGENTS_LOG_CONSOLE_ENABLED": "false",
|
||||
"TRADINGAGENTS_LOG_FILE_ENABLED": "true",
|
||||
}
|
||||
with patch.dict(os.environ, env_vars, clear=False):
|
||||
import importlib
|
||||
import json
|
||||
import tradingagents.logging as log_module
|
||||
importlib.reload(log_module)
|
||||
|
||||
logger = log_module.setup_logging()
|
||||
|
||||
logger.debug("Debug message")
|
||||
|
|
@ -103,7 +79,7 @@ class TestLoggingIntegration:
|
|||
log_file = os.path.join(tmpdir, "tradingagents.log")
|
||||
assert os.path.exists(log_file)
|
||||
|
||||
with open(log_file, "r") as f:
|
||||
with open(log_file) as f:
|
||||
lines = f.readlines()
|
||||
|
||||
assert len(lines) >= 4
|
||||
|
|
@ -120,18 +96,18 @@ class TestLoggingIntegration:
|
|||
env_vars = {
|
||||
"TRADINGAGENTS_LOG_LEVEL": "WARNING",
|
||||
"TRADINGAGENTS_LOG_DIR": tmpdir,
|
||||
"TRADINGAGENTS_LOG_CONSOLE": "false",
|
||||
"TRADINGAGENTS_LOG_FILE": "true",
|
||||
"TRADINGAGENTS_LOG_CONSOLE_ENABLED": "false",
|
||||
"TRADINGAGENTS_LOG_FILE_ENABLED": "true",
|
||||
}
|
||||
with patch.dict(os.environ, env_vars, clear=False):
|
||||
import importlib
|
||||
import tradingagents.logging as log_module
|
||||
importlib.reload(log_module)
|
||||
|
||||
root_logger = log_module.setup_logging()
|
||||
|
||||
child_logger = log_module.get_logger("tradingagents.dataflows.interface")
|
||||
grandchild_logger = log_module.get_logger("tradingagents.dataflows.interface.submodule")
|
||||
child_logger = log_module.get_logger(
|
||||
"tradingagents.dataflows.interface"
|
||||
)
|
||||
grandchild_logger = log_module.get_logger(
|
||||
"tradingagents.dataflows.interface.submodule"
|
||||
)
|
||||
|
||||
assert root_logger.level == logging.WARNING
|
||||
|
||||
|
|
@ -142,7 +118,7 @@ class TestLoggingIntegration:
|
|||
handler.flush()
|
||||
|
||||
log_file = os.path.join(tmpdir, "tradingagents.log")
|
||||
with open(log_file, "r") as f:
|
||||
with open(log_file) as f:
|
||||
content = f.read()
|
||||
|
||||
assert "This should not be logged" not in content
|
||||
|
|
@ -153,14 +129,10 @@ class TestLoggingIntegration:
|
|||
env_vars = {
|
||||
"TRADINGAGENTS_LOG_LEVEL": "INFO",
|
||||
"TRADINGAGENTS_LOG_DIR": tmpdir,
|
||||
"TRADINGAGENTS_LOG_CONSOLE": "false",
|
||||
"TRADINGAGENTS_LOG_FILE": "true",
|
||||
"TRADINGAGENTS_LOG_CONSOLE_ENABLED": "false",
|
||||
"TRADINGAGENTS_LOG_FILE_ENABLED": "true",
|
||||
}
|
||||
with patch.dict(os.environ, env_vars, clear=False):
|
||||
import importlib
|
||||
import tradingagents.logging as log_module
|
||||
importlib.reload(log_module)
|
||||
|
||||
log_module._logging_initialized = False
|
||||
|
||||
logger = log_module.get_logger("tradingagents.test")
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
import ast
|
||||
import os
|
||||
import pytest
|
||||
|
||||
|
||||
class TestLoggingMigration:
|
||||
|
|
@ -12,7 +11,7 @@ class TestLoggingMigration:
|
|||
"dataflows",
|
||||
"interface.py",
|
||||
)
|
||||
with open(file_path, "r") as f:
|
||||
with open(file_path) as f:
|
||||
content = f.read()
|
||||
|
||||
tree = ast.parse(content)
|
||||
|
|
@ -33,7 +32,7 @@ class TestLoggingMigration:
|
|||
"dataflows",
|
||||
"brave.py",
|
||||
)
|
||||
with open(file_path, "r") as f:
|
||||
with open(file_path) as f:
|
||||
content = f.read()
|
||||
|
||||
tree = ast.parse(content)
|
||||
|
|
@ -54,7 +53,7 @@ class TestLoggingMigration:
|
|||
"dataflows",
|
||||
"tavily.py",
|
||||
)
|
||||
with open(file_path, "r") as f:
|
||||
with open(file_path) as f:
|
||||
content = f.read()
|
||||
|
||||
tree = ast.parse(content)
|
||||
|
|
@ -93,7 +92,7 @@ class TestLoggingMigration:
|
|||
if not os.path.exists(file_path):
|
||||
continue
|
||||
|
||||
with open(file_path, "r") as f:
|
||||
with open(file_path) as f:
|
||||
content = f.read()
|
||||
|
||||
tree = ast.parse(content)
|
||||
|
|
@ -107,7 +106,9 @@ class TestLoggingMigration:
|
|||
if print_calls:
|
||||
all_print_calls[filename] = print_calls
|
||||
|
||||
assert len(all_print_calls) == 0, f"Found print statements in: {all_print_calls}"
|
||||
assert (
|
||||
len(all_print_calls) == 0
|
||||
), f"Found print statements in: {all_print_calls}"
|
||||
|
||||
def test_logger_import_exists_in_interface_py(self):
|
||||
file_path = os.path.join(
|
||||
|
|
@ -117,8 +118,10 @@ class TestLoggingMigration:
|
|||
"dataflows",
|
||||
"interface.py",
|
||||
)
|
||||
with open(file_path, "r") as f:
|
||||
with open(file_path) as f:
|
||||
content = f.read()
|
||||
|
||||
assert "import logging" in content, "interface.py should import logging"
|
||||
assert "logger = logging.getLogger(__name__)" in content, "interface.py should define logger"
|
||||
assert (
|
||||
"logger = logging.getLogger(__name__)" in content
|
||||
), "interface.py should define logger"
|
||||
|
|
|
|||
|
|
@ -1,21 +1,21 @@
|
|||
import pytest
|
||||
from datetime import date, datetime, timedelta
|
||||
|
||||
import pytest
|
||||
|
||||
from tradingagents.validation import (
|
||||
ValidationError,
|
||||
TickerValidationError,
|
||||
DateValidationError,
|
||||
validate_ticker,
|
||||
validate_tickers,
|
||||
TickerValidationError,
|
||||
format_date,
|
||||
get_next_trading_day,
|
||||
get_previous_trading_day,
|
||||
is_trading_day,
|
||||
is_valid_date,
|
||||
is_valid_ticker,
|
||||
parse_date,
|
||||
validate_date,
|
||||
validate_date_range,
|
||||
format_date,
|
||||
is_valid_ticker,
|
||||
is_valid_date,
|
||||
is_trading_day,
|
||||
get_previous_trading_day,
|
||||
get_next_trading_day,
|
||||
validate_ticker,
|
||||
validate_tickers,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,23 +1,18 @@
|
|||
from .utils.agent_utils import create_msg_delete
|
||||
from .utils.agent_states import AgentState, InvestDebateState, RiskDebateState
|
||||
from .utils.memory import FinancialSituationMemory
|
||||
|
||||
from .analysts.fundamentals_analyst import create_fundamentals_analyst
|
||||
from .analysts.market_analyst import create_market_analyst
|
||||
from .analysts.news_analyst import create_news_analyst
|
||||
from .analysts.social_media_analyst import create_social_media_analyst
|
||||
|
||||
from .managers.research_manager import create_research_manager
|
||||
from .managers.risk_manager import create_risk_manager
|
||||
from .researchers.bear_researcher import create_bear_researcher
|
||||
from .researchers.bull_researcher import create_bull_researcher
|
||||
|
||||
from .risk_mgmt.aggressive_debator import create_risky_debator
|
||||
from .risk_mgmt.conservative_debator import create_safe_debator
|
||||
from .risk_mgmt.neutral_debator import create_neutral_debator
|
||||
|
||||
from .managers.research_manager import create_research_manager
|
||||
from .managers.risk_manager import create_risk_manager
|
||||
|
||||
from .trader.trader import create_trader
|
||||
from .utils.agent_states import AgentState, InvestDebateState, RiskDebateState
|
||||
from .utils.agent_utils import create_msg_delete
|
||||
from .utils.memory import FinancialSituationMemory
|
||||
|
||||
__all__ = [
|
||||
"FinancialSituationMemory",
|
||||
|
|
|
|||
|
|
@ -1,5 +1,11 @@
|
|||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from tradingagents.agents.utils.agent_utils import get_fundamentals, get_balance_sheet, get_cashflow, get_income_statement
|
||||
|
||||
from tradingagents.agents.utils.agent_utils import (
|
||||
get_balance_sheet,
|
||||
get_cashflow,
|
||||
get_fundamentals,
|
||||
get_income_statement,
|
||||
)
|
||||
|
||||
|
||||
def create_fundamentals_analyst(llm):
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from tradingagents.agents.utils.agent_utils import get_stock_data, get_indicators
|
||||
|
||||
from tradingagents.agents.utils.agent_utils import get_indicators, get_stock_data
|
||||
|
||||
|
||||
def create_market_analyst(llm):
|
||||
|
||||
def market_analyst_node(state):
|
||||
current_date = state["trade_date"]
|
||||
ticker = state["company_of_interest"]
|
||||
|
|
@ -73,7 +73,7 @@ Volume-Based Indicators:
|
|||
|
||||
if len(result.tool_calls) == 0:
|
||||
report = result.content
|
||||
|
||||
|
||||
return {
|
||||
"messages": [result],
|
||||
"market_report": report,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from tradingagents.agents.utils.agent_utils import get_news, get_global_news
|
||||
|
||||
from tradingagents.agents.utils.agent_utils import get_global_news, get_news
|
||||
|
||||
|
||||
def create_news_analyst(llm):
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
|
||||
from tradingagents.agents.utils.agent_utils import get_news
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,32 +1,32 @@
|
|||
from .models import (
|
||||
NewsArticle,
|
||||
TrendingStock,
|
||||
DiscoveryRequest,
|
||||
DiscoveryResult,
|
||||
DiscoveryStatus,
|
||||
Sector,
|
||||
EventCategory,
|
||||
from .entity_extractor import (
|
||||
BATCH_SIZE,
|
||||
EntityMention,
|
||||
extract_entities,
|
||||
)
|
||||
from .exceptions import (
|
||||
DiscoveryError,
|
||||
NewsUnavailableError,
|
||||
DiscoveryTimeoutError,
|
||||
NewsUnavailableError,
|
||||
TickerResolutionError,
|
||||
)
|
||||
from .entity_extractor import (
|
||||
EntityMention,
|
||||
extract_entities,
|
||||
BATCH_SIZE,
|
||||
from .models import (
|
||||
DiscoveryRequest,
|
||||
DiscoveryResult,
|
||||
DiscoveryStatus,
|
||||
EventCategory,
|
||||
NewsArticle,
|
||||
Sector,
|
||||
TrendingStock,
|
||||
)
|
||||
from .persistence import (
|
||||
generate_markdown_summary,
|
||||
save_discovery_result,
|
||||
)
|
||||
from .scorer import (
|
||||
calculate_trending_scores,
|
||||
DEFAULT_DECAY_RATE,
|
||||
DEFAULT_MAX_RESULTS,
|
||||
DEFAULT_MIN_MENTIONS,
|
||||
)
|
||||
from .persistence import (
|
||||
save_discovery_result,
|
||||
generate_markdown_summary,
|
||||
calculate_trending_scores,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
|
|
|
|||
|
|
@ -1,14 +1,14 @@
|
|||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional
|
||||
from pydantic import BaseModel, Field as PydanticField
|
||||
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
from langchain_openai import ChatOpenAI
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field as PydanticField
|
||||
|
||||
from tradingagents.agents.discovery.models import EventCategory, NewsArticle
|
||||
from tradingagents.dataflows.config import get_config
|
||||
from tradingagents.agents.discovery.models import NewsArticle, EventCategory
|
||||
|
||||
|
||||
BATCH_SIZE = 10
|
||||
|
||||
|
|
@ -24,19 +24,34 @@ class EntityMention:
|
|||
|
||||
|
||||
class ExtractedEntity(BaseModel):
|
||||
company_name: str = PydanticField(description="The name of the publicly traded company mentioned")
|
||||
confidence: float = PydanticField(description="Confidence score from 0.0 to 1.0 based on mention clarity")
|
||||
context_snippet: str = PydanticField(description="Surrounding context of 50-100 characters around the company mention")
|
||||
event_type: str = PydanticField(description="Event category: earnings, merger_acquisition, regulatory, product_launch, executive_change, or other")
|
||||
sentiment: float = PydanticField(default=0.0, description="Sentiment score from -1.0 (negative) to 1.0 (positive)")
|
||||
article_id: str = PydanticField(description="The article ID where this company was mentioned (e.g., article_0, article_1)")
|
||||
company_name: str = PydanticField(
|
||||
description="The name of the publicly traded company mentioned"
|
||||
)
|
||||
confidence: float = PydanticField(
|
||||
description="Confidence score from 0.0 to 1.0 based on mention clarity"
|
||||
)
|
||||
context_snippet: str = PydanticField(
|
||||
description="Surrounding context of 50-100 characters around the company mention"
|
||||
)
|
||||
event_type: str = PydanticField(
|
||||
description="Event category: earnings, merger_acquisition, regulatory, product_launch, executive_change, or other"
|
||||
)
|
||||
sentiment: float = PydanticField(
|
||||
default=0.0,
|
||||
description="Sentiment score from -1.0 (negative) to 1.0 (positive)",
|
||||
)
|
||||
article_id: str = PydanticField(
|
||||
description="The article ID where this company was mentioned (e.g., article_0, article_1)"
|
||||
)
|
||||
|
||||
|
||||
class ExtractionResponse(BaseModel):
|
||||
entities: List[ExtractedEntity] = PydanticField(default_factory=list, description="List of extracted company entities")
|
||||
entities: list[ExtractedEntity] = PydanticField(
|
||||
default_factory=list, description="List of extracted company entities"
|
||||
)
|
||||
|
||||
|
||||
def _get_llm(config: Optional[dict] = None):
|
||||
def _get_llm(config: dict | None = None):
|
||||
cfg = config or get_config()
|
||||
provider = cfg.get("llm_provider", "openai").lower()
|
||||
model = cfg.get("quick_think_llm", "gpt-4o-mini")
|
||||
|
|
@ -88,7 +103,7 @@ Articles to analyze:
|
|||
Extract all company mentions from the articles above."""
|
||||
|
||||
|
||||
def _format_articles_for_prompt(articles: List[NewsArticle], start_idx: int) -> str:
|
||||
def _format_articles_for_prompt(articles: list[NewsArticle], start_idx: int) -> str:
|
||||
formatted = []
|
||||
for i, article in enumerate(articles):
|
||||
article_id = f"article_{start_idx + i}"
|
||||
|
|
@ -102,10 +117,10 @@ def _format_articles_for_prompt(articles: List[NewsArticle], start_idx: int) ->
|
|||
|
||||
|
||||
def _extract_batch(
|
||||
articles: List[NewsArticle],
|
||||
articles: list[NewsArticle],
|
||||
start_idx: int,
|
||||
llm,
|
||||
) -> List[EntityMention]:
|
||||
) -> list[EntityMention]:
|
||||
if not articles:
|
||||
return []
|
||||
|
||||
|
|
@ -144,14 +159,14 @@ def _extract_batch(
|
|||
|
||||
|
||||
def extract_entities(
|
||||
articles: List[NewsArticle],
|
||||
config: Optional[dict] = None,
|
||||
) -> List[EntityMention]:
|
||||
articles: list[NewsArticle],
|
||||
config: dict | None = None,
|
||||
) -> list[EntityMention]:
|
||||
if not articles:
|
||||
return []
|
||||
|
||||
llm = _get_llm(config)
|
||||
all_mentions: List[EntityMention] = []
|
||||
all_mentions: list[EntityMention] = []
|
||||
|
||||
for batch_start in range(0, len(articles), BATCH_SIZE):
|
||||
batch_end = min(batch_start + BATCH_SIZE, len(articles))
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import List, Optional, Dict, Any
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
class DiscoveryStatus(Enum):
|
||||
|
|
@ -37,9 +37,9 @@ class NewsArticle:
|
|||
url: str
|
||||
published_at: datetime
|
||||
content_snippet: str
|
||||
ticker_mentions: List[str]
|
||||
ticker_mentions: list[str]
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"title": self.title,
|
||||
"source": self.source,
|
||||
|
|
@ -50,7 +50,7 @@ class NewsArticle:
|
|||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "NewsArticle":
|
||||
def from_dict(cls, data: dict[str, Any]) -> "NewsArticle":
|
||||
return cls(
|
||||
title=data["title"],
|
||||
source=data["source"],
|
||||
|
|
@ -71,9 +71,9 @@ class TrendingStock:
|
|||
sector: Sector
|
||||
event_type: EventCategory
|
||||
news_summary: str
|
||||
source_articles: List[NewsArticle]
|
||||
source_articles: list[NewsArticle]
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"ticker": self.ticker,
|
||||
"company_name": self.company_name,
|
||||
|
|
@ -87,7 +87,7 @@ class TrendingStock:
|
|||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "TrendingStock":
|
||||
def from_dict(cls, data: dict[str, Any]) -> "TrendingStock":
|
||||
return cls(
|
||||
ticker=data["ticker"],
|
||||
company_name=data["company_name"],
|
||||
|
|
@ -106,12 +106,12 @@ class TrendingStock:
|
|||
@dataclass
|
||||
class DiscoveryRequest:
|
||||
lookback_period: str
|
||||
sector_filter: Optional[List[Sector]] = None
|
||||
event_filter: Optional[List[EventCategory]] = None
|
||||
sector_filter: list[Sector] | None = None
|
||||
event_filter: list[EventCategory] | None = None
|
||||
max_results: int = 20
|
||||
created_at: datetime = field(default_factory=datetime.now)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"lookback_period": self.lookback_period,
|
||||
"sector_filter": (
|
||||
|
|
@ -125,7 +125,7 @@ class DiscoveryRequest:
|
|||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "DiscoveryRequest":
|
||||
def from_dict(cls, data: dict[str, Any]) -> "DiscoveryRequest":
|
||||
return cls(
|
||||
lookback_period=data["lookback_period"],
|
||||
sector_filter=(
|
||||
|
|
@ -146,24 +146,26 @@ class DiscoveryRequest:
|
|||
@dataclass
|
||||
class DiscoveryResult:
|
||||
request: DiscoveryRequest
|
||||
trending_stocks: List[TrendingStock]
|
||||
trending_stocks: list[TrendingStock]
|
||||
status: DiscoveryStatus
|
||||
started_at: datetime
|
||||
completed_at: Optional[datetime] = None
|
||||
error_message: Optional[str] = None
|
||||
completed_at: datetime | None = None
|
||||
error_message: str | None = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"request": self.request.to_dict(),
|
||||
"trending_stocks": [stock.to_dict() for stock in self.trending_stocks],
|
||||
"status": self.status.value,
|
||||
"started_at": self.started_at.isoformat(),
|
||||
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
|
||||
"completed_at": self.completed_at.isoformat()
|
||||
if self.completed_at
|
||||
else None,
|
||||
"error_message": self.error_message,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "DiscoveryResult":
|
||||
def from_dict(cls, data: dict[str, Any]) -> "DiscoveryResult":
|
||||
return cls(
|
||||
request=DiscoveryRequest.from_dict(data["request"]),
|
||||
trending_stocks=[
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from .models import DiscoveryResult, TrendingStock
|
|||
|
||||
def save_discovery_result(
|
||||
result: DiscoveryResult,
|
||||
base_path: Optional[Path] = None,
|
||||
base_path: Path | None = None,
|
||||
) -> Path:
|
||||
if base_path is None:
|
||||
base_path = Path("results")
|
||||
|
|
|
|||
|
|
@ -1,25 +1,24 @@
|
|||
import math
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from typing import List, Dict
|
||||
from typing import Dict, List
|
||||
|
||||
from tradingagents.agents.discovery.entity_extractor import EntityMention
|
||||
from tradingagents.agents.discovery.models import (
|
||||
TrendingStock,
|
||||
EventCategory,
|
||||
NewsArticle,
|
||||
Sector,
|
||||
EventCategory,
|
||||
TrendingStock,
|
||||
)
|
||||
from tradingagents.agents.discovery.entity_extractor import EntityMention
|
||||
from tradingagents.dataflows.trending.stock_resolver import resolve_ticker
|
||||
from tradingagents.dataflows.trending.sector_classifier import classify_sector
|
||||
|
||||
from tradingagents.dataflows.trending.stock_resolver import resolve_ticker
|
||||
|
||||
DEFAULT_DECAY_RATE = 0.1
|
||||
DEFAULT_MAX_RESULTS = 20
|
||||
DEFAULT_MIN_MENTIONS = 2
|
||||
|
||||
|
||||
def _aggregate_sentiment(mentions: List[EntityMention]) -> float:
|
||||
def _aggregate_sentiment(mentions: list[EntityMention]) -> float:
|
||||
if not mentions:
|
||||
return 0.0
|
||||
|
||||
|
|
@ -37,7 +36,7 @@ def _aggregate_sentiment(mentions: List[EntityMention]) -> float:
|
|||
|
||||
|
||||
def _calculate_recency_weight(
|
||||
articles: List[NewsArticle],
|
||||
articles: list[NewsArticle],
|
||||
article_ids: set,
|
||||
decay_rate: float,
|
||||
) -> float:
|
||||
|
|
@ -60,18 +59,18 @@ def _calculate_recency_weight(
|
|||
return sum(weights) / len(weights)
|
||||
|
||||
|
||||
def _get_most_common_event_type(mentions: List[EntityMention]) -> EventCategory:
|
||||
def _get_most_common_event_type(mentions: list[EntityMention]) -> EventCategory:
|
||||
if not mentions:
|
||||
return EventCategory.OTHER
|
||||
|
||||
event_counts: Dict[EventCategory, int] = defaultdict(int)
|
||||
event_counts: dict[EventCategory, int] = defaultdict(int)
|
||||
for mention in mentions:
|
||||
event_counts[mention.event_type] += 1
|
||||
|
||||
return max(event_counts.keys(), key=lambda e: event_counts[e])
|
||||
|
||||
|
||||
def _build_news_summary(mentions: List[EntityMention]) -> str:
|
||||
def _build_news_summary(mentions: list[EntityMention]) -> str:
|
||||
if not mentions:
|
||||
return ""
|
||||
|
||||
|
|
@ -80,17 +79,17 @@ def _build_news_summary(mentions: List[EntityMention]) -> str:
|
|||
|
||||
|
||||
def calculate_trending_scores(
|
||||
mentions: List[EntityMention],
|
||||
articles: List[NewsArticle],
|
||||
mentions: list[EntityMention],
|
||||
articles: list[NewsArticle],
|
||||
decay_rate: float = DEFAULT_DECAY_RATE,
|
||||
max_results: int = DEFAULT_MAX_RESULTS,
|
||||
min_mentions: int = DEFAULT_MIN_MENTIONS,
|
||||
) -> List[TrendingStock]:
|
||||
) -> list[TrendingStock]:
|
||||
if not mentions:
|
||||
return []
|
||||
|
||||
ticker_mentions: Dict[str, List[EntityMention]] = defaultdict(list)
|
||||
ticker_company_names: Dict[str, str] = {}
|
||||
ticker_mentions: dict[str, list[EntityMention]] = defaultdict(list)
|
||||
ticker_company_names: dict[str, str] = {}
|
||||
|
||||
for mention in mentions:
|
||||
ticker = resolve_ticker(mention.company_name)
|
||||
|
|
@ -99,11 +98,11 @@ def calculate_trending_scores(
|
|||
if ticker not in ticker_company_names:
|
||||
ticker_company_names[ticker] = mention.company_name
|
||||
|
||||
article_index: Dict[str, int] = {}
|
||||
article_index: dict[str, int] = {}
|
||||
for i, article in enumerate(articles):
|
||||
article_index[f"article_{i}"] = i
|
||||
|
||||
trending_stocks: List[TrendingStock] = []
|
||||
trending_stocks: list[TrendingStock] = []
|
||||
|
||||
for ticker, ticker_mention_list in ticker_mentions.items():
|
||||
article_ids = {m.article_id for m in ticker_mention_list}
|
||||
|
|
@ -127,7 +126,7 @@ def calculate_trending_scores(
|
|||
|
||||
event_type = _get_most_common_event_type(ticker_mention_list)
|
||||
|
||||
source_article_list: List[NewsArticle] = []
|
||||
source_article_list: list[NewsArticle] = []
|
||||
for article_id in article_ids:
|
||||
idx = article_index.get(article_id)
|
||||
if idx is not None and idx < len(articles):
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ Additionally, develop a detailed investment plan for the trader. This should inc
|
|||
Your Recommendation: A decisive stance supported by the most convincing arguments.
|
||||
Rationale: An explanation of why these arguments lead to your conclusion.
|
||||
Strategic Actions: Concrete steps for implementing the recommendation.
|
||||
Take into account your past mistakes on similar situations. Use these insights to refine your decision-making and ensure you are learning and improving. Present your analysis conversationally, as if speaking naturally, without special formatting.
|
||||
Take into account your past mistakes on similar situations. Use these insights to refine your decision-making and ensure you are learning and improving. Present your analysis conversationally, as if speaking naturally, without special formatting.
|
||||
|
||||
Here are your past reflections on mistakes:
|
||||
\"{past_memory_str}\"
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
def create_risk_manager(llm, memory):
|
||||
def risk_manager_node(state) -> dict:
|
||||
|
||||
company_name = state["company_of_interest"]
|
||||
|
||||
history = state["risk_debate_state"]["history"]
|
||||
|
|
@ -32,7 +31,7 @@ Deliverables:
|
|||
|
||||
---
|
||||
|
||||
**Analysts Debate History:**
|
||||
**Analysts Debate History:**
|
||||
{history}
|
||||
|
||||
---
|
||||
|
|
|
|||
|
|
@ -1,16 +1,14 @@
|
|||
from typing import Annotated
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from langgraph.graph import MessagesState
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
||||
class InvestDebateState(TypedDict):
|
||||
"""Researcher team state"""
|
||||
bull_history: Annotated[
|
||||
str, "Bullish Conversation history"
|
||||
]
|
||||
bear_history: Annotated[
|
||||
str, "Bearish Conversation history"
|
||||
]
|
||||
|
||||
bull_history: Annotated[str, "Bullish Conversation history"]
|
||||
bear_history: Annotated[str, "Bearish Conversation history"]
|
||||
history: Annotated[str, "Conversation history"]
|
||||
current_response: Annotated[str, "Latest response"]
|
||||
judge_decision: Annotated[str, "Final judge decision"]
|
||||
|
|
@ -19,26 +17,15 @@ class InvestDebateState(TypedDict):
|
|||
|
||||
class RiskDebateState(TypedDict):
|
||||
"""Risk management team state"""
|
||||
risky_history: Annotated[
|
||||
str, "Risky Agent's Conversation history"
|
||||
]
|
||||
safe_history: Annotated[
|
||||
str, "Safe Agent's Conversation history"
|
||||
]
|
||||
neutral_history: Annotated[
|
||||
str, "Neutral Agent's Conversation history"
|
||||
]
|
||||
history: Annotated[str, "Conversation history"]
|
||||
|
||||
risky_history: Annotated[str, "Risky Agent's Conversation history"]
|
||||
safe_history: Annotated[str, "Safe Agent's Conversation history"]
|
||||
neutral_history: Annotated[str, "Neutral Agent's Conversation history"]
|
||||
history: Annotated[str, "Conversation history"]
|
||||
latest_speaker: Annotated[str, "Analyst that spoke last"]
|
||||
current_risky_response: Annotated[
|
||||
str, "Latest response by the risky analyst"
|
||||
]
|
||||
current_safe_response: Annotated[
|
||||
str, "Latest response by the safe analyst"
|
||||
]
|
||||
current_neutral_response: Annotated[
|
||||
str, "Latest response by the neutral analyst"
|
||||
]
|
||||
current_risky_response: Annotated[str, "Latest response by the risky analyst"]
|
||||
current_safe_response: Annotated[str, "Latest response by the safe analyst"]
|
||||
current_neutral_response: Annotated[str, "Latest response by the neutral analyst"]
|
||||
judge_decision: Annotated[str, "Judge's decision"]
|
||||
count: Annotated[int, "Length of the current conversation"]
|
||||
|
||||
|
|
|
|||
|
|
@ -1,23 +1,20 @@
|
|||
from langchain_core.messages import HumanMessage, RemoveMessage
|
||||
|
||||
from tradingagents.agents.utils.core_stock_tools import (
|
||||
get_stock_data
|
||||
)
|
||||
from tradingagents.agents.utils.technical_indicators_tools import (
|
||||
get_indicators
|
||||
)
|
||||
from tradingagents.agents.utils.core_stock_tools import get_stock_data
|
||||
from tradingagents.agents.utils.fundamental_data_tools import (
|
||||
get_fundamentals,
|
||||
get_balance_sheet,
|
||||
get_cashflow,
|
||||
get_income_statement
|
||||
get_fundamentals,
|
||||
get_income_statement,
|
||||
)
|
||||
from tradingagents.agents.utils.news_data_tools import (
|
||||
get_news,
|
||||
get_global_news,
|
||||
get_insider_sentiment,
|
||||
get_insider_transactions,
|
||||
get_global_news
|
||||
get_news,
|
||||
)
|
||||
from tradingagents.agents.utils.technical_indicators_tools import get_indicators
|
||||
|
||||
|
||||
def create_msg_delete():
|
||||
def delete_messages(state):
|
||||
|
|
@ -26,4 +23,5 @@ def create_msg_delete():
|
|||
removal_operations = [RemoveMessage(id=m.id) for m in messages]
|
||||
placeholder = HumanMessage(content="Continue")
|
||||
return {"messages": removal_operations + [placeholder]}
|
||||
|
||||
return delete_messages
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
from langchain_core.tools import tool
|
||||
from typing import Annotated
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from tradingagents.dataflows.interface import route_to_vendor
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
from langchain_core.tools import tool
|
||||
from typing import Annotated
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from tradingagents.dataflows.interface import route_to_vendor
|
||||
|
||||
|
||||
|
|
@ -74,4 +76,4 @@ def get_income_statement(
|
|||
Returns:
|
||||
str: A formatted report containing income statement data
|
||||
"""
|
||||
return route_to_vendor("get_income_statement", ticker, freq, curr_date)
|
||||
return route_to_vendor("get_income_statement", ticker, freq, curr_date)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import logging
|
||||
|
||||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
from openai import OpenAI
|
||||
|
|
@ -14,17 +15,15 @@ class FinancialSituationMemory:
|
|||
self.embedding = "text-embedding-3-small"
|
||||
self.client = OpenAI(base_url=config["backend_url"])
|
||||
self.chroma_client = chromadb.Client(Settings(allow_reset=True))
|
||||
self.situation_collection = self.chroma_client.get_or_create_collection(name=name)
|
||||
self.situation_collection = self.chroma_client.get_or_create_collection(
|
||||
name=name
|
||||
)
|
||||
|
||||
def get_embedding(self, text):
|
||||
|
||||
response = self.client.embeddings.create(
|
||||
model=self.embedding, input=text
|
||||
)
|
||||
response = self.client.embeddings.create(model=self.embedding, input=text)
|
||||
return response.data[0].embedding
|
||||
|
||||
def add_situations(self, situations_and_advice):
|
||||
|
||||
situations = []
|
||||
advice = []
|
||||
ids = []
|
||||
|
|
|
|||
|
|
@ -1,7 +1,10 @@
|
|||
from langchain_core.tools import tool
|
||||
from typing import Annotated
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from tradingagents.dataflows.interface import route_to_vendor
|
||||
|
||||
|
||||
@tool
|
||||
def get_news(
|
||||
ticker: Annotated[str, "Ticker symbol"],
|
||||
|
|
@ -20,6 +23,7 @@ def get_news(
|
|||
"""
|
||||
return route_to_vendor("get_news", ticker, start_date, end_date)
|
||||
|
||||
|
||||
@tool
|
||||
def get_global_news(
|
||||
curr_date: Annotated[str, "Current date in yyyy-mm-dd format"],
|
||||
|
|
@ -38,6 +42,7 @@ def get_global_news(
|
|||
"""
|
||||
return route_to_vendor("get_global_news", curr_date, look_back_days, limit)
|
||||
|
||||
|
||||
@tool
|
||||
def get_insider_sentiment(
|
||||
ticker: Annotated[str, "ticker symbol for the company"],
|
||||
|
|
@ -54,6 +59,7 @@ def get_insider_sentiment(
|
|||
"""
|
||||
return route_to_vendor("get_insider_sentiment", ticker, curr_date)
|
||||
|
||||
|
||||
@tool
|
||||
def get_insider_transactions(
|
||||
ticker: Annotated[str, "ticker symbol"],
|
||||
|
|
|
|||
|
|
@ -1,12 +1,17 @@
|
|||
from langchain_core.tools import tool
|
||||
from typing import Annotated
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from tradingagents.dataflows.interface import route_to_vendor
|
||||
|
||||
|
||||
@tool
|
||||
def get_indicators(
|
||||
symbol: Annotated[str, "ticker symbol of the company"],
|
||||
indicator: Annotated[str, "technical indicator to get the analysis and report of"],
|
||||
curr_date: Annotated[str, "The current trading date you are trading on, YYYY-mm-dd"],
|
||||
curr_date: Annotated[
|
||||
str, "The current trading date you are trading on, YYYY-mm-dd"
|
||||
],
|
||||
look_back_days: Annotated[int, "how many days to look back"] = 30,
|
||||
) -> str:
|
||||
"""
|
||||
|
|
@ -20,4 +25,6 @@ def get_indicators(
|
|||
Returns:
|
||||
str: A formatted dataframe containing the technical indicators for the specified ticker symbol and indicator.
|
||||
"""
|
||||
return route_to_vendor("get_indicators", symbol, indicator, curr_date, look_back_days)
|
||||
return route_to_vendor(
|
||||
"get_indicators", symbol, indicator, curr_date, look_back_days
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from .agent_integration import AgentBacktestEngine, run_agent_backtest
|
||||
from .data_loader import DataLoader
|
||||
from .engine import BacktestEngine, SimpleBacktestEngine
|
||||
from .metrics import MetricsCalculator
|
||||
from .agent_integration import AgentBacktestEngine, run_agent_backtest
|
||||
|
||||
__all__ = [
|
||||
"DataLoader",
|
||||
|
|
|
|||
|
|
@ -1,15 +1,15 @@
|
|||
import logging
|
||||
from datetime import date, datetime
|
||||
from decimal import Decimal
|
||||
from typing import Optional, Dict, Any
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||
from tradingagents.models.backtest import BacktestConfig, BacktestResult
|
||||
from tradingagents.models.decisions import (
|
||||
SignalType,
|
||||
TradingDecision,
|
||||
AnalystReport,
|
||||
AnalystType,
|
||||
SignalType,
|
||||
TradingDecision,
|
||||
)
|
||||
|
||||
from .engine import BacktestEngine
|
||||
|
|
@ -21,12 +21,12 @@ class AgentBacktestEngine(BacktestEngine):
|
|||
def __init__(
|
||||
self,
|
||||
config: BacktestConfig,
|
||||
agent_config: Optional[Dict[str, Any]] = None,
|
||||
agent_config: dict[str, Any] | None = None,
|
||||
):
|
||||
super().__init__(config)
|
||||
self.agent_config = agent_config or config.agent_config
|
||||
self.trading_graph: Optional[TradingAgentsGraph] = None
|
||||
self._decision_cache: Dict[str, TradingDecision] = {}
|
||||
self.trading_graph: TradingAgentsGraph | None = None
|
||||
self._decision_cache: dict[str, TradingDecision] = {}
|
||||
|
||||
def _initialize(self):
|
||||
super()._initialize()
|
||||
|
|
@ -49,7 +49,7 @@ class AgentBacktestEngine(BacktestEngine):
|
|||
ticker: str,
|
||||
trading_date: date,
|
||||
day_index: int,
|
||||
) -> Optional[TradingDecision]:
|
||||
) -> TradingDecision | None:
|
||||
cache_key = f"{ticker}_{trading_date}"
|
||||
if cache_key in self._decision_cache:
|
||||
return self._decision_cache[cache_key]
|
||||
|
|
@ -68,8 +68,7 @@ class AgentBacktestEngine(BacktestEngine):
|
|||
|
||||
except (ValueError, KeyError, RuntimeError, ConnectionError, TimeoutError) as e:
|
||||
logger.error(
|
||||
"Agent decision failed for %s on %s: %s",
|
||||
ticker, trading_date, e
|
||||
"Agent decision failed for %s on %s: %s", ticker, trading_date, e
|
||||
)
|
||||
return None
|
||||
|
||||
|
|
@ -77,8 +76,8 @@ class AgentBacktestEngine(BacktestEngine):
|
|||
self,
|
||||
ticker: str,
|
||||
trading_date: date,
|
||||
final_state: Dict[str, Any],
|
||||
signal_info: Dict[str, Any],
|
||||
final_state: dict[str, Any],
|
||||
signal_info: dict[str, Any],
|
||||
) -> TradingDecision:
|
||||
signal = self._extract_signal(signal_info)
|
||||
confidence = self._extract_confidence(signal_info)
|
||||
|
|
@ -134,9 +133,17 @@ class AgentBacktestEngine(BacktestEngine):
|
|||
bear_argument = None
|
||||
|
||||
if debate_state.get("bull_history"):
|
||||
bull_argument = debate_state["bull_history"][-1] if debate_state["bull_history"] else None
|
||||
bull_argument = (
|
||||
debate_state["bull_history"][-1]
|
||||
if debate_state["bull_history"]
|
||||
else None
|
||||
)
|
||||
if debate_state.get("bear_history"):
|
||||
bear_argument = debate_state["bear_history"][-1] if debate_state["bear_history"] else None
|
||||
bear_argument = (
|
||||
debate_state["bear_history"][-1]
|
||||
if debate_state["bear_history"]
|
||||
else None
|
||||
)
|
||||
|
||||
risk_state = final_state.get("risk_debate_state", {})
|
||||
risk_approved = self._extract_risk_approval(risk_state)
|
||||
|
|
@ -160,7 +167,7 @@ class AgentBacktestEngine(BacktestEngine):
|
|||
rationale=final_decision_text[:1000] if final_decision_text else "",
|
||||
)
|
||||
|
||||
def _extract_signal(self, signal_info: Dict[str, Any]) -> SignalType:
|
||||
def _extract_signal(self, signal_info: dict[str, Any]) -> SignalType:
|
||||
action = signal_info.get("action", "").upper()
|
||||
direction = signal_info.get("direction", "").upper()
|
||||
|
||||
|
|
@ -178,7 +185,7 @@ class AgentBacktestEngine(BacktestEngine):
|
|||
|
||||
return SignalType.HOLD
|
||||
|
||||
def _extract_confidence(self, signal_info: Dict[str, Any]) -> Decimal:
|
||||
def _extract_confidence(self, signal_info: dict[str, Any]) -> Decimal:
|
||||
confidence = signal_info.get("confidence", 0.5)
|
||||
if isinstance(confidence, str):
|
||||
try:
|
||||
|
|
@ -190,7 +197,7 @@ class AgentBacktestEngine(BacktestEngine):
|
|||
|
||||
def _extract_action(
|
||||
self,
|
||||
signal_info: Dict[str, Any],
|
||||
signal_info: dict[str, Any],
|
||||
final_decision_text: str,
|
||||
) -> str:
|
||||
action = signal_info.get("action", "")
|
||||
|
|
@ -205,7 +212,7 @@ class AgentBacktestEngine(BacktestEngine):
|
|||
|
||||
return "HOLD"
|
||||
|
||||
def _extract_risk_approval(self, risk_state: Dict[str, Any]) -> Optional[bool]:
|
||||
def _extract_risk_approval(self, risk_state: dict[str, Any]) -> bool | None:
|
||||
judge_decision = risk_state.get("judge_decision", "")
|
||||
if not judge_decision:
|
||||
return None
|
||||
|
|
@ -224,7 +231,7 @@ def run_agent_backtest(
|
|||
start_date: date,
|
||||
end_date: date,
|
||||
initial_cash: Decimal = Decimal("100000"),
|
||||
agent_config: Optional[Dict[str, Any]] = None,
|
||||
agent_config: dict[str, Any] | None = None,
|
||||
) -> BacktestResult:
|
||||
from tradingagents.models.portfolio import PortfolioConfig
|
||||
|
||||
|
|
|
|||
|
|
@ -9,17 +9,17 @@ from stockstats import wrap
|
|||
|
||||
from tradingagents.models.market_data import (
|
||||
OHLCV,
|
||||
OHLCVBar,
|
||||
TechnicalIndicators,
|
||||
HistoricalDataRequest,
|
||||
HistoricalDataResponse,
|
||||
OHLCVBar,
|
||||
TechnicalIndicators,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DataLoader:
|
||||
def __init__(self, cache_dir: Optional[str] = None):
|
||||
def __init__(self, cache_dir: str | None = None):
|
||||
self.cache_dir = cache_dir
|
||||
self._cache: dict[str, pd.DataFrame] = {}
|
||||
|
||||
|
|
@ -89,7 +89,9 @@ class DataLoader:
|
|||
)
|
||||
|
||||
if df.empty:
|
||||
logger.warning("No data returned for %s from %s to %s", ticker, start_date, end_date)
|
||||
logger.warning(
|
||||
"No data returned for %s from %s to %s", ticker, start_date, end_date
|
||||
)
|
||||
return pd.DataFrame()
|
||||
|
||||
df = df.reset_index()
|
||||
|
|
@ -116,7 +118,9 @@ class DataLoader:
|
|||
low=Decimal(str(round(row["Low"], 4))),
|
||||
close=Decimal(str(round(row["Close"], 4))),
|
||||
volume=int(row["Volume"]),
|
||||
adjusted_close=Decimal(str(round(row["Adj Close"], 4))) if "Adj Close" in row else None,
|
||||
adjusted_close=Decimal(str(round(row["Adj Close"], 4)))
|
||||
if "Adj Close" in row
|
||||
else None,
|
||||
)
|
||||
bars.append(bar)
|
||||
|
||||
|
|
@ -193,7 +197,7 @@ class DataLoader:
|
|||
return indicators
|
||||
|
||||
@staticmethod
|
||||
def _safe_decimal(value) -> Optional[Decimal]:
|
||||
def _safe_decimal(value) -> Decimal | None:
|
||||
if value is None or pd.isna(value):
|
||||
return None
|
||||
return Decimal(str(round(float(value), 4)))
|
||||
|
|
@ -202,10 +206,12 @@ class DataLoader:
|
|||
self,
|
||||
ticker: str,
|
||||
target_date: date,
|
||||
ohlcv: Optional[OHLCV] = None,
|
||||
) -> Optional[Decimal]:
|
||||
ohlcv: OHLCV | None = None,
|
||||
) -> Decimal | None:
|
||||
if ohlcv is None:
|
||||
ohlcv = self.load_ohlcv(ticker, target_date - timedelta(days=5), target_date)
|
||||
ohlcv = self.load_ohlcv(
|
||||
ticker, target_date - timedelta(days=5), target_date
|
||||
)
|
||||
|
||||
target_datetime = datetime.combine(target_date, datetime.min.time())
|
||||
bar = ohlcv.get_bar(target_datetime)
|
||||
|
|
|
|||
|
|
@ -1,10 +1,12 @@
|
|||
import logging
|
||||
from collections.abc import Callable
|
||||
from datetime import date, datetime, timedelta
|
||||
from decimal import Decimal
|
||||
from typing import Optional, Callable
|
||||
from typing import Optional
|
||||
|
||||
from tradingagents.models.backtest import (
|
||||
BacktestConfig,
|
||||
BacktestMetrics,
|
||||
BacktestResult,
|
||||
BacktestStatus,
|
||||
EquityCurvePoint,
|
||||
|
|
@ -12,7 +14,7 @@ from tradingagents.models.backtest import (
|
|||
)
|
||||
from tradingagents.models.decisions import SignalType, TradingDecision
|
||||
from tradingagents.models.portfolio import PortfolioSnapshot
|
||||
from tradingagents.models.trading import Order, OrderSide, OrderStatus, Fill, Trade
|
||||
from tradingagents.models.trading import Fill, Order, OrderSide, OrderStatus, Trade
|
||||
|
||||
from .data_loader import DataLoader
|
||||
from .metrics import MetricsCalculator
|
||||
|
|
@ -24,15 +26,15 @@ class BacktestEngine:
|
|||
def __init__(
|
||||
self,
|
||||
config: BacktestConfig,
|
||||
decision_callback: Optional[Callable[[str, date, dict], TradingDecision]] = None,
|
||||
decision_callback: Callable[[str, date, dict], TradingDecision] | None = None,
|
||||
):
|
||||
self.config = config
|
||||
self.decision_callback = decision_callback
|
||||
self.data_loader = DataLoader()
|
||||
self.metrics_calculator = MetricsCalculator(config.risk_free_rate)
|
||||
|
||||
self.portfolio: Optional[PortfolioSnapshot] = None
|
||||
self.trade_log: Optional[TradeLog] = None
|
||||
self.portfolio: PortfolioSnapshot | None = None
|
||||
self.trade_log: TradeLog | None = None
|
||||
self.equity_curve: list[EquityCurvePoint] = []
|
||||
self.daily_returns: list[Decimal] = []
|
||||
self.decisions: list[TradingDecision] = []
|
||||
|
|
@ -52,7 +54,9 @@ class BacktestEngine:
|
|||
|
||||
self._process_day(trading_date, i)
|
||||
|
||||
self._close_all_positions(trading_days[-1] if trading_days else self.config.end_date)
|
||||
self._close_all_positions(
|
||||
trading_days[-1] if trading_days else self.config.end_date
|
||||
)
|
||||
|
||||
metrics = self.metrics_calculator.calculate_metrics(
|
||||
self.equity_curve,
|
||||
|
|
@ -138,7 +142,7 @@ class BacktestEngine:
|
|||
ticker: str,
|
||||
trading_date: date,
|
||||
day_index: int,
|
||||
) -> Optional[TradingDecision]:
|
||||
) -> TradingDecision | None:
|
||||
if self.decision_callback:
|
||||
context = {
|
||||
"day_index": day_index,
|
||||
|
|
@ -153,7 +157,7 @@ class BacktestEngine:
|
|||
self,
|
||||
ticker: str,
|
||||
trading_date: date,
|
||||
) -> Optional[TradingDecision]:
|
||||
) -> TradingDecision | None:
|
||||
return None
|
||||
|
||||
def _execute_decision(
|
||||
|
|
@ -172,14 +176,18 @@ class BacktestEngine:
|
|||
if decision.recommended_quantity:
|
||||
quantity = decision.recommended_quantity
|
||||
else:
|
||||
max_position_value = self.portfolio.cash * (config.max_position_size_percent / 100)
|
||||
max_position_value = self.portfolio.cash * (
|
||||
config.max_position_size_percent / 100
|
||||
)
|
||||
quantity = int(max_position_value / execution_price)
|
||||
|
||||
if quantity <= 0:
|
||||
return
|
||||
|
||||
if not self.portfolio.can_afford(ticker, quantity, execution_price, config):
|
||||
quantity = self.portfolio.max_shares_affordable(ticker, execution_price, config)
|
||||
quantity = self.portfolio.max_shares_affordable(
|
||||
ticker, execution_price, config
|
||||
)
|
||||
|
||||
if quantity <= 0:
|
||||
return
|
||||
|
|
@ -221,7 +229,10 @@ class BacktestEngine:
|
|||
|
||||
logger.debug(
|
||||
"BUY %s: %d shares @ $%.2f on %s",
|
||||
ticker, quantity, execution_price, trading_date
|
||||
ticker,
|
||||
quantity,
|
||||
execution_price,
|
||||
trading_date,
|
||||
)
|
||||
|
||||
elif decision.is_sell and position.quantity > 0:
|
||||
|
|
@ -259,15 +270,18 @@ class BacktestEngine:
|
|||
trade.exit_time = datetime.combine(trading_date, datetime.min.time())
|
||||
trade.exit_order_id = order.id
|
||||
trade.commission = (
|
||||
config.calculate_commission(trade.entry_quantity, trade.entry_price) +
|
||||
commission
|
||||
config.calculate_commission(trade.entry_quantity, trade.entry_price)
|
||||
+ commission
|
||||
)
|
||||
self.trade_log.add_trade(trade)
|
||||
del self.open_trades[ticker]
|
||||
|
||||
logger.debug(
|
||||
"SELL %s: %d shares @ $%.2f on %s",
|
||||
ticker, quantity, execution_price, trading_date
|
||||
ticker,
|
||||
quantity,
|
||||
execution_price,
|
||||
trading_date,
|
||||
)
|
||||
|
||||
def _record_equity(self, trading_date: date, prices: dict[str, Decimal]) -> None:
|
||||
|
|
@ -305,11 +319,12 @@ class BacktestEngine:
|
|||
)
|
||||
self._execute_decision(decision, prices[ticker], final_date)
|
||||
|
||||
def _empty_metrics(self) -> "BacktestMetrics":
|
||||
from tradingagents.models.backtest import BacktestMetrics
|
||||
def _empty_metrics(self) -> BacktestMetrics:
|
||||
return BacktestMetrics(
|
||||
start_equity=self.config.portfolio_config.initial_cash,
|
||||
end_equity=self.portfolio.cash if self.portfolio else self.config.portfolio_config.initial_cash,
|
||||
end_equity=self.portfolio.cash
|
||||
if self.portfolio
|
||||
else self.config.portfolio_config.initial_cash,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -329,7 +344,7 @@ class SimpleBacktestEngine(BacktestEngine):
|
|||
ticker: str,
|
||||
trading_date: date,
|
||||
day_index: int,
|
||||
) -> Optional[TradingDecision]:
|
||||
) -> TradingDecision | None:
|
||||
context = {
|
||||
"day_index": day_index,
|
||||
"portfolio": self.portfolio,
|
||||
|
|
@ -339,7 +354,11 @@ class SimpleBacktestEngine(BacktestEngine):
|
|||
|
||||
position = self.portfolio.get_position(ticker)
|
||||
|
||||
if position.quantity == 0 and self.buy_signal and self.buy_signal(ticker, trading_date, context):
|
||||
if (
|
||||
position.quantity == 0
|
||||
and self.buy_signal
|
||||
and self.buy_signal(ticker, trading_date, context)
|
||||
):
|
||||
return TradingDecision(
|
||||
ticker=ticker,
|
||||
timestamp=datetime.now(),
|
||||
|
|
@ -351,7 +370,11 @@ class SimpleBacktestEngine(BacktestEngine):
|
|||
rationale="Buy signal triggered",
|
||||
)
|
||||
|
||||
if position.quantity > 0 and self.sell_signal and self.sell_signal(ticker, trading_date, context):
|
||||
if (
|
||||
position.quantity > 0
|
||||
and self.sell_signal
|
||||
and self.sell_signal(ticker, trading_date, context)
|
||||
):
|
||||
return TradingDecision(
|
||||
ticker=ticker,
|
||||
timestamp=datetime.now(),
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ class MetricsCalculator:
|
|||
self,
|
||||
equity_curve: list[EquityCurvePoint],
|
||||
trade_log: TradeLog,
|
||||
benchmark_curve: Optional[list[EquityCurvePoint]] = None,
|
||||
benchmark_curve: list[EquityCurvePoint] | None = None,
|
||||
) -> BacktestMetrics:
|
||||
if not equity_curve:
|
||||
raise ValueError("Equity curve cannot be empty")
|
||||
|
|
@ -35,16 +35,28 @@ class MetricsCalculator:
|
|||
|
||||
daily_returns = self._calculate_daily_returns(equity_curve)
|
||||
volatility = self._calculate_volatility(daily_returns)
|
||||
annualized_volatility = volatility * Decimal(math.sqrt(self.TRADING_DAYS_PER_YEAR))
|
||||
annualized_volatility = volatility * Decimal(
|
||||
math.sqrt(self.TRADING_DAYS_PER_YEAR)
|
||||
)
|
||||
|
||||
downside_returns = [r for r in daily_returns if r < 0]
|
||||
downside_volatility = self._calculate_volatility(downside_returns) if downside_returns else Decimal("0")
|
||||
annualized_downside_vol = downside_volatility * Decimal(math.sqrt(self.TRADING_DAYS_PER_YEAR))
|
||||
downside_volatility = (
|
||||
self._calculate_volatility(downside_returns)
|
||||
if downside_returns
|
||||
else Decimal("0")
|
||||
)
|
||||
annualized_downside_vol = downside_volatility * Decimal(
|
||||
math.sqrt(self.TRADING_DAYS_PER_YEAR)
|
||||
)
|
||||
|
||||
max_dd, max_dd_pct, max_dd_duration, avg_dd = self._calculate_drawdown_metrics(equity_curve)
|
||||
max_dd, max_dd_pct, max_dd_duration, avg_dd = self._calculate_drawdown_metrics(
|
||||
equity_curve
|
||||
)
|
||||
|
||||
sharpe = self._calculate_sharpe_ratio(annualized_return, annualized_volatility)
|
||||
sortino = self._calculate_sortino_ratio(annualized_return, annualized_downside_vol)
|
||||
sortino = self._calculate_sortino_ratio(
|
||||
annualized_return, annualized_downside_vol
|
||||
)
|
||||
calmar = self._calculate_calmar_ratio(annualized_return, max_dd_pct)
|
||||
|
||||
benchmark_return = None
|
||||
|
|
@ -55,7 +67,9 @@ class MetricsCalculator:
|
|||
|
||||
if benchmark_curve and len(benchmark_curve) == len(equity_curve):
|
||||
benchmark_return = benchmark_curve[-1].equity - benchmark_curve[0].equity
|
||||
benchmark_return_percent = (benchmark_return / benchmark_curve[0].equity) * 100
|
||||
benchmark_return_percent = (
|
||||
benchmark_return / benchmark_curve[0].equity
|
||||
) * 100
|
||||
|
||||
benchmark_daily = self._calculate_daily_returns(benchmark_curve)
|
||||
alpha, beta = self._calculate_alpha_beta(daily_returns, benchmark_daily)
|
||||
|
|
@ -63,7 +77,9 @@ class MetricsCalculator:
|
|||
daily_returns, benchmark_daily
|
||||
)
|
||||
|
||||
all_pnls = [t.pnl for t in trade_log.trades if t.is_closed and t.pnl is not None]
|
||||
all_pnls = [
|
||||
t.pnl for t in trade_log.trades if t.is_closed and t.pnl is not None
|
||||
]
|
||||
avg_trade_pnl = sum(all_pnls) / len(all_pnls) if all_pnls else None
|
||||
largest_win = max((p for p in all_pnls if p > 0), default=None)
|
||||
largest_loss = min((p for p in all_pnls if p < 0), default=None)
|
||||
|
|
@ -125,7 +141,7 @@ class MetricsCalculator:
|
|||
def _calculate_drawdown_metrics(
|
||||
self,
|
||||
equity_curve: list[EquityCurvePoint],
|
||||
) -> tuple[Decimal, Decimal, Optional[int], Decimal]:
|
||||
) -> tuple[Decimal, Decimal, int | None, Decimal]:
|
||||
if not equity_curve:
|
||||
return Decimal("0"), Decimal("0"), None, Decimal("0")
|
||||
|
||||
|
|
@ -170,13 +186,18 @@ class MetricsCalculator:
|
|||
|
||||
avg_drawdown = sum(drawdowns) / len(drawdowns) if drawdowns else Decimal("0")
|
||||
|
||||
return max_drawdown, max_drawdown_percent, max_drawdown_duration or None, avg_drawdown
|
||||
return (
|
||||
max_drawdown,
|
||||
max_drawdown_percent,
|
||||
max_drawdown_duration or None,
|
||||
avg_drawdown,
|
||||
)
|
||||
|
||||
def _calculate_sharpe_ratio(
|
||||
self,
|
||||
annualized_return: Decimal,
|
||||
annualized_volatility: Decimal,
|
||||
) -> Optional[Decimal]:
|
||||
) -> Decimal | None:
|
||||
if annualized_volatility == 0:
|
||||
return None
|
||||
|
||||
|
|
@ -187,7 +208,7 @@ class MetricsCalculator:
|
|||
self,
|
||||
annualized_return: Decimal,
|
||||
annualized_downside_vol: Decimal,
|
||||
) -> Optional[Decimal]:
|
||||
) -> Decimal | None:
|
||||
if annualized_downside_vol == 0:
|
||||
return None
|
||||
|
||||
|
|
@ -198,7 +219,7 @@ class MetricsCalculator:
|
|||
self,
|
||||
annualized_return: Decimal,
|
||||
max_drawdown_percent: Decimal,
|
||||
) -> Optional[Decimal]:
|
||||
) -> Decimal | None:
|
||||
if max_drawdown_percent == 0:
|
||||
return None
|
||||
|
||||
|
|
@ -208,14 +229,14 @@ class MetricsCalculator:
|
|||
self,
|
||||
returns: list[Decimal],
|
||||
benchmark_returns: list[Decimal],
|
||||
) -> tuple[Optional[Decimal], Optional[Decimal]]:
|
||||
) -> tuple[Decimal | None, Decimal | None]:
|
||||
if len(returns) != len(benchmark_returns) or len(returns) < 2:
|
||||
return None, None
|
||||
|
||||
n = len(returns)
|
||||
sum_x = sum(benchmark_returns)
|
||||
sum_y = sum(returns)
|
||||
sum_xy = sum(r * b for r, b in zip(returns, benchmark_returns))
|
||||
sum_xy = sum(r * b for r, b in zip(returns, benchmark_returns, strict=False))
|
||||
sum_xx = sum(b * b for b in benchmark_returns)
|
||||
|
||||
denominator = n * sum_xx - sum_x * sum_x
|
||||
|
|
@ -233,18 +254,22 @@ class MetricsCalculator:
|
|||
self,
|
||||
returns: list[Decimal],
|
||||
benchmark_returns: list[Decimal],
|
||||
) -> Optional[Decimal]:
|
||||
) -> Decimal | None:
|
||||
if len(returns) != len(benchmark_returns) or len(returns) < 2:
|
||||
return None
|
||||
|
||||
excess_returns = [r - b for r, b in zip(returns, benchmark_returns)]
|
||||
excess_returns = [
|
||||
r - b for r, b in zip(returns, benchmark_returns, strict=False)
|
||||
]
|
||||
mean_excess = sum(excess_returns) / len(excess_returns)
|
||||
tracking_error = self._calculate_volatility(excess_returns)
|
||||
|
||||
if tracking_error == 0:
|
||||
return None
|
||||
|
||||
annualized_tracking_error = tracking_error * Decimal(math.sqrt(self.TRADING_DAYS_PER_YEAR))
|
||||
annualized_tracking_error = tracking_error * Decimal(
|
||||
math.sqrt(self.TRADING_DAYS_PER_YEAR)
|
||||
)
|
||||
annualized_excess = mean_excess * self.TRADING_DAYS_PER_YEAR
|
||||
|
||||
return annualized_excess / annualized_tracking_error
|
||||
|
|
@ -263,14 +288,16 @@ class MetricsCalculator:
|
|||
daily_returns = self._calculate_daily_returns(equity_curve)
|
||||
|
||||
for i in range(window - 1, len(daily_returns)):
|
||||
window_returns = daily_returns[i - window + 1:i + 1]
|
||||
window_returns = daily_returns[i - window + 1 : i + 1]
|
||||
vol = self._calculate_volatility(window_returns)
|
||||
annualized_vol = vol * Decimal(math.sqrt(self.TRADING_DAYS_PER_YEAR))
|
||||
|
||||
mean_return = sum(window_returns) / len(window_returns)
|
||||
annualized_return = mean_return * self.TRADING_DAYS_PER_YEAR * 100
|
||||
|
||||
sharpe = self._calculate_sharpe_ratio(annualized_return, annualized_vol * 100)
|
||||
sharpe = self._calculate_sharpe_ratio(
|
||||
annualized_return, annualized_vol * 100
|
||||
)
|
||||
|
||||
rolling_sharpe.append(sharpe if sharpe else Decimal("0"))
|
||||
rolling_volatility.append(annualized_vol * 100)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import os
|
||||
from typing import Optional, Dict, Any, List
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
|
@ -13,11 +14,13 @@ class DataVendorsConfig(BaseModel):
|
|||
|
||||
class TradingAgentsSettings(BaseSettings):
|
||||
project_dir: str = Field(
|
||||
default_factory=lambda: os.path.abspath(os.path.join(os.path.dirname(__file__), "."))
|
||||
default_factory=lambda: os.path.abspath(
|
||||
os.path.join(os.path.dirname(__file__), ".")
|
||||
)
|
||||
)
|
||||
results_dir: str = Field(default="./results")
|
||||
data_dir: str = Field(default="./data")
|
||||
data_cache_dir: Optional[str] = None
|
||||
data_cache_dir: str | None = None
|
||||
|
||||
llm_provider: str = Field(default="openai")
|
||||
deep_think_llm: str = Field(default="gpt-5")
|
||||
|
|
@ -34,7 +37,7 @@ class TradingAgentsSettings(BaseSettings):
|
|||
discovery_max_results: int = Field(default=20, ge=1, le=100)
|
||||
discovery_min_mentions: int = Field(default=2, ge=1)
|
||||
|
||||
bulk_news_vendor_order: List[str] = Field(
|
||||
bulk_news_vendor_order: list[str] = Field(
|
||||
default=["tavily", "brave", "alpha_vantage", "openai", "google"]
|
||||
)
|
||||
bulk_news_timeout: int = Field(default=30, ge=5)
|
||||
|
|
@ -45,15 +48,15 @@ class TradingAgentsSettings(BaseSettings):
|
|||
log_console_enabled: bool = Field(default=True)
|
||||
log_file_enabled: bool = Field(default=True)
|
||||
|
||||
openai_api_key: Optional[str] = Field(default=None)
|
||||
alpha_vantage_api_key: Optional[str] = Field(default=None)
|
||||
brave_api_key: Optional[str] = Field(default=None)
|
||||
tavily_api_key: Optional[str] = Field(default=None)
|
||||
google_api_key: Optional[str] = Field(default=None)
|
||||
anthropic_api_key: Optional[str] = Field(default=None)
|
||||
openai_api_key: str | None = Field(default=None)
|
||||
alpha_vantage_api_key: str | None = Field(default=None)
|
||||
brave_api_key: str | None = Field(default=None)
|
||||
tavily_api_key: str | None = Field(default=None)
|
||||
google_api_key: str | None = Field(default=None)
|
||||
anthropic_api_key: str | None = Field(default=None)
|
||||
|
||||
data_vendors: DataVendorsConfig = Field(default_factory=DataVendorsConfig)
|
||||
tool_vendors: Dict[str, Any] = Field(default_factory=dict)
|
||||
tool_vendors: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
model_config = {
|
||||
"env_prefix": "TRADINGAGENTS_",
|
||||
|
|
@ -93,15 +96,17 @@ class TradingAgentsSettings(BaseSettings):
|
|||
def validate_llm_provider(cls, v: str) -> str:
|
||||
valid_providers = {"openai", "anthropic", "google", "ollama", "openrouter"}
|
||||
if v.lower() not in valid_providers:
|
||||
raise ValueError(f"Invalid LLM provider: {v}. Must be one of {valid_providers}")
|
||||
raise ValueError(
|
||||
f"Invalid LLM provider: {v}. Must be one of {valid_providers}"
|
||||
)
|
||||
return v.lower()
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
result = self.model_dump()
|
||||
result["data_vendors"] = self.data_vendors.model_dump()
|
||||
return result
|
||||
|
||||
def get_api_key(self, vendor: str) -> Optional[str]:
|
||||
def get_api_key(self, vendor: str) -> str | None:
|
||||
key_map = {
|
||||
"openai": self.openai_api_key,
|
||||
"alpha_vantage": self.alpha_vantage_api_key,
|
||||
|
|
@ -123,7 +128,7 @@ class TradingAgentsSettings(BaseSettings):
|
|||
return key
|
||||
|
||||
|
||||
_settings: Optional[TradingAgentsSettings] = None
|
||||
_settings: TradingAgentsSettings | None = None
|
||||
|
||||
|
||||
def get_settings() -> TradingAgentsSettings:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,10 @@
|
|||
from .base import Base
|
||||
from .engine import get_db_session, get_engine, init_database, reset_engine
|
||||
|
||||
__all__ = [
|
||||
"Base",
|
||||
"get_db_session",
|
||||
"get_engine",
|
||||
"init_database",
|
||||
"reset_engine",
|
||||
]
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
|
@ -0,0 +1,71 @@
|
|||
import os
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from .base import Base
|
||||
|
||||
DEFAULT_DB_DIR = "./data"
|
||||
DEFAULT_DB_NAME = "tradingagents.db"
|
||||
|
||||
_engine: Engine | None = None
|
||||
_SessionLocal: sessionmaker | None = None
|
||||
|
||||
|
||||
def get_database_url() -> str:
|
||||
db_dir = os.getenv("TRADINGAGENTS_DB_DIR", DEFAULT_DB_DIR)
|
||||
db_name = os.getenv("TRADINGAGENTS_DB_NAME", DEFAULT_DB_NAME)
|
||||
|
||||
Path(db_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
db_path = Path(db_dir) / db_name
|
||||
return f"sqlite:///{db_path}"
|
||||
|
||||
|
||||
def get_engine() -> Engine:
|
||||
global _engine
|
||||
if _engine is None:
|
||||
_engine = create_engine(
|
||||
get_database_url(),
|
||||
echo=os.getenv("TRADINGAGENTS_DB_ECHO", "false").lower() == "true",
|
||||
connect_args={"check_same_thread": False},
|
||||
)
|
||||
return _engine
|
||||
|
||||
|
||||
def get_session_factory() -> sessionmaker:
|
||||
global _SessionLocal
|
||||
if _SessionLocal is None:
|
||||
_SessionLocal = sessionmaker(
|
||||
autocommit=False, autoflush=False, bind=get_engine()
|
||||
)
|
||||
return _SessionLocal
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_db_session() -> Generator[Session, None, None]:
|
||||
session = get_session_factory()()
|
||||
try:
|
||||
yield session
|
||||
session.commit()
|
||||
except Exception:
|
||||
session.rollback()
|
||||
raise
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
|
||||
def init_database() -> None:
|
||||
Base.metadata.create_all(bind=get_engine())
|
||||
|
||||
|
||||
def reset_engine() -> None:
|
||||
global _engine, _SessionLocal
|
||||
if _engine:
|
||||
_engine.dispose()
|
||||
_engine = None
|
||||
_SessionLocal = None
|
||||
|
|
@ -0,0 +1,55 @@
|
|||
from tradingagents.database.base import Base
|
||||
from tradingagents.database.models.analysis import (
|
||||
AnalysisSession,
|
||||
AnalystReport,
|
||||
InvestmentDebate,
|
||||
RiskDebate,
|
||||
)
|
||||
from tradingagents.database.models.backtesting import (
|
||||
BacktestMetricsRecord,
|
||||
BacktestRun,
|
||||
BacktestTrade,
|
||||
EquityCurveRecord,
|
||||
)
|
||||
from tradingagents.database.models.discovery import (
|
||||
DiscoveryArticle,
|
||||
DiscoveryRun,
|
||||
TrendingStockResult,
|
||||
)
|
||||
from tradingagents.database.models.market_data import (
|
||||
DataCache,
|
||||
FundamentalData,
|
||||
NewsArticle,
|
||||
SocialMediaPost,
|
||||
StockPrice,
|
||||
TechnicalIndicator,
|
||||
)
|
||||
from tradingagents.database.models.trading import (
|
||||
TradeExecution,
|
||||
TradeReflection,
|
||||
TradingDecision,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Base",
|
||||
"AnalysisSession",
|
||||
"AnalystReport",
|
||||
"InvestmentDebate",
|
||||
"RiskDebate",
|
||||
"TradingDecision",
|
||||
"TradeExecution",
|
||||
"TradeReflection",
|
||||
"StockPrice",
|
||||
"TechnicalIndicator",
|
||||
"NewsArticle",
|
||||
"SocialMediaPost",
|
||||
"FundamentalData",
|
||||
"DataCache",
|
||||
"DiscoveryRun",
|
||||
"TrendingStockResult",
|
||||
"DiscoveryArticle",
|
||||
"BacktestRun",
|
||||
"BacktestMetricsRecord",
|
||||
"BacktestTrade",
|
||||
"EquityCurveRecord",
|
||||
]
|
||||
|
|
@ -0,0 +1,131 @@
|
|||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import DateTime, Enum, ForeignKey, String, Text
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from tradingagents.database.base import Base
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tradingagents.database.models.trading import TradingDecision
|
||||
|
||||
|
||||
class AnalysisSession(Base):
|
||||
__tablename__ = "analysis_sessions"
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
String(36), primary_key=True, default=lambda: str(uuid4())
|
||||
)
|
||||
ticker: Mapped[str] = mapped_column(String(20), nullable=False, index=True)
|
||||
trade_date: Mapped[str] = mapped_column(String(10), nullable=False, index=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, default=datetime.utcnow, nullable=False
|
||||
)
|
||||
completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
status: Mapped[str] = mapped_column(
|
||||
Enum("pending", "running", "completed", "failed", name="session_status"),
|
||||
default="pending",
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
analyst_reports: Mapped[list["AnalystReport"]] = relationship(
|
||||
"AnalystReport", back_populates="session", cascade="all, delete-orphan"
|
||||
)
|
||||
investment_debate: Mapped["InvestmentDebate | None"] = relationship(
|
||||
"InvestmentDebate",
|
||||
back_populates="session",
|
||||
uselist=False,
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
risk_debate: Mapped["RiskDebate | None"] = relationship(
|
||||
"RiskDebate",
|
||||
back_populates="session",
|
||||
uselist=False,
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
trading_decision: Mapped["TradingDecision | None"] = relationship(
|
||||
"TradingDecision",
|
||||
back_populates="session",
|
||||
uselist=False,
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
|
||||
class AnalystReport(Base):
|
||||
__tablename__ = "analyst_reports"
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
String(36), primary_key=True, default=lambda: str(uuid4())
|
||||
)
|
||||
session_id: Mapped[str] = mapped_column(
|
||||
String(36), ForeignKey("analysis_sessions.id"), nullable=False, index=True
|
||||
)
|
||||
analyst_type: Mapped[str] = mapped_column(
|
||||
Enum("market", "sentiment", "news", "fundamentals", name="analyst_type"),
|
||||
nullable=False,
|
||||
)
|
||||
report_content: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, default=datetime.utcnow, nullable=False
|
||||
)
|
||||
|
||||
session: Mapped["AnalysisSession"] = relationship(
|
||||
"AnalysisSession", back_populates="analyst_reports"
|
||||
)
|
||||
|
||||
|
||||
class InvestmentDebate(Base):
|
||||
__tablename__ = "investment_debates"
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
String(36), primary_key=True, default=lambda: str(uuid4())
|
||||
)
|
||||
session_id: Mapped[str] = mapped_column(
|
||||
String(36),
|
||||
ForeignKey("analysis_sessions.id"),
|
||||
nullable=False,
|
||||
unique=True,
|
||||
index=True,
|
||||
)
|
||||
bull_history: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
bear_history: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
debate_history: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
judge_decision: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
investment_plan: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
debate_rounds: Mapped[int] = mapped_column(default=0, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, default=datetime.utcnow, nullable=False
|
||||
)
|
||||
|
||||
session: Mapped["AnalysisSession"] = relationship(
|
||||
"AnalysisSession", back_populates="investment_debate"
|
||||
)
|
||||
|
||||
|
||||
class RiskDebate(Base):
|
||||
__tablename__ = "risk_debates"
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
String(36), primary_key=True, default=lambda: str(uuid4())
|
||||
)
|
||||
session_id: Mapped[str] = mapped_column(
|
||||
String(36),
|
||||
ForeignKey("analysis_sessions.id"),
|
||||
nullable=False,
|
||||
unique=True,
|
||||
index=True,
|
||||
)
|
||||
risky_history: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
safe_history: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
neutral_history: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
debate_history: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
judge_decision: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
debate_rounds: Mapped[int] = mapped_column(default=0, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, default=datetime.utcnow, nullable=False
|
||||
)
|
||||
|
||||
session: Mapped["AnalysisSession"] = relationship(
|
||||
"AnalysisSession", back_populates="risk_debate"
|
||||
)
|
||||
|
|
@ -0,0 +1,167 @@
|
|||
from datetime import datetime
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import DateTime, Enum, Float, ForeignKey, Integer, String, Text
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from tradingagents.database.base import Base
|
||||
|
||||
|
||||
class BacktestRun(Base):
|
||||
__tablename__ = "backtest_runs"
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
String(36), primary_key=True, default=lambda: str(uuid4())
|
||||
)
|
||||
name: Mapped[str] = mapped_column(String(200), nullable=False)
|
||||
description: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
tickers: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
start_date: Mapped[str] = mapped_column(String(10), nullable=False)
|
||||
end_date: Mapped[str] = mapped_column(String(10), nullable=False)
|
||||
interval: Mapped[str] = mapped_column(String(10), default="1d", nullable=False)
|
||||
initial_cash: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
benchmark_ticker: Mapped[str | None] = mapped_column(String(20), nullable=True)
|
||||
risk_free_rate: Mapped[float] = mapped_column(Float, default=0.05, nullable=False)
|
||||
use_agent_pipeline: Mapped[bool] = mapped_column(default=True, nullable=False)
|
||||
agent_config: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
status: Mapped[str] = mapped_column(
|
||||
Enum("pending", "running", "completed", "failed", name="backtest_status"),
|
||||
default="pending",
|
||||
nullable=False,
|
||||
)
|
||||
error_message: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
started_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, default=datetime.utcnow, nullable=False
|
||||
)
|
||||
completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
|
||||
metrics: Mapped["BacktestMetricsRecord | None"] = relationship(
|
||||
"BacktestMetricsRecord",
|
||||
back_populates="backtest_run",
|
||||
uselist=False,
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
trades: Mapped[list["BacktestTrade"]] = relationship(
|
||||
"BacktestTrade", back_populates="backtest_run", cascade="all, delete-orphan"
|
||||
)
|
||||
equity_curve: Mapped[list["EquityCurveRecord"]] = relationship(
|
||||
"EquityCurveRecord", back_populates="backtest_run", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
|
||||
class BacktestMetricsRecord(Base):
|
||||
__tablename__ = "backtest_metrics"
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
String(36), primary_key=True, default=lambda: str(uuid4())
|
||||
)
|
||||
backtest_run_id: Mapped[str] = mapped_column(
|
||||
String(36),
|
||||
ForeignKey("backtest_runs.id"),
|
||||
nullable=False,
|
||||
unique=True,
|
||||
index=True,
|
||||
)
|
||||
total_return: Mapped[float] = mapped_column(Float, default=0.0, nullable=False)
|
||||
total_return_percent: Mapped[float] = mapped_column(
|
||||
Float, default=0.0, nullable=False
|
||||
)
|
||||
annualized_return: Mapped[float] = mapped_column(Float, default=0.0, nullable=False)
|
||||
benchmark_return: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
benchmark_return_percent: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
alpha: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
beta: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
volatility: Mapped[float] = mapped_column(Float, default=0.0, nullable=False)
|
||||
annualized_volatility: Mapped[float] = mapped_column(
|
||||
Float, default=0.0, nullable=False
|
||||
)
|
||||
downside_volatility: Mapped[float] = mapped_column(
|
||||
Float, default=0.0, nullable=False
|
||||
)
|
||||
sharpe_ratio: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
sortino_ratio: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
calmar_ratio: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
information_ratio: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
max_drawdown: Mapped[float] = mapped_column(Float, default=0.0, nullable=False)
|
||||
max_drawdown_percent: Mapped[float] = mapped_column(
|
||||
Float, default=0.0, nullable=False
|
||||
)
|
||||
max_drawdown_duration: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
avg_drawdown: Mapped[float] = mapped_column(Float, default=0.0, nullable=False)
|
||||
total_trades: Mapped[int] = mapped_column(Integer, default=0, nullable=False)
|
||||
win_rate: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
profit_factor: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
avg_trade_pnl: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
avg_win: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
avg_loss: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
largest_win: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
largest_loss: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
avg_holding_period_days: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
trading_days: Mapped[int] = mapped_column(Integer, default=0, nullable=False)
|
||||
start_equity: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
end_equity: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, default=datetime.utcnow, nullable=False
|
||||
)
|
||||
|
||||
backtest_run: Mapped["BacktestRun"] = relationship(
|
||||
"BacktestRun", back_populates="metrics"
|
||||
)
|
||||
|
||||
|
||||
class BacktestTrade(Base):
|
||||
__tablename__ = "backtest_trades"
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
String(36), primary_key=True, default=lambda: str(uuid4())
|
||||
)
|
||||
backtest_run_id: Mapped[str] = mapped_column(
|
||||
String(36), ForeignKey("backtest_runs.id"), nullable=False, index=True
|
||||
)
|
||||
ticker: Mapped[str] = mapped_column(String(20), nullable=False, index=True)
|
||||
side: Mapped[str] = mapped_column(
|
||||
Enum("buy", "sell", name="trade_side"), nullable=False
|
||||
)
|
||||
quantity: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
entry_price: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
exit_price: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
entry_date: Mapped[datetime] = mapped_column(DateTime, nullable=False)
|
||||
exit_date: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
pnl: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
pnl_percent: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
is_closed: Mapped[bool] = mapped_column(default=False, nullable=False)
|
||||
holding_period_days: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
commission: Mapped[float] = mapped_column(Float, default=0.0, nullable=False)
|
||||
slippage: Mapped[float] = mapped_column(Float, default=0.0, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, default=datetime.utcnow, nullable=False
|
||||
)
|
||||
|
||||
backtest_run: Mapped["BacktestRun"] = relationship(
|
||||
"BacktestRun", back_populates="trades"
|
||||
)
|
||||
|
||||
|
||||
class EquityCurveRecord(Base):
|
||||
__tablename__ = "equity_curve_records"
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
String(36), primary_key=True, default=lambda: str(uuid4())
|
||||
)
|
||||
backtest_run_id: Mapped[str] = mapped_column(
|
||||
String(36), ForeignKey("backtest_runs.id"), nullable=False, index=True
|
||||
)
|
||||
timestamp: Mapped[datetime] = mapped_column(DateTime, nullable=False)
|
||||
equity: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
cash: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
positions_value: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
benchmark_value: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
drawdown: Mapped[float] = mapped_column(Float, default=0.0, nullable=False)
|
||||
drawdown_percent: Mapped[float] = mapped_column(Float, default=0.0, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, default=datetime.utcnow, nullable=False
|
||||
)
|
||||
|
||||
backtest_run: Mapped["BacktestRun"] = relationship(
|
||||
"BacktestRun", back_populates="equity_curve"
|
||||
)
|
||||
|
|
@ -0,0 +1,113 @@
|
|||
from datetime import datetime
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import DateTime, Enum, Float, ForeignKey, Integer, String, Text
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from tradingagents.database.base import Base
|
||||
|
||||
|
||||
class DiscoveryRun(Base):
|
||||
__tablename__ = "discovery_runs"
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
String(36), primary_key=True, default=lambda: str(uuid4())
|
||||
)
|
||||
lookback_period: Mapped[str] = mapped_column(String(20), nullable=False)
|
||||
sector_filter: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
event_filter: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
max_results: Mapped[int] = mapped_column(Integer, default=20, nullable=False)
|
||||
status: Mapped[str] = mapped_column(
|
||||
Enum("created", "processing", "completed", "failed", name="discovery_status"),
|
||||
default="created",
|
||||
nullable=False,
|
||||
)
|
||||
error_message: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
started_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, default=datetime.utcnow, nullable=False
|
||||
)
|
||||
completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
|
||||
trending_stocks: Mapped[list["TrendingStockResult"]] = relationship(
|
||||
"TrendingStockResult",
|
||||
back_populates="discovery_run",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
|
||||
class TrendingStockResult(Base):
|
||||
__tablename__ = "trending_stock_results"
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
String(36), primary_key=True, default=lambda: str(uuid4())
|
||||
)
|
||||
discovery_run_id: Mapped[str] = mapped_column(
|
||||
String(36), ForeignKey("discovery_runs.id"), nullable=False, index=True
|
||||
)
|
||||
ticker: Mapped[str] = mapped_column(String(20), nullable=False, index=True)
|
||||
company_name: Mapped[str] = mapped_column(String(200), nullable=False)
|
||||
score: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
mention_count: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
sentiment: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
sector: Mapped[str] = mapped_column(
|
||||
Enum(
|
||||
"technology",
|
||||
"healthcare",
|
||||
"finance",
|
||||
"energy",
|
||||
"consumer_goods",
|
||||
"industrials",
|
||||
"other",
|
||||
name="stock_sector",
|
||||
),
|
||||
nullable=False,
|
||||
)
|
||||
event_type: Mapped[str] = mapped_column(
|
||||
Enum(
|
||||
"earnings",
|
||||
"merger_acquisition",
|
||||
"regulatory",
|
||||
"product_launch",
|
||||
"executive_change",
|
||||
"other",
|
||||
name="event_category",
|
||||
),
|
||||
nullable=False,
|
||||
)
|
||||
news_summary: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, default=datetime.utcnow, nullable=False
|
||||
)
|
||||
|
||||
discovery_run: Mapped["DiscoveryRun"] = relationship(
|
||||
"DiscoveryRun", back_populates="trending_stocks"
|
||||
)
|
||||
source_articles: Mapped[list["DiscoveryArticle"]] = relationship(
|
||||
"DiscoveryArticle",
|
||||
back_populates="trending_stock",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
|
||||
class DiscoveryArticle(Base):
|
||||
__tablename__ = "discovery_articles"
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
String(36), primary_key=True, default=lambda: str(uuid4())
|
||||
)
|
||||
trending_stock_id: Mapped[str] = mapped_column(
|
||||
String(36), ForeignKey("trending_stock_results.id"), nullable=False, index=True
|
||||
)
|
||||
title: Mapped[str] = mapped_column(String(500), nullable=False)
|
||||
source: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
url: Mapped[str | None] = mapped_column(String(1000), nullable=True)
|
||||
content_snippet: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
ticker_mentions: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
published_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, default=datetime.utcnow, nullable=False
|
||||
)
|
||||
|
||||
trending_stock: Mapped["TrendingStockResult"] = relationship(
|
||||
"TrendingStockResult", back_populates="source_articles"
|
||||
)
|
||||
|
|
@ -0,0 +1,141 @@
|
|||
from datetime import datetime
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import DateTime, Float, Index, Integer, String, Text
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from tradingagents.database.base import Base
|
||||
|
||||
|
||||
class StockPrice(Base):
|
||||
__tablename__ = "stock_prices"
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
String(36), primary_key=True, default=lambda: str(uuid4())
|
||||
)
|
||||
ticker: Mapped[str] = mapped_column(String(20), nullable=False)
|
||||
date: Mapped[str] = mapped_column(String(10), nullable=False)
|
||||
open: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
high: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
low: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
close: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
adj_close: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
volume: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
data_source: Mapped[str | None] = mapped_column(String(50), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, default=datetime.utcnow, nullable=False
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_stock_prices_ticker_date", "ticker", "date", unique=True),
|
||||
)
|
||||
|
||||
|
||||
class TechnicalIndicator(Base):
|
||||
__tablename__ = "technical_indicators"
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
String(36), primary_key=True, default=lambda: str(uuid4())
|
||||
)
|
||||
ticker: Mapped[str] = mapped_column(String(20), nullable=False)
|
||||
date: Mapped[str] = mapped_column(String(10), nullable=False)
|
||||
indicator_name: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||
indicator_value: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, default=datetime.utcnow, nullable=False
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index(
|
||||
"ix_tech_indicators_ticker_date_name",
|
||||
"ticker",
|
||||
"date",
|
||||
"indicator_name",
|
||||
unique=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class NewsArticle(Base):
|
||||
__tablename__ = "news_articles"
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
String(36), primary_key=True, default=lambda: str(uuid4())
|
||||
)
|
||||
ticker: Mapped[str | None] = mapped_column(String(20), nullable=True, index=True)
|
||||
headline: Mapped[str] = mapped_column(String(500), nullable=False)
|
||||
source: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||
url: Mapped[str | None] = mapped_column(String(1000), nullable=True)
|
||||
summary: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
content: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
published_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
sentiment_score: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
data_source: Mapped[str | None] = mapped_column(String(50), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, default=datetime.utcnow, nullable=False
|
||||
)
|
||||
|
||||
|
||||
class SocialMediaPost(Base):
|
||||
__tablename__ = "social_media_posts"
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
String(36), primary_key=True, default=lambda: str(uuid4())
|
||||
)
|
||||
ticker: Mapped[str | None] = mapped_column(String(20), nullable=True, index=True)
|
||||
platform: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||
post_id: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||
author: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||
content: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
engagement_score: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
sentiment_score: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
posted_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
data_source: Mapped[str | None] = mapped_column(String(50), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, default=datetime.utcnow, nullable=False
|
||||
)
|
||||
|
||||
|
||||
class FundamentalData(Base):
|
||||
__tablename__ = "fundamental_data"
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
String(36), primary_key=True, default=lambda: str(uuid4())
|
||||
)
|
||||
ticker: Mapped[str] = mapped_column(String(20), nullable=False)
|
||||
report_date: Mapped[str] = mapped_column(String(10), nullable=False)
|
||||
metric_name: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
metric_value: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
metric_unit: Mapped[str | None] = mapped_column(String(20), nullable=True)
|
||||
data_source: Mapped[str | None] = mapped_column(String(50), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, default=datetime.utcnow, nullable=False
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index(
|
||||
"ix_fundamental_ticker_date_metric",
|
||||
"ticker",
|
||||
"report_date",
|
||||
"metric_name",
|
||||
unique=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class DataCache(Base):
|
||||
__tablename__ = "data_cache"
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
String(36), primary_key=True, default=lambda: str(uuid4())
|
||||
)
|
||||
cache_key: Mapped[str] = mapped_column(String(255), nullable=False, unique=True)
|
||||
data_type: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||
ticker: Mapped[str | None] = mapped_column(String(20), nullable=True, index=True)
|
||||
date_range_start: Mapped[str | None] = mapped_column(String(10), nullable=True)
|
||||
date_range_end: Mapped[str | None] = mapped_column(String(10), nullable=True)
|
||||
cached_data: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
expires_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, default=datetime.utcnow, nullable=False
|
||||
)
|
||||
|
|
@ -0,0 +1,104 @@
|
|||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import DateTime, Enum, Float, ForeignKey, String, Text
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from tradingagents.database.base import Base
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tradingagents.database.models.analysis import AnalysisSession
|
||||
|
||||
|
||||
class TradingDecision(Base):
|
||||
__tablename__ = "trading_decisions"
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
String(36), primary_key=True, default=lambda: str(uuid4())
|
||||
)
|
||||
session_id: Mapped[str] = mapped_column(
|
||||
String(36),
|
||||
ForeignKey("analysis_sessions.id"),
|
||||
nullable=False,
|
||||
unique=True,
|
||||
index=True,
|
||||
)
|
||||
ticker: Mapped[str] = mapped_column(String(20), nullable=False, index=True)
|
||||
decision: Mapped[str] = mapped_column(
|
||||
Enum("buy", "sell", "hold", name="trade_decision"),
|
||||
nullable=False,
|
||||
)
|
||||
trader_plan: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
final_decision_content: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
confidence_score: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, default=datetime.utcnow, nullable=False
|
||||
)
|
||||
|
||||
session: Mapped["AnalysisSession"] = relationship(
|
||||
"AnalysisSession", back_populates="trading_decision"
|
||||
)
|
||||
execution: Mapped["TradeExecution | None"] = relationship(
|
||||
"TradeExecution",
|
||||
back_populates="decision",
|
||||
uselist=False,
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
|
||||
class TradeExecution(Base):
|
||||
__tablename__ = "trade_executions"
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
String(36), primary_key=True, default=lambda: str(uuid4())
|
||||
)
|
||||
decision_id: Mapped[str] = mapped_column(
|
||||
String(36),
|
||||
ForeignKey("trading_decisions.id"),
|
||||
nullable=False,
|
||||
unique=True,
|
||||
index=True,
|
||||
)
|
||||
ticker: Mapped[str] = mapped_column(String(20), nullable=False, index=True)
|
||||
action: Mapped[str] = mapped_column(
|
||||
Enum("buy", "sell", "hold", name="trade_action"),
|
||||
nullable=False,
|
||||
)
|
||||
quantity: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
price: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
executed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
status: Mapped[str] = mapped_column(
|
||||
Enum("pending", "executed", "cancelled", "failed", name="execution_status"),
|
||||
default="pending",
|
||||
nullable=False,
|
||||
)
|
||||
notes: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, default=datetime.utcnow, nullable=False
|
||||
)
|
||||
|
||||
decision: Mapped["TradingDecision"] = relationship(
|
||||
"TradingDecision", back_populates="execution"
|
||||
)
|
||||
|
||||
|
||||
class TradeReflection(Base):
|
||||
__tablename__ = "trade_reflections"
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
String(36), primary_key=True, default=lambda: str(uuid4())
|
||||
)
|
||||
ticker: Mapped[str] = mapped_column(String(20), nullable=False, index=True)
|
||||
trade_date: Mapped[str] = mapped_column(String(10), nullable=False, index=True)
|
||||
original_decision: Mapped[str] = mapped_column(
|
||||
Enum("buy", "sell", "hold", name="reflection_decision"),
|
||||
nullable=False,
|
||||
)
|
||||
actual_outcome: Mapped[str | None] = mapped_column(String(50), nullable=True)
|
||||
reflection_content: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
lessons_learned: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
profit_loss: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, default=datetime.utcnow, nullable=False
|
||||
)
|
||||
|
|
@ -0,0 +1,37 @@
|
|||
from .analysis import (
|
||||
AnalysisSessionRepository,
|
||||
AnalystReportRepository,
|
||||
InvestmentDebateRepository,
|
||||
RiskDebateRepository,
|
||||
)
|
||||
from .base import BaseRepository
|
||||
from .market_data import (
|
||||
DataCacheRepository,
|
||||
FundamentalDataRepository,
|
||||
NewsArticleRepository,
|
||||
SocialMediaPostRepository,
|
||||
StockPriceRepository,
|
||||
TechnicalIndicatorRepository,
|
||||
)
|
||||
from .trading import (
|
||||
TradeExecutionRepository,
|
||||
TradeReflectionRepository,
|
||||
TradingDecisionRepository,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BaseRepository",
|
||||
"AnalysisSessionRepository",
|
||||
"AnalystReportRepository",
|
||||
"InvestmentDebateRepository",
|
||||
"RiskDebateRepository",
|
||||
"TradingDecisionRepository",
|
||||
"TradeExecutionRepository",
|
||||
"TradeReflectionRepository",
|
||||
"StockPriceRepository",
|
||||
"TechnicalIndicatorRepository",
|
||||
"NewsArticleRepository",
|
||||
"SocialMediaPostRepository",
|
||||
"FundamentalDataRepository",
|
||||
"DataCacheRepository",
|
||||
]
|
||||
|
|
@ -0,0 +1,114 @@
|
|||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import and_
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from tradingagents.database.models.analysis import (
|
||||
AnalysisSession,
|
||||
AnalystReport,
|
||||
InvestmentDebate,
|
||||
RiskDebate,
|
||||
)
|
||||
from tradingagents.database.repositories.base import BaseRepository
|
||||
|
||||
|
||||
class AnalysisSessionRepository(BaseRepository[AnalysisSession]):
|
||||
def __init__(self, session: Session):
|
||||
super().__init__(session, AnalysisSession)
|
||||
|
||||
def get_by_ticker_and_date(
|
||||
self, ticker: str, trade_date: str
|
||||
) -> AnalysisSession | None:
|
||||
return (
|
||||
self.session.query(AnalysisSession)
|
||||
.filter(
|
||||
and_(
|
||||
AnalysisSession.ticker == ticker,
|
||||
AnalysisSession.trade_date == trade_date,
|
||||
)
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
def get_latest_by_ticker(self, ticker: str) -> AnalysisSession | None:
|
||||
return (
|
||||
self.session.query(AnalysisSession)
|
||||
.filter(AnalysisSession.ticker == ticker)
|
||||
.order_by(AnalysisSession.created_at.desc())
|
||||
.first()
|
||||
)
|
||||
|
||||
def get_completed_sessions(self, limit: int = 100) -> list[AnalysisSession]:
|
||||
return (
|
||||
self.session.query(AnalysisSession)
|
||||
.filter(AnalysisSession.status == "completed")
|
||||
.order_by(AnalysisSession.completed_at.desc())
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
def mark_completed(self, session_id: str) -> AnalysisSession | None:
|
||||
obj = self.get(session_id)
|
||||
if obj:
|
||||
obj.status = "completed"
|
||||
obj.completed_at = datetime.utcnow()
|
||||
self.session.flush()
|
||||
return obj
|
||||
|
||||
def mark_failed(self, session_id: str) -> AnalysisSession | None:
|
||||
obj = self.get(session_id)
|
||||
if obj:
|
||||
obj.status = "failed"
|
||||
obj.completed_at = datetime.utcnow()
|
||||
self.session.flush()
|
||||
return obj
|
||||
|
||||
|
||||
class AnalystReportRepository(BaseRepository[AnalystReport]):
|
||||
def __init__(self, session: Session):
|
||||
super().__init__(session, AnalystReport)
|
||||
|
||||
def get_by_session_and_type(
|
||||
self, session_id: str, analyst_type: str
|
||||
) -> AnalystReport | None:
|
||||
return (
|
||||
self.session.query(AnalystReport)
|
||||
.filter(
|
||||
and_(
|
||||
AnalystReport.session_id == session_id,
|
||||
AnalystReport.analyst_type == analyst_type,
|
||||
)
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
def get_all_by_session(self, session_id: str) -> list[AnalystReport]:
|
||||
return (
|
||||
self.session.query(AnalystReport)
|
||||
.filter(AnalystReport.session_id == session_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
class InvestmentDebateRepository(BaseRepository[InvestmentDebate]):
|
||||
def __init__(self, session: Session):
|
||||
super().__init__(session, InvestmentDebate)
|
||||
|
||||
def get_by_session(self, session_id: str) -> InvestmentDebate | None:
|
||||
return (
|
||||
self.session.query(InvestmentDebate)
|
||||
.filter(InvestmentDebate.session_id == session_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
|
||||
class RiskDebateRepository(BaseRepository[RiskDebate]):
|
||||
def __init__(self, session: Session):
|
||||
super().__init__(session, RiskDebate)
|
||||
|
||||
def get_by_session(self, session_id: str) -> RiskDebate | None:
|
||||
return (
|
||||
self.session.query(RiskDebate)
|
||||
.filter(RiskDebate.session_id == session_id)
|
||||
.first()
|
||||
)
|
||||
|
|
@ -0,0 +1,46 @@
|
|||
from typing import Generic, TypeVar
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from tradingagents.database.base import Base
|
||||
|
||||
ModelType = TypeVar("ModelType", bound=Base)
|
||||
|
||||
|
||||
class BaseRepository(Generic[ModelType]):
|
||||
def __init__(self, session: Session, model_class: type[ModelType]):
|
||||
self.session = session
|
||||
self.model_class = model_class
|
||||
|
||||
def get(self, id: UUID | str | int) -> ModelType | None:
|
||||
return (
|
||||
self.session.query(self.model_class)
|
||||
.filter(self.model_class.id == id)
|
||||
.first()
|
||||
)
|
||||
|
||||
def get_all(self, skip: int = 0, limit: int = 100) -> list[ModelType]:
|
||||
return self.session.query(self.model_class).offset(skip).limit(limit).all()
|
||||
|
||||
def create(self, obj_in: dict) -> ModelType:
|
||||
db_obj = self.model_class(**obj_in)
|
||||
self.session.add(db_obj)
|
||||
self.session.flush()
|
||||
return db_obj
|
||||
|
||||
def update(self, db_obj: ModelType, obj_in: dict) -> ModelType:
|
||||
for field, value in obj_in.items():
|
||||
setattr(db_obj, field, value)
|
||||
self.session.flush()
|
||||
return db_obj
|
||||
|
||||
def delete(self, id: UUID | str | int) -> bool:
|
||||
obj = self.get(id)
|
||||
if obj:
|
||||
self.session.delete(obj)
|
||||
return True
|
||||
return False
|
||||
|
||||
def count(self) -> int:
|
||||
return self.session.query(self.model_class).count()
|
||||
|
|
@ -0,0 +1,203 @@
|
|||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import and_
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from tradingagents.database.models.market_data import (
|
||||
DataCache,
|
||||
FundamentalData,
|
||||
NewsArticle,
|
||||
SocialMediaPost,
|
||||
StockPrice,
|
||||
TechnicalIndicator,
|
||||
)
|
||||
from tradingagents.database.repositories.base import BaseRepository
|
||||
|
||||
|
||||
class StockPriceRepository(BaseRepository[StockPrice]):
|
||||
def __init__(self, session: Session):
|
||||
super().__init__(session, StockPrice)
|
||||
|
||||
def get_by_ticker_and_date(self, ticker: str, date: str) -> StockPrice | None:
|
||||
return (
|
||||
self.session.query(StockPrice)
|
||||
.filter(and_(StockPrice.ticker == ticker, StockPrice.date == date))
|
||||
.first()
|
||||
)
|
||||
|
||||
def get_by_ticker_range(
|
||||
self, ticker: str, start_date: str, end_date: str
|
||||
) -> list[StockPrice]:
|
||||
return (
|
||||
self.session.query(StockPrice)
|
||||
.filter(
|
||||
and_(
|
||||
StockPrice.ticker == ticker,
|
||||
StockPrice.date >= start_date,
|
||||
StockPrice.date <= end_date,
|
||||
)
|
||||
)
|
||||
.order_by(StockPrice.date)
|
||||
.all()
|
||||
)
|
||||
|
||||
def upsert(self, data: dict) -> StockPrice:
|
||||
existing = self.get_by_ticker_and_date(data["ticker"], data["date"])
|
||||
if existing:
|
||||
return self.update(existing, data)
|
||||
return self.create(data)
|
||||
|
||||
|
||||
class TechnicalIndicatorRepository(BaseRepository[TechnicalIndicator]):
|
||||
def __init__(self, session: Session):
|
||||
super().__init__(session, TechnicalIndicator)
|
||||
|
||||
def get_by_ticker_date_indicator(
|
||||
self, ticker: str, date: str, indicator_name: str
|
||||
) -> TechnicalIndicator | None:
|
||||
return (
|
||||
self.session.query(TechnicalIndicator)
|
||||
.filter(
|
||||
and_(
|
||||
TechnicalIndicator.ticker == ticker,
|
||||
TechnicalIndicator.date == date,
|
||||
TechnicalIndicator.indicator_name == indicator_name,
|
||||
)
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
def get_by_ticker_and_date(
|
||||
self, ticker: str, date: str
|
||||
) -> list[TechnicalIndicator]:
|
||||
return (
|
||||
self.session.query(TechnicalIndicator)
|
||||
.filter(
|
||||
and_(
|
||||
TechnicalIndicator.ticker == ticker,
|
||||
TechnicalIndicator.date == date,
|
||||
)
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
class NewsArticleRepository(BaseRepository[NewsArticle]):
|
||||
def __init__(self, session: Session):
|
||||
super().__init__(session, NewsArticle)
|
||||
|
||||
def get_by_ticker(self, ticker: str, limit: int = 100) -> list[NewsArticle]:
|
||||
return (
|
||||
self.session.query(NewsArticle)
|
||||
.filter(NewsArticle.ticker == ticker)
|
||||
.order_by(NewsArticle.published_at.desc())
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
def get_recent(self, hours: int = 24, limit: int = 100) -> list[NewsArticle]:
|
||||
cutoff = datetime.utcnow().timestamp() - (hours * 3600)
|
||||
return (
|
||||
self.session.query(NewsArticle)
|
||||
.filter(NewsArticle.published_at >= datetime.fromtimestamp(cutoff))
|
||||
.order_by(NewsArticle.published_at.desc())
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
class SocialMediaPostRepository(BaseRepository[SocialMediaPost]):
|
||||
def __init__(self, session: Session):
|
||||
super().__init__(session, SocialMediaPost)
|
||||
|
||||
def get_by_ticker(self, ticker: str, limit: int = 100) -> list[SocialMediaPost]:
|
||||
return (
|
||||
self.session.query(SocialMediaPost)
|
||||
.filter(SocialMediaPost.ticker == ticker)
|
||||
.order_by(SocialMediaPost.posted_at.desc())
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
class FundamentalDataRepository(BaseRepository[FundamentalData]):
|
||||
def __init__(self, session: Session):
|
||||
super().__init__(session, FundamentalData)
|
||||
|
||||
def get_by_ticker_and_metric(
|
||||
self, ticker: str, metric_name: str
|
||||
) -> FundamentalData | None:
|
||||
return (
|
||||
self.session.query(FundamentalData)
|
||||
.filter(
|
||||
and_(
|
||||
FundamentalData.ticker == ticker,
|
||||
FundamentalData.metric_name == metric_name,
|
||||
)
|
||||
)
|
||||
.order_by(FundamentalData.report_date.desc())
|
||||
.first()
|
||||
)
|
||||
|
||||
def get_all_by_ticker(self, ticker: str) -> list[FundamentalData]:
|
||||
return (
|
||||
self.session.query(FundamentalData)
|
||||
.filter(FundamentalData.ticker == ticker)
|
||||
.order_by(FundamentalData.report_date.desc())
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
class DataCacheRepository(BaseRepository[DataCache]):
|
||||
def __init__(self, session: Session):
|
||||
super().__init__(session, DataCache)
|
||||
|
||||
def get_by_key(self, cache_key: str) -> DataCache | None:
|
||||
return (
|
||||
self.session.query(DataCache)
|
||||
.filter(DataCache.cache_key == cache_key)
|
||||
.first()
|
||||
)
|
||||
|
||||
def get_valid_cache(self, cache_key: str) -> DataCache | None:
|
||||
cache = self.get_by_key(cache_key)
|
||||
if cache and cache.expires_at and cache.expires_at > datetime.utcnow():
|
||||
return cache
|
||||
return None
|
||||
|
||||
def set_cache(
|
||||
self,
|
||||
cache_key: str,
|
||||
data_type: str,
|
||||
cached_data: str,
|
||||
expires_at: datetime | None = None,
|
||||
ticker: str | None = None,
|
||||
) -> DataCache:
|
||||
existing = self.get_by_key(cache_key)
|
||||
if existing:
|
||||
return self.update(
|
||||
existing,
|
||||
{
|
||||
"data_type": data_type,
|
||||
"cached_data": cached_data,
|
||||
"expires_at": expires_at,
|
||||
"ticker": ticker,
|
||||
},
|
||||
)
|
||||
return self.create(
|
||||
{
|
||||
"cache_key": cache_key,
|
||||
"data_type": data_type,
|
||||
"cached_data": cached_data,
|
||||
"expires_at": expires_at,
|
||||
"ticker": ticker,
|
||||
}
|
||||
)
|
||||
|
||||
def clear_expired(self) -> int:
|
||||
result = (
|
||||
self.session.query(DataCache)
|
||||
.filter(DataCache.expires_at < datetime.utcnow())
|
||||
.delete()
|
||||
)
|
||||
return result
|
||||
|
|
@ -0,0 +1,97 @@
|
|||
from sqlalchemy import and_
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from tradingagents.database.models.trading import (
|
||||
TradeExecution,
|
||||
TradeReflection,
|
||||
TradingDecision,
|
||||
)
|
||||
from tradingagents.database.repositories.base import BaseRepository
|
||||
|
||||
|
||||
class TradingDecisionRepository(BaseRepository[TradingDecision]):
|
||||
def __init__(self, session: Session):
|
||||
super().__init__(session, TradingDecision)
|
||||
|
||||
def get_by_session(self, session_id: str) -> TradingDecision | None:
|
||||
return (
|
||||
self.session.query(TradingDecision)
|
||||
.filter(TradingDecision.session_id == session_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
def get_by_ticker(self, ticker: str, limit: int = 100) -> list[TradingDecision]:
|
||||
return (
|
||||
self.session.query(TradingDecision)
|
||||
.filter(TradingDecision.ticker == ticker)
|
||||
.order_by(TradingDecision.created_at.desc())
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
def get_by_decision_type(
|
||||
self, decision: str, limit: int = 100
|
||||
) -> list[TradingDecision]:
|
||||
return (
|
||||
self.session.query(TradingDecision)
|
||||
.filter(TradingDecision.decision == decision)
|
||||
.order_by(TradingDecision.created_at.desc())
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
class TradeExecutionRepository(BaseRepository[TradeExecution]):
|
||||
def __init__(self, session: Session):
|
||||
super().__init__(session, TradeExecution)
|
||||
|
||||
def get_by_decision(self, decision_id: str) -> TradeExecution | None:
|
||||
return (
|
||||
self.session.query(TradeExecution)
|
||||
.filter(TradeExecution.decision_id == decision_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
def get_pending(self) -> list[TradeExecution]:
|
||||
return (
|
||||
self.session.query(TradeExecution)
|
||||
.filter(TradeExecution.status == "pending")
|
||||
.all()
|
||||
)
|
||||
|
||||
def get_by_ticker(self, ticker: str, limit: int = 100) -> list[TradeExecution]:
|
||||
return (
|
||||
self.session.query(TradeExecution)
|
||||
.filter(TradeExecution.ticker == ticker)
|
||||
.order_by(TradeExecution.created_at.desc())
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
class TradeReflectionRepository(BaseRepository[TradeReflection]):
|
||||
def __init__(self, session: Session):
|
||||
super().__init__(session, TradeReflection)
|
||||
|
||||
def get_by_ticker(self, ticker: str, limit: int = 100) -> list[TradeReflection]:
|
||||
return (
|
||||
self.session.query(TradeReflection)
|
||||
.filter(TradeReflection.ticker == ticker)
|
||||
.order_by(TradeReflection.created_at.desc())
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
def get_by_ticker_and_date(
|
||||
self, ticker: str, trade_date: str
|
||||
) -> TradeReflection | None:
|
||||
return (
|
||||
self.session.query(TradeReflection)
|
||||
.filter(
|
||||
and_(
|
||||
TradeReflection.ticker == ticker,
|
||||
TradeReflection.trade_date == trade_date,
|
||||
)
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
|
@ -1,5 +1,10 @@
|
|||
# Import functions from specialized modules
|
||||
from .alpha_vantage_stock import get_stock
|
||||
from .alpha_vantage_fundamentals import (
|
||||
get_balance_sheet,
|
||||
get_cashflow,
|
||||
get_fundamentals,
|
||||
get_income_statement,
|
||||
)
|
||||
from .alpha_vantage_indicator import get_indicator
|
||||
from .alpha_vantage_fundamentals import get_fundamentals, get_balance_sheet, get_cashflow, get_income_statement
|
||||
from .alpha_vantage_news import get_news, get_insider_transactions
|
||||
from .alpha_vantage_news import get_insider_transactions, get_news
|
||||
from .alpha_vantage_stock import get_stock
|
||||
|
|
|
|||
|
|
@ -1,18 +1,21 @@
|
|||
import json
|
||||
import logging
|
||||
import os
|
||||
import requests
|
||||
import pandas as pd
|
||||
import json
|
||||
from datetime import datetime
|
||||
from io import StringIO
|
||||
|
||||
import pandas as pd
|
||||
import requests
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
API_BASE_URL = "https://www.alphavantage.co/query"
|
||||
|
||||
|
||||
def get_api_key() -> str:
|
||||
try:
|
||||
from tradingagents.config import get_settings
|
||||
|
||||
return get_settings().require_api_key("alpha_vantage")
|
||||
except ImportError:
|
||||
api_key = os.getenv("ALPHA_VANTAGE_API_KEY")
|
||||
|
|
@ -20,9 +23,10 @@ def get_api_key() -> str:
|
|||
raise ValueError("ALPHA_VANTAGE_API_KEY environment variable is not set.")
|
||||
return api_key
|
||||
|
||||
|
||||
def format_datetime_for_api(date_input) -> str:
|
||||
if isinstance(date_input, str):
|
||||
if len(date_input) == 13 and 'T' in date_input:
|
||||
if len(date_input) == 13 and "T" in date_input:
|
||||
return date_input
|
||||
try:
|
||||
dt = datetime.strptime(date_input, "%Y-%m-%d")
|
||||
|
|
@ -36,20 +40,26 @@ def format_datetime_for_api(date_input) -> str:
|
|||
elif isinstance(date_input, datetime):
|
||||
return date_input.strftime("%Y%m%dT%H%M")
|
||||
else:
|
||||
raise ValueError(f"Date must be string or datetime object, got {type(date_input)}")
|
||||
raise ValueError(
|
||||
f"Date must be string or datetime object, got {type(date_input)}"
|
||||
)
|
||||
|
||||
|
||||
class AlphaVantageRateLimitError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def _make_api_request(function_name: str, params: dict) -> dict | str:
|
||||
api_params = params.copy()
|
||||
api_params.update({
|
||||
"function": function_name,
|
||||
"apikey": get_api_key(),
|
||||
"source": "trading_agents",
|
||||
})
|
||||
api_params.update(
|
||||
{
|
||||
"function": function_name,
|
||||
"apikey": get_api_key(),
|
||||
"source": "trading_agents",
|
||||
}
|
||||
)
|
||||
|
||||
current_entitlement = globals().get('_current_entitlement')
|
||||
current_entitlement = globals().get("_current_entitlement")
|
||||
entitlement = api_params.get("entitlement") or current_entitlement
|
||||
|
||||
if entitlement:
|
||||
|
|
@ -66,15 +76,19 @@ def _make_api_request(function_name: str, params: dict) -> dict | str:
|
|||
response_json = json.loads(response_text)
|
||||
if "Information" in response_json:
|
||||
info_message = response_json["Information"]
|
||||
if "rate limit" in info_message.lower() or "api key" in info_message.lower():
|
||||
raise AlphaVantageRateLimitError(f"Alpha Vantage rate limit exceeded: {info_message}")
|
||||
if (
|
||||
"rate limit" in info_message.lower()
|
||||
or "api key" in info_message.lower()
|
||||
):
|
||||
raise AlphaVantageRateLimitError(
|
||||
f"Alpha Vantage rate limit exceeded: {info_message}"
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
return response_text
|
||||
|
||||
|
||||
|
||||
def _filter_csv_by_date_range(csv_data: str, start_date: str, end_date: str) -> str:
|
||||
if not csv_data or csv_data.strip() == "":
|
||||
return csv_data
|
||||
|
|
|
|||
|
|
@ -19,7 +19,9 @@ def get_fundamentals(ticker: str, curr_date: str = None) -> str:
|
|||
return _make_api_request("OVERVIEW", params)
|
||||
|
||||
|
||||
def get_balance_sheet(ticker: str, freq: str = "quarterly", curr_date: str = None) -> str:
|
||||
def get_balance_sheet(
|
||||
ticker: str, freq: str = "quarterly", curr_date: str = None
|
||||
) -> str:
|
||||
"""
|
||||
Retrieve balance sheet data for a given ticker symbol using Alpha Vantage.
|
||||
|
||||
|
|
@ -57,7 +59,9 @@ def get_cashflow(ticker: str, freq: str = "quarterly", curr_date: str = None) ->
|
|||
return _make_api_request("CASH_FLOW", params)
|
||||
|
||||
|
||||
def get_income_statement(ticker: str, freq: str = "quarterly", curr_date: str = None) -> str:
|
||||
def get_income_statement(
|
||||
ticker: str, freq: str = "quarterly", curr_date: str = None
|
||||
) -> str:
|
||||
"""
|
||||
Retrieve income statement data for a given ticker symbol using Alpha Vantage.
|
||||
|
||||
|
|
@ -74,4 +78,3 @@ def get_income_statement(ticker: str, freq: str = "quarterly", curr_date: str =
|
|||
}
|
||||
|
||||
return _make_api_request("INCOME_STATEMENT", params)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,12 @@
|
|||
import logging
|
||||
|
||||
import requests
|
||||
|
||||
from .alpha_vantage_common import _make_api_request
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_indicator(
|
||||
symbol: str,
|
||||
indicator: str,
|
||||
|
|
@ -11,9 +14,10 @@ def get_indicator(
|
|||
look_back_days: int,
|
||||
interval: str = "daily",
|
||||
time_period: int = 14,
|
||||
series_type: str = "close"
|
||||
series_type: str = "close",
|
||||
) -> str:
|
||||
from datetime import datetime
|
||||
|
||||
from dateutil.relativedelta import relativedelta
|
||||
|
||||
supported_indicators = {
|
||||
|
|
@ -28,7 +32,7 @@ def get_indicator(
|
|||
"boll_ub": ("Bollinger Upper Band", "close"),
|
||||
"boll_lb": ("Bollinger Lower Band", "close"),
|
||||
"atr": ("ATR", None),
|
||||
"vwma": ("VWMA", "close")
|
||||
"vwma": ("VWMA", "close"),
|
||||
}
|
||||
|
||||
indicator_descriptions = {
|
||||
|
|
@ -43,7 +47,7 @@ def get_indicator(
|
|||
"boll_ub": "Bollinger Upper Band: Typically 2 standard deviations above the middle line. Usage: Signals potential overbought conditions and breakout zones. Tips: Confirm signals with other tools; prices may ride the band in strong trends.",
|
||||
"boll_lb": "Bollinger Lower Band: Typically 2 standard deviations below the middle line. Usage: Indicates potential oversold conditions. Tips: Use additional analysis to avoid false reversal signals.",
|
||||
"atr": "ATR: Averages true range to measure volatility. Usage: Set stop-loss levels and adjust position sizes based on current market volatility. Tips: It's a reactive measure, so use it as part of a broader risk management strategy.",
|
||||
"vwma": "VWMA: A moving average weighted by volume. Usage: Confirm trends by integrating price action with volume data. Tips: Watch for skewed results from volume spikes; use in combination with other volume analyses."
|
||||
"vwma": "VWMA: A moving average weighted by volume. Usage: Confirm trends by integrating price action with volume data. Tips: Watch for skewed results from volume spikes; use in combination with other volume analyses.",
|
||||
}
|
||||
|
||||
if indicator not in supported_indicators:
|
||||
|
|
@ -61,93 +65,107 @@ def get_indicator(
|
|||
|
||||
try:
|
||||
if indicator == "close_50_sma":
|
||||
data = _make_api_request("SMA", {
|
||||
"symbol": symbol,
|
||||
"interval": interval,
|
||||
"time_period": "50",
|
||||
"series_type": series_type,
|
||||
"datatype": "csv"
|
||||
})
|
||||
data = _make_api_request(
|
||||
"SMA",
|
||||
{
|
||||
"symbol": symbol,
|
||||
"interval": interval,
|
||||
"time_period": "50",
|
||||
"series_type": series_type,
|
||||
"datatype": "csv",
|
||||
},
|
||||
)
|
||||
elif indicator == "close_200_sma":
|
||||
data = _make_api_request("SMA", {
|
||||
"symbol": symbol,
|
||||
"interval": interval,
|
||||
"time_period": "200",
|
||||
"series_type": series_type,
|
||||
"datatype": "csv"
|
||||
})
|
||||
data = _make_api_request(
|
||||
"SMA",
|
||||
{
|
||||
"symbol": symbol,
|
||||
"interval": interval,
|
||||
"time_period": "200",
|
||||
"series_type": series_type,
|
||||
"datatype": "csv",
|
||||
},
|
||||
)
|
||||
elif indicator == "close_10_ema":
|
||||
data = _make_api_request("EMA", {
|
||||
"symbol": symbol,
|
||||
"interval": interval,
|
||||
"time_period": "10",
|
||||
"series_type": series_type,
|
||||
"datatype": "csv"
|
||||
})
|
||||
elif indicator == "macd":
|
||||
data = _make_api_request("MACD", {
|
||||
"symbol": symbol,
|
||||
"interval": interval,
|
||||
"series_type": series_type,
|
||||
"datatype": "csv"
|
||||
})
|
||||
elif indicator == "macds":
|
||||
data = _make_api_request("MACD", {
|
||||
"symbol": symbol,
|
||||
"interval": interval,
|
||||
"series_type": series_type,
|
||||
"datatype": "csv"
|
||||
})
|
||||
elif indicator == "macdh":
|
||||
data = _make_api_request("MACD", {
|
||||
"symbol": symbol,
|
||||
"interval": interval,
|
||||
"series_type": series_type,
|
||||
"datatype": "csv"
|
||||
})
|
||||
data = _make_api_request(
|
||||
"EMA",
|
||||
{
|
||||
"symbol": symbol,
|
||||
"interval": interval,
|
||||
"time_period": "10",
|
||||
"series_type": series_type,
|
||||
"datatype": "csv",
|
||||
},
|
||||
)
|
||||
elif indicator == "macd" or indicator == "macds" or indicator == "macdh":
|
||||
data = _make_api_request(
|
||||
"MACD",
|
||||
{
|
||||
"symbol": symbol,
|
||||
"interval": interval,
|
||||
"series_type": series_type,
|
||||
"datatype": "csv",
|
||||
},
|
||||
)
|
||||
elif indicator == "rsi":
|
||||
data = _make_api_request("RSI", {
|
||||
"symbol": symbol,
|
||||
"interval": interval,
|
||||
"time_period": str(time_period),
|
||||
"series_type": series_type,
|
||||
"datatype": "csv"
|
||||
})
|
||||
data = _make_api_request(
|
||||
"RSI",
|
||||
{
|
||||
"symbol": symbol,
|
||||
"interval": interval,
|
||||
"time_period": str(time_period),
|
||||
"series_type": series_type,
|
||||
"datatype": "csv",
|
||||
},
|
||||
)
|
||||
elif indicator in ["boll", "boll_ub", "boll_lb"]:
|
||||
data = _make_api_request("BBANDS", {
|
||||
"symbol": symbol,
|
||||
"interval": interval,
|
||||
"time_period": "20",
|
||||
"series_type": series_type,
|
||||
"datatype": "csv"
|
||||
})
|
||||
data = _make_api_request(
|
||||
"BBANDS",
|
||||
{
|
||||
"symbol": symbol,
|
||||
"interval": interval,
|
||||
"time_period": "20",
|
||||
"series_type": series_type,
|
||||
"datatype": "csv",
|
||||
},
|
||||
)
|
||||
elif indicator == "atr":
|
||||
data = _make_api_request("ATR", {
|
||||
"symbol": symbol,
|
||||
"interval": interval,
|
||||
"time_period": str(time_period),
|
||||
"datatype": "csv"
|
||||
})
|
||||
data = _make_api_request(
|
||||
"ATR",
|
||||
{
|
||||
"symbol": symbol,
|
||||
"interval": interval,
|
||||
"time_period": str(time_period),
|
||||
"datatype": "csv",
|
||||
},
|
||||
)
|
||||
elif indicator == "vwma":
|
||||
return f"## VWMA (Volume Weighted Moving Average) for {symbol}:\n\nVWMA calculation requires OHLCV data and is not directly available from Alpha Vantage API.\nThis indicator would need to be calculated from the raw stock data using volume-weighted price averaging.\n\n{indicator_descriptions.get('vwma', 'No description available.')}"
|
||||
else:
|
||||
return f"Error: Indicator {indicator} not implemented yet."
|
||||
|
||||
lines = data.strip().split('\n')
|
||||
lines = data.strip().split("\n")
|
||||
if len(lines) < 2:
|
||||
return f"Error: No data returned for {indicator}"
|
||||
|
||||
header = [col.strip() for col in lines[0].split(',')]
|
||||
header = [col.strip() for col in lines[0].split(",")]
|
||||
try:
|
||||
date_col_idx = header.index('time')
|
||||
date_col_idx = header.index("time")
|
||||
except ValueError:
|
||||
return f"Error: 'time' column not found in data for {indicator}. Available columns: {header}"
|
||||
|
||||
col_name_map = {
|
||||
"macd": "MACD", "macds": "MACD_Signal", "macdh": "MACD_Hist",
|
||||
"boll": "Real Middle Band", "boll_ub": "Real Upper Band", "boll_lb": "Real Lower Band",
|
||||
"rsi": "RSI", "atr": "ATR", "close_10_ema": "EMA",
|
||||
"close_50_sma": "SMA", "close_200_sma": "SMA"
|
||||
"macd": "MACD",
|
||||
"macds": "MACD_Signal",
|
||||
"macdh": "MACD_Hist",
|
||||
"boll": "Real Middle Band",
|
||||
"boll_ub": "Real Upper Band",
|
||||
"boll_lb": "Real Lower Band",
|
||||
"rsi": "RSI",
|
||||
"atr": "ATR",
|
||||
"close_10_ema": "EMA",
|
||||
"close_50_sma": "SMA",
|
||||
"close_200_sma": "SMA",
|
||||
}
|
||||
|
||||
target_col_name = col_name_map.get(indicator)
|
||||
|
|
@ -164,7 +182,7 @@ def get_indicator(
|
|||
for line in lines[1:]:
|
||||
if not line.strip():
|
||||
continue
|
||||
values = line.split(',')
|
||||
values = line.split(",")
|
||||
if len(values) > value_col_idx:
|
||||
try:
|
||||
date_str = values[date_col_idx].strip()
|
||||
|
|
@ -195,5 +213,7 @@ def get_indicator(
|
|||
return result_str
|
||||
|
||||
except (ValueError, KeyError, IndexError, requests.RequestException) as e:
|
||||
logger.error("Error getting Alpha Vantage indicator data for %s: %s", indicator, e)
|
||||
logger.error(
|
||||
"Error getting Alpha Vantage indicator data for %s: %s", indicator, e
|
||||
)
|
||||
return f"Error retrieving {indicator} data: {str(e)}"
|
||||
|
|
|
|||
|
|
@ -1,11 +1,13 @@
|
|||
import json
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Dict, Any
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from .alpha_vantage_common import _make_api_request, format_datetime_for_api
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_news(ticker, start_date, end_date) -> dict[str, str] | str:
|
||||
params = {
|
||||
"tickers": ticker,
|
||||
|
|
@ -17,6 +19,7 @@ def get_news(ticker, start_date, end_date) -> dict[str, str] | str:
|
|||
|
||||
return _make_api_request("NEWS_SENTIMENT", params)
|
||||
|
||||
|
||||
def get_insider_transactions(symbol: str) -> dict[str, str] | str:
|
||||
params = {
|
||||
"symbol": symbol,
|
||||
|
|
@ -25,7 +28,7 @@ def get_insider_transactions(symbol: str) -> dict[str, str] | str:
|
|||
return _make_api_request("INSIDER_TRANSACTIONS", params)
|
||||
|
||||
|
||||
def get_bulk_news_alpha_vantage(lookback_hours: int) -> List[Dict[str, Any]]:
|
||||
def get_bulk_news_alpha_vantage(lookback_hours: int) -> list[dict[str, Any]]:
|
||||
end_date = datetime.now()
|
||||
start_date = end_date - timedelta(hours=lookback_hours)
|
||||
|
||||
|
|
@ -55,7 +58,9 @@ def get_bulk_news_alpha_vantage(lookback_hours: int) -> List[Dict[str, Any]]:
|
|||
|
||||
feed = response.get("feed", [])
|
||||
if not feed:
|
||||
logger.debug("Alpha Vantage feed empty. Keys in response: %s", list(response.keys()))
|
||||
logger.debug(
|
||||
"Alpha Vantage feed empty. Keys in response: %s", list(response.keys())
|
||||
)
|
||||
|
||||
articles = []
|
||||
for item in feed:
|
||||
|
|
|
|||
|
|
@ -1,11 +1,9 @@
|
|||
from datetime import datetime
|
||||
from .alpha_vantage_common import _make_api_request, _filter_csv_by_date_range
|
||||
|
||||
def get_stock(
|
||||
symbol: str,
|
||||
start_date: str,
|
||||
end_date: str
|
||||
) -> str:
|
||||
from .alpha_vantage_common import _filter_csv_by_date_range, _make_api_request
|
||||
|
||||
|
||||
def get_stock(symbol: str, start_date: str, end_date: str) -> str:
|
||||
"""
|
||||
Returns raw daily OHLCV values, adjusted close values, and historical split/dividend events
|
||||
filtered to the specified date range.
|
||||
|
|
@ -35,4 +33,4 @@ def get_stock(
|
|||
|
||||
response = _make_api_request("TIME_SERIES_DAILY_ADJUSTED", params)
|
||||
|
||||
return _filter_csv_by_date_range(response, start_date, end_date)
|
||||
return _filter_csv_by_date_range(response, start_date, end_date)
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
import logging
|
||||
import os
|
||||
import time
|
||||
import requests
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Dict, Any
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import requests
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -16,6 +17,7 @@ RETRY_BACKOFF = 1.0
|
|||
def get_api_key() -> str:
|
||||
try:
|
||||
from tradingagents.config import get_settings
|
||||
|
||||
return get_settings().require_api_key("brave")
|
||||
except ImportError:
|
||||
api_key = os.getenv("BRAVE_API_KEY")
|
||||
|
|
@ -24,14 +26,25 @@ def get_api_key() -> str:
|
|||
return api_key
|
||||
|
||||
|
||||
def _make_request_with_retry(url: str, headers: Dict, params: Dict, max_retries: int = MAX_RETRIES) -> requests.Response:
|
||||
def _make_request_with_retry(
|
||||
url: str, headers: dict, params: dict, max_retries: int = MAX_RETRIES
|
||||
) -> requests.Response:
|
||||
last_exception = None
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
response = requests.get(url, headers=headers, params=params, timeout=DEFAULT_TIMEOUT)
|
||||
response = requests.get(
|
||||
url, headers=headers, params=params, timeout=DEFAULT_TIMEOUT
|
||||
)
|
||||
if response.status_code == 429:
|
||||
retry_after = int(response.headers.get("Retry-After", RETRY_BACKOFF * (attempt + 1)))
|
||||
logger.debug("Brave rate limited, waiting %ds before retry %d/%d", retry_after, attempt + 1, max_retries)
|
||||
retry_after = int(
|
||||
response.headers.get("Retry-After", RETRY_BACKOFF * (attempt + 1))
|
||||
)
|
||||
logger.debug(
|
||||
"Brave rate limited, waiting %ds before retry %d/%d",
|
||||
retry_after,
|
||||
attempt + 1,
|
||||
max_retries,
|
||||
)
|
||||
time.sleep(retry_after)
|
||||
continue
|
||||
response.raise_for_status()
|
||||
|
|
@ -42,19 +55,30 @@ def _make_request_with_retry(url: str, headers: Dict, params: Dict, max_retries:
|
|||
time.sleep(RETRY_BACKOFF * (attempt + 1))
|
||||
except requests.exceptions.ConnectionError as e:
|
||||
last_exception = e
|
||||
logger.debug("Brave connection error, retry %d/%d", attempt + 1, max_retries)
|
||||
logger.debug(
|
||||
"Brave connection error, retry %d/%d", attempt + 1, max_retries
|
||||
)
|
||||
time.sleep(RETRY_BACKOFF * (attempt + 1))
|
||||
except requests.exceptions.HTTPError as e:
|
||||
if e.response is not None and e.response.status_code >= 500:
|
||||
last_exception = e
|
||||
logger.debug("Brave server error %d, retry %d/%d", e.response.status_code, attempt + 1, max_retries)
|
||||
logger.debug(
|
||||
"Brave server error %d, retry %d/%d",
|
||||
e.response.status_code,
|
||||
attempt + 1,
|
||||
max_retries,
|
||||
)
|
||||
time.sleep(RETRY_BACKOFF * (attempt + 1))
|
||||
else:
|
||||
raise
|
||||
raise last_exception if last_exception else requests.exceptions.RequestException("Max retries exceeded")
|
||||
raise (
|
||||
last_exception
|
||||
if last_exception
|
||||
else requests.exceptions.RequestException("Max retries exceeded")
|
||||
)
|
||||
|
||||
|
||||
def get_bulk_news_brave(lookback_hours: int) -> List[Dict[str, Any]]:
|
||||
def get_bulk_news_brave(lookback_hours: int) -> list[dict[str, Any]]:
|
||||
try:
|
||||
api_key = get_api_key()
|
||||
except ValueError as e:
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
from typing import Dict, Optional
|
||||
|
||||
from tradingagents.config import get_settings, update_settings
|
||||
|
||||
_config: Optional[Dict] = None
|
||||
DATA_DIR: Optional[str] = None
|
||||
_config: dict | None = None
|
||||
DATA_DIR: str | None = None
|
||||
|
||||
|
||||
def initialize_config():
|
||||
|
|
@ -13,7 +14,7 @@ def initialize_config():
|
|||
DATA_DIR = _config["data_dir"]
|
||||
|
||||
|
||||
def set_config(config: Dict):
|
||||
def set_config(config: dict):
|
||||
global _config, DATA_DIR
|
||||
|
||||
settings = get_settings()
|
||||
|
|
@ -25,7 +26,7 @@ def set_config(config: Dict):
|
|||
DATA_DIR = _config["data_dir"]
|
||||
|
||||
|
||||
def get_config() -> Dict:
|
||||
def get_config() -> dict:
|
||||
global _config
|
||||
if _config is None:
|
||||
initialize_config()
|
||||
|
|
|
|||
|
|
@ -1,10 +1,12 @@
|
|||
import logging
|
||||
import re
|
||||
import requests
|
||||
from typing import Annotated, List, Dict, Any
|
||||
from datetime import datetime, timedelta
|
||||
from dateutil.relativedelta import relativedelta
|
||||
from typing import Annotated, Any, Dict, List
|
||||
|
||||
import requests
|
||||
from dateutil import parser as dateutil_parser
|
||||
from dateutil.relativedelta import relativedelta
|
||||
|
||||
from .googlenews_utils import getNewsData
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -75,7 +77,7 @@ def get_google_news(
|
|||
return f"## {query} Google News, from {before} to {curr_date}:\n\n{news_str}"
|
||||
|
||||
|
||||
def get_bulk_news_google(lookback_hours: int) -> List[Dict[str, Any]]:
|
||||
def get_bulk_news_google(lookback_hours: int) -> list[dict[str, Any]]:
|
||||
end_date = datetime.now()
|
||||
start_date = end_date - timedelta(hours=lookback_hours)
|
||||
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue