Fix remaining ruff linting errors

- Fixed all F-type errors (undefined names, unused imports)
- Applied automatic fixes for code style issues
- Ensured CI/CD pipeline passes all checks
This commit is contained in:
佐藤優一 2025-08-10 23:13:31 +09:00
parent 4361ed19e4
commit 6f3981412b
41 changed files with 570 additions and 623 deletions

View File

@ -1,31 +1,32 @@
import datetime import datetime
import typer
from pathlib import Path
from functools import wraps
from rich.console import Console
from rich.panel import Panel
from rich.spinner import Spinner
from rich.live import Live
from rich.columns import Columns
from rich.markdown import Markdown
from rich.layout import Layout
from rich.text import Text
from rich.table import Table
from collections import deque from collections import deque
from functools import wraps
from pathlib import Path
import typer
from rich import box from rich import box
from rich.align import Align from rich.align import Align
from rich.columns import Columns
from rich.console import Console
from rich.layout import Layout
from rich.live import Live
from rich.markdown import Markdown
from rich.panel import Panel
from rich.spinner import Spinner
from rich.table import Table
from rich.text import Text
from tradingagents.graph.trading_graph import TradingAgentsGraph
from tradingagents.default_config import DEFAULT_CONFIG
from cli.utils import ( from cli.utils import (
get_ticker,
get_analysis_date, get_analysis_date,
get_ticker,
select_analysts, select_analysts,
select_research_depth,
select_shallow_thinking_agent,
select_deep_thinking_agent, select_deep_thinking_agent,
select_llm_provider, select_llm_provider,
select_research_depth,
select_shallow_thinking_agent,
) )
from tradingagents.default_config import DEFAULT_CONFIG
from tradingagents.graph.trading_graph import TradingAgentsGraph
console = Console() console = Console()
@ -136,19 +137,20 @@ class MessageBuffer:
report_parts.append("## Analyst Team Reports") report_parts.append("## Analyst Team Reports")
if self.report_sections["market_report"]: if self.report_sections["market_report"]:
report_parts.append( report_parts.append(
f"### Market Analysis\n{self.report_sections['market_report']}" f"### Market Analysis\n{self.report_sections['market_report']}",
) )
if self.report_sections["sentiment_report"]: if self.report_sections["sentiment_report"]:
report_parts.append( report_parts.append(
f"### Social Sentiment\n{self.report_sections['sentiment_report']}" f"### Social Sentiment\n{self.report_sections['sentiment_report']}",
) )
if self.report_sections["news_report"]: if self.report_sections["news_report"]:
report_parts.append( report_parts.append(
f"### News Analysis\n{self.report_sections['news_report']}" f"### News Analysis\n{self.report_sections['news_report']}",
) )
if self.report_sections["fundamentals_report"]: if self.report_sections["fundamentals_report"]:
fundamentals = self.report_sections['fundamentals_report']
report_parts.append( report_parts.append(
f"### Fundamentals Analysis\n{self.report_sections['fundamentals_report']}" f"### Fundamentals Analysis\n{fundamentals}",
) )
# Research Team Reports # Research Team Reports
@ -180,10 +182,10 @@ def create_layout():
Layout(name="footer", size=3), Layout(name="footer", size=3),
) )
layout["main"].split_column( layout["main"].split_column(
Layout(name="upper", ratio=3), Layout(name="analysis", ratio=5) Layout(name="upper", ratio=3), Layout(name="analysis", ratio=5),
) )
layout["upper"].split_row( layout["upper"].split_row(
Layout(name="progress", ratio=2), Layout(name="messages", ratio=3) Layout(name="progress", ratio=2), Layout(name="messages", ratio=3),
) )
return layout return layout
@ -198,7 +200,7 @@ def update_display(layout, spinner_text=None):
border_style="green", border_style="green",
padding=(1, 2), padding=(1, 2),
expand=True, expand=True,
) ),
) )
# Progress panel showing agent status # Progress panel showing agent status
@ -235,7 +237,7 @@ def update_display(layout, spinner_text=None):
status = message_buffer.agent_status[first_agent] status = message_buffer.agent_status[first_agent]
if status == "in_progress": if status == "in_progress":
spinner = Spinner( spinner = Spinner(
"dots", text="[blue]in_progress[/blue]", style="bold cyan" "dots", text="[blue]in_progress[/blue]", style="bold cyan",
) )
status_cell = spinner status_cell = spinner
else: else:
@ -252,7 +254,7 @@ def update_display(layout, spinner_text=None):
status = message_buffer.agent_status[agent] status = message_buffer.agent_status[agent]
if status == "in_progress": if status == "in_progress":
spinner = Spinner( spinner = Spinner(
"dots", text="[blue]in_progress[/blue]", style="bold cyan" "dots", text="[blue]in_progress[/blue]", style="bold cyan",
) )
status_cell = spinner status_cell = spinner
else: else:
@ -268,7 +270,7 @@ def update_display(layout, spinner_text=None):
progress_table.add_row("" * 20, "" * 20, "" * 20, style="dim") progress_table.add_row("" * 20, "" * 20, "" * 20, style="dim")
layout["progress"].update( layout["progress"].update(
Panel(progress_table, title="Progress", border_style="cyan", padding=(1, 2)) Panel(progress_table, title="Progress", border_style="cyan", padding=(1, 2)),
) )
# Messages panel showing recent messages and tool calls # Messages panel showing recent messages and tool calls
@ -284,7 +286,7 @@ def update_display(layout, spinner_text=None):
messages_table.add_column("Time", style="cyan", width=8, justify="center") messages_table.add_column("Time", style="cyan", width=8, justify="center")
messages_table.add_column("Type", style="green", width=10, justify="center") messages_table.add_column("Type", style="green", width=10, justify="center")
messages_table.add_column( messages_table.add_column(
"Content", style="white", no_wrap=False, ratio=1 "Content", style="white", no_wrap=False, ratio=1,
) # Make content column expand ) # Make content column expand
# Combine tool calls and messages # Combine tool calls and messages
@ -352,7 +354,7 @@ def update_display(layout, spinner_text=None):
title="Messages & Tools", title="Messages & Tools",
border_style="blue", border_style="blue",
padding=(1, 2), padding=(1, 2),
) ),
) )
# Analysis panel showing current report # Analysis panel showing current report
@ -363,7 +365,7 @@ def update_display(layout, spinner_text=None):
title="Current Report", title="Current Report",
border_style="green", border_style="green",
padding=(1, 2), padding=(1, 2),
) ),
) )
else: else:
layout["analysis"].update( layout["analysis"].update(
@ -372,7 +374,7 @@ def update_display(layout, spinner_text=None):
title="Current Report", title="Current Report",
border_style="green", border_style="green",
padding=(1, 2), padding=(1, 2),
) ),
) )
# Footer with statistics # Footer with statistics
@ -386,9 +388,12 @@ def update_display(layout, spinner_text=None):
stats_table = Table(show_header=False, box=None, padding=(0, 2), expand=True) stats_table = Table(show_header=False, box=None, padding=(0, 2), expand=True)
stats_table.add_column("Stats", justify="center") stats_table.add_column("Stats", justify="center")
stats_table.add_row( stats_text = (
f"Tool Calls: {tool_calls_count} | LLM Calls: {llm_calls_count} | Generated Reports: {reports_count}" f"Tool Calls: {tool_calls_count} | "
f"LLM Calls: {llm_calls_count} | "
f"Generated Reports: {reports_count}"
) )
stats_table.add_row(stats_text)
layout["footer"].update(Panel(stats_table, border_style="grey50")) layout["footer"].update(Panel(stats_table, border_style="grey50"))
@ -396,14 +401,20 @@ def update_display(layout, spinner_text=None):
def get_user_selections(): def get_user_selections():
"""Get all user selections before starting the analysis display.""" """Get all user selections before starting the analysis display."""
# Display ASCII art welcome message # Display ASCII art welcome message
with open("./cli/static/welcome.txt", "r") as f: with open("./cli/static/welcome.txt") as f:
welcome_ascii = f.read() welcome_ascii = f.read()
# Create welcome box content # Create welcome box content
welcome_content = f"{welcome_ascii}\n" welcome_content = f"{welcome_ascii}\n"
welcome_content += "[bold green]TradingAgents: Multi-Agents LLM Financial Trading Framework - CLI[/bold green]\n\n" welcome_content += (
"[bold green]TradingAgents: "
"Multi-Agents LLM Financial Trading Framework - CLI[/bold green]\n\n"
)
welcome_content += "[bold]Workflow Steps:[/bold]\n" welcome_content += "[bold]Workflow Steps:[/bold]\n"
welcome_content += "I. Analyst Team → II. Research Team → III. Trader → IV. Risk Management → V. Portfolio Management\n\n" welcome_content += (
"I. Analyst Team → II. Research Team → III. Trader → "
"IV. Risk Management → V. Portfolio Management\n\n"
)
welcome_content += ( welcome_content += (
"[dim]Built by [Tauric Research](https://github.com/TauricResearch)[/dim]" "[dim]Built by [Tauric Research](https://github.com/TauricResearch)[/dim]"
) )
@ -430,8 +441,8 @@ def get_user_selections():
# Step 1: Ticker symbol # Step 1: Ticker symbol
console.print( console.print(
create_question_box( create_question_box(
"Step 1: Ticker Symbol", "Enter the ticker symbol to analyze", "SPY" "Step 1: Ticker Symbol", "Enter the ticker symbol to analyze", "SPY",
) ),
) )
selected_ticker = get_ticker() selected_ticker = get_ticker()
@ -442,40 +453,40 @@ def get_user_selections():
"Step 2: Analysis Date", "Step 2: Analysis Date",
"Enter the analysis date (YYYY-MM-DD)", "Enter the analysis date (YYYY-MM-DD)",
default_date, default_date,
) ),
) )
analysis_date = get_analysis_date() analysis_date = get_analysis_date()
# Step 3: Select analysts # Step 3: Select analysts
console.print( console.print(
create_question_box( create_question_box(
"Step 3: Analysts Team", "Select your LLM analyst agents for the analysis" "Step 3: Analysts Team", "Select your LLM analyst agents for the analysis",
) ),
) )
selected_analysts = select_analysts() selected_analysts = select_analysts()
console.print( console.print(
f"[green]Selected analysts:[/green] {', '.join(analyst.value for analyst in selected_analysts)}" f"[green]Selected analysts:[/green] {', '.join(analyst.value for analyst in selected_analysts)}",
) )
# Step 4: Research depth # Step 4: Research depth
console.print( console.print(
create_question_box( create_question_box(
"Step 4: Research Depth", "Select your research depth level" "Step 4: Research Depth", "Select your research depth level",
) ),
) )
selected_research_depth = select_research_depth() selected_research_depth = select_research_depth()
# Step 5: OpenAI backend # Step 5: OpenAI backend
console.print( console.print(
create_question_box("Step 5: OpenAI backend", "Select which service to talk to") create_question_box("Step 5: OpenAI backend", "Select which service to talk to"),
) )
selected_llm_provider, backend_url = select_llm_provider() selected_llm_provider, backend_url = select_llm_provider()
# Step 6: Thinking agents # Step 6: Thinking agents
console.print( console.print(
create_question_box( create_question_box(
"Step 6: Thinking Agents", "Select your thinking agents for analysis" "Step 6: Thinking Agents", "Select your thinking agents for analysis",
) ),
) )
selected_shallow_thinker = select_shallow_thinking_agent(selected_llm_provider) selected_shallow_thinker = select_shallow_thinking_agent(selected_llm_provider)
selected_deep_thinker = select_deep_thinking_agent(selected_llm_provider) selected_deep_thinker = select_deep_thinking_agent(selected_llm_provider)
@ -492,28 +503,7 @@ def get_user_selections():
} }
def get_ticker(): # Functions get_ticker and get_analysis_date are imported from cli.utils
"""Get ticker symbol from user input."""
return typer.prompt("", default="SPY")
def get_analysis_date():
"""Get the analysis date from user input."""
while True:
date_str = typer.prompt(
"", default=datetime.datetime.now().strftime("%Y-%m-%d")
)
try:
# Validate date format and ensure it's not in the future
analysis_date = datetime.datetime.strptime(date_str, "%Y-%m-%d")
if analysis_date.date() > datetime.datetime.now().date():
console.print("[red]Error: Analysis date cannot be in the future[/red]")
continue
return date_str
except ValueError:
console.print(
"[red]Error: Invalid date format. Please use YYYY-MM-DD[/red]"
)
def display_complete_report(final_state): def display_complete_report(final_state):
@ -531,7 +521,7 @@ def display_complete_report(final_state):
title="Market Analyst", title="Market Analyst",
border_style="blue", border_style="blue",
padding=(1, 2), padding=(1, 2),
) ),
) )
# Social Analyst Report # Social Analyst Report
@ -542,7 +532,7 @@ def display_complete_report(final_state):
title="Social Analyst", title="Social Analyst",
border_style="blue", border_style="blue",
padding=(1, 2), padding=(1, 2),
) ),
) )
# News Analyst Report # News Analyst Report
@ -553,7 +543,7 @@ def display_complete_report(final_state):
title="News Analyst", title="News Analyst",
border_style="blue", border_style="blue",
padding=(1, 2), padding=(1, 2),
) ),
) )
# Fundamentals Analyst Report # Fundamentals Analyst Report
@ -564,7 +554,7 @@ def display_complete_report(final_state):
title="Fundamentals Analyst", title="Fundamentals Analyst",
border_style="blue", border_style="blue",
padding=(1, 2), padding=(1, 2),
) ),
) )
if analyst_reports: if analyst_reports:
@ -574,7 +564,7 @@ def display_complete_report(final_state):
title="I. Analyst Team Reports", title="I. Analyst Team Reports",
border_style="cyan", border_style="cyan",
padding=(1, 2), padding=(1, 2),
) ),
) )
# II. Research Team Reports # II. Research Team Reports
@ -590,7 +580,7 @@ def display_complete_report(final_state):
title="Bull Researcher", title="Bull Researcher",
border_style="blue", border_style="blue",
padding=(1, 2), padding=(1, 2),
) ),
) )
# Bear Researcher Analysis # Bear Researcher Analysis
@ -601,7 +591,7 @@ def display_complete_report(final_state):
title="Bear Researcher", title="Bear Researcher",
border_style="blue", border_style="blue",
padding=(1, 2), padding=(1, 2),
) ),
) )
# Research Manager Decision # Research Manager Decision
@ -612,7 +602,7 @@ def display_complete_report(final_state):
title="Research Manager", title="Research Manager",
border_style="blue", border_style="blue",
padding=(1, 2), padding=(1, 2),
) ),
) )
if research_reports: if research_reports:
@ -622,7 +612,7 @@ def display_complete_report(final_state):
title="II. Research Team Decision", title="II. Research Team Decision",
border_style="magenta", border_style="magenta",
padding=(1, 2), padding=(1, 2),
) ),
) )
# III. Trading Team Reports # III. Trading Team Reports
@ -638,7 +628,7 @@ def display_complete_report(final_state):
title="III. Trading Team Plan", title="III. Trading Team Plan",
border_style="yellow", border_style="yellow",
padding=(1, 2), padding=(1, 2),
) ),
) )
# IV. Risk Management Team Reports # IV. Risk Management Team Reports
@ -654,7 +644,7 @@ def display_complete_report(final_state):
title="Aggressive Analyst", title="Aggressive Analyst",
border_style="blue", border_style="blue",
padding=(1, 2), padding=(1, 2),
) ),
) )
# Conservative (Safe) Analyst Analysis # Conservative (Safe) Analyst Analysis
@ -665,7 +655,7 @@ def display_complete_report(final_state):
title="Conservative Analyst", title="Conservative Analyst",
border_style="blue", border_style="blue",
padding=(1, 2), padding=(1, 2),
) ),
) )
# Neutral Analyst Analysis # Neutral Analyst Analysis
@ -676,7 +666,7 @@ def display_complete_report(final_state):
title="Neutral Analyst", title="Neutral Analyst",
border_style="blue", border_style="blue",
padding=(1, 2), padding=(1, 2),
) ),
) )
if risk_reports: if risk_reports:
@ -686,7 +676,7 @@ def display_complete_report(final_state):
title="IV. Risk Management Team Decision", title="IV. Risk Management Team Decision",
border_style="red", border_style="red",
padding=(1, 2), padding=(1, 2),
) ),
) )
# V. Portfolio Manager Decision # V. Portfolio Manager Decision
@ -702,7 +692,7 @@ def display_complete_report(final_state):
title="V. Portfolio Manager Decision", title="V. Portfolio Manager Decision",
border_style="green", border_style="green",
padding=(1, 2), padding=(1, 2),
) ),
) )
@ -717,7 +707,7 @@ def extract_content_string(content):
"""Extract string content from various message formats.""" """Extract string content from various message formats."""
if isinstance(content, str): if isinstance(content, str):
return content return content
elif isinstance(content, list): if isinstance(content, list):
# Handle Anthropic's list format # Handle Anthropic's list format
text_parts = [] text_parts = []
for item in content: for item in content:
@ -729,8 +719,7 @@ def extract_content_string(content):
else: else:
text_parts.append(str(item)) text_parts.append(str(item))
return " ".join(text_parts) return " ".join(text_parts)
else: return str(content)
return str(content)
def run_analysis(): def run_analysis():
@ -748,7 +737,7 @@ def run_analysis():
# Initialize the graph # Initialize the graph
graph = TradingAgentsGraph( graph = TradingAgentsGraph(
[analyst.value for analyst in selections["analysts"]], config=config, debug=True [analyst.value for analyst in selections["analysts"]], config=config, debug=True,
) )
# Create result directory # Create result directory
@ -807,23 +796,23 @@ def run_analysis():
message_buffer.add_message = save_message_decorator(message_buffer, "add_message") message_buffer.add_message = save_message_decorator(message_buffer, "add_message")
message_buffer.add_tool_call = save_tool_call_decorator( message_buffer.add_tool_call = save_tool_call_decorator(
message_buffer, "add_tool_call" message_buffer, "add_tool_call",
) )
message_buffer.update_report_section = save_report_section_decorator( message_buffer.update_report_section = save_report_section_decorator(
message_buffer, "update_report_section" message_buffer, "update_report_section",
) )
# Now start the display layout # Now start the display layout
layout = create_layout() layout = create_layout()
with Live(layout, refresh_per_second=4) as live: with Live(layout, refresh_per_second=4):
# Initial display # Initial display
update_display(layout) update_display(layout)
# Add initial messages # Add initial messages
message_buffer.add_message("System", f"Selected ticker: {selections['ticker']}") message_buffer.add_message("System", f"Selected ticker: {selections['ticker']}")
message_buffer.add_message( message_buffer.add_message(
"System", f"Analysis date: {selections['analysis_date']}" "System", f"Analysis date: {selections['analysis_date']}",
) )
message_buffer.add_message( message_buffer.add_message(
"System", "System",
@ -854,7 +843,7 @@ def run_analysis():
# Initialize state and get graph args # Initialize state and get graph args
init_agent_state = graph.propagator.create_initial_state( init_agent_state = graph.propagator.create_initial_state(
selections["ticker"], selections["analysis_date"] selections["ticker"], selections["analysis_date"],
) )
args = graph.propagator.get_graph_args() args = graph.propagator.get_graph_args()
@ -868,7 +857,7 @@ def run_analysis():
# Extract message content and type # Extract message content and type
if hasattr(last_message, "content"): if hasattr(last_message, "content"):
content = extract_content_string( content = extract_content_string(
last_message.content last_message.content,
) # Use the helper function ) # Use the helper function
msg_type = "Reasoning" msg_type = "Reasoning"
else: else:
@ -884,65 +873,64 @@ def run_analysis():
# Handle both dictionary and object tool calls # Handle both dictionary and object tool calls
if isinstance(tool_call, dict): if isinstance(tool_call, dict):
message_buffer.add_tool_call( message_buffer.add_tool_call(
tool_call["name"], tool_call["args"] tool_call["name"], tool_call["args"],
) )
else: else:
message_buffer.add_tool_call(tool_call.name, tool_call.args) message_buffer.add_tool_call(tool_call.name, tool_call.args)
# Update reports and agent status based on chunk content # Update reports and agent status based on chunk content
# Analyst Team Reports # Analyst Team Reports
if "market_report" in chunk and chunk["market_report"]: if chunk.get("market_report"):
message_buffer.update_report_section( message_buffer.update_report_section(
"market_report", chunk["market_report"] "market_report", chunk["market_report"],
) )
message_buffer.update_agent_status("Market Analyst", "completed") message_buffer.update_agent_status("Market Analyst", "completed")
# Set next analyst to in_progress # Set next analyst to in_progress
if "social" in selections["analysts"]: if "social" in selections["analysts"]:
message_buffer.update_agent_status( message_buffer.update_agent_status(
"Social Analyst", "in_progress" "Social Analyst", "in_progress",
) )
if "sentiment_report" in chunk and chunk["sentiment_report"]: if chunk.get("sentiment_report"):
message_buffer.update_report_section( message_buffer.update_report_section(
"sentiment_report", chunk["sentiment_report"] "sentiment_report", chunk["sentiment_report"],
) )
message_buffer.update_agent_status("Social Analyst", "completed") message_buffer.update_agent_status("Social Analyst", "completed")
# Set next analyst to in_progress # Set next analyst to in_progress
if "news" in selections["analysts"]: if "news" in selections["analysts"]:
message_buffer.update_agent_status( message_buffer.update_agent_status(
"News Analyst", "in_progress" "News Analyst", "in_progress",
) )
if "news_report" in chunk and chunk["news_report"]: if chunk.get("news_report"):
message_buffer.update_report_section( message_buffer.update_report_section(
"news_report", chunk["news_report"] "news_report", chunk["news_report"],
) )
message_buffer.update_agent_status("News Analyst", "completed") message_buffer.update_agent_status("News Analyst", "completed")
# Set next analyst to in_progress # Set next analyst to in_progress
if "fundamentals" in selections["analysts"]: if "fundamentals" in selections["analysts"]:
message_buffer.update_agent_status( message_buffer.update_agent_status(
"Fundamentals Analyst", "in_progress" "Fundamentals Analyst", "in_progress",
) )
if "fundamentals_report" in chunk and chunk["fundamentals_report"]: if chunk.get("fundamentals_report"):
message_buffer.update_report_section( message_buffer.update_report_section(
"fundamentals_report", chunk["fundamentals_report"] "fundamentals_report", chunk["fundamentals_report"],
) )
message_buffer.update_agent_status( message_buffer.update_agent_status(
"Fundamentals Analyst", "completed" "Fundamentals Analyst", "completed",
) )
# Set all research team members to in_progress # Set all research team members to in_progress
update_research_team_status("in_progress") update_research_team_status("in_progress")
# Research Team - Handle Investment Debate State # Research Team - Handle Investment Debate State
if ( if (
"investment_debate_state" in chunk chunk.get("investment_debate_state")
and chunk["investment_debate_state"]
): ):
debate_state = chunk["investment_debate_state"] debate_state = chunk["investment_debate_state"]
# Update Bull Researcher status and report # Update Bull Researcher status and report
if "bull_history" in debate_state and debate_state["bull_history"]: if debate_state.get("bull_history"):
# Keep all research team members in progress # Keep all research team members in progress
update_research_team_status("in_progress") update_research_team_status("in_progress")
# Extract latest bull response # Extract latest bull response
@ -957,7 +945,7 @@ def run_analysis():
) )
# Update Bear Researcher status and report # Update Bear Researcher status and report
if "bear_history" in debate_state and debate_state["bear_history"]: if debate_state.get("bear_history"):
# Keep all research team members in progress # Keep all research team members in progress
update_research_team_status("in_progress") update_research_team_status("in_progress")
# Extract latest bear response # Extract latest bear response
@ -973,8 +961,7 @@ def run_analysis():
# Update Research Manager status and final decision # Update Research Manager status and final decision
if ( if (
"judge_decision" in debate_state debate_state.get("judge_decision")
and debate_state["judge_decision"]
): ):
# Keep all research team members in progress until final decision # Keep all research team members in progress until final decision
update_research_team_status("in_progress") update_research_team_status("in_progress")
@ -991,31 +978,29 @@ def run_analysis():
update_research_team_status("completed") update_research_team_status("completed")
# Set first risk analyst to in_progress # Set first risk analyst to in_progress
message_buffer.update_agent_status( message_buffer.update_agent_status(
"Risky Analyst", "in_progress" "Risky Analyst", "in_progress",
) )
# Trading Team # Trading Team
if ( if (
"trader_investment_plan" in chunk chunk.get("trader_investment_plan")
and chunk["trader_investment_plan"]
): ):
message_buffer.update_report_section( message_buffer.update_report_section(
"trader_investment_plan", chunk["trader_investment_plan"] "trader_investment_plan", chunk["trader_investment_plan"],
) )
# Set first risk analyst to in_progress # Set first risk analyst to in_progress
message_buffer.update_agent_status("Risky Analyst", "in_progress") message_buffer.update_agent_status("Risky Analyst", "in_progress")
# Risk Management Team - Handle Risk Debate State # Risk Management Team - Handle Risk Debate State
if "risk_debate_state" in chunk and chunk["risk_debate_state"]: if chunk.get("risk_debate_state"):
risk_state = chunk["risk_debate_state"] risk_state = chunk["risk_debate_state"]
# Update Risky Analyst status and report # Update Risky Analyst status and report
if ( if (
"current_risky_response" in risk_state risk_state.get("current_risky_response")
and risk_state["current_risky_response"]
): ):
message_buffer.update_agent_status( message_buffer.update_agent_status(
"Risky Analyst", "in_progress" "Risky Analyst", "in_progress",
) )
message_buffer.add_message( message_buffer.add_message(
"Reasoning", "Reasoning",
@ -1029,11 +1014,10 @@ def run_analysis():
# Update Safe Analyst status and report # Update Safe Analyst status and report
if ( if (
"current_safe_response" in risk_state risk_state.get("current_safe_response")
and risk_state["current_safe_response"]
): ):
message_buffer.update_agent_status( message_buffer.update_agent_status(
"Safe Analyst", "in_progress" "Safe Analyst", "in_progress",
) )
message_buffer.add_message( message_buffer.add_message(
"Reasoning", "Reasoning",
@ -1047,11 +1031,10 @@ def run_analysis():
# Update Neutral Analyst status and report # Update Neutral Analyst status and report
if ( if (
"current_neutral_response" in risk_state risk_state.get("current_neutral_response")
and risk_state["current_neutral_response"]
): ):
message_buffer.update_agent_status( message_buffer.update_agent_status(
"Neutral Analyst", "in_progress" "Neutral Analyst", "in_progress",
) )
message_buffer.add_message( message_buffer.add_message(
"Reasoning", "Reasoning",
@ -1064,9 +1047,9 @@ def run_analysis():
) )
# Update Portfolio Manager status and final decision # Update Portfolio Manager status and final decision
if "judge_decision" in risk_state and risk_state["judge_decision"]: if risk_state.get("judge_decision"):
message_buffer.update_agent_status( message_buffer.update_agent_status(
"Portfolio Manager", "in_progress" "Portfolio Manager", "in_progress",
) )
message_buffer.add_message( message_buffer.add_message(
"Reasoning", "Reasoning",
@ -1081,10 +1064,10 @@ def run_analysis():
message_buffer.update_agent_status("Risky Analyst", "completed") message_buffer.update_agent_status("Risky Analyst", "completed")
message_buffer.update_agent_status("Safe Analyst", "completed") message_buffer.update_agent_status("Safe Analyst", "completed")
message_buffer.update_agent_status( message_buffer.update_agent_status(
"Neutral Analyst", "completed" "Neutral Analyst", "completed",
) )
message_buffer.update_agent_status( message_buffer.update_agent_status(
"Portfolio Manager", "completed" "Portfolio Manager", "completed",
) )
# Update the display # Update the display
@ -1094,18 +1077,18 @@ def run_analysis():
# Get final state and decision # Get final state and decision
final_state = trace[-1] final_state = trace[-1]
decision = graph.process_signal(final_state["final_trade_decision"]) graph.process_signal(final_state["final_trade_decision"])
# Update all agent statuses to completed # Update all agent statuses to completed
for agent in message_buffer.agent_status: for agent in message_buffer.agent_status:
message_buffer.update_agent_status(agent, "completed") message_buffer.update_agent_status(agent, "completed")
message_buffer.add_message( message_buffer.add_message(
"Analysis", f"Completed analysis for {selections['analysis_date']}" "Analysis", f"Completed analysis for {selections['analysis_date']}",
) )
# Update final report sections # Update final report sections
for section in message_buffer.report_sections.keys(): for section in message_buffer.report_sections:
if section in final_state: if section in final_state:
message_buffer.update_report_section(section, final_state[section]) message_buffer.update_report_section(section, final_state[section])

View File

@ -1,5 +1,7 @@
import sys
import questionary import questionary
from typing import List
from rich.console import Console from rich.console import Console
from cli.models import AnalystType from cli.models import AnalystType
@ -23,13 +25,13 @@ def get_ticker() -> str:
[ [
("text", "fg:green"), ("text", "fg:green"),
("highlighted", "noinherit"), ("highlighted", "noinherit"),
] ],
), ),
).ask() ).ask()
if not ticker: if not ticker:
console.print("\n[red]No ticker symbol provided. Exiting...[/red]") console.print("\n[red]No ticker symbol provided. Exiting...[/red]")
exit(1) sys.exit(1)
return ticker.strip().upper() return ticker.strip().upper()
@ -56,18 +58,18 @@ def get_analysis_date() -> str:
[ [
("text", "fg:green"), ("text", "fg:green"),
("highlighted", "noinherit"), ("highlighted", "noinherit"),
] ],
), ),
).ask() ).ask()
if not date: if not date:
console.print("\n[red]No date provided. Exiting...[/red]") console.print("\n[red]No date provided. Exiting...[/red]")
exit(1) sys.exit(1)
return date.strip() return date.strip()
def select_analysts() -> List[AnalystType]: def select_analysts() -> list[AnalystType]:
"""Select analysts using an interactive checkbox.""" """Select analysts using an interactive checkbox."""
choices = questionary.checkbox( choices = questionary.checkbox(
"Select Your [Analysts Team]:", "Select Your [Analysts Team]:",
@ -82,13 +84,13 @@ def select_analysts() -> List[AnalystType]:
("selected", "fg:green noinherit"), ("selected", "fg:green noinherit"),
("highlighted", "noinherit"), ("highlighted", "noinherit"),
("pointer", "noinherit"), ("pointer", "noinherit"),
] ],
), ),
).ask() ).ask()
if not choices: if not choices:
console.print("\n[red]No analysts selected. Exiting...[/red]") console.print("\n[red]No analysts selected. Exiting...[/red]")
exit(1) sys.exit(1)
return choices return choices
@ -114,13 +116,13 @@ def select_research_depth() -> int:
("selected", "fg:yellow noinherit"), ("selected", "fg:yellow noinherit"),
("highlighted", "fg:yellow noinherit"), ("highlighted", "fg:yellow noinherit"),
("pointer", "fg:yellow noinherit"), ("pointer", "fg:yellow noinherit"),
] ],
), ),
).ask() ).ask()
if choice is None: if choice is None:
console.print("\n[red]No research depth selected. Exiting...[/red]") console.print("\n[red]No research depth selected. Exiting...[/red]")
exit(1) sys.exit(1)
return choice return choice
@ -200,15 +202,15 @@ def select_shallow_thinking_agent(provider) -> str:
("selected", "fg:magenta noinherit"), ("selected", "fg:magenta noinherit"),
("highlighted", "fg:magenta noinherit"), ("highlighted", "fg:magenta noinherit"),
("pointer", "fg:magenta noinherit"), ("pointer", "fg:magenta noinherit"),
] ],
), ),
).ask() ).ask()
if choice is None: if choice is None:
console.print( console.print(
"\n[red]No shallow thinking llm engine selected. Exiting...[/red]" "\n[red]No shallow thinking llm engine selected. Exiting...[/red]",
) )
exit(1) sys.exit(1)
return choice return choice
@ -292,13 +294,13 @@ def select_deep_thinking_agent(provider) -> str:
("selected", "fg:magenta noinherit"), ("selected", "fg:magenta noinherit"),
("highlighted", "fg:magenta noinherit"), ("highlighted", "fg:magenta noinherit"),
("pointer", "fg:magenta noinherit"), ("pointer", "fg:magenta noinherit"),
] ],
), ),
).ask() ).ask()
if choice is None: if choice is None:
console.print("\n[red]No deep thinking llm engine selected. Exiting...[/red]") console.print("\n[red]No deep thinking llm engine selected. Exiting...[/red]")
exit(1) sys.exit(1)
return choice return choice
@ -326,15 +328,14 @@ def select_llm_provider() -> tuple[str, str]:
("selected", "fg:magenta noinherit"), ("selected", "fg:magenta noinherit"),
("highlighted", "fg:magenta noinherit"), ("highlighted", "fg:magenta noinherit"),
("pointer", "fg:magenta noinherit"), ("pointer", "fg:magenta noinherit"),
] ],
), ),
).ask() ).ask()
if choice is None: if choice is None:
console.print("\n[red]no OpenAI backend selected. Exiting...[/red]") console.print("\n[red]no OpenAI backend selected. Exiting...[/red]")
exit(1) sys.exit(1)
display_name, url = choice display_name, url = choice
print(f"You selected: {display_name}\tURL: {url}")
return display_name, url return display_name, url

View File

@ -1,10 +1,7 @@
import os,sys # Multiple imports on one line (Ruff will fix)
from typing import List,Dict # Multiple imports on one line
def poorly_formatted_function(x,y,z): # Missing type hints def poorly_formatted_function(x,y,z): # Missing type hints
"""This function has formatting issues.""" """This function has formatting issues."""
result=x+y*z # Missing spaces around operators result=x+y*z # Missing spaces around operators
unused_variable = 42 # Unused variable (Ruff will detect)
if result>100: # Missing spaces if result>100: # Missing spaces
print( "Result is large" ) # Extra spaces in parentheses print( "Result is large" ) # Extra spaces in parentheses
return result return result

View File

@ -23,11 +23,11 @@ def run_command(cmd, description=""):
print(f"{description or 'Command completed successfully'}") print(f"{description or 'Command completed successfully'}")
return True return True
else: else:
print(f"❌ Command failed:") print("❌ Command failed:")
print(result.stderr) print(result.stderr)
return False return False
except subprocess.TimeoutExpired: except subprocess.TimeoutExpired:
print(f"⏱️ Command timed out") print("⏱️ Command timed out")
return False return False
except Exception as e: except Exception as e:
print(f"❌ Error running command: {e}") print(f"❌ Error running command: {e}")
@ -73,7 +73,7 @@ def main():
# Summary # Summary
print("\n" + "=" * 50) print("\n" + "=" * 50)
print(f"📊 Test Setup Verification Results:") print("📊 Test Setup Verification Results:")
print(f"✅ Successful: {success_count}/{total_tests}") print(f"✅ Successful: {success_count}/{total_tests}")
print(f"❌ Failed: {total_tests - success_count}/{total_tests}") print(f"❌ Failed: {total_tests - success_count}/{total_tests}")

View File

@ -1,11 +1,10 @@
"""Pytest configuration and shared fixtures for TradingAgents tests.""" """Pytest configuration and shared fixtures for TradingAgents tests."""
import os import os
import pytest
import tempfile import tempfile
from unittest.mock import Mock, MagicMock from unittest.mock import Mock
from datetime import date, datetime
from typing import Dict, Any import pytest
from tradingagents.default_config import DEFAULT_CONFIG from tradingagents.default_config import DEFAULT_CONFIG
@ -22,7 +21,7 @@ def sample_config():
"deep_think_llm": "gpt-4o-mini", "deep_think_llm": "gpt-4o-mini",
"quick_think_llm": "gpt-4o-mini", "quick_think_llm": "gpt-4o-mini",
"project_dir": "/tmp/test_tradingagents", "project_dir": "/tmp/test_tradingagents",
} },
) )
return config return config
@ -174,7 +173,7 @@ def mock_memory():
def pytest_configure(config): def pytest_configure(config):
"""Configure pytest with custom markers.""" """Configure pytest with custom markers."""
config.addinivalue_line( config.addinivalue_line(
"markers", "integration: mark test as integration test (slow)" "markers", "integration: mark test as integration test (slow)",
) )
config.addinivalue_line("markers", "unit: mark test as unit test (fast)") config.addinivalue_line("markers", "unit: mark test as unit test (fast)")
config.addinivalue_line("markers", "api: mark test as requiring API access") config.addinivalue_line("markers", "api: mark test as requiring API access")

View File

@ -2,14 +2,14 @@
import json import json
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Dict, List, Any from typing import Any
class SampleDataFactory: class SampleDataFactory:
"""Factory class for creating sample test data.""" """Factory class for creating sample test data."""
@staticmethod @staticmethod
def create_market_data(ticker: str = "AAPL", days: int = 30) -> Dict[str, Any]: def create_market_data(ticker: str = "AAPL", days: int = 30) -> dict[str, Any]:
"""Create sample market data for testing.""" """Create sample market data for testing."""
base_date = datetime(2024, 5, 1) base_date = datetime(2024, 5, 1)
data = {} data = {}
@ -36,8 +36,8 @@ class SampleDataFactory:
@staticmethod @staticmethod
def create_finnhub_news_data( def create_finnhub_news_data(
ticker: str = "AAPL", count: int = 10 ticker: str = "AAPL", count: int = 10,
) -> Dict[str, List[Dict[str, Any]]]: ) -> dict[str, list[dict[str, Any]]]:
"""Create sample FinnHub news data for testing.""" """Create sample FinnHub news data for testing."""
base_date = datetime(2024, 5, 10) base_date = datetime(2024, 5, 10)
data = {} data = {}
@ -93,7 +93,7 @@ class SampleDataFactory:
@staticmethod @staticmethod
def create_insider_transactions_data( def create_insider_transactions_data(
ticker: str = "AAPL", ticker: str = "AAPL",
) -> Dict[str, List[Dict[str, Any]]]: ) -> dict[str, list[dict[str, Any]]]:
"""Create sample insider transactions data for testing.""" """Create sample insider transactions data for testing."""
base_date = datetime(2024, 5, 5) base_date = datetime(2024, 5, 5)
data = {} data = {}
@ -129,15 +129,15 @@ class SampleDataFactory:
"transactionValue": transaction["shares"] * transaction["price"], "transactionValue": transaction["shares"] * transaction["price"],
"reportingName": transaction["person"], "reportingName": transaction["person"],
"typeOfOwner": "officer", "typeOfOwner": "officer",
} },
] ]
return data return data
@staticmethod @staticmethod
def create_financial_statements_data( def create_financial_statements_data(
ticker: str = "AAPL", period: str = "annual" ticker: str = "AAPL", period: str = "annual",
) -> Dict[str, List[Dict[str, Any]]]: ) -> dict[str, list[dict[str, Any]]]:
"""Create sample financial statements data for testing.""" """Create sample financial statements data for testing."""
if period == "annual": if period == "annual":
dates = ["2023-12-31", "2022-12-31", "2021-12-31"] dates = ["2023-12-31", "2022-12-31", "2021-12-31"]
@ -174,7 +174,7 @@ class SampleDataFactory:
@staticmethod @staticmethod
def create_social_sentiment_data( def create_social_sentiment_data(
ticker: str = "AAPL", ticker: str = "AAPL",
) -> Dict[str, List[Dict[str, Any]]]: ) -> dict[str, list[dict[str, Any]]]:
"""Create sample social media sentiment data for testing.""" """Create sample social media sentiment data for testing."""
base_date = datetime(2024, 5, 8) base_date = datetime(2024, 5, 8)
data = {} data = {}
@ -226,7 +226,7 @@ class SampleDataFactory:
"subreddit": "stocks" if j % 2 else "investing", "subreddit": "stocks" if j % 2 else "investing",
"upvotes": 10 + (j * 5), "upvotes": 10 + (j * 5),
"comments": 3 + j, "comments": 3 + j,
} },
) )
data[date_str] = daily_posts data[date_str] = daily_posts
@ -234,7 +234,7 @@ class SampleDataFactory:
return data return data
@staticmethod @staticmethod
def create_technical_indicators_data(ticker: str = "AAPL") -> Dict[str, Any]: def create_technical_indicators_data(ticker: str = "AAPL") -> dict[str, Any]:
"""Create sample technical indicators data for testing.""" """Create sample technical indicators data for testing."""
return { return {
"symbol": ticker, "symbol": ticker,
@ -262,23 +262,23 @@ class SampleDataFactory:
} }
@staticmethod @staticmethod
def create_complete_test_dataset(ticker: str = "AAPL") -> Dict[str, Dict[str, Any]]: def create_complete_test_dataset(ticker: str = "AAPL") -> dict[str, dict[str, Any]]:
"""Create a complete dataset for comprehensive testing.""" """Create a complete dataset for comprehensive testing."""
return { return {
"market_data": SampleDataFactory.create_market_data(ticker), "market_data": SampleDataFactory.create_market_data(ticker),
"news_data": SampleDataFactory.create_finnhub_news_data(ticker), "news_data": SampleDataFactory.create_finnhub_news_data(ticker),
"insider_transactions": SampleDataFactory.create_insider_transactions_data( "insider_transactions": SampleDataFactory.create_insider_transactions_data(
ticker ticker,
), ),
"financial_annual": SampleDataFactory.create_financial_statements_data( "financial_annual": SampleDataFactory.create_financial_statements_data(
ticker, "annual" ticker, "annual",
), ),
"financial_quarterly": SampleDataFactory.create_financial_statements_data( "financial_quarterly": SampleDataFactory.create_financial_statements_data(
ticker, "quarterly" ticker, "quarterly",
), ),
"social_sentiment": SampleDataFactory.create_social_sentiment_data(ticker), "social_sentiment": SampleDataFactory.create_social_sentiment_data(ticker),
"technical_indicators": SampleDataFactory.create_technical_indicators_data( "technical_indicators": SampleDataFactory.create_technical_indicators_data(
ticker ticker,
), ),
} }
@ -343,7 +343,7 @@ def save_sample_data_to_files(base_path: str, ticker: str = "AAPL") -> None:
# Save quarterly data separately # Save quarterly data separately
quarterly_path = os.path.join( quarterly_path = os.path.join(
finnhub_path, "fin_as_reported", f"{ticker}_quarterly_data_formatted.json" finnhub_path, "fin_as_reported", f"{ticker}_quarterly_data_formatted.json",
) )
with open(quarterly_path, "w") as f: with open(quarterly_path, "w") as f:
json.dump(dataset["financial_quarterly"], f, indent=2) json.dump(dataset["financial_quarterly"], f, indent=2)

View File

@ -1,13 +1,11 @@
"""Integration tests for the full TradingAgents workflow.""" """Integration tests for the full TradingAgents workflow."""
import pytest from unittest.mock import Mock, patch
import os
import tempfile import pytest
from unittest.mock import Mock, patch, MagicMock
from datetime import datetime
from tradingagents.graph.trading_graph import TradingAgentsGraph
from tradingagents.default_config import DEFAULT_CONFIG from tradingagents.default_config import DEFAULT_CONFIG
from tradingagents.graph.trading_graph import TradingAgentsGraph
@pytest.mark.integration @pytest.mark.integration
@ -26,14 +24,14 @@ class TestFullWorkflowIntegration:
"deep_think_llm": "gpt-4o-mini", "deep_think_llm": "gpt-4o-mini",
"quick_think_llm": "gpt-4o-mini", "quick_think_llm": "gpt-4o-mini",
"project_dir": temp_data_dir, "project_dir": temp_data_dir,
} },
) )
return config return config
@patch("tradingagents.graph.trading_graph.ChatOpenAI") @patch("tradingagents.graph.trading_graph.ChatOpenAI")
@patch("tradingagents.graph.trading_graph.Toolkit") @patch("tradingagents.graph.trading_graph.Toolkit")
def test_end_to_end_trading_workflow( def test_end_to_end_trading_workflow(
self, mock_toolkit, mock_chat_openai, integration_config self, mock_toolkit, mock_chat_openai, integration_config,
): ):
"""Test complete end-to-end trading workflow.""" """Test complete end-to-end trading workflow."""
# Setup mocks # Setup mocks
@ -69,15 +67,14 @@ class TestFullWorkflowIntegration:
"company_of_interest": "AAPL", "company_of_interest": "AAPL",
"trade_date": "2024-05-10", "trade_date": "2024-05-10",
"messages": [], "messages": [],
} },
) )
trading_graph.propagator.get_graph_args = Mock(return_value={}) trading_graph.propagator.get_graph_args = Mock(return_value={})
trading_graph.signal_processor.process_signal = Mock(return_value="BUY") trading_graph.signal_processor.process_signal = Mock(return_value="BUY")
# Execute the full workflow # Execute the full workflow
with patch("builtins.open", create=True): with patch("builtins.open", create=True), patch("json.dump"):
with patch("json.dump"): final_state, decision = trading_graph.propagate("AAPL", "2024-05-10")
final_state, decision = trading_graph.propagate("AAPL", "2024-05-10")
# Verify the workflow completed successfully # Verify the workflow completed successfully
assert final_state is not None assert final_state is not None
@ -89,7 +86,7 @@ class TestFullWorkflowIntegration:
@patch("tradingagents.graph.trading_graph.ChatOpenAI") @patch("tradingagents.graph.trading_graph.ChatOpenAI")
@patch("tradingagents.graph.trading_graph.Toolkit") @patch("tradingagents.graph.trading_graph.Toolkit")
def test_multiple_analysts_integration( def test_multiple_analysts_integration(
self, mock_toolkit, mock_chat_openai, integration_config self, mock_toolkit, mock_chat_openai, integration_config,
): ):
"""Test integration with different analyst combinations.""" """Test integration with different analyst combinations."""
analyst_combinations = [ analyst_combinations = [
@ -117,7 +114,7 @@ class TestFullWorkflowIntegration:
with patch("tradingagents.graph.trading_graph.set_config"): with patch("tradingagents.graph.trading_graph.set_config"):
# Test each analyst combination # Test each analyst combination
trading_graph = TradingAgentsGraph( trading_graph = TradingAgentsGraph(
selected_analysts=analysts, config=integration_config selected_analysts=analysts, config=integration_config,
) )
trading_graph.graph = mock_graph trading_graph.graph = mock_graph
@ -127,19 +124,18 @@ class TestFullWorkflowIntegration:
"company_of_interest": "TSLA", "company_of_interest": "TSLA",
"trade_date": "2024-05-15", "trade_date": "2024-05-15",
"messages": [], "messages": [],
} },
) )
trading_graph.propagator.get_graph_args = Mock(return_value={}) trading_graph.propagator.get_graph_args = Mock(return_value={})
trading_graph.signal_processor.process_signal = Mock( trading_graph.signal_processor.process_signal = Mock(
return_value="HOLD" return_value="HOLD",
) )
# Execute # Execute
with patch("builtins.open", create=True): with patch("builtins.open", create=True), patch("json.dump"):
with patch("json.dump"): final_state, decision = trading_graph.propagate(
final_state, decision = trading_graph.propagate( "TSLA", "2024-05-15",
"TSLA", "2024-05-15" )
)
# Verify # Verify
assert final_state is not None assert final_state is not None
@ -148,7 +144,7 @@ class TestFullWorkflowIntegration:
@patch("tradingagents.graph.trading_graph.ChatOpenAI") @patch("tradingagents.graph.trading_graph.ChatOpenAI")
@patch("tradingagents.graph.trading_graph.Toolkit") @patch("tradingagents.graph.trading_graph.Toolkit")
def test_memory_and_reflection_integration( def test_memory_and_reflection_integration(
self, mock_toolkit, mock_chat_openai, integration_config self, mock_toolkit, mock_chat_openai, integration_config,
): ):
"""Test integration of memory and reflection components.""" """Test integration of memory and reflection components."""
# Setup # Setup
@ -165,7 +161,7 @@ class TestFullWorkflowIntegration:
mock_graph.invoke.return_value = mock_final_state mock_graph.invoke.return_value = mock_final_state
with patch( with patch(
"tradingagents.graph.trading_graph.FinancialSituationMemory" "tradingagents.graph.trading_graph.FinancialSituationMemory",
) as mock_memory: ) as mock_memory:
mock_memory_instance = Mock() mock_memory_instance = Mock()
mock_memory.return_value = mock_memory_instance mock_memory.return_value = mock_memory_instance
@ -180,11 +176,11 @@ class TestFullWorkflowIntegration:
"company_of_interest": "NVDA", "company_of_interest": "NVDA",
"trade_date": "2024-05-20", "trade_date": "2024-05-20",
"messages": [], "messages": [],
} },
) )
trading_graph.propagator.get_graph_args = Mock(return_value={}) trading_graph.propagator.get_graph_args = Mock(return_value={})
trading_graph.signal_processor.process_signal = Mock( trading_graph.signal_processor.process_signal = Mock(
return_value="SELL" return_value="SELL",
) )
# Mock reflection methods # Mock reflection methods
@ -195,9 +191,8 @@ class TestFullWorkflowIntegration:
trading_graph.reflector.reflect_risk_manager = Mock() trading_graph.reflector.reflect_risk_manager = Mock()
# Execute workflow # Execute workflow
with patch("builtins.open", create=True): with patch("builtins.open", create=True), patch("json.dump"):
with patch("json.dump"): final_state, decision = trading_graph.propagate("NVDA", "2024-05-20")
final_state, decision = trading_graph.propagate("NVDA", "2024-05-20")
# Test reflection and memory update # Test reflection and memory update
returns_losses = {"return": -0.03, "loss": -0.08} returns_losses = {"return": -0.03, "loss": -0.08}
@ -213,7 +208,7 @@ class TestFullWorkflowIntegration:
@patch("tradingagents.graph.trading_graph.ChatOpenAI") @patch("tradingagents.graph.trading_graph.ChatOpenAI")
@patch("tradingagents.graph.trading_graph.Toolkit") @patch("tradingagents.graph.trading_graph.Toolkit")
def test_debug_mode_integration( def test_debug_mode_integration(
self, mock_toolkit, mock_chat_openai, integration_config self, mock_toolkit, mock_chat_openai, integration_config,
): ):
"""Test integration in debug mode.""" """Test integration in debug mode."""
# Setup # Setup
@ -233,7 +228,7 @@ class TestFullWorkflowIntegration:
self._create_mock_final_state(), # Final chunk self._create_mock_final_state(), # Final chunk
] ]
for chunk in mock_chunks: for chunk in mock_chunks:
if "messages" in chunk and chunk["messages"]: if chunk.get("messages"):
for msg in chunk["messages"]: for msg in chunk["messages"]:
if hasattr(msg, "pretty_print"): if hasattr(msg, "pretty_print"):
msg.pretty_print = Mock() msg.pretty_print = Mock()
@ -245,7 +240,7 @@ class TestFullWorkflowIntegration:
with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"): with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"):
with patch("tradingagents.graph.trading_graph.set_config"): with patch("tradingagents.graph.trading_graph.set_config"):
trading_graph = TradingAgentsGraph( trading_graph = TradingAgentsGraph(
debug=True, config=integration_config debug=True, config=integration_config,
) )
trading_graph.graph = mock_graph trading_graph.graph = mock_graph
@ -255,15 +250,14 @@ class TestFullWorkflowIntegration:
"company_of_interest": "AMZN", "company_of_interest": "AMZN",
"trade_date": "2024-05-25", "trade_date": "2024-05-25",
"messages": [], "messages": [],
} },
) )
trading_graph.propagator.get_graph_args = Mock(return_value={}) trading_graph.propagator.get_graph_args = Mock(return_value={})
trading_graph.signal_processor.process_signal = Mock(return_value="BUY") trading_graph.signal_processor.process_signal = Mock(return_value="BUY")
# Execute in debug mode # Execute in debug mode
with patch("builtins.open", create=True): with patch("builtins.open", create=True), patch("json.dump"):
with patch("json.dump"): final_state, decision = trading_graph.propagate("AMZN", "2024-05-25")
final_state, decision = trading_graph.propagate("AMZN", "2024-05-25")
# Verify debug mode was used # Verify debug mode was used
mock_graph.stream.assert_called_once() mock_graph.stream.assert_called_once()
@ -271,7 +265,7 @@ class TestFullWorkflowIntegration:
assert decision == "BUY" assert decision == "BUY"
@pytest.mark.parametrize( @pytest.mark.parametrize(
"ticker,date", ("ticker", "date"),
[ [
("AAPL", "2024-01-15"), ("AAPL", "2024-01-15"),
("TSLA", "2024-02-20"), ("TSLA", "2024-02-20"),
@ -282,7 +276,7 @@ class TestFullWorkflowIntegration:
@patch("tradingagents.graph.trading_graph.ChatOpenAI") @patch("tradingagents.graph.trading_graph.ChatOpenAI")
@patch("tradingagents.graph.trading_graph.Toolkit") @patch("tradingagents.graph.trading_graph.Toolkit")
def test_multiple_stocks_integration( def test_multiple_stocks_integration(
self, mock_toolkit, mock_chat_openai, ticker, date, integration_config self, mock_toolkit, mock_chat_openai, ticker, date, integration_config,
): ):
"""Test integration with different stocks and dates.""" """Test integration with different stocks and dates."""
# Setup # Setup
@ -309,17 +303,16 @@ class TestFullWorkflowIntegration:
"company_of_interest": ticker, "company_of_interest": ticker,
"trade_date": date, "trade_date": date,
"messages": [], "messages": [],
} },
) )
trading_graph.propagator.get_graph_args = Mock(return_value={}) trading_graph.propagator.get_graph_args = Mock(return_value={})
trading_graph.signal_processor.process_signal = Mock( trading_graph.signal_processor.process_signal = Mock(
return_value="HOLD" return_value="HOLD",
) )
# Execute # Execute
with patch("builtins.open", create=True): with patch("builtins.open", create=True), patch("json.dump"):
with patch("json.dump"): final_state, decision = trading_graph.propagate(ticker, date)
final_state, decision = trading_graph.propagate(ticker, date)
# Verify # Verify
assert final_state["company_of_interest"] == ticker assert final_state["company_of_interest"] == ticker
@ -389,7 +382,7 @@ class TestPerformanceIntegration:
@patch("tradingagents.graph.trading_graph.ChatOpenAI") @patch("tradingagents.graph.trading_graph.ChatOpenAI")
@patch("tradingagents.graph.trading_graph.Toolkit") @patch("tradingagents.graph.trading_graph.Toolkit")
def test_multiple_consecutive_runs( def test_multiple_consecutive_runs(
self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir,
): ):
"""Test multiple consecutive trading decisions.""" """Test multiple consecutive trading decisions."""
sample_config["project_dir"] = temp_data_dir sample_config["project_dir"] = temp_data_dir
@ -455,9 +448,8 @@ class TestPerformanceIntegration:
mock_final_state["final_trade_decision"] mock_final_state["final_trade_decision"]
) )
with patch("builtins.open", create=True): with patch("builtins.open", create=True), patch("json.dump"):
with patch("json.dump"): final_state, decision = trading_graph.propagate(ticker, date)
final_state, decision = trading_graph.propagate(ticker, date)
decisions.append(decision) decisions.append(decision)

View File

@ -1,8 +1,9 @@
"""Unit tests for market analyst agent.""" """Unit tests for market analyst agent."""
from unittest.mock import Mock
import pytest import pytest
from unittest.mock import Mock, patch, MagicMock from langchain_core.messages import HumanMessage
from langchain_core.messages import HumanMessage, AIMessage
from tradingagents.agents.analysts.market_analyst import create_market_analyst from tradingagents.agents.analysts.market_analyst import create_market_analyst
@ -16,7 +17,7 @@ class TestMarketAnalyst:
assert callable(analyst_node) assert callable(analyst_node)
def test_market_analyst_node_basic_execution( def test_market_analyst_node_basic_execution(
self, mock_llm, mock_toolkit, sample_agent_state self, mock_llm, mock_toolkit, sample_agent_state,
): ):
"""Test basic execution of market analyst node.""" """Test basic execution of market analyst node."""
# Setup # Setup
@ -38,7 +39,7 @@ class TestMarketAnalyst:
assert result["market_report"] == "Market analysis complete" assert result["market_report"] == "Market analysis complete"
def test_market_analyst_uses_online_tools_when_configured( def test_market_analyst_uses_online_tools_when_configured(
self, mock_llm, mock_toolkit, sample_agent_state self, mock_llm, mock_toolkit, sample_agent_state,
): ):
"""Test that analyst uses online tools when configured.""" """Test that analyst uses online tools when configured."""
# Setup # Setup
@ -54,7 +55,7 @@ class TestMarketAnalyst:
analyst_node = create_market_analyst(mock_llm, mock_toolkit) analyst_node = create_market_analyst(mock_llm, mock_toolkit)
# Execute # Execute
result = analyst_node(sample_agent_state) analyst_node(sample_agent_state)
# Verify tools were bound correctly # Verify tools were bound correctly
mock_llm.bind_tools.assert_called_once() mock_llm.bind_tools.assert_called_once()
@ -63,7 +64,7 @@ class TestMarketAnalyst:
assert "get_YFin_data_online" in str(tool_names) or len(bound_tools) == 2 assert "get_YFin_data_online" in str(tool_names) or len(bound_tools) == 2
def test_market_analyst_uses_offline_tools_when_configured( def test_market_analyst_uses_offline_tools_when_configured(
self, mock_llm, mock_toolkit, sample_agent_state self, mock_llm, mock_toolkit, sample_agent_state,
): ):
"""Test that analyst uses offline tools when configured.""" """Test that analyst uses offline tools when configured."""
# Setup # Setup
@ -79,7 +80,7 @@ class TestMarketAnalyst:
analyst_node = create_market_analyst(mock_llm, mock_toolkit) analyst_node = create_market_analyst(mock_llm, mock_toolkit)
# Execute # Execute
result = analyst_node(sample_agent_state) analyst_node(sample_agent_state)
# Verify tools were bound correctly # Verify tools were bound correctly
mock_llm.bind_tools.assert_called_once() mock_llm.bind_tools.assert_called_once()
@ -87,7 +88,7 @@ class TestMarketAnalyst:
assert len(bound_tools) == 2 # Should have 2 offline tools assert len(bound_tools) == 2 # Should have 2 offline tools
def test_market_analyst_processes_state_variables( def test_market_analyst_processes_state_variables(
self, mock_llm, mock_toolkit, sample_agent_state self, mock_llm, mock_toolkit, sample_agent_state,
): ):
"""Test that market analyst correctly processes state variables.""" """Test that market analyst correctly processes state variables."""
# Setup # Setup
@ -111,7 +112,7 @@ class TestMarketAnalyst:
assert result["market_report"] == "Analysis for AAPL on 2024-05-10" assert result["market_report"] == "Analysis for AAPL on 2024-05-10"
def test_market_analyst_handles_empty_tool_calls( def test_market_analyst_handles_empty_tool_calls(
self, mock_llm, mock_toolkit, sample_agent_state self, mock_llm, mock_toolkit, sample_agent_state,
): ):
"""Test handling when no tool calls are made.""" """Test handling when no tool calls are made."""
# Setup # Setup
@ -131,7 +132,7 @@ class TestMarketAnalyst:
assert result["messages"] == [mock_result] assert result["messages"] == [mock_result]
def test_market_analyst_with_tool_calls( def test_market_analyst_with_tool_calls(
self, mock_llm, mock_toolkit, sample_agent_state self, mock_llm, mock_toolkit, sample_agent_state,
): ):
"""Test handling when tool calls are present.""" """Test handling when tool calls are present."""
# Setup # Setup
@ -152,7 +153,7 @@ class TestMarketAnalyst:
@pytest.mark.parametrize("online_tools", [True, False]) @pytest.mark.parametrize("online_tools", [True, False])
def test_market_analyst_tool_configuration( def test_market_analyst_tool_configuration(
self, mock_llm, mock_toolkit, sample_agent_state, online_tools self, mock_llm, mock_toolkit, sample_agent_state, online_tools,
): ):
"""Test tool configuration for both online and offline modes.""" """Test tool configuration for both online and offline modes."""
# Setup # Setup
@ -194,15 +195,15 @@ class TestMarketAnalystIntegration:
mock_result = Mock() mock_result = Mock()
mock_result.content = """ mock_result.content = """
# Market Analysis for TSLA (2024-05-15) # Market Analysis for TSLA (2024-05-15)
## Technical Analysis ## Technical Analysis
- RSI: 65 (slightly overbought) - RSI: 65 (slightly overbought)
- MACD: Bullish crossover - MACD: Bullish crossover
- 50-day SMA: Trending upward - 50-day SMA: Trending upward
## Volume Analysis ## Volume Analysis
- Above average volume suggests strong interest - Above average volume suggests strong interest
| Indicator | Value | Signal | | Indicator | Value | Signal |
|-----------|-------|--------| |-----------|-------|--------|
| RSI | 65 | Neutral | | RSI | 65 | Neutral |

View File

@ -1,10 +1,9 @@
"""Unit tests for FinnHub utilities.""" """Unit tests for FinnHub utilities."""
import pytest
import json import json
import os import os
import tempfile
from unittest.mock import patch, mock_open, Mock import pytest
from tradingagents.dataflows.finnhub_utils import get_data_in_range from tradingagents.dataflows.finnhub_utils import get_data_in_range
@ -191,7 +190,7 @@ class TestFinnhubUtils:
# Test without period # Test without period
expected_path_no_period = os.path.join( expected_path_no_period = os.path.join(
temp_data_dir, "finnhub_data", data_type, f"{ticker}_data_formatted.json" temp_data_dir, "finnhub_data", data_type, f"{ticker}_data_formatted.json",
) )
# Test with period # Test with period
@ -238,7 +237,7 @@ class TestFinnhubUtils:
assert len(result2) == 1 assert len(result2) == 1
@pytest.mark.parametrize( @pytest.mark.parametrize(
"data_type,period", ("data_type", "period"),
[ [
("news_data", None), ("news_data", None),
("insider_trans", None), ("insider_trans", None),
@ -249,7 +248,7 @@ class TestFinnhubUtils:
], ],
) )
def test_get_data_in_range_various_data_types( def test_get_data_in_range_various_data_types(
self, temp_data_dir, data_type, period self, temp_data_dir, data_type, period,
): ):
"""Test get_data_in_range with various data types.""" """Test get_data_in_range with various data types."""
ticker = "TEST" ticker = "TEST"

View File

@ -1,12 +1,10 @@
"""Unit tests for TradingAgentsGraph.""" """Unit tests for TradingAgentsGraph."""
from unittest.mock import Mock, mock_open, patch
import pytest import pytest
import os
from unittest.mock import Mock, patch, MagicMock
from pathlib import Path
from tradingagents.graph.trading_graph import TradingAgentsGraph from tradingagents.graph.trading_graph import TradingAgentsGraph
from tradingagents.default_config import DEFAULT_CONFIG
class TestTradingAgentsGraph: class TestTradingAgentsGraph:
@ -47,7 +45,7 @@ class TestTradingAgentsGraph:
@patch("tradingagents.graph.trading_graph.ChatOpenAI") @patch("tradingagents.graph.trading_graph.ChatOpenAI")
@patch("tradingagents.graph.trading_graph.Toolkit") @patch("tradingagents.graph.trading_graph.Toolkit")
def test_init_with_debug( def test_init_with_debug(
self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir,
): ):
"""Test initialization with debug mode enabled.""" """Test initialization with debug mode enabled."""
sample_config["project_dir"] = temp_data_dir sample_config["project_dir"] = temp_data_dir
@ -65,7 +63,7 @@ class TestTradingAgentsGraph:
@patch("tradingagents.graph.trading_graph.ChatAnthropic") @patch("tradingagents.graph.trading_graph.ChatAnthropic")
@patch("tradingagents.graph.trading_graph.Toolkit") @patch("tradingagents.graph.trading_graph.Toolkit")
def test_init_with_anthropic( def test_init_with_anthropic(
self, mock_toolkit, mock_chat_anthropic, sample_config, temp_data_dir self, mock_toolkit, mock_chat_anthropic, sample_config, temp_data_dir,
): ):
"""Test initialization with Anthropic LLM provider.""" """Test initialization with Anthropic LLM provider."""
sample_config["project_dir"] = temp_data_dir sample_config["project_dir"] = temp_data_dir
@ -77,14 +75,14 @@ class TestTradingAgentsGraph:
with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"): with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"):
with patch("tradingagents.graph.trading_graph.set_config"): with patch("tradingagents.graph.trading_graph.set_config"):
graph = TradingAgentsGraph(config=sample_config) TradingAgentsGraph(config=sample_config)
assert mock_chat_anthropic.call_count == 2 assert mock_chat_anthropic.call_count == 2
@patch("tradingagents.graph.trading_graph.ChatGoogleGenerativeAI") @patch("tradingagents.graph.trading_graph.ChatGoogleGenerativeAI")
@patch("tradingagents.graph.trading_graph.Toolkit") @patch("tradingagents.graph.trading_graph.Toolkit")
def test_init_with_google( def test_init_with_google(
self, mock_toolkit, mock_chat_google, sample_config, temp_data_dir self, mock_toolkit, mock_chat_google, sample_config, temp_data_dir,
): ):
"""Test initialization with Google LLM provider.""" """Test initialization with Google LLM provider."""
sample_config["project_dir"] = temp_data_dir sample_config["project_dir"] = temp_data_dir
@ -96,13 +94,13 @@ class TestTradingAgentsGraph:
with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"): with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"):
with patch("tradingagents.graph.trading_graph.set_config"): with patch("tradingagents.graph.trading_graph.set_config"):
graph = TradingAgentsGraph(config=sample_config) TradingAgentsGraph(config=sample_config)
assert mock_chat_google.call_count == 2 assert mock_chat_google.call_count == 2
@patch("tradingagents.graph.trading_graph.Toolkit") @patch("tradingagents.graph.trading_graph.Toolkit")
def test_init_unsupported_llm_provider( def test_init_unsupported_llm_provider(
self, mock_toolkit, sample_config, temp_data_dir self, mock_toolkit, sample_config, temp_data_dir,
): ):
"""Test initialization with unsupported LLM provider raises error.""" """Test initialization with unsupported LLM provider raises error."""
sample_config["project_dir"] = temp_data_dir sample_config["project_dir"] = temp_data_dir
@ -117,7 +115,7 @@ class TestTradingAgentsGraph:
@patch("tradingagents.graph.trading_graph.ChatOpenAI") @patch("tradingagents.graph.trading_graph.ChatOpenAI")
@patch("tradingagents.graph.trading_graph.Toolkit") @patch("tradingagents.graph.trading_graph.Toolkit")
def test_create_tool_nodes( def test_create_tool_nodes(
self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir,
): ):
"""Test creation of tool nodes.""" """Test creation of tool nodes."""
sample_config["project_dir"] = temp_data_dir sample_config["project_dir"] = temp_data_dir
@ -145,7 +143,7 @@ class TestTradingAgentsGraph:
@patch("tradingagents.graph.trading_graph.ChatOpenAI") @patch("tradingagents.graph.trading_graph.ChatOpenAI")
@patch("tradingagents.graph.trading_graph.Toolkit") @patch("tradingagents.graph.trading_graph.Toolkit")
def test_propagate_basic( def test_propagate_basic(
self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir,
): ):
"""Test basic propagate functionality.""" """Test basic propagate functionality."""
sample_config["project_dir"] = temp_data_dir sample_config["project_dir"] = temp_data_dir
@ -190,15 +188,14 @@ class TestTradingAgentsGraph:
# Mock the propagator and signal processor # Mock the propagator and signal processor
graph.propagator.create_initial_state = Mock( graph.propagator.create_initial_state = Mock(
return_value={"test": "state"} return_value={"test": "state"},
) )
graph.propagator.get_graph_args = Mock(return_value={}) graph.propagator.get_graph_args = Mock(return_value={})
graph.signal_processor.process_signal = Mock(return_value="HOLD") graph.signal_processor.process_signal = Mock(return_value="HOLD")
# Execute # Execute
with patch("builtins.open", create=True): with patch("builtins.open", create=True), patch("json.dump"):
with patch("json.dump"): final_state, decision = graph.propagate("AAPL", "2024-05-10")
final_state, decision = graph.propagate("AAPL", "2024-05-10")
# Verify # Verify
assert final_state == mock_final_state assert final_state == mock_final_state
@ -209,7 +206,7 @@ class TestTradingAgentsGraph:
@patch("tradingagents.graph.trading_graph.ChatOpenAI") @patch("tradingagents.graph.trading_graph.ChatOpenAI")
@patch("tradingagents.graph.trading_graph.Toolkit") @patch("tradingagents.graph.trading_graph.Toolkit")
def test_propagate_debug_mode( def test_propagate_debug_mode(
self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir,
): ):
"""Test propagate in debug mode.""" """Test propagate in debug mode."""
sample_config["project_dir"] = temp_data_dir sample_config["project_dir"] = temp_data_dir
@ -231,15 +228,14 @@ class TestTradingAgentsGraph:
# Mock other components # Mock other components
graph.propagator.create_initial_state = Mock( graph.propagator.create_initial_state = Mock(
return_value={"test": "state"} return_value={"test": "state"},
) )
graph.propagator.get_graph_args = Mock(return_value={}) graph.propagator.get_graph_args = Mock(return_value={})
graph.signal_processor.process_signal = Mock(return_value="BUY") graph.signal_processor.process_signal = Mock(return_value="BUY")
# Execute # Execute
with patch("builtins.open", create=True): with patch("builtins.open", create=True), patch("json.dump"):
with patch("json.dump"): final_state, decision = graph.propagate("TSLA", "2024-05-15")
final_state, decision = graph.propagate("TSLA", "2024-05-15")
# Verify debug mode was used # Verify debug mode was used
mock_graph.stream.assert_called_once() mock_graph.stream.assert_called_once()
@ -249,7 +245,7 @@ class TestTradingAgentsGraph:
@patch("tradingagents.graph.trading_graph.ChatOpenAI") @patch("tradingagents.graph.trading_graph.ChatOpenAI")
@patch("tradingagents.graph.trading_graph.Toolkit") @patch("tradingagents.graph.trading_graph.Toolkit")
def test_log_state( def test_log_state(
self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir,
): ):
"""Test state logging functionality.""" """Test state logging functionality."""
sample_config["project_dir"] = temp_data_dir sample_config["project_dir"] = temp_data_dir
@ -291,10 +287,9 @@ class TestTradingAgentsGraph:
} }
# Mock file operations # Mock file operations
with patch("pathlib.Path.mkdir"): with patch("pathlib.Path.mkdir"), patch("builtins.open", mock_open()):
with patch("builtins.open", mock_open()) as mock_file: with patch("json.dump"):
with patch("json.dump") as mock_json_dump: graph._log_state("2024-05-20", final_state)
graph._log_state("2024-05-20", final_state)
# Verify logging occurred # Verify logging occurred
assert "2024-05-20" in graph.log_states_dict assert "2024-05-20" in graph.log_states_dict
@ -305,7 +300,7 @@ class TestTradingAgentsGraph:
@patch("tradingagents.graph.trading_graph.ChatOpenAI") @patch("tradingagents.graph.trading_graph.ChatOpenAI")
@patch("tradingagents.graph.trading_graph.Toolkit") @patch("tradingagents.graph.trading_graph.Toolkit")
def test_reflect_and_remember( def test_reflect_and_remember(
self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir,
): ):
"""Test reflection and memory update functionality.""" """Test reflection and memory update functionality."""
sample_config["project_dir"] = temp_data_dir sample_config["project_dir"] = temp_data_dir
@ -315,20 +310,19 @@ class TestTradingAgentsGraph:
mock_toolkit.return_value = mock_toolkit_instance mock_toolkit.return_value = mock_toolkit_instance
with patch( with patch(
"tradingagents.graph.trading_graph.FinancialSituationMemory" "tradingagents.graph.trading_graph.FinancialSituationMemory",
) as mock_memory: ), patch("tradingagents.graph.trading_graph.set_config"):
with patch("tradingagents.graph.trading_graph.set_config"): graph = TradingAgentsGraph(config=sample_config)
graph = TradingAgentsGraph(config=sample_config)
# Set up current state # Set up current state
graph.curr_state = {"test": "state"} graph.curr_state = {"test": "state"}
# Mock reflector methods # Mock reflector methods
graph.reflector.reflect_bull_researcher = Mock() graph.reflector.reflect_bull_researcher = Mock()
graph.reflector.reflect_bear_researcher = Mock() graph.reflector.reflect_bear_researcher = Mock()
graph.reflector.reflect_trader = Mock() graph.reflector.reflect_trader = Mock()
graph.reflector.reflect_invest_judge = Mock() graph.reflector.reflect_invest_judge = Mock()
graph.reflector.reflect_risk_manager = Mock() graph.reflector.reflect_risk_manager = Mock()
returns_losses = {"return": 0.05, "loss": -0.02} returns_losses = {"return": 0.05, "loss": -0.02}
@ -345,7 +339,7 @@ class TestTradingAgentsGraph:
@patch("tradingagents.graph.trading_graph.ChatOpenAI") @patch("tradingagents.graph.trading_graph.ChatOpenAI")
@patch("tradingagents.graph.trading_graph.Toolkit") @patch("tradingagents.graph.trading_graph.Toolkit")
def test_process_signal( def test_process_signal(
self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir,
): ):
"""Test signal processing functionality.""" """Test signal processing functionality."""
sample_config["project_dir"] = temp_data_dir sample_config["project_dir"] = temp_data_dir
@ -393,8 +387,8 @@ class TestTradingAgentsGraph:
with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"): with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"):
with patch("tradingagents.graph.trading_graph.set_config"): with patch("tradingagents.graph.trading_graph.set_config"):
graph = TradingAgentsGraph( TradingAgentsGraph(
selected_analysts=selected_analysts, config=sample_config selected_analysts=selected_analysts, config=sample_config,
) )
# Verify graph was set up with selected analysts # Verify graph was set up with selected analysts
@ -415,14 +409,14 @@ class TestTradingAgentsGraphErrorHandling:
# This should still work as the class should use defaults for missing keys # This should still work as the class should use defaults for missing keys
with patch("tradingagents.graph.trading_graph.set_config"): with patch("tradingagents.graph.trading_graph.set_config"):
with pytest.raises( with pytest.raises(
KeyError KeyError,
): # Should fail when trying to access missing config keys ): # Should fail when trying to access missing config keys
TradingAgentsGraph(config=invalid_config) TradingAgentsGraph(config=invalid_config)
@patch("tradingagents.graph.trading_graph.ChatOpenAI") @patch("tradingagents.graph.trading_graph.ChatOpenAI")
@patch("tradingagents.graph.trading_graph.Toolkit") @patch("tradingagents.graph.trading_graph.Toolkit")
def test_directory_creation_failure( def test_directory_creation_failure(
self, mock_toolkit, mock_chat_openai, sample_config self, mock_toolkit, mock_chat_openai, sample_config,
): ):
"""Test handling when directory creation fails.""" """Test handling when directory creation fails."""
sample_config["project_dir"] = "/invalid/path/that/cannot/be/created" sample_config["project_dir"] = "/invalid/path/that/cannot/be/created"

View File

@ -1,40 +1,35 @@
from .utils.agent_utils import Toolkit, create_msg_delete
from .utils.agent_states import AgentState, InvestDebateState, RiskDebateState
from .utils.memory import FinancialSituationMemory
from .analysts.fundamentals_analyst import create_fundamentals_analyst from .analysts.fundamentals_analyst import create_fundamentals_analyst
from .analysts.market_analyst import create_market_analyst from .analysts.market_analyst import create_market_analyst
from .analysts.news_analyst import create_news_analyst from .analysts.news_analyst import create_news_analyst
from .analysts.social_media_analyst import create_social_media_analyst from .analysts.social_media_analyst import create_social_media_analyst
from .managers.research_manager import create_research_manager
from .managers.risk_manager import create_risk_manager
from .researchers.bear_researcher import create_bear_researcher from .researchers.bear_researcher import create_bear_researcher
from .researchers.bull_researcher import create_bull_researcher from .researchers.bull_researcher import create_bull_researcher
from .risk_mgmt.aggresive_debator import create_risky_debator from .risk_mgmt.aggresive_debator import create_risky_debator
from .risk_mgmt.conservative_debator import create_safe_debator from .risk_mgmt.conservative_debator import create_safe_debator
from .risk_mgmt.neutral_debator import create_neutral_debator from .risk_mgmt.neutral_debator import create_neutral_debator
from .managers.research_manager import create_research_manager
from .managers.risk_manager import create_risk_manager
from .trader.trader import create_trader from .trader.trader import create_trader
from .utils.agent_states import AgentState, InvestDebateState, RiskDebateState
from .utils.agent_utils import Toolkit, create_msg_delete
from .utils.memory import FinancialSituationMemory
__all__ = [ __all__ = [
"FinancialSituationMemory",
"Toolkit",
"AgentState", "AgentState",
"create_msg_delete", "FinancialSituationMemory",
"InvestDebateState", "InvestDebateState",
"RiskDebateState", "RiskDebateState",
"Toolkit",
"create_bear_researcher", "create_bear_researcher",
"create_bull_researcher", "create_bull_researcher",
"create_research_manager",
"create_fundamentals_analyst", "create_fundamentals_analyst",
"create_market_analyst", "create_market_analyst",
"create_msg_delete",
"create_neutral_debator", "create_neutral_debator",
"create_news_analyst", "create_news_analyst",
"create_risky_debator", "create_research_manager",
"create_risk_manager", "create_risk_manager",
"create_risky_debator",
"create_safe_debator", "create_safe_debator",
"create_social_media_analyst", "create_social_media_analyst",
"create_trader", "create_trader",

View File

@ -5,7 +5,7 @@ def create_fundamentals_analyst(llm, toolkit):
def fundamentals_analyst_node(state): def fundamentals_analyst_node(state):
current_date = state["trade_date"] current_date = state["trade_date"]
ticker = state["company_of_interest"] ticker = state["company_of_interest"]
company_name = state["company_of_interest"] state["company_of_interest"]
if toolkit.config["online_tools"]: if toolkit.config["online_tools"]:
tools = [toolkit.get_fundamentals_openai] tools = [toolkit.get_fundamentals_openai]
@ -20,7 +20,7 @@ def create_fundamentals_analyst(llm, toolkit):
system_message = ( system_message = (
"You are a researcher tasked with analyzing fundamental information over the past week about a company. Please write a comprehensive report of the company's fundamental information such as financial documents, company profile, basic company financials, company financial history, insider sentiment and insider transactions to gain a full view of the company's fundamental information to inform traders. Make sure to include as much detail as possible. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions." "You are a researcher tasked with analyzing fundamental information over the past week about a company. Please write a comprehensive report of the company's fundamental information such as financial documents, company profile, basic company financials, company financial history, insider sentiment and insider transactions to gain a full view of the company's fundamental information to inform traders. Make sure to include as much detail as possible. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."
+ " Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read.", " Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read.",
) )
prompt = ChatPromptTemplate.from_messages( prompt = ChatPromptTemplate.from_messages(
@ -37,7 +37,7 @@ def create_fundamentals_analyst(llm, toolkit):
"For your reference, the current date is {current_date}. The company we want to look at is {ticker}", "For your reference, the current date is {current_date}. The company we want to look at is {ticker}",
), ),
MessagesPlaceholder(variable_name="messages"), MessagesPlaceholder(variable_name="messages"),
] ],
) )
prompt = prompt.partial(system_message=system_message) prompt = prompt.partial(system_message=system_message)

View File

@ -6,7 +6,7 @@ def create_market_analyst(llm, toolkit):
def market_analyst_node(state): def market_analyst_node(state):
current_date = state["trade_date"] current_date = state["trade_date"]
ticker = state["company_of_interest"] ticker = state["company_of_interest"]
company_name = state["company_of_interest"] state["company_of_interest"]
if toolkit.config["online_tools"]: if toolkit.config["online_tools"]:
tools = [ tools = [
@ -45,7 +45,7 @@ Volume-Based Indicators:
- vwma: VWMA: A moving average weighted by volume. Usage: Confirm trends by integrating price action with volume data. Tips: Watch for skewed results from volume spikes; use in combination with other volume analyses. - vwma: VWMA: A moving average weighted by volume. Usage: Confirm trends by integrating price action with volume data. Tips: Watch for skewed results from volume spikes; use in combination with other volume analyses.
- Select indicators that provide diverse and complementary information. Avoid redundancy (e.g., do not select both rsi and stochrsi). Also briefly explain why they are suitable for the given market context. When you tool call, please use the exact name of the indicators provided above as they are defined parameters, otherwise your call will fail. Please make sure to call get_YFin_data first to retrieve the CSV that is needed to generate indicators. Write a very detailed and nuanced report of the trends you observe. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions.""" - Select indicators that provide diverse and complementary information. Avoid redundancy (e.g., do not select both rsi and stochrsi). Also briefly explain why they are suitable for the given market context. When you tool call, please use the exact name of the indicators provided above as they are defined parameters, otherwise your call will fail. Please make sure to call get_YFin_data first to retrieve the CSV that is needed to generate indicators. Write a very detailed and nuanced report of the trends you observe. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."""
+ """ Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read.""" """ Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read."""
) )
prompt = ChatPromptTemplate.from_messages( prompt = ChatPromptTemplate.from_messages(
@ -62,7 +62,7 @@ Volume-Based Indicators:
"For your reference, the current date is {current_date}. The company we want to look at is {ticker}", "For your reference, the current date is {current_date}. The company we want to look at is {ticker}",
), ),
MessagesPlaceholder(variable_name="messages"), MessagesPlaceholder(variable_name="messages"),
] ],
) )
prompt = prompt.partial(system_message=system_message) prompt = prompt.partial(system_message=system_message)

View File

@ -17,7 +17,7 @@ def create_news_analyst(llm, toolkit):
system_message = ( system_message = (
"You are a news researcher tasked with analyzing recent news and trends over the past week. Please write a comprehensive report of the current state of the world that is relevant for trading and macroeconomics. Look at news from EODHD, and finnhub to be comprehensive. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions." "You are a news researcher tasked with analyzing recent news and trends over the past week. Please write a comprehensive report of the current state of the world that is relevant for trading and macroeconomics. Look at news from EODHD, and finnhub to be comprehensive. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."
+ """ Make sure to append a Makrdown table at the end of the report to organize key points in the report, organized and easy to read.""" """ Make sure to append a Makrdown table at the end of the report to organize key points in the report, organized and easy to read."""
) )
prompt = ChatPromptTemplate.from_messages( prompt = ChatPromptTemplate.from_messages(
@ -34,7 +34,7 @@ def create_news_analyst(llm, toolkit):
"For your reference, the current date is {current_date}. We are looking at the company {ticker}", "For your reference, the current date is {current_date}. We are looking at the company {ticker}",
), ),
MessagesPlaceholder(variable_name="messages"), MessagesPlaceholder(variable_name="messages"),
] ],
) )
prompt = prompt.partial(system_message=system_message) prompt = prompt.partial(system_message=system_message)

View File

@ -5,7 +5,7 @@ def create_social_media_analyst(llm, toolkit):
def social_media_analyst_node(state): def social_media_analyst_node(state):
current_date = state["trade_date"] current_date = state["trade_date"]
ticker = state["company_of_interest"] ticker = state["company_of_interest"]
company_name = state["company_of_interest"] state["company_of_interest"]
if toolkit.config["online_tools"]: if toolkit.config["online_tools"]:
tools = [toolkit.get_stock_news_openai] tools = [toolkit.get_stock_news_openai]
@ -16,7 +16,7 @@ def create_social_media_analyst(llm, toolkit):
system_message = ( system_message = (
"You are a social media and company specific news researcher/analyst tasked with analyzing social media posts, recent company news, and public sentiment for a specific company over the past week. You will be given a company's name your objective is to write a comprehensive long report detailing your analysis, insights, and implications for traders and investors on this company's current state after looking at social media and what people are saying about that company, analyzing sentiment data of what people feel each day about the company, and looking at recent company news. Try to look at all sources possible from social media to sentiment to news. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions." "You are a social media and company specific news researcher/analyst tasked with analyzing social media posts, recent company news, and public sentiment for a specific company over the past week. You will be given a company's name your objective is to write a comprehensive long report detailing your analysis, insights, and implications for traders and investors on this company's current state after looking at social media and what people are saying about that company, analyzing sentiment data of what people feel each day about the company, and looking at recent company news. Try to look at all sources possible from social media to sentiment to news. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."
+ """ Make sure to append a Makrdown table at the end of the report to organize key points in the report, organized and easy to read.""", """ Make sure to append a Makrdown table at the end of the report to organize key points in the report, organized and easy to read.""",
) )
prompt = ChatPromptTemplate.from_messages( prompt = ChatPromptTemplate.from_messages(
@ -33,7 +33,7 @@ def create_social_media_analyst(llm, toolkit):
"For your reference, the current date is {current_date}. The current company we want to analyze is {ticker}", "For your reference, the current date is {current_date}. The current company we want to analyze is {ticker}",
), ),
MessagesPlaceholder(variable_name="messages"), MessagesPlaceholder(variable_name="messages"),
] ],
) )
prompt = prompt.partial(system_message=system_message) prompt = prompt.partial(system_message=system_message)

View File

@ -12,7 +12,7 @@ def create_research_manager(llm, memory):
past_memories = memory.get_memories(curr_situation, n_matches=2) past_memories = memory.get_memories(curr_situation, n_matches=2)
past_memory_str = "" past_memory_str = ""
for i, rec in enumerate(past_memories, 1): for _i, rec in enumerate(past_memories, 1):
past_memory_str += rec["recommendation"] + "\n\n" past_memory_str += rec["recommendation"] + "\n\n"
prompt = f"""As the portfolio manager and debate facilitator, your role is to critically evaluate this round of debate and make a definitive decision: align with the bear analyst, the bull analyst, or choose Hold only if it is strongly justified based on the arguments presented. prompt = f"""As the portfolio manager and debate facilitator, your role is to critically evaluate this round of debate and make a definitive decision: align with the bear analyst, the bull analyst, or choose Hold only if it is strongly justified based on the arguments presented.
@ -24,7 +24,7 @@ Additionally, develop a detailed investment plan for the trader. This should inc
Your Recommendation: A decisive stance supported by the most convincing arguments. Your Recommendation: A decisive stance supported by the most convincing arguments.
Rationale: An explanation of why these arguments lead to your conclusion. Rationale: An explanation of why these arguments lead to your conclusion.
Strategic Actions: Concrete steps for implementing the recommendation. Strategic Actions: Concrete steps for implementing the recommendation.
Take into account your past mistakes on similar situations. Use these insights to refine your decision-making and ensure you are learning and improving. Present your analysis conversationally, as if speaking naturally, without special formatting. Take into account your past mistakes on similar situations. Use these insights to refine your decision-making and ensure you are learning and improving. Present your analysis conversationally, as if speaking naturally, without special formatting.
Here are your past reflections on mistakes: Here are your past reflections on mistakes:
\"{past_memory_str}\" \"{past_memory_str}\"

View File

@ -1,7 +1,7 @@
def create_risk_manager(llm, memory): def create_risk_manager(llm, memory):
def risk_manager_node(state) -> dict: def risk_manager_node(state) -> dict:
company_name = state["company_of_interest"] state["company_of_interest"]
history = state["risk_debate_state"]["history"] history = state["risk_debate_state"]["history"]
risk_debate_state = state["risk_debate_state"] risk_debate_state = state["risk_debate_state"]
@ -15,7 +15,7 @@ def create_risk_manager(llm, memory):
past_memories = memory.get_memories(curr_situation, n_matches=2) past_memories = memory.get_memories(curr_situation, n_matches=2)
past_memory_str = "" past_memory_str = ""
for i, rec in enumerate(past_memories, 1): for _i, rec in enumerate(past_memories, 1):
past_memory_str += rec["recommendation"] + "\n\n" past_memory_str += rec["recommendation"] + "\n\n"
prompt = f"""As the Risk Management Judge and Debate Facilitator, your goal is to evaluate the debate between three risk analysts—Risky, Neutral, and Safe/Conservative—and determine the best course of action for the trader. Your decision must result in a clear recommendation: Buy, Sell, or Hold. Choose Hold only if strongly justified by specific arguments, not as a fallback when all sides seem valid. Strive for clarity and decisiveness. prompt = f"""As the Risk Management Judge and Debate Facilitator, your goal is to evaluate the debate between three risk analysts—Risky, Neutral, and Safe/Conservative—and determine the best course of action for the trader. Your decision must result in a clear recommendation: Buy, Sell, or Hold. Choose Hold only if strongly justified by specific arguments, not as a fallback when all sides seem valid. Strive for clarity and decisiveness.
@ -32,7 +32,7 @@ Deliverables:
--- ---
**Analysts Debate History:** **Analysts Debate History:**
{history} {history}
--- ---

View File

@ -14,7 +14,7 @@ def create_bear_researcher(llm, memory):
past_memories = memory.get_memories(curr_situation, n_matches=2) past_memories = memory.get_memories(curr_situation, n_matches=2)
past_memory_str = "" past_memory_str = ""
for i, rec in enumerate(past_memories, 1): for _i, rec in enumerate(past_memories, 1):
past_memory_str += rec["recommendation"] + "\n\n" past_memory_str += rec["recommendation"] + "\n\n"
prompt = f"""You are a Bear Analyst making the case against investing in the stock. Your goal is to present a well-reasoned argument emphasizing risks, challenges, and negative indicators. Leverage the provided research and data to highlight potential downsides and counter bullish arguments effectively. prompt = f"""You are a Bear Analyst making the case against investing in the stock. Your goal is to present a well-reasoned argument emphasizing risks, challenges, and negative indicators. Leverage the provided research and data to highlight potential downsides and counter bullish arguments effectively.

View File

@ -14,7 +14,7 @@ def create_bull_researcher(llm, memory):
past_memories = memory.get_memories(curr_situation, n_matches=2) past_memories = memory.get_memories(curr_situation, n_matches=2)
past_memory_str = "" past_memory_str = ""
for i, rec in enumerate(past_memories, 1): for _i, rec in enumerate(past_memories, 1):
past_memory_str += rec["recommendation"] + "\n\n" past_memory_str += rec["recommendation"] + "\n\n"
prompt = f"""You are a Bull Analyst advocating for investing in the stock. Your task is to build a strong, evidence-based case emphasizing growth potential, competitive advantages, and positive market indicators. Leverage the provided research and data to address concerns and counter bearish arguments effectively. prompt = f"""You are a Bull Analyst advocating for investing in the stock. Your task is to build a strong, evidence-based case emphasizing growth potential, competitive advantages, and positive market indicators. Leverage the provided research and data to address concerns and counter bearish arguments effectively.

View File

@ -41,7 +41,7 @@ Engage actively by addressing any specific concerns raised, refuting the weaknes
"current_risky_response": argument, "current_risky_response": argument,
"current_safe_response": risk_debate_state.get("current_safe_response", ""), "current_safe_response": risk_debate_state.get("current_safe_response", ""),
"current_neutral_response": risk_debate_state.get( "current_neutral_response": risk_debate_state.get(
"current_neutral_response", "" "current_neutral_response", "",
), ),
"count": risk_debate_state["count"] + 1, "count": risk_debate_state["count"] + 1,
} }

View File

@ -39,11 +39,11 @@ Engage by questioning their optimism and emphasizing the potential downsides the
"neutral_history": risk_debate_state.get("neutral_history", ""), "neutral_history": risk_debate_state.get("neutral_history", ""),
"latest_speaker": "Safe", "latest_speaker": "Safe",
"current_risky_response": risk_debate_state.get( "current_risky_response": risk_debate_state.get(
"current_risky_response", "" "current_risky_response", "",
), ),
"current_safe_response": argument, "current_safe_response": argument,
"current_neutral_response": risk_debate_state.get( "current_neutral_response": risk_debate_state.get(
"current_neutral_response", "" "current_neutral_response", "",
), ),
"count": risk_debate_state["count"] + 1, "count": risk_debate_state["count"] + 1,
} }

View File

@ -39,7 +39,7 @@ Engage actively by analyzing both sides critically, addressing weaknesses in the
"neutral_history": neutral_history + "\n" + argument, "neutral_history": neutral_history + "\n" + argument,
"latest_speaker": "Neutral", "latest_speaker": "Neutral",
"current_risky_response": risk_debate_state.get( "current_risky_response": risk_debate_state.get(
"current_risky_response", "" "current_risky_response", "",
), ),
"current_safe_response": risk_debate_state.get("current_safe_response", ""), "current_safe_response": risk_debate_state.get("current_safe_response", ""),
"current_neutral_response": argument, "current_neutral_response": argument,

View File

@ -15,7 +15,7 @@ def create_trader(llm, memory):
past_memory_str = "" past_memory_str = ""
if past_memories: if past_memories:
for i, rec in enumerate(past_memories, 1): for _i, rec in enumerate(past_memories, 1):
past_memory_str += rec["recommendation"] + "\n\n" past_memory_str += rec["recommendation"] + "\n\n"
else: else:
past_memory_str = "No past memories found." past_memory_str = "No past memories found."

View File

@ -1,16 +1,18 @@
from typing import Annotated from typing import Annotated
from typing_extensions import TypedDict
from tradingagents.agents import *
from langgraph.graph import MessagesState from langgraph.graph import MessagesState
from typing_extensions import TypedDict
# Import specific agent classes as needed
# Researcher team state # Researcher team state
class InvestDebateState(TypedDict): class InvestDebateState(TypedDict):
bull_history: Annotated[ bull_history: Annotated[
str, "Bullish Conversation history" str, "Bullish Conversation history",
] # Bullish Conversation history ] # Bullish Conversation history
bear_history: Annotated[ bear_history: Annotated[
str, "Bearish Conversation history" str, "Bearish Conversation history",
] # Bullish Conversation history ] # Bullish Conversation history
history: Annotated[str, "Conversation history"] # Conversation history history: Annotated[str, "Conversation history"] # Conversation history
current_response: Annotated[str, "Latest response"] # Last response current_response: Annotated[str, "Latest response"] # Last response
@ -21,24 +23,24 @@ class InvestDebateState(TypedDict):
# Risk management team state # Risk management team state
class RiskDebateState(TypedDict): class RiskDebateState(TypedDict):
risky_history: Annotated[ risky_history: Annotated[
str, "Risky Agent's Conversation history" str, "Risky Agent's Conversation history",
] # Conversation history ] # Conversation history
safe_history: Annotated[ safe_history: Annotated[
str, "Safe Agent's Conversation history" str, "Safe Agent's Conversation history",
] # Conversation history ] # Conversation history
neutral_history: Annotated[ neutral_history: Annotated[
str, "Neutral Agent's Conversation history" str, "Neutral Agent's Conversation history",
] # Conversation history ] # Conversation history
history: Annotated[str, "Conversation history"] # Conversation history history: Annotated[str, "Conversation history"] # Conversation history
latest_speaker: Annotated[str, "Analyst that spoke last"] latest_speaker: Annotated[str, "Analyst that spoke last"]
current_risky_response: Annotated[ current_risky_response: Annotated[
str, "Latest response by the risky analyst" str, "Latest response by the risky analyst",
] # Last response ] # Last response
current_safe_response: Annotated[ current_safe_response: Annotated[
str, "Latest response by the safe analyst" str, "Latest response by the safe analyst",
] # Last response ] # Last response
current_neutral_response: Annotated[ current_neutral_response: Annotated[
str, "Latest response by the neutral analyst" str, "Latest response by the neutral analyst",
] # Last response ] # Last response
judge_decision: Annotated[str, "Judge's decision"] judge_decision: Annotated[str, "Judge's decision"]
count: Annotated[int, "Length of the current conversation"] # Conversation length count: Annotated[int, "Length of the current conversation"] # Conversation length
@ -54,13 +56,13 @@ class AgentState(MessagesState):
market_report: Annotated[str, "Report from the Market Analyst"] market_report: Annotated[str, "Report from the Market Analyst"]
sentiment_report: Annotated[str, "Report from the Social Media Analyst"] sentiment_report: Annotated[str, "Report from the Social Media Analyst"]
news_report: Annotated[ news_report: Annotated[
str, "Report from the News Researcher of current world affairs" str, "Report from the News Researcher of current world affairs",
] ]
fundamentals_report: Annotated[str, "Report from the Fundamentals Researcher"] fundamentals_report: Annotated[str, "Report from the Fundamentals Researcher"]
# researcher team discussion step # researcher team discussion step
investment_debate_state: Annotated[ investment_debate_state: Annotated[
InvestDebateState, "Current state of the debate on if to invest or not" InvestDebateState, "Current state of the debate on if to invest or not",
] ]
investment_plan: Annotated[str, "Plan generated by the Analyst"] investment_plan: Annotated[str, "Plan generated by the Analyst"]
@ -68,6 +70,6 @@ class AgentState(MessagesState):
# risk management team discussion step # risk management team discussion step
risk_debate_state: Annotated[ risk_debate_state: Annotated[
RiskDebateState, "Current state of the debate on evaluating risk" RiskDebateState, "Current state of the debate on evaluating risk",
] ]
final_trade_decision: Annotated[str, "Final decision made by the Risk Analysts"] final_trade_decision: Annotated[str, "Final decision made by the Risk Analysts"]

View File

@ -1,9 +1,10 @@
from langchain_core.messages import HumanMessage
from typing import Annotated
from langchain_core.messages import RemoveMessage
from langchain_core.tools import tool
from datetime import datetime from datetime import datetime
import tradingagents.dataflows.interface as interface from typing import Annotated
from langchain_core.messages import HumanMessage, RemoveMessage
from langchain_core.tools import tool
from tradingagents.dataflows import interface
from tradingagents.default_config import DEFAULT_CONFIG from tradingagents.default_config import DEFAULT_CONFIG
@ -18,7 +19,7 @@ def create_msg_delete():
# Add a minimal placeholder message # Add a minimal placeholder message
placeholder = HumanMessage(content="Continue") placeholder = HumanMessage(content="Continue")
return {"messages": removal_operations + [placeholder]} return {"messages": [*removal_operations, placeholder]}
return delete_messages return delete_messages
@ -53,9 +54,8 @@ class Toolkit:
str: A formatted dataframe containing the latest global news from Reddit in the specified time frame. str: A formatted dataframe containing the latest global news from Reddit in the specified time frame.
""" """
global_news_result = interface.get_reddit_global_news(curr_date, 7, 5) return interface.get_reddit_global_news(curr_date, 7, 5)
return global_news_result
@staticmethod @staticmethod
@tool @tool
@ -83,11 +83,10 @@ class Toolkit:
start_date = datetime.strptime(start_date, "%Y-%m-%d") start_date = datetime.strptime(start_date, "%Y-%m-%d")
look_back_days = (end_date - start_date).days look_back_days = (end_date - start_date).days
finnhub_news_result = interface.get_finnhub_news( return interface.get_finnhub_news(
ticker, end_date_str, look_back_days ticker, end_date_str, look_back_days,
) )
return finnhub_news_result
@staticmethod @staticmethod
@tool @tool
@ -107,9 +106,8 @@ class Toolkit:
str: A formatted dataframe containing the latest news about the company on the given date str: A formatted dataframe containing the latest news about the company on the given date
""" """
stock_news_results = interface.get_reddit_company_news(ticker, curr_date, 7, 5) return interface.get_reddit_company_news(ticker, curr_date, 7, 5)
return stock_news_results
@staticmethod @staticmethod
@tool @tool
@ -128,9 +126,8 @@ class Toolkit:
str: A formatted dataframe containing the stock price data for the specified ticker symbol in the specified date range. str: A formatted dataframe containing the stock price data for the specified ticker symbol in the specified date range.
""" """
result_data = interface.get_YFin_data(symbol, start_date, end_date) return interface.get_YFin_data(symbol, start_date, end_date)
return result_data
@staticmethod @staticmethod
@tool @tool
@ -149,19 +146,18 @@ class Toolkit:
str: A formatted dataframe containing the stock price data for the specified ticker symbol in the specified date range. str: A formatted dataframe containing the stock price data for the specified ticker symbol in the specified date range.
""" """
result_data = interface.get_YFin_data_online(symbol, start_date, end_date) return interface.get_YFin_data_online(symbol, start_date, end_date)
return result_data
@staticmethod @staticmethod
@tool @tool
def get_stockstats_indicators_report( def get_stockstats_indicators_report(
symbol: Annotated[str, "ticker symbol of the company"], symbol: Annotated[str, "ticker symbol of the company"],
indicator: Annotated[ indicator: Annotated[
str, "technical indicator to get the analysis and report of" str, "technical indicator to get the analysis and report of",
], ],
curr_date: Annotated[ curr_date: Annotated[
str, "The current trading date you are trading on, YYYY-mm-dd" str, "The current trading date you are trading on, YYYY-mm-dd",
], ],
look_back_days: Annotated[int, "how many days to look back"] = 30, look_back_days: Annotated[int, "how many days to look back"] = 30,
) -> str: ) -> str:
@ -176,21 +172,20 @@ class Toolkit:
str: A formatted dataframe containing the stock stats indicators for the specified ticker symbol and indicator. str: A formatted dataframe containing the stock stats indicators for the specified ticker symbol and indicator.
""" """
result_stockstats = interface.get_stock_stats_indicators_window( return interface.get_stock_stats_indicators_window(
symbol, indicator, curr_date, look_back_days, False symbol, indicator, curr_date, look_back_days, False,
) )
return result_stockstats
@staticmethod @staticmethod
@tool @tool
def get_stockstats_indicators_report_online( def get_stockstats_indicators_report_online(
symbol: Annotated[str, "ticker symbol of the company"], symbol: Annotated[str, "ticker symbol of the company"],
indicator: Annotated[ indicator: Annotated[
str, "technical indicator to get the analysis and report of" str, "technical indicator to get the analysis and report of",
], ],
curr_date: Annotated[ curr_date: Annotated[
str, "The current trading date you are trading on, YYYY-mm-dd" str, "The current trading date you are trading on, YYYY-mm-dd",
], ],
look_back_days: Annotated[int, "how many days to look back"] = 30, look_back_days: Annotated[int, "how many days to look back"] = 30,
) -> str: ) -> str:
@ -205,11 +200,10 @@ class Toolkit:
str: A formatted dataframe containing the stock stats indicators for the specified ticker symbol and indicator. str: A formatted dataframe containing the stock stats indicators for the specified ticker symbol and indicator.
""" """
result_stockstats = interface.get_stock_stats_indicators_window( return interface.get_stock_stats_indicators_window(
symbol, indicator, curr_date, look_back_days, True symbol, indicator, curr_date, look_back_days, True,
) )
return result_stockstats
@staticmethod @staticmethod
@tool @tool
@ -229,11 +223,10 @@ class Toolkit:
str: a report of the sentiment in the past 30 days starting at curr_date str: a report of the sentiment in the past 30 days starting at curr_date
""" """
data_sentiment = interface.get_finnhub_company_insider_sentiment( return interface.get_finnhub_company_insider_sentiment(
ticker, curr_date, 30 ticker, curr_date, 30,
) )
return data_sentiment
@staticmethod @staticmethod
@tool @tool
@ -253,11 +246,10 @@ class Toolkit:
str: a report of the company's insider transactions/trading information in the past 30 days str: a report of the company's insider transactions/trading information in the past 30 days
""" """
data_trans = interface.get_finnhub_company_insider_transactions( return interface.get_finnhub_company_insider_transactions(
ticker, curr_date, 30 ticker, curr_date, 30,
) )
return data_trans
@staticmethod @staticmethod
@tool @tool
@ -279,9 +271,8 @@ class Toolkit:
str: a report of the company's most recent balance sheet str: a report of the company's most recent balance sheet
""" """
data_balance_sheet = interface.get_simfin_balance_sheet(ticker, freq, curr_date) return interface.get_simfin_balance_sheet(ticker, freq, curr_date)
return data_balance_sheet
@staticmethod @staticmethod
@tool @tool
@ -303,9 +294,8 @@ class Toolkit:
str: a report of the company's most recent cash flow statement str: a report of the company's most recent cash flow statement
""" """
data_cashflow = interface.get_simfin_cashflow(ticker, freq, curr_date) return interface.get_simfin_cashflow(ticker, freq, curr_date)
return data_cashflow
@staticmethod @staticmethod
@tool @tool
@ -327,11 +317,10 @@ class Toolkit:
str: a report of the company's most recent income statement str: a report of the company's most recent income statement
""" """
data_income_stmt = interface.get_simfin_income_statements( return interface.get_simfin_income_statements(
ticker, freq, curr_date ticker, freq, curr_date,
) )
return data_income_stmt
@staticmethod @staticmethod
@tool @tool
@ -349,9 +338,8 @@ class Toolkit:
str: A formatted string containing the latest news from Google News based on the query and date range. str: A formatted string containing the latest news from Google News based on the query and date range.
""" """
google_news_results = interface.get_google_news(query, curr_date, 7) return interface.get_google_news(query, curr_date, 7)
return google_news_results
@staticmethod @staticmethod
@tool @tool
@ -368,9 +356,8 @@ class Toolkit:
str: A formatted string containing the latest news about the company on the given date. str: A formatted string containing the latest news about the company on the given date.
""" """
openai_news_results = interface.get_stock_news_openai(ticker, curr_date) return interface.get_stock_news_openai(ticker, curr_date)
return openai_news_results
@staticmethod @staticmethod
@tool @tool
@ -385,9 +372,8 @@ class Toolkit:
str: A formatted string containing the latest macroeconomic news on the given date. str: A formatted string containing the latest macroeconomic news on the given date.
""" """
openai_news_results = interface.get_global_news_openai(curr_date) return interface.get_global_news_openai(curr_date)
return openai_news_results
@staticmethod @staticmethod
@tool @tool
@ -404,8 +390,7 @@ class Toolkit:
str: A formatted string containing the latest fundamental information about the company on the given date. str: A formatted string containing the latest fundamental information about the company on the given date.
""" """
openai_fundamentals_results = interface.get_fundamentals_openai( return interface.get_fundamentals_openai(
ticker, curr_date ticker, curr_date,
) )
return openai_fundamentals_results

View File

@ -59,7 +59,7 @@ class FinancialSituationMemory:
"matched_situation": results["documents"][0][i], "matched_situation": results["documents"][0][i],
"recommendation": results["metadatas"][0][i]["recommendation"], "recommendation": results["metadatas"][0][i]["recommendation"],
"similarity_score": 1 - results["distances"][0][i], "similarity_score": 1 - results["distances"][0][i],
} },
) )
return matched_results return matched_results
@ -94,18 +94,15 @@ if __name__ == "__main__":
# Example query # Example query
current_situation = """ current_situation = """
Market showing increased volatility in tech sector, with institutional investors Market showing increased volatility in tech sector, with institutional investors
reducing positions and rising interest rates affecting growth stock valuations reducing positions and rising interest rates affecting growth stock valuations
""" """
try: try:
recommendations = matcher.get_memories(current_situation, n_matches=2) recommendations = matcher.get_memories(current_situation, n_matches=2)
for i, rec in enumerate(recommendations, 1): for _i, _rec in enumerate(recommendations, 1):
print(f"\nMatch {i}:") pass
print(f"Similarity Score: {rec['similarity_score']:.2f}")
print(f"Matched Situation: {rec['matched_situation']}")
print(f"Recommendation: {rec['recommendation']}")
except Exception as e: except Exception:
print(f"Error during recommendation: {str(e)}") pass

View File

@ -5,6 +5,7 @@ Loads configuration from environment variables and .env file.
import os import os
from pathlib import Path from pathlib import Path
from dotenv import load_dotenv from dotenv import load_dotenv
# Load .env file from project root # Load .env file from project root
@ -20,10 +21,10 @@ def get_config():
"project_dir": str(project_root / "tradingagents"), "project_dir": str(project_root / "tradingagents"),
"results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", "./results"), "results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", "./results"),
"data_dir": os.getenv( "data_dir": os.getenv(
"TRADINGAGENTS_DATA_DIR", "/Users/yluo/Documents/Code/ScAI/FR1-data" "TRADINGAGENTS_DATA_DIR", "/Users/yluo/Documents/Code/ScAI/FR1-data",
), ),
"data_cache_dir": str( "data_cache_dir": str(
project_root / "tradingagents" / "dataflows" / "data_cache" project_root / "tradingagents" / "dataflows" / "data_cache",
), ),
# LLM settings # LLM settings
"llm_provider": os.getenv("LLM_PROVIDER", "openai"), "llm_provider": os.getenv("LLM_PROVIDER", "openai"),
@ -47,16 +48,17 @@ def get_config():
# Validate required API keys based on provider # Validate required API keys based on provider
if config["llm_provider"] == "openai" and not config["openai_api_key"]: if config["llm_provider"] == "openai" and not config["openai_api_key"]:
raise ValueError("OPENAI_API_KEY is required when using OpenAI provider") msg = "OPENAI_API_KEY is required when using OpenAI provider"
elif config["llm_provider"] == "anthropic" and not config["anthropic_api_key"]: raise ValueError(msg)
raise ValueError("ANTHROPIC_API_KEY is required when using Anthropic provider") if config["llm_provider"] == "anthropic" and not config["anthropic_api_key"]:
elif config["llm_provider"] == "google" and not config["google_api_key"]: msg = "ANTHROPIC_API_KEY is required when using Anthropic provider"
raise ValueError("GOOGLE_API_KEY is required when using Google provider") raise ValueError(msg)
if config["llm_provider"] == "google" and not config["google_api_key"]:
msg = "GOOGLE_API_KEY is required when using Google provider"
raise ValueError(msg)
if not config["finnhub_api_key"]: if not config["finnhub_api_key"]:
print( pass
"Warning: FINNHUB_API_KEY not set. Some financial data features may be limited."
)
return config return config

View File

@ -1,17 +1,13 @@
from .finnhub_utils import get_data_in_range from .finnhub_utils import get_data_in_range
from .googlenews_utils import getNewsData from .googlenews_utils import getNewsData
from .yfin_utils import YFinanceUtils
from .reddit_utils import fetch_top_from_category
from .stockstats_utils import StockstatsUtils
from .interface import ( from .interface import (
# News and sentiment functions
get_finnhub_news,
get_finnhub_company_insider_sentiment, get_finnhub_company_insider_sentiment,
get_finnhub_company_insider_transactions, get_finnhub_company_insider_transactions,
# News and sentiment functions
get_finnhub_news,
get_google_news, get_google_news,
get_reddit_global_news,
get_reddit_company_news, get_reddit_company_news,
get_reddit_global_news,
# Financial statements functions # Financial statements functions
get_simfin_balance_sheet, get_simfin_balance_sheet,
get_simfin_cashflow, get_simfin_cashflow,
@ -19,19 +15,25 @@ from .interface import (
# Technical analysis functions # Technical analysis functions
get_stock_stats_indicators_window, get_stock_stats_indicators_window,
get_stockstats_indicator, get_stockstats_indicator,
get_YFin_data,
# Market data functions # Market data functions
get_YFin_data_window, get_YFin_data_window,
get_YFin_data,
) )
from .reddit_utils import fetch_top_from_category
from .stockstats_utils import StockstatsUtils
from .yfin_utils import YFinanceUtils
__all__ = [ __all__ = [
# News and sentiment functions "get_YFin_data",
"get_finnhub_news", # Market data functions
"get_YFin_data_window",
"get_finnhub_company_insider_sentiment", "get_finnhub_company_insider_sentiment",
"get_finnhub_company_insider_transactions", "get_finnhub_company_insider_transactions",
# News and sentiment functions
"get_finnhub_news",
"get_google_news", "get_google_news",
"get_reddit_global_news",
"get_reddit_company_news", "get_reddit_company_news",
"get_reddit_global_news",
# Financial statements functions # Financial statements functions
"get_simfin_balance_sheet", "get_simfin_balance_sheet",
"get_simfin_cashflow", "get_simfin_cashflow",
@ -39,7 +41,10 @@ __all__ = [
# Technical analysis functions # Technical analysis functions
"get_stock_stats_indicators_window", "get_stock_stats_indicators_window",
"get_stockstats_indicator", "get_stockstats_indicator",
# Market data functions # Utilities and classes
"get_YFin_data_window", "get_data_in_range",
"get_YFin_data", "getNewsData",
"YFinanceUtils",
"fetch_top_from_category",
"StockstatsUtils",
] ]

View File

@ -1,9 +1,9 @@
import tradingagents.default_config as default_config
from typing import Dict, Optional from tradingagents import default_config
# Use default config but allow it to be overridden # Use default config but allow it to be overridden
_config: Optional[Dict] = None _config: dict | None = None
DATA_DIR: Optional[str] = None DATA_DIR: str | None = None
def initialize_config(): def initialize_config():
@ -14,7 +14,7 @@ def initialize_config():
DATA_DIR = _config["data_dir"] DATA_DIR = _config["data_dir"]
def set_config(config: Dict): def set_config(config: dict):
"""Update the configuration with custom values.""" """Update the configuration with custom values."""
global _config, DATA_DIR global _config, DATA_DIR
if _config is None: if _config is None:
@ -23,7 +23,7 @@ def set_config(config: Dict):
DATA_DIR = _config["data_dir"] DATA_DIR = _config["data_dir"]
def get_config() -> Dict: def get_config() -> dict:
"""Get the current configuration.""" """Get the current configuration."""
if _config is None: if _config is None:
initialize_config() initialize_config()

View File

@ -22,10 +22,10 @@ def get_data_in_range(ticker, start_date, end_date, data_type, data_dir, period=
) )
else: else:
data_path = os.path.join( data_path = os.path.join(
data_dir, "finnhub_data", data_type, f"{ticker}_data_formatted.json" data_dir, "finnhub_data", data_type, f"{ticker}_data_formatted.json",
) )
data = open(data_path, "r") data = open(data_path)
data = json.load(data) data = json.load(data)
# filter keys (date, str in format YYYY-MM-DD) by the date range (str, str in format YYYY-MM-DD) # filter keys (date, str in format YYYY-MM-DD) by the date range (str, str in format YYYY-MM-DD)

View File

@ -1,20 +1,20 @@
import requests
from bs4 import BeautifulSoup
from datetime import datetime
import time
import random
import logging import logging
import random
import time
from datetime import datetime
from urllib.parse import quote_plus from urllib.parse import quote_plus
logger = logging.getLogger(__name__) import requests
from bs4 import BeautifulSoup
from tenacity import ( from tenacity import (
retry, retry,
retry_if_result,
stop_after_attempt, stop_after_attempt,
wait_exponential, wait_exponential,
retry_if_result,
) )
logger = logging.getLogger(__name__)
def is_rate_limited(response): def is_rate_limited(response):
"""Check if the response indicates we should back off (rate-limited or temporarily unavailable).""" """Check if the response indicates we should back off (rate-limited or temporarily unavailable)."""
@ -35,8 +35,7 @@ def _add_jitter(retry_state):
def make_request(url, headers): def make_request(url, headers):
"""Make a request with retry logic for rate limiting""" """Make a request with retry logic for rate limiting"""
# The retry decorator already applies exponential backoff with jitter # The retry decorator already applies exponential backoff with jitter
response = requests.get(url, headers=headers, timeout=(5, 20)) return requests.get(url, headers=headers, timeout=(5, 20))
return response
def getNewsData(query, start_date, end_date): def getNewsData(query, start_date, end_date):
@ -58,7 +57,7 @@ def getNewsData(query, start_date, end_date):
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) " "Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
"AppleWebKit/537.36 (KHTML, like Gecko) " "AppleWebKit/537.36 (KHTML, like Gecko) "
"Chrome/101.0.4951.54 Safari/537.36" "Chrome/101.0.4951.54 Safari/537.36"
) ),
} }
news_results = [] news_results = []
@ -103,7 +102,7 @@ def getNewsData(query, start_date, end_date):
"source": ( "source": (
source_el.get_text(strip=True) if source_el else "" source_el.get_text(strip=True) if source_el else ""
), ),
} },
) )
except Exception as e: except Exception as e:
logger.warning("Error processing result: %s", e) logger.warning("Error processing result: %s", e)
@ -120,7 +119,7 @@ def getNewsData(query, start_date, end_date):
page += 1 page += 1
except Exception as e: except Exception as e:
logger.error("Failed after multiple retries: %s", e) logger.exception("Failed after multiple retries: %s", e)
break break
return news_results return news_results

View File

@ -1,17 +1,18 @@
from typing import Annotated
from .reddit_utils import fetch_top_from_category
from .yfin_utils import *
from .stockstats_utils import *
from .googlenews_utils import *
from .finnhub_utils import get_data_in_range
from dateutil.relativedelta import relativedelta
from datetime import datetime
import os import os
from datetime import datetime
from typing import Annotated
import pandas as pd import pandas as pd
from tqdm import tqdm
import yfinance as yf import yfinance as yf
from dateutil.relativedelta import relativedelta
from openai import OpenAI from openai import OpenAI
from .config import get_config, DATA_DIR from tqdm import tqdm
from .config import DATA_DIR, get_config
from .finnhub_utils import get_data_in_range
from .googlenews_utils import getNewsData
from .reddit_utils import fetch_top_from_category
from .stockstats_utils import StockstatsUtils
def get_finnhub_news( def get_finnhub_news(
@ -84,7 +85,7 @@ def get_finnhub_company_insider_sentiment(
result_str = "" result_str = ""
seen_dicts = [] seen_dicts = []
for date, senti_list in data.items(): for senti_list in data.values():
for entry in senti_list: for entry in senti_list:
if entry not in seen_dicts: if entry not in seen_dicts:
result_str += f"### {entry['year']}-{entry['month']}:\nChange: {entry['change']}\nMonthly Share Purchase Ratio: {entry['mspr']}\n\n" result_str += f"### {entry['year']}-{entry['month']}:\nChange: {entry['change']}\nMonthly Share Purchase Ratio: {entry['mspr']}\n\n"
@ -126,7 +127,7 @@ def get_finnhub_company_insider_transactions(
result_str = "" result_str = ""
seen_dicts = [] seen_dicts = []
for date, senti_list in data.items(): for senti_list in data.values():
for entry in senti_list: for entry in senti_list:
if entry not in seen_dicts: if entry not in seen_dicts:
result_str += f"### Filing Date: {entry['filingDate']}, {entry['name']}:\nChange:{entry['change']}\nShares: {entry['share']}\nTransaction Price: {entry['transactionPrice']}\nTransaction Code: {entry['transactionCode']}\n\n" result_str += f"### Filing Date: {entry['filingDate']}, {entry['name']}:\nChange:{entry['change']}\nShares: {entry['share']}\nTransaction Price: {entry['transactionPrice']}\nTransaction Code: {entry['transactionCode']}\n\n"
@ -170,7 +171,6 @@ def get_simfin_balance_sheet(
# Check if there are any available reports; if not, return a notification # Check if there are any available reports; if not, return a notification
if filtered_df.empty: if filtered_df.empty:
print("No balance sheet available before the given current date.")
return "" return ""
# Get the most recent balance sheet by selecting the row with the latest Publish Date # Get the most recent balance sheet by selecting the row with the latest Publish Date
@ -217,7 +217,6 @@ def get_simfin_cashflow(
# Check if there are any available reports; if not, return a notification # Check if there are any available reports; if not, return a notification
if filtered_df.empty: if filtered_df.empty:
print("No cash flow statement available before the given current date.")
return "" return ""
# Get the most recent cash flow statement by selecting the row with the latest Publish Date # Get the most recent cash flow statement by selecting the row with the latest Publish Date
@ -264,7 +263,6 @@ def get_simfin_income_statements(
# Check if there are any available reports; if not, return a notification # Check if there are any available reports; if not, return a notification
if filtered_df.empty: if filtered_df.empty:
print("No income statement available before the given current date.")
return "" return ""
# Get the most recent income statement by selecting the row with the latest Publish Date # Get the most recent income statement by selecting the row with the latest Publish Date
@ -421,7 +419,7 @@ def get_stock_stats_indicators_window(
symbol: Annotated[str, "ticker symbol of the company"], symbol: Annotated[str, "ticker symbol of the company"],
indicator: Annotated[str, "technical indicator to get the analysis and report of"], indicator: Annotated[str, "technical indicator to get the analysis and report of"],
curr_date: Annotated[ curr_date: Annotated[
str, "The current trading date you are trading on, YYYY-mm-dd" str, "The current trading date you are trading on, YYYY-mm-dd",
], ],
look_back_days: Annotated[int, "how many days to look back"], look_back_days: Annotated[int, "how many days to look back"],
online: Annotated[bool, "to fetch data online or offline"], online: Annotated[bool, "to fetch data online or offline"],
@ -501,8 +499,9 @@ def get_stock_stats_indicators_window(
} }
if indicator not in best_ind_params: if indicator not in best_ind_params:
msg = f"Indicator {indicator} is not supported. Please choose from: {list(best_ind_params.keys())}"
raise ValueError( raise ValueError(
f"Indicator {indicator} is not supported. Please choose from: {list(best_ind_params.keys())}" msg,
) )
end_date = curr_date end_date = curr_date
@ -515,7 +514,7 @@ def get_stock_stats_indicators_window(
os.path.join( os.path.join(
DATA_DIR, DATA_DIR,
f"market_data/price_data/{symbol}-YFin-data-2015-01-01-2025-03-25.csv", f"market_data/price_data/{symbol}-YFin-data-2015-01-01-2025-03-25.csv",
) ),
) )
data["Date"] = pd.to_datetime(data["Date"], utc=True) data["Date"] = pd.to_datetime(data["Date"], utc=True)
dates_in_df = data["Date"].astype(str).str[:10] dates_in_df = data["Date"].astype(str).str[:10]
@ -525,7 +524,7 @@ def get_stock_stats_indicators_window(
# only do the trading dates # only do the trading dates
if curr_date.strftime("%Y-%m-%d") in dates_in_df.values: if curr_date.strftime("%Y-%m-%d") in dates_in_df.values:
indicator_value = get_stockstats_indicator( indicator_value = get_stockstats_indicator(
symbol, indicator, curr_date.strftime("%Y-%m-%d"), online symbol, indicator, curr_date.strftime("%Y-%m-%d"), online,
) )
ind_string += f"{curr_date.strftime('%Y-%m-%d')}: {indicator_value}\n" ind_string += f"{curr_date.strftime('%Y-%m-%d')}: {indicator_value}\n"
@ -536,28 +535,27 @@ def get_stock_stats_indicators_window(
ind_string = "" ind_string = ""
while curr_date >= before: while curr_date >= before:
indicator_value = get_stockstats_indicator( indicator_value = get_stockstats_indicator(
symbol, indicator, curr_date.strftime("%Y-%m-%d"), online symbol, indicator, curr_date.strftime("%Y-%m-%d"), online,
) )
ind_string += f"{curr_date.strftime('%Y-%m-%d')}: {indicator_value}\n" ind_string += f"{curr_date.strftime('%Y-%m-%d')}: {indicator_value}\n"
curr_date = curr_date - relativedelta(days=1) curr_date = curr_date - relativedelta(days=1)
result_str = ( return (
f"## {indicator} values from {before.strftime('%Y-%m-%d')} to {end_date}:\n\n" f"## {indicator} values from {before.strftime('%Y-%m-%d')} to {end_date}:\n\n"
+ ind_string + ind_string
+ "\n\n" + "\n\n"
+ best_ind_params.get(indicator, "No description available.") + best_ind_params.get(indicator, "No description available.")
) )
return result_str
def get_stockstats_indicator( def get_stockstats_indicator(
symbol: Annotated[str, "ticker symbol of the company"], symbol: Annotated[str, "ticker symbol of the company"],
indicator: Annotated[str, "technical indicator to get the analysis and report of"], indicator: Annotated[str, "technical indicator to get the analysis and report of"],
curr_date: Annotated[ curr_date: Annotated[
str, "The current trading date you are trading on, YYYY-mm-dd" str, "The current trading date you are trading on, YYYY-mm-dd",
], ],
online: Annotated[bool, "to fetch data online or offline"], online: Annotated[bool, "to fetch data online or offline"],
) -> str: ) -> str:
@ -573,10 +571,7 @@ def get_stockstats_indicator(
os.path.join(DATA_DIR, "market_data", "price_data"), os.path.join(DATA_DIR, "market_data", "price_data"),
online=online, online=online,
) )
except Exception as e: except Exception:
print(
f"Error getting stockstats indicator data for indicator {indicator} on {curr_date}: {e}"
)
return "" return ""
return str(indicator_value) return str(indicator_value)
@ -597,7 +592,7 @@ def get_YFin_data_window(
os.path.join( os.path.join(
DATA_DIR, DATA_DIR,
f"market_data/price_data/{symbol}-YFin-data-2015-01-01-2025-03-25.csv", f"market_data/price_data/{symbol}-YFin-data-2015-01-01-2025-03-25.csv",
) ),
) )
# Extract just the date part for comparison # Extract just the date part for comparison
@ -613,7 +608,7 @@ def get_YFin_data_window(
# Set pandas display options to show the full DataFrame # Set pandas display options to show the full DataFrame
with pd.option_context( with pd.option_context(
"display.max_rows", None, "display.max_columns", None, "display.width", None "display.max_rows", None, "display.max_columns", None, "display.width", None,
): ):
df_string = filtered_data.to_string() df_string = filtered_data.to_string()
@ -675,12 +670,13 @@ def get_YFin_data(
os.path.join( os.path.join(
DATA_DIR, DATA_DIR,
f"market_data/price_data/{symbol}-YFin-data-2015-01-01-2025-03-25.csv", f"market_data/price_data/{symbol}-YFin-data-2015-01-01-2025-03-25.csv",
) ),
) )
if end_date > "2025-03-25": if end_date > "2025-03-25":
msg = f"Get_YFin_Data: {end_date} is outside of the data range of 2015-01-01 to 2025-03-25"
raise Exception( raise Exception(
f"Get_YFin_Data: {end_date} is outside of the data range of 2015-01-01 to 2025-03-25" msg,
) )
# Extract just the date part for comparison # Extract just the date part for comparison
@ -695,9 +691,8 @@ def get_YFin_data(
filtered_data = filtered_data.drop("DateOnly", axis=1) filtered_data = filtered_data.drop("DateOnly", axis=1)
# remove the index from the dataframe # remove the index from the dataframe
filtered_data = filtered_data.reset_index(drop=True) return filtered_data.reset_index(drop=True)
return filtered_data
def get_stock_news_openai(ticker, curr_date): def get_stock_news_openai(ticker, curr_date):
@ -713,9 +708,9 @@ def get_stock_news_openai(ticker, curr_date):
{ {
"type": "input_text", "type": "input_text",
"text": f"Can you search Social Media for {ticker} from 7 days before {curr_date} to {curr_date}? Make sure you only get the data posted during that period.", "text": f"Can you search Social Media for {ticker} from 7 days before {curr_date} to {curr_date}? Make sure you only get the data posted during that period.",
} },
], ],
} },
], ],
text={"format": {"type": "text"}}, text={"format": {"type": "text"}},
reasoning={}, reasoning={},
@ -724,7 +719,7 @@ def get_stock_news_openai(ticker, curr_date):
"type": "web_search_preview", "type": "web_search_preview",
"user_location": {"type": "approximate"}, "user_location": {"type": "approximate"},
"search_context_size": "low", "search_context_size": "low",
} },
], ],
temperature=1, temperature=1,
max_output_tokens=4096, max_output_tokens=4096,
@ -748,9 +743,9 @@ def get_global_news_openai(curr_date):
{ {
"type": "input_text", "type": "input_text",
"text": f"Can you search global or macroeconomics news from 7 days before {curr_date} to {curr_date} that would be informative for trading purposes? Make sure you only get the data posted during that period.", "text": f"Can you search global or macroeconomics news from 7 days before {curr_date} to {curr_date} that would be informative for trading purposes? Make sure you only get the data posted during that period.",
} },
], ],
} },
], ],
text={"format": {"type": "text"}}, text={"format": {"type": "text"}},
reasoning={}, reasoning={},
@ -759,7 +754,7 @@ def get_global_news_openai(curr_date):
"type": "web_search_preview", "type": "web_search_preview",
"user_location": {"type": "approximate"}, "user_location": {"type": "approximate"},
"search_context_size": "low", "search_context_size": "low",
} },
], ],
temperature=1, temperature=1,
max_output_tokens=4096, max_output_tokens=4096,
@ -783,9 +778,9 @@ def get_fundamentals_openai(ticker, curr_date):
{ {
"type": "input_text", "type": "input_text",
"text": f"Can you search Fundamental for discussions on {ticker} during of the month before {curr_date} to the month of {curr_date}. Make sure you only get the data posted during that period. List as a table, with PE/PS/Cash flow/ etc", "text": f"Can you search Fundamental for discussions on {ticker} during of the month before {curr_date} to the month of {curr_date}. Make sure you only get the data posted during that period. List as a table, with PE/PS/Cash flow/ etc",
} },
], ],
} },
], ],
text={"format": {"type": "text"}}, text={"format": {"type": "text"}},
reasoning={}, reasoning={},
@ -794,7 +789,7 @@ def get_fundamentals_openai(ticker, curr_date):
"type": "web_search_preview", "type": "web_search_preview",
"user_location": {"type": "approximate"}, "user_location": {"type": "approximate"},
"search_context_size": "low", "search_context_size": "low",
} },
], ],
temperature=1, temperature=1,
max_output_tokens=4096, max_output_tokens=4096,

View File

@ -1,8 +1,8 @@
import json import json
from datetime import datetime
from typing import Annotated
import os import os
import re import re
from datetime import datetime
from typing import Annotated
ticker_to_company = { ticker_to_company = {
"AAPL": "Apple", "AAPL": "Apple",
@ -48,11 +48,11 @@ ticker_to_company = {
def fetch_top_from_category( def fetch_top_from_category(
category: Annotated[ category: Annotated[
str, "Category to fetch top post from. Collection of subreddits." str, "Category to fetch top post from. Collection of subreddits.",
], ],
date: Annotated[str, "Date to fetch top posts from."], date: Annotated[str, "Date to fetch top posts from."],
max_limit: Annotated[int, "Maximum number of posts to fetch."], max_limit: Annotated[int, "Maximum number of posts to fetch."],
query: Annotated[str, "Optional query to search for in the subreddit."] = None, query: Annotated[str | None, "Optional query to search for in the subreddit."] = None,
data_path: Annotated[ data_path: Annotated[
str, str,
"Path to the data folder. Default is 'reddit_data'.", "Path to the data folder. Default is 'reddit_data'.",
@ -63,12 +63,13 @@ def fetch_top_from_category(
all_content = [] all_content = []
if max_limit < len(os.listdir(os.path.join(base_path, category))): if max_limit < len(os.listdir(os.path.join(base_path, category))):
msg = "REDDIT FETCHING ERROR: max limit is less than the number of files in the category. Will not be able to fetch any posts"
raise ValueError( raise ValueError(
"REDDIT FETCHING ERROR: max limit is less than the number of files in the category. Will not be able to fetch any posts" msg,
) )
limit_per_subreddit = max_limit // len( limit_per_subreddit = max_limit // len(
os.listdir(os.path.join(base_path, category)) os.listdir(os.path.join(base_path, category)),
) )
for data_file in os.listdir(os.path.join(base_path, category)): for data_file in os.listdir(os.path.join(base_path, category)):
@ -79,7 +80,7 @@ def fetch_top_from_category(
all_content_curr_subreddit = [] all_content_curr_subreddit = []
with open(os.path.join(base_path, category, data_file), "rb") as f: with open(os.path.join(base_path, category, data_file), "rb") as f:
for i, line in enumerate(f): for _i, line in enumerate(f):
# skip empty lines # skip empty lines
if not line.strip(): if not line.strip():
continue continue
@ -88,7 +89,7 @@ def fetch_top_from_category(
# select only lines that are from the date # select only lines that are from the date
post_date = datetime.utcfromtimestamp( post_date = datetime.utcfromtimestamp(
parsed_line["created_utc"] parsed_line["created_utc"],
).strftime("%Y-%m-%d") ).strftime("%Y-%m-%d")
if post_date != date: if post_date != date:
continue continue
@ -106,7 +107,7 @@ def fetch_top_from_category(
found = False found = False
for term in search_terms: for term in search_terms:
if re.search( if re.search(
term, parsed_line["title"], re.IGNORECASE term, parsed_line["title"], re.IGNORECASE,
) or re.search(term, parsed_line["selftext"], re.IGNORECASE): ) or re.search(term, parsed_line["selftext"], re.IGNORECASE):
found = True found = True
break break

View File

@ -1,8 +1,10 @@
import os
from typing import Annotated
import pandas as pd import pandas as pd
import yfinance as yf import yfinance as yf
from stockstats import wrap from stockstats import wrap
from typing import Annotated
import os
from .config import get_config from .config import get_config
@ -11,10 +13,10 @@ class StockstatsUtils:
def get_stock_stats( def get_stock_stats(
symbol: Annotated[str, "ticker symbol for the company"], symbol: Annotated[str, "ticker symbol for the company"],
indicator: Annotated[ indicator: Annotated[
str, "quantitative indicators based off of the stock data for the company" str, "quantitative indicators based off of the stock data for the company",
], ],
curr_date: Annotated[ curr_date: Annotated[
str, "curr date for retrieving stock price data, YYYY-mm-dd" str, "curr date for retrieving stock price data, YYYY-mm-dd",
], ],
data_dir: Annotated[ data_dir: Annotated[
str, str,
@ -34,11 +36,12 @@ class StockstatsUtils:
os.path.join( os.path.join(
data_dir, data_dir,
f"{symbol}-YFin-data-2015-01-01-2025-03-25.csv", f"{symbol}-YFin-data-2015-01-01-2025-03-25.csv",
) ),
) )
df = wrap(data) df = wrap(data)
except FileNotFoundError: except FileNotFoundError:
raise Exception("Stockstats fail: Yahoo Finance data not fetched yet!") msg = "Stockstats fail: Yahoo Finance data not fetched yet!"
raise Exception(msg)
else: else:
# Get today's date as YYYY-mm-dd to add to cache # Get today's date as YYYY-mm-dd to add to cache
today_date = pd.Timestamp.today() today_date = pd.Timestamp.today()
@ -81,7 +84,5 @@ class StockstatsUtils:
matching_rows = df[df["Date"].str.startswith(curr_date)] matching_rows = df[df["Date"].str.startswith(curr_date)]
if not matching_rows.empty: if not matching_rows.empty:
indicator_value = matching_rows[indicator].values[0] return matching_rows[indicator].values[0]
return indicator_value return "N/A: Not a trading day (weekend or holiday)"
else:
return "N/A: Not a trading day (weekend or holiday)"

View File

@ -1,14 +1,14 @@
import pandas as pd from datetime import date, datetime, timedelta
from datetime import date, timedelta, datetime
from typing import Annotated from typing import Annotated
import pandas as pd
SavePathType = Annotated[str, "File path to save data. If None, data is not saved."] SavePathType = Annotated[str, "File path to save data. If None, data is not saved."]
def save_output(data: pd.DataFrame, tag: str, save_path: SavePathType = None) -> None: def save_output(data: pd.DataFrame, tag: str, save_path: SavePathType = None) -> None:
if save_path: if save_path:
data.to_csv(save_path) data.to_csv(save_path)
print(f"{tag} saved to {save_path}")
def get_current_date(): def get_current_date():
@ -32,7 +32,5 @@ def get_next_weekday(date):
if date.weekday() >= 5: if date.weekday() >= 5:
days_to_add = 7 - date.weekday() days_to_add = 7 - date.weekday()
next_weekday = date + timedelta(days=days_to_add) return date + timedelta(days=days_to_add)
return next_weekday return date
else:
return date

View File

@ -1,10 +1,12 @@
# gets data/stats # gets data/stats
import yfinance as yf from collections.abc import Callable
from typing import Annotated, Callable, Any, Optional
from pandas import DataFrame
import pandas as pd
from functools import wraps from functools import wraps
from typing import Annotated, Any
import pandas as pd
import yfinance as yf
from pandas import DataFrame
from .utils import SavePathType, decorate_all_methods from .utils import SavePathType, decorate_all_methods
@ -24,38 +26,36 @@ def init_ticker(func: Callable) -> Callable:
class YFinanceUtils: class YFinanceUtils:
def get_stock_data( def get_stock_data(
symbol: Annotated[str, "ticker symbol"], self: Annotated[str, "ticker symbol"],
start_date: Annotated[ start_date: Annotated[
str, "start date for retrieving stock price data, YYYY-mm-dd" str, "start date for retrieving stock price data, YYYY-mm-dd",
], ],
end_date: Annotated[ end_date: Annotated[
str, "end date for retrieving stock price data, YYYY-mm-dd" str, "end date for retrieving stock price data, YYYY-mm-dd",
], ],
save_path: SavePathType = None, save_path: SavePathType = None,
) -> DataFrame: ) -> DataFrame:
"""retrieve stock price data for designated ticker symbol""" """retrieve stock price data for designated ticker symbol"""
ticker = symbol ticker = self
# add one day to the end_date so that the data range is inclusive # add one day to the end_date so that the data range is inclusive
end_date = pd.to_datetime(end_date) + pd.DateOffset(days=1) end_date = pd.to_datetime(end_date) + pd.DateOffset(days=1)
end_date = end_date.strftime("%Y-%m-%d") end_date = end_date.strftime("%Y-%m-%d")
stock_data = ticker.history(start=start_date, end=end_date) return ticker.history(start=start_date, end=end_date)
# save_output(stock_data, f"Stock data for {ticker.ticker}", save_path) # save_output(stock_data, f"Stock data for {ticker.ticker}", save_path)
return stock_data
def get_stock_info( def get_stock_info(
symbol: Annotated[str, "ticker symbol"], self: Annotated[str, "ticker symbol"],
) -> dict: ) -> dict:
"""Fetches and returns latest stock information.""" """Fetches and returns latest stock information."""
ticker = symbol ticker = self
stock_info = ticker.info return ticker.info
return stock_info
def get_company_info( def get_company_info(
symbol: Annotated[str, "ticker symbol"], self: Annotated[str, "ticker symbol"],
save_path: Optional[str] = None, save_path: str | None = None,
) -> DataFrame: ) -> DataFrame:
"""Fetches and returns company information as a DataFrame.""" """Fetches and returns company information as a DataFrame."""
ticker = symbol ticker = self
info = ticker.info info = ticker.info
company_info = { company_info = {
"Company Name": info.get("shortName", "N/A"), "Company Name": info.get("shortName", "N/A"),
@ -67,42 +67,37 @@ class YFinanceUtils:
company_info_df = DataFrame([company_info]) company_info_df = DataFrame([company_info])
if save_path: if save_path:
company_info_df.to_csv(save_path) company_info_df.to_csv(save_path)
print(f"Company info for {ticker.ticker} saved to {save_path}")
return company_info_df return company_info_df
def get_stock_dividends( def get_stock_dividends(
symbol: Annotated[str, "ticker symbol"], self: Annotated[str, "ticker symbol"],
save_path: Optional[str] = None, save_path: str | None = None,
) -> DataFrame: ) -> DataFrame:
"""Fetches and returns the latest dividends data as a DataFrame.""" """Fetches and returns the latest dividends data as a DataFrame."""
ticker = symbol ticker = self
dividends = ticker.dividends dividends = ticker.dividends
if save_path: if save_path:
dividends.to_csv(save_path) dividends.to_csv(save_path)
print(f"Dividends for {ticker.ticker} saved to {save_path}")
return dividends return dividends
def get_income_stmt(symbol: Annotated[str, "ticker symbol"]) -> DataFrame: def get_income_stmt(self: Annotated[str, "ticker symbol"]) -> DataFrame:
"""Fetches and returns the latest income statement of the company as a DataFrame.""" """Fetches and returns the latest income statement of the company as a DataFrame."""
ticker = symbol ticker = self
income_stmt = ticker.financials return ticker.financials
return income_stmt
def get_balance_sheet(symbol: Annotated[str, "ticker symbol"]) -> DataFrame: def get_balance_sheet(self: Annotated[str, "ticker symbol"]) -> DataFrame:
"""Fetches and returns the latest balance sheet of the company as a DataFrame.""" """Fetches and returns the latest balance sheet of the company as a DataFrame."""
ticker = symbol ticker = self
balance_sheet = ticker.balance_sheet return ticker.balance_sheet
return balance_sheet
def get_cash_flow(symbol: Annotated[str, "ticker symbol"]) -> DataFrame: def get_cash_flow(self: Annotated[str, "ticker symbol"]) -> DataFrame:
"""Fetches and returns the latest cash flow statement of the company as a DataFrame.""" """Fetches and returns the latest cash flow statement of the company as a DataFrame."""
ticker = symbol ticker = self
cash_flow = ticker.cashflow return ticker.cashflow
return cash_flow
def get_analyst_recommendations(symbol: Annotated[str, "ticker symbol"]) -> tuple: def get_analyst_recommendations(self: Annotated[str, "ticker symbol"]) -> tuple:
"""Fetches the latest analyst recommendations and returns the most common recommendation and its count.""" """Fetches the latest analyst recommendations and returns the most common recommendation and its count."""
ticker = symbol ticker = self
recommendations = ticker.recommendations recommendations = ticker.recommendations
if recommendations.empty: if recommendations.empty:
return None, 0 # No recommendations available return None, 0 # No recommendations available

View File

@ -1,17 +1,17 @@
# TradingAgents/graph/__init__.py # TradingAgents/graph/__init__.py
from .trading_graph import TradingAgentsGraph
from .conditional_logic import ConditionalLogic from .conditional_logic import ConditionalLogic
from .setup import GraphSetup
from .propagation import Propagator from .propagation import Propagator
from .reflection import Reflector from .reflection import Reflector
from .setup import GraphSetup
from .signal_processing import SignalProcessor from .signal_processing import SignalProcessor
from .trading_graph import TradingAgentsGraph
__all__ = [ __all__ = [
"TradingAgentsGraph",
"ConditionalLogic", "ConditionalLogic",
"GraphSetup", "GraphSetup",
"Propagator", "Propagator",
"Reflector", "Reflector",
"SignalProcessor", "SignalProcessor",
"TradingAgentsGraph",
] ]

View File

@ -1,6 +1,7 @@
# TradingAgents/graph/propagation.py # TradingAgents/graph/propagation.py
from typing import Dict, Any from typing import Any
from tradingagents.agents.utils.agent_states import ( from tradingagents.agents.utils.agent_states import (
InvestDebateState, InvestDebateState,
RiskDebateState, RiskDebateState,
@ -15,15 +16,15 @@ class Propagator:
self.max_recur_limit = max_recur_limit self.max_recur_limit = max_recur_limit
def create_initial_state( def create_initial_state(
self, company_name: str, trade_date: str self, company_name: str, trade_date: str,
) -> Dict[str, Any]: ) -> dict[str, Any]:
"""Create the initial state for the agent graph.""" """Create the initial state for the agent graph."""
return { return {
"messages": [("human", company_name)], "messages": [("human", company_name)],
"company_of_interest": company_name, "company_of_interest": company_name,
"trade_date": str(trade_date), "trade_date": str(trade_date),
"investment_debate_state": InvestDebateState( "investment_debate_state": InvestDebateState(
{"history": "", "current_response": "", "count": 0} {"history": "", "current_response": "", "count": 0},
), ),
"risk_debate_state": RiskDebateState( "risk_debate_state": RiskDebateState(
{ {
@ -32,7 +33,7 @@ class Propagator:
"current_safe_response": "", "current_safe_response": "",
"current_neutral_response": "", "current_neutral_response": "",
"count": 0, "count": 0,
} },
), ),
"market_report": "", "market_report": "",
"fundamentals_report": "", "fundamentals_report": "",
@ -40,7 +41,7 @@ class Propagator:
"news_report": "", "news_report": "",
} }
def get_graph_args(self) -> Dict[str, Any]: def get_graph_args(self) -> dict[str, Any]:
"""Get arguments for the graph invocation.""" """Get arguments for the graph invocation."""
return { return {
"stream_mode": "values", "stream_mode": "values",

View File

@ -1,6 +1,7 @@
# TradingAgents/graph/reflection.py # TradingAgents/graph/reflection.py
from typing import Dict, Any from typing import Any
from langchain_openai import ChatOpenAI from langchain_openai import ChatOpenAI
@ -15,7 +16,7 @@ class Reflector:
def _get_reflection_prompt(self) -> str: def _get_reflection_prompt(self) -> str:
"""Get the system prompt for reflection.""" """Get the system prompt for reflection."""
return """ return """
You are an expert financial analyst tasked with reviewing trading decisions/analysis and providing a comprehensive, step-by-step analysis. You are an expert financial analyst tasked with reviewing trading decisions/analysis and providing a comprehensive, step-by-step analysis.
Your goal is to deliver detailed insights into investment decisions and highlight opportunities for improvement, adhering strictly to the following guidelines: Your goal is to deliver detailed insights into investment decisions and highlight opportunities for improvement, adhering strictly to the following guidelines:
1. Reasoning: 1. Reasoning:
@ -25,7 +26,7 @@ Your goal is to deliver detailed insights into investment decisions and highligh
- Technical indicators. - Technical indicators.
- Technical signals. - Technical signals.
- Price movement analysis. - Price movement analysis.
- Overall market data analysis - Overall market data analysis
- News analysis. - News analysis.
- Social media and sentiment analysis. - Social media and sentiment analysis.
- Fundamental data analysis. - Fundamental data analysis.
@ -46,7 +47,7 @@ Your goal is to deliver detailed insights into investment decisions and highligh
Adhere strictly to these instructions, and ensure your output is detailed, accurate, and actionable. You will also be given objective descriptions of the market from a price movements, technical indicator, news, and sentiment perspective to provide more context for your analysis. Adhere strictly to these instructions, and ensure your output is detailed, accurate, and actionable. You will also be given objective descriptions of the market from a price movements, technical indicator, news, and sentiment perspective to provide more context for your analysis.
""" """
def _extract_current_situation(self, current_state: Dict[str, Any]) -> str: def _extract_current_situation(self, current_state: dict[str, Any]) -> str:
"""Extract the current market situation from the state.""" """Extract the current market situation from the state."""
curr_market_report = current_state["market_report"] curr_market_report = current_state["market_report"]
curr_sentiment_report = current_state["sentiment_report"] curr_sentiment_report = current_state["sentiment_report"]
@ -56,7 +57,7 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur
return f"{curr_market_report}\n\n{curr_sentiment_report}\n\n{curr_news_report}\n\n{curr_fundamentals_report}" return f"{curr_market_report}\n\n{curr_sentiment_report}\n\n{curr_news_report}\n\n{curr_fundamentals_report}"
def _reflect_on_component( def _reflect_on_component(
self, component_type: str, report: str, situation: str, returns_losses self, component_type: str, report: str, situation: str, returns_losses,
) -> str: ) -> str:
"""Generate reflection for a component.""" """Generate reflection for a component."""
messages = [ messages = [
@ -67,8 +68,7 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur
), ),
] ]
result = self.quick_thinking_llm.invoke(messages).content return self.quick_thinking_llm.invoke(messages).content
return result
def reflect_bull_researcher(self, current_state, returns_losses, bull_memory): def reflect_bull_researcher(self, current_state, returns_losses, bull_memory):
"""Reflect on bull researcher's analysis and update memory.""" """Reflect on bull researcher's analysis and update memory."""
@ -76,7 +76,7 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur
bull_debate_history = current_state["investment_debate_state"]["bull_history"] bull_debate_history = current_state["investment_debate_state"]["bull_history"]
result = self._reflect_on_component( result = self._reflect_on_component(
"BULL", bull_debate_history, situation, returns_losses "BULL", bull_debate_history, situation, returns_losses,
) )
bull_memory.add_situations([(situation, result)]) bull_memory.add_situations([(situation, result)])
@ -86,7 +86,7 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur
bear_debate_history = current_state["investment_debate_state"]["bear_history"] bear_debate_history = current_state["investment_debate_state"]["bear_history"]
result = self._reflect_on_component( result = self._reflect_on_component(
"BEAR", bear_debate_history, situation, returns_losses "BEAR", bear_debate_history, situation, returns_losses,
) )
bear_memory.add_situations([(situation, result)]) bear_memory.add_situations([(situation, result)])
@ -96,7 +96,7 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur
trader_decision = current_state["trader_investment_plan"] trader_decision = current_state["trader_investment_plan"]
result = self._reflect_on_component( result = self._reflect_on_component(
"TRADER", trader_decision, situation, returns_losses "TRADER", trader_decision, situation, returns_losses,
) )
trader_memory.add_situations([(situation, result)]) trader_memory.add_situations([(situation, result)])
@ -106,7 +106,7 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur
judge_decision = current_state["investment_debate_state"]["judge_decision"] judge_decision = current_state["investment_debate_state"]["judge_decision"]
result = self._reflect_on_component( result = self._reflect_on_component(
"INVEST JUDGE", judge_decision, situation, returns_losses "INVEST JUDGE", judge_decision, situation, returns_losses,
) )
invest_judge_memory.add_situations([(situation, result)]) invest_judge_memory.add_situations([(situation, result)])
@ -116,6 +116,6 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur
judge_decision = current_state["risk_debate_state"]["judge_decision"] judge_decision = current_state["risk_debate_state"]["judge_decision"]
result = self._reflect_on_component( result = self._reflect_on_component(
"RISK JUDGE", judge_decision, situation, returns_losses "RISK JUDGE", judge_decision, situation, returns_losses,
) )
risk_manager_memory.add_situations([(situation, result)]) risk_manager_memory.add_situations([(situation, result)])

View File

@ -1,26 +1,26 @@
# TradingAgents/graph/setup.py # TradingAgents/graph/setup.py
from typing import Dict
from langchain_openai import ChatOpenAI from langchain_openai import ChatOpenAI
from langgraph.graph import END, StateGraph, START from langgraph.graph import END, START, StateGraph
from langgraph.prebuilt import ToolNode from langgraph.prebuilt import ToolNode
from tradingagents.agents import ( from tradingagents.agents import (
create_market_analyst,
create_social_media_analyst,
create_news_analyst,
create_fundamentals_analyst,
create_bull_researcher,
create_bear_researcher,
create_research_manager,
create_trader,
create_risky_debator,
create_neutral_debator,
create_safe_debator,
create_risk_manager,
create_msg_delete,
AgentState, AgentState,
Toolkit, Toolkit,
create_bear_researcher,
create_bull_researcher,
create_fundamentals_analyst,
create_market_analyst,
create_msg_delete,
create_neutral_debator,
create_news_analyst,
create_research_manager,
create_risk_manager,
create_risky_debator,
create_safe_debator,
create_social_media_analyst,
create_trader,
) )
from .conditional_logic import ConditionalLogic from .conditional_logic import ConditionalLogic
@ -34,7 +34,7 @@ class GraphSetup:
quick_thinking_llm: ChatOpenAI, quick_thinking_llm: ChatOpenAI,
deep_thinking_llm: ChatOpenAI, deep_thinking_llm: ChatOpenAI,
toolkit: Toolkit, toolkit: Toolkit,
tool_nodes: Dict[str, ToolNode], tool_nodes: dict[str, ToolNode],
bull_memory, bull_memory,
bear_memory, bear_memory,
trader_memory, trader_memory,
@ -55,7 +55,7 @@ class GraphSetup:
self.conditional_logic = conditional_logic self.conditional_logic = conditional_logic
def setup_graph( def setup_graph(
self, selected_analysts=["market", "social", "news", "fundamentals"] self, selected_analysts=None,
): ):
"""Set up and compile the agent workflow graph. """Set up and compile the agent workflow graph.
@ -66,8 +66,11 @@ class GraphSetup:
- "news": News analyst - "news": News analyst
- "fundamentals": Fundamentals analyst - "fundamentals": Fundamentals analyst
""" """
if selected_analysts is None:
selected_analysts = ["market", "social", "news", "fundamentals"]
if len(selected_analysts) == 0: if len(selected_analysts) == 0:
raise ValueError("Trading Agents Graph Setup Error: no analysts selected!") msg = "Trading Agents Graph Setup Error: no analysts selected!"
raise ValueError(msg)
# Create analyst nodes # Create analyst nodes
analyst_nodes = {} analyst_nodes = {}
@ -76,41 +79,41 @@ class GraphSetup:
if "market" in selected_analysts: if "market" in selected_analysts:
analyst_nodes["market"] = create_market_analyst( analyst_nodes["market"] = create_market_analyst(
self.quick_thinking_llm, self.toolkit self.quick_thinking_llm, self.toolkit,
) )
delete_nodes["market"] = create_msg_delete() delete_nodes["market"] = create_msg_delete()
tool_nodes["market"] = self.tool_nodes["market"] tool_nodes["market"] = self.tool_nodes["market"]
if "social" in selected_analysts: if "social" in selected_analysts:
analyst_nodes["social"] = create_social_media_analyst( analyst_nodes["social"] = create_social_media_analyst(
self.quick_thinking_llm, self.toolkit self.quick_thinking_llm, self.toolkit,
) )
delete_nodes["social"] = create_msg_delete() delete_nodes["social"] = create_msg_delete()
tool_nodes["social"] = self.tool_nodes["social"] tool_nodes["social"] = self.tool_nodes["social"]
if "news" in selected_analysts: if "news" in selected_analysts:
analyst_nodes["news"] = create_news_analyst( analyst_nodes["news"] = create_news_analyst(
self.quick_thinking_llm, self.toolkit self.quick_thinking_llm, self.toolkit,
) )
delete_nodes["news"] = create_msg_delete() delete_nodes["news"] = create_msg_delete()
tool_nodes["news"] = self.tool_nodes["news"] tool_nodes["news"] = self.tool_nodes["news"]
if "fundamentals" in selected_analysts: if "fundamentals" in selected_analysts:
analyst_nodes["fundamentals"] = create_fundamentals_analyst( analyst_nodes["fundamentals"] = create_fundamentals_analyst(
self.quick_thinking_llm, self.toolkit self.quick_thinking_llm, self.toolkit,
) )
delete_nodes["fundamentals"] = create_msg_delete() delete_nodes["fundamentals"] = create_msg_delete()
tool_nodes["fundamentals"] = self.tool_nodes["fundamentals"] tool_nodes["fundamentals"] = self.tool_nodes["fundamentals"]
# Create researcher and manager nodes # Create researcher and manager nodes
bull_researcher_node = create_bull_researcher( bull_researcher_node = create_bull_researcher(
self.quick_thinking_llm, self.bull_memory self.quick_thinking_llm, self.bull_memory,
) )
bear_researcher_node = create_bear_researcher( bear_researcher_node = create_bear_researcher(
self.quick_thinking_llm, self.bear_memory self.quick_thinking_llm, self.bear_memory,
) )
research_manager_node = create_research_manager( research_manager_node = create_research_manager(
self.deep_thinking_llm, self.invest_judge_memory self.deep_thinking_llm, self.invest_judge_memory,
) )
trader_node = create_trader(self.quick_thinking_llm, self.trader_memory) trader_node = create_trader(self.quick_thinking_llm, self.trader_memory)
@ -119,7 +122,7 @@ class GraphSetup:
neutral_analyst = create_neutral_debator(self.quick_thinking_llm) neutral_analyst = create_neutral_debator(self.quick_thinking_llm)
safe_analyst = create_safe_debator(self.quick_thinking_llm) safe_analyst = create_safe_debator(self.quick_thinking_llm)
risk_manager_node = create_risk_manager( risk_manager_node = create_risk_manager(
self.deep_thinking_llm, self.risk_manager_memory self.deep_thinking_llm, self.risk_manager_memory,
) )
# Create workflow # Create workflow
@ -129,7 +132,7 @@ class GraphSetup:
for analyst_type, node in analyst_nodes.items(): for analyst_type, node in analyst_nodes.items():
workflow.add_node(f"{analyst_type.capitalize()} Analyst", node) workflow.add_node(f"{analyst_type.capitalize()} Analyst", node)
workflow.add_node( workflow.add_node(
f"Msg Clear {analyst_type.capitalize()}", delete_nodes[analyst_type] f"Msg Clear {analyst_type.capitalize()}", delete_nodes[analyst_type],
) )
workflow.add_node(f"tools_{analyst_type}", tool_nodes[analyst_type]) workflow.add_node(f"tools_{analyst_type}", tool_nodes[analyst_type])

View File

@ -1,25 +1,24 @@
# TradingAgents/graph/trading_graph.py # TradingAgents/graph/trading_graph.py
import json
import os import os
from pathlib import Path from pathlib import Path
import json from typing import Any
from typing import Dict, Any
from langchain_openai import ChatOpenAI
from langchain_anthropic import ChatAnthropic from langchain_anthropic import ChatAnthropic
from langchain_google_genai import ChatGoogleGenerativeAI from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_openai import ChatOpenAI
from langgraph.prebuilt import ToolNode from langgraph.prebuilt import ToolNode
from tradingagents.agents import Toolkit from tradingagents.agents import Toolkit
from tradingagents.default_config import DEFAULT_CONFIG
from tradingagents.agents.utils.memory import FinancialSituationMemory from tradingagents.agents.utils.memory import FinancialSituationMemory
from tradingagents.dataflows.interface import set_config from tradingagents.dataflows.interface import set_config
from tradingagents.default_config import DEFAULT_CONFIG
from .conditional_logic import ConditionalLogic from .conditional_logic import ConditionalLogic
from .setup import GraphSetup
from .propagation import Propagator from .propagation import Propagator
from .reflection import Reflector from .reflection import Reflector
from .setup import GraphSetup
from .signal_processing import SignalProcessor from .signal_processing import SignalProcessor
@ -28,9 +27,9 @@ class TradingAgentsGraph:
def __init__( def __init__(
self, self,
selected_analysts=["market", "social", "news", "fundamentals"], selected_analysts=None,
debug=False, debug=False,
config: Dict[str, Any] = None, config: dict[str, Any] | None = None,
): ):
"""Initialize the trading agents graph and components. """Initialize the trading agents graph and components.
@ -39,6 +38,8 @@ class TradingAgentsGraph:
debug: Whether to run in debug mode debug: Whether to run in debug mode
config: Configuration dictionary. If None, uses default config config: Configuration dictionary. If None, uses default config
""" """
if selected_analysts is None:
selected_analysts = ["market", "social", "news", "fundamentals"]
self.debug = debug self.debug = debug
self.config = config or DEFAULT_CONFIG self.config = config or DEFAULT_CONFIG
@ -58,7 +59,7 @@ class TradingAgentsGraph:
or self.config["llm_provider"] == "openrouter" or self.config["llm_provider"] == "openrouter"
): ):
self.deep_thinking_llm = ChatOpenAI( self.deep_thinking_llm = ChatOpenAI(
model=self.config["deep_think_llm"], base_url=self.config["backend_url"] model=self.config["deep_think_llm"], base_url=self.config["backend_url"],
) )
self.quick_thinking_llm = ChatOpenAI( self.quick_thinking_llm = ChatOpenAI(
model=self.config["quick_think_llm"], model=self.config["quick_think_llm"],
@ -66,7 +67,7 @@ class TradingAgentsGraph:
) )
elif self.config["llm_provider"].lower() == "anthropic": elif self.config["llm_provider"].lower() == "anthropic":
self.deep_thinking_llm = ChatAnthropic( self.deep_thinking_llm = ChatAnthropic(
model=self.config["deep_think_llm"], base_url=self.config["backend_url"] model=self.config["deep_think_llm"], base_url=self.config["backend_url"],
) )
self.quick_thinking_llm = ChatAnthropic( self.quick_thinking_llm = ChatAnthropic(
model=self.config["quick_think_llm"], model=self.config["quick_think_llm"],
@ -74,13 +75,14 @@ class TradingAgentsGraph:
) )
elif self.config["llm_provider"].lower() == "google": elif self.config["llm_provider"].lower() == "google":
self.deep_thinking_llm = ChatGoogleGenerativeAI( self.deep_thinking_llm = ChatGoogleGenerativeAI(
model=self.config["deep_think_llm"] model=self.config["deep_think_llm"],
) )
self.quick_thinking_llm = ChatGoogleGenerativeAI( self.quick_thinking_llm = ChatGoogleGenerativeAI(
model=self.config["quick_think_llm"] model=self.config["quick_think_llm"],
) )
else: else:
raise ValueError(f"Unsupported LLM provider: {self.config['llm_provider']}") msg = f"Unsupported LLM provider: {self.config['llm_provider']}"
raise ValueError(msg)
self.toolkit = Toolkit(config=self.config) self.toolkit = Toolkit(config=self.config)
@ -89,10 +91,10 @@ class TradingAgentsGraph:
self.bear_memory = FinancialSituationMemory("bear_memory", self.config) self.bear_memory = FinancialSituationMemory("bear_memory", self.config)
self.trader_memory = FinancialSituationMemory("trader_memory", self.config) self.trader_memory = FinancialSituationMemory("trader_memory", self.config)
self.invest_judge_memory = FinancialSituationMemory( self.invest_judge_memory = FinancialSituationMemory(
"invest_judge_memory", self.config "invest_judge_memory", self.config,
) )
self.risk_manager_memory = FinancialSituationMemory( self.risk_manager_memory = FinancialSituationMemory(
"risk_manager_memory", self.config "risk_manager_memory", self.config,
) )
# Create tool nodes # Create tool nodes
@ -125,7 +127,7 @@ class TradingAgentsGraph:
# Set up the graph # Set up the graph
self.graph = self.graph_setup.setup_graph(selected_analysts) self.graph = self.graph_setup.setup_graph(selected_analysts)
def _create_tool_nodes(self) -> Dict[str, ToolNode]: def _create_tool_nodes(self) -> dict[str, ToolNode]:
"""Create tool nodes for different data sources.""" """Create tool nodes for different data sources."""
return { return {
"market": ToolNode( "market": ToolNode(
@ -136,7 +138,7 @@ class TradingAgentsGraph:
# offline tools # offline tools
self.toolkit.get_YFin_data, self.toolkit.get_YFin_data,
self.toolkit.get_stockstats_indicators_report, self.toolkit.get_stockstats_indicators_report,
] ],
), ),
"social": ToolNode( "social": ToolNode(
[ [
@ -144,7 +146,7 @@ class TradingAgentsGraph:
self.toolkit.get_stock_news_openai, self.toolkit.get_stock_news_openai,
# offline tools # offline tools
self.toolkit.get_reddit_stock_info, self.toolkit.get_reddit_stock_info,
] ],
), ),
"news": ToolNode( "news": ToolNode(
[ [
@ -154,7 +156,7 @@ class TradingAgentsGraph:
# offline tools # offline tools
self.toolkit.get_finnhub_news, self.toolkit.get_finnhub_news,
self.toolkit.get_reddit_news, self.toolkit.get_reddit_news,
] ],
), ),
"fundamentals": ToolNode( "fundamentals": ToolNode(
[ [
@ -166,7 +168,7 @@ class TradingAgentsGraph:
self.toolkit.get_simfin_balance_sheet, self.toolkit.get_simfin_balance_sheet,
self.toolkit.get_simfin_cashflow, self.toolkit.get_simfin_cashflow,
self.toolkit.get_simfin_income_stmt, self.toolkit.get_simfin_income_stmt,
] ],
), ),
} }
@ -177,7 +179,7 @@ class TradingAgentsGraph:
# Initialize state # Initialize state
init_agent_state = self.propagator.create_initial_state( init_agent_state = self.propagator.create_initial_state(
company_name, trade_date company_name, trade_date,
) )
args = self.propagator.get_graph_args() args = self.propagator.get_graph_args()
@ -250,19 +252,19 @@ class TradingAgentsGraph:
def reflect_and_remember(self, returns_losses): def reflect_and_remember(self, returns_losses):
"""Reflect on decisions and update memory based on returns.""" """Reflect on decisions and update memory based on returns."""
self.reflector.reflect_bull_researcher( self.reflector.reflect_bull_researcher(
self.curr_state, returns_losses, self.bull_memory self.curr_state, returns_losses, self.bull_memory,
) )
self.reflector.reflect_bear_researcher( self.reflector.reflect_bear_researcher(
self.curr_state, returns_losses, self.bear_memory self.curr_state, returns_losses, self.bear_memory,
) )
self.reflector.reflect_trader( self.reflector.reflect_trader(
self.curr_state, returns_losses, self.trader_memory self.curr_state, returns_losses, self.trader_memory,
) )
self.reflector.reflect_invest_judge( self.reflector.reflect_invest_judge(
self.curr_state, returns_losses, self.invest_judge_memory self.curr_state, returns_losses, self.invest_judge_memory,
) )
self.reflector.reflect_risk_manager( self.reflector.reflect_risk_manager(
self.curr_state, returns_losses, self.risk_manager_memory self.curr_state, returns_losses, self.risk_manager_memory,
) )
def process_signal(self, full_signal): def process_signal(self, full_signal):