Fix remaining ruff linting errors

- Fixed all F-type errors (undefined names, unused imports)
- Applied automatic fixes for code style issues
- Ensured CI/CD pipeline passes all checks
This commit is contained in:
佐藤優一 2025-08-10 23:13:31 +09:00
parent 4361ed19e4
commit 6f3981412b
41 changed files with 570 additions and 623 deletions

View File

@ -1,31 +1,32 @@
import datetime
import typer
from pathlib import Path
from functools import wraps
from rich.console import Console
from rich.panel import Panel
from rich.spinner import Spinner
from rich.live import Live
from rich.columns import Columns
from rich.markdown import Markdown
from rich.layout import Layout
from rich.text import Text
from rich.table import Table
from collections import deque
from functools import wraps
from pathlib import Path
import typer
from rich import box
from rich.align import Align
from rich.columns import Columns
from rich.console import Console
from rich.layout import Layout
from rich.live import Live
from rich.markdown import Markdown
from rich.panel import Panel
from rich.spinner import Spinner
from rich.table import Table
from rich.text import Text
from tradingagents.graph.trading_graph import TradingAgentsGraph
from tradingagents.default_config import DEFAULT_CONFIG
from cli.utils import (
get_ticker,
get_analysis_date,
get_ticker,
select_analysts,
select_research_depth,
select_shallow_thinking_agent,
select_deep_thinking_agent,
select_llm_provider,
select_research_depth,
select_shallow_thinking_agent,
)
from tradingagents.default_config import DEFAULT_CONFIG
from tradingagents.graph.trading_graph import TradingAgentsGraph
console = Console()
@ -136,19 +137,20 @@ class MessageBuffer:
report_parts.append("## Analyst Team Reports")
if self.report_sections["market_report"]:
report_parts.append(
f"### Market Analysis\n{self.report_sections['market_report']}"
f"### Market Analysis\n{self.report_sections['market_report']}",
)
if self.report_sections["sentiment_report"]:
report_parts.append(
f"### Social Sentiment\n{self.report_sections['sentiment_report']}"
f"### Social Sentiment\n{self.report_sections['sentiment_report']}",
)
if self.report_sections["news_report"]:
report_parts.append(
f"### News Analysis\n{self.report_sections['news_report']}"
f"### News Analysis\n{self.report_sections['news_report']}",
)
if self.report_sections["fundamentals_report"]:
fundamentals = self.report_sections['fundamentals_report']
report_parts.append(
f"### Fundamentals Analysis\n{self.report_sections['fundamentals_report']}"
f"### Fundamentals Analysis\n{fundamentals}",
)
# Research Team Reports
@ -180,10 +182,10 @@ def create_layout():
Layout(name="footer", size=3),
)
layout["main"].split_column(
Layout(name="upper", ratio=3), Layout(name="analysis", ratio=5)
Layout(name="upper", ratio=3), Layout(name="analysis", ratio=5),
)
layout["upper"].split_row(
Layout(name="progress", ratio=2), Layout(name="messages", ratio=3)
Layout(name="progress", ratio=2), Layout(name="messages", ratio=3),
)
return layout
@ -198,7 +200,7 @@ def update_display(layout, spinner_text=None):
border_style="green",
padding=(1, 2),
expand=True,
)
),
)
# Progress panel showing agent status
@ -235,7 +237,7 @@ def update_display(layout, spinner_text=None):
status = message_buffer.agent_status[first_agent]
if status == "in_progress":
spinner = Spinner(
"dots", text="[blue]in_progress[/blue]", style="bold cyan"
"dots", text="[blue]in_progress[/blue]", style="bold cyan",
)
status_cell = spinner
else:
@ -252,7 +254,7 @@ def update_display(layout, spinner_text=None):
status = message_buffer.agent_status[agent]
if status == "in_progress":
spinner = Spinner(
"dots", text="[blue]in_progress[/blue]", style="bold cyan"
"dots", text="[blue]in_progress[/blue]", style="bold cyan",
)
status_cell = spinner
else:
@ -268,7 +270,7 @@ def update_display(layout, spinner_text=None):
progress_table.add_row("" * 20, "" * 20, "" * 20, style="dim")
layout["progress"].update(
Panel(progress_table, title="Progress", border_style="cyan", padding=(1, 2))
Panel(progress_table, title="Progress", border_style="cyan", padding=(1, 2)),
)
# Messages panel showing recent messages and tool calls
@ -284,7 +286,7 @@ def update_display(layout, spinner_text=None):
messages_table.add_column("Time", style="cyan", width=8, justify="center")
messages_table.add_column("Type", style="green", width=10, justify="center")
messages_table.add_column(
"Content", style="white", no_wrap=False, ratio=1
"Content", style="white", no_wrap=False, ratio=1,
) # Make content column expand
# Combine tool calls and messages
@ -352,7 +354,7 @@ def update_display(layout, spinner_text=None):
title="Messages & Tools",
border_style="blue",
padding=(1, 2),
)
),
)
# Analysis panel showing current report
@ -363,7 +365,7 @@ def update_display(layout, spinner_text=None):
title="Current Report",
border_style="green",
padding=(1, 2),
)
),
)
else:
layout["analysis"].update(
@ -372,7 +374,7 @@ def update_display(layout, spinner_text=None):
title="Current Report",
border_style="green",
padding=(1, 2),
)
),
)
# Footer with statistics
@ -386,9 +388,12 @@ def update_display(layout, spinner_text=None):
stats_table = Table(show_header=False, box=None, padding=(0, 2), expand=True)
stats_table.add_column("Stats", justify="center")
stats_table.add_row(
f"Tool Calls: {tool_calls_count} | LLM Calls: {llm_calls_count} | Generated Reports: {reports_count}"
stats_text = (
f"Tool Calls: {tool_calls_count} | "
f"LLM Calls: {llm_calls_count} | "
f"Generated Reports: {reports_count}"
)
stats_table.add_row(stats_text)
layout["footer"].update(Panel(stats_table, border_style="grey50"))
@ -396,14 +401,20 @@ def update_display(layout, spinner_text=None):
def get_user_selections():
"""Get all user selections before starting the analysis display."""
# Display ASCII art welcome message
with open("./cli/static/welcome.txt", "r") as f:
with open("./cli/static/welcome.txt") as f:
welcome_ascii = f.read()
# Create welcome box content
welcome_content = f"{welcome_ascii}\n"
welcome_content += "[bold green]TradingAgents: Multi-Agents LLM Financial Trading Framework - CLI[/bold green]\n\n"
welcome_content += (
"[bold green]TradingAgents: "
"Multi-Agents LLM Financial Trading Framework - CLI[/bold green]\n\n"
)
welcome_content += "[bold]Workflow Steps:[/bold]\n"
welcome_content += "I. Analyst Team → II. Research Team → III. Trader → IV. Risk Management → V. Portfolio Management\n\n"
welcome_content += (
"I. Analyst Team → II. Research Team → III. Trader → "
"IV. Risk Management → V. Portfolio Management\n\n"
)
welcome_content += (
"[dim]Built by [Tauric Research](https://github.com/TauricResearch)[/dim]"
)
@ -430,8 +441,8 @@ def get_user_selections():
# Step 1: Ticker symbol
console.print(
create_question_box(
"Step 1: Ticker Symbol", "Enter the ticker symbol to analyze", "SPY"
)
"Step 1: Ticker Symbol", "Enter the ticker symbol to analyze", "SPY",
),
)
selected_ticker = get_ticker()
@ -442,40 +453,40 @@ def get_user_selections():
"Step 2: Analysis Date",
"Enter the analysis date (YYYY-MM-DD)",
default_date,
)
),
)
analysis_date = get_analysis_date()
# Step 3: Select analysts
console.print(
create_question_box(
"Step 3: Analysts Team", "Select your LLM analyst agents for the analysis"
)
"Step 3: Analysts Team", "Select your LLM analyst agents for the analysis",
),
)
selected_analysts = select_analysts()
console.print(
f"[green]Selected analysts:[/green] {', '.join(analyst.value for analyst in selected_analysts)}"
f"[green]Selected analysts:[/green] {', '.join(analyst.value for analyst in selected_analysts)}",
)
# Step 4: Research depth
console.print(
create_question_box(
"Step 4: Research Depth", "Select your research depth level"
)
"Step 4: Research Depth", "Select your research depth level",
),
)
selected_research_depth = select_research_depth()
# Step 5: OpenAI backend
console.print(
create_question_box("Step 5: OpenAI backend", "Select which service to talk to")
create_question_box("Step 5: OpenAI backend", "Select which service to talk to"),
)
selected_llm_provider, backend_url = select_llm_provider()
# Step 6: Thinking agents
console.print(
create_question_box(
"Step 6: Thinking Agents", "Select your thinking agents for analysis"
)
"Step 6: Thinking Agents", "Select your thinking agents for analysis",
),
)
selected_shallow_thinker = select_shallow_thinking_agent(selected_llm_provider)
selected_deep_thinker = select_deep_thinking_agent(selected_llm_provider)
@ -492,28 +503,7 @@ def get_user_selections():
}
def get_ticker():
"""Get ticker symbol from user input."""
return typer.prompt("", default="SPY")
def get_analysis_date():
"""Get the analysis date from user input."""
while True:
date_str = typer.prompt(
"", default=datetime.datetime.now().strftime("%Y-%m-%d")
)
try:
# Validate date format and ensure it's not in the future
analysis_date = datetime.datetime.strptime(date_str, "%Y-%m-%d")
if analysis_date.date() > datetime.datetime.now().date():
console.print("[red]Error: Analysis date cannot be in the future[/red]")
continue
return date_str
except ValueError:
console.print(
"[red]Error: Invalid date format. Please use YYYY-MM-DD[/red]"
)
# Functions get_ticker and get_analysis_date are imported from cli.utils
def display_complete_report(final_state):
@ -531,7 +521,7 @@ def display_complete_report(final_state):
title="Market Analyst",
border_style="blue",
padding=(1, 2),
)
),
)
# Social Analyst Report
@ -542,7 +532,7 @@ def display_complete_report(final_state):
title="Social Analyst",
border_style="blue",
padding=(1, 2),
)
),
)
# News Analyst Report
@ -553,7 +543,7 @@ def display_complete_report(final_state):
title="News Analyst",
border_style="blue",
padding=(1, 2),
)
),
)
# Fundamentals Analyst Report
@ -564,7 +554,7 @@ def display_complete_report(final_state):
title="Fundamentals Analyst",
border_style="blue",
padding=(1, 2),
)
),
)
if analyst_reports:
@ -574,7 +564,7 @@ def display_complete_report(final_state):
title="I. Analyst Team Reports",
border_style="cyan",
padding=(1, 2),
)
),
)
# II. Research Team Reports
@ -590,7 +580,7 @@ def display_complete_report(final_state):
title="Bull Researcher",
border_style="blue",
padding=(1, 2),
)
),
)
# Bear Researcher Analysis
@ -601,7 +591,7 @@ def display_complete_report(final_state):
title="Bear Researcher",
border_style="blue",
padding=(1, 2),
)
),
)
# Research Manager Decision
@ -612,7 +602,7 @@ def display_complete_report(final_state):
title="Research Manager",
border_style="blue",
padding=(1, 2),
)
),
)
if research_reports:
@ -622,7 +612,7 @@ def display_complete_report(final_state):
title="II. Research Team Decision",
border_style="magenta",
padding=(1, 2),
)
),
)
# III. Trading Team Reports
@ -638,7 +628,7 @@ def display_complete_report(final_state):
title="III. Trading Team Plan",
border_style="yellow",
padding=(1, 2),
)
),
)
# IV. Risk Management Team Reports
@ -654,7 +644,7 @@ def display_complete_report(final_state):
title="Aggressive Analyst",
border_style="blue",
padding=(1, 2),
)
),
)
# Conservative (Safe) Analyst Analysis
@ -665,7 +655,7 @@ def display_complete_report(final_state):
title="Conservative Analyst",
border_style="blue",
padding=(1, 2),
)
),
)
# Neutral Analyst Analysis
@ -676,7 +666,7 @@ def display_complete_report(final_state):
title="Neutral Analyst",
border_style="blue",
padding=(1, 2),
)
),
)
if risk_reports:
@ -686,7 +676,7 @@ def display_complete_report(final_state):
title="IV. Risk Management Team Decision",
border_style="red",
padding=(1, 2),
)
),
)
# V. Portfolio Manager Decision
@ -702,7 +692,7 @@ def display_complete_report(final_state):
title="V. Portfolio Manager Decision",
border_style="green",
padding=(1, 2),
)
),
)
@ -717,7 +707,7 @@ def extract_content_string(content):
"""Extract string content from various message formats."""
if isinstance(content, str):
return content
elif isinstance(content, list):
if isinstance(content, list):
# Handle Anthropic's list format
text_parts = []
for item in content:
@ -729,8 +719,7 @@ def extract_content_string(content):
else:
text_parts.append(str(item))
return " ".join(text_parts)
else:
return str(content)
return str(content)
def run_analysis():
@ -748,7 +737,7 @@ def run_analysis():
# Initialize the graph
graph = TradingAgentsGraph(
[analyst.value for analyst in selections["analysts"]], config=config, debug=True
[analyst.value for analyst in selections["analysts"]], config=config, debug=True,
)
# Create result directory
@ -807,23 +796,23 @@ def run_analysis():
message_buffer.add_message = save_message_decorator(message_buffer, "add_message")
message_buffer.add_tool_call = save_tool_call_decorator(
message_buffer, "add_tool_call"
message_buffer, "add_tool_call",
)
message_buffer.update_report_section = save_report_section_decorator(
message_buffer, "update_report_section"
message_buffer, "update_report_section",
)
# Now start the display layout
layout = create_layout()
with Live(layout, refresh_per_second=4) as live:
with Live(layout, refresh_per_second=4):
# Initial display
update_display(layout)
# Add initial messages
message_buffer.add_message("System", f"Selected ticker: {selections['ticker']}")
message_buffer.add_message(
"System", f"Analysis date: {selections['analysis_date']}"
"System", f"Analysis date: {selections['analysis_date']}",
)
message_buffer.add_message(
"System",
@ -854,7 +843,7 @@ def run_analysis():
# Initialize state and get graph args
init_agent_state = graph.propagator.create_initial_state(
selections["ticker"], selections["analysis_date"]
selections["ticker"], selections["analysis_date"],
)
args = graph.propagator.get_graph_args()
@ -868,7 +857,7 @@ def run_analysis():
# Extract message content and type
if hasattr(last_message, "content"):
content = extract_content_string(
last_message.content
last_message.content,
) # Use the helper function
msg_type = "Reasoning"
else:
@ -884,65 +873,64 @@ def run_analysis():
# Handle both dictionary and object tool calls
if isinstance(tool_call, dict):
message_buffer.add_tool_call(
tool_call["name"], tool_call["args"]
tool_call["name"], tool_call["args"],
)
else:
message_buffer.add_tool_call(tool_call.name, tool_call.args)
# Update reports and agent status based on chunk content
# Analyst Team Reports
if "market_report" in chunk and chunk["market_report"]:
if chunk.get("market_report"):
message_buffer.update_report_section(
"market_report", chunk["market_report"]
"market_report", chunk["market_report"],
)
message_buffer.update_agent_status("Market Analyst", "completed")
# Set next analyst to in_progress
if "social" in selections["analysts"]:
message_buffer.update_agent_status(
"Social Analyst", "in_progress"
"Social Analyst", "in_progress",
)
if "sentiment_report" in chunk and chunk["sentiment_report"]:
if chunk.get("sentiment_report"):
message_buffer.update_report_section(
"sentiment_report", chunk["sentiment_report"]
"sentiment_report", chunk["sentiment_report"],
)
message_buffer.update_agent_status("Social Analyst", "completed")
# Set next analyst to in_progress
if "news" in selections["analysts"]:
message_buffer.update_agent_status(
"News Analyst", "in_progress"
"News Analyst", "in_progress",
)
if "news_report" in chunk and chunk["news_report"]:
if chunk.get("news_report"):
message_buffer.update_report_section(
"news_report", chunk["news_report"]
"news_report", chunk["news_report"],
)
message_buffer.update_agent_status("News Analyst", "completed")
# Set next analyst to in_progress
if "fundamentals" in selections["analysts"]:
message_buffer.update_agent_status(
"Fundamentals Analyst", "in_progress"
"Fundamentals Analyst", "in_progress",
)
if "fundamentals_report" in chunk and chunk["fundamentals_report"]:
if chunk.get("fundamentals_report"):
message_buffer.update_report_section(
"fundamentals_report", chunk["fundamentals_report"]
"fundamentals_report", chunk["fundamentals_report"],
)
message_buffer.update_agent_status(
"Fundamentals Analyst", "completed"
"Fundamentals Analyst", "completed",
)
# Set all research team members to in_progress
update_research_team_status("in_progress")
# Research Team - Handle Investment Debate State
if (
"investment_debate_state" in chunk
and chunk["investment_debate_state"]
chunk.get("investment_debate_state")
):
debate_state = chunk["investment_debate_state"]
# Update Bull Researcher status and report
if "bull_history" in debate_state and debate_state["bull_history"]:
if debate_state.get("bull_history"):
# Keep all research team members in progress
update_research_team_status("in_progress")
# Extract latest bull response
@ -957,7 +945,7 @@ def run_analysis():
)
# Update Bear Researcher status and report
if "bear_history" in debate_state and debate_state["bear_history"]:
if debate_state.get("bear_history"):
# Keep all research team members in progress
update_research_team_status("in_progress")
# Extract latest bear response
@ -973,8 +961,7 @@ def run_analysis():
# Update Research Manager status and final decision
if (
"judge_decision" in debate_state
and debate_state["judge_decision"]
debate_state.get("judge_decision")
):
# Keep all research team members in progress until final decision
update_research_team_status("in_progress")
@ -991,31 +978,29 @@ def run_analysis():
update_research_team_status("completed")
# Set first risk analyst to in_progress
message_buffer.update_agent_status(
"Risky Analyst", "in_progress"
"Risky Analyst", "in_progress",
)
# Trading Team
if (
"trader_investment_plan" in chunk
and chunk["trader_investment_plan"]
chunk.get("trader_investment_plan")
):
message_buffer.update_report_section(
"trader_investment_plan", chunk["trader_investment_plan"]
"trader_investment_plan", chunk["trader_investment_plan"],
)
# Set first risk analyst to in_progress
message_buffer.update_agent_status("Risky Analyst", "in_progress")
# Risk Management Team - Handle Risk Debate State
if "risk_debate_state" in chunk and chunk["risk_debate_state"]:
if chunk.get("risk_debate_state"):
risk_state = chunk["risk_debate_state"]
# Update Risky Analyst status and report
if (
"current_risky_response" in risk_state
and risk_state["current_risky_response"]
risk_state.get("current_risky_response")
):
message_buffer.update_agent_status(
"Risky Analyst", "in_progress"
"Risky Analyst", "in_progress",
)
message_buffer.add_message(
"Reasoning",
@ -1029,11 +1014,10 @@ def run_analysis():
# Update Safe Analyst status and report
if (
"current_safe_response" in risk_state
and risk_state["current_safe_response"]
risk_state.get("current_safe_response")
):
message_buffer.update_agent_status(
"Safe Analyst", "in_progress"
"Safe Analyst", "in_progress",
)
message_buffer.add_message(
"Reasoning",
@ -1047,11 +1031,10 @@ def run_analysis():
# Update Neutral Analyst status and report
if (
"current_neutral_response" in risk_state
and risk_state["current_neutral_response"]
risk_state.get("current_neutral_response")
):
message_buffer.update_agent_status(
"Neutral Analyst", "in_progress"
"Neutral Analyst", "in_progress",
)
message_buffer.add_message(
"Reasoning",
@ -1064,9 +1047,9 @@ def run_analysis():
)
# Update Portfolio Manager status and final decision
if "judge_decision" in risk_state and risk_state["judge_decision"]:
if risk_state.get("judge_decision"):
message_buffer.update_agent_status(
"Portfolio Manager", "in_progress"
"Portfolio Manager", "in_progress",
)
message_buffer.add_message(
"Reasoning",
@ -1081,10 +1064,10 @@ def run_analysis():
message_buffer.update_agent_status("Risky Analyst", "completed")
message_buffer.update_agent_status("Safe Analyst", "completed")
message_buffer.update_agent_status(
"Neutral Analyst", "completed"
"Neutral Analyst", "completed",
)
message_buffer.update_agent_status(
"Portfolio Manager", "completed"
"Portfolio Manager", "completed",
)
# Update the display
@ -1094,18 +1077,18 @@ def run_analysis():
# Get final state and decision
final_state = trace[-1]
decision = graph.process_signal(final_state["final_trade_decision"])
graph.process_signal(final_state["final_trade_decision"])
# Update all agent statuses to completed
for agent in message_buffer.agent_status:
message_buffer.update_agent_status(agent, "completed")
message_buffer.add_message(
"Analysis", f"Completed analysis for {selections['analysis_date']}"
"Analysis", f"Completed analysis for {selections['analysis_date']}",
)
# Update final report sections
for section in message_buffer.report_sections.keys():
for section in message_buffer.report_sections:
if section in final_state:
message_buffer.update_report_section(section, final_state[section])

View File

@ -1,5 +1,7 @@
import sys
import questionary
from typing import List
from rich.console import Console
from cli.models import AnalystType
@ -23,13 +25,13 @@ def get_ticker() -> str:
[
("text", "fg:green"),
("highlighted", "noinherit"),
]
],
),
).ask()
if not ticker:
console.print("\n[red]No ticker symbol provided. Exiting...[/red]")
exit(1)
sys.exit(1)
return ticker.strip().upper()
@ -56,18 +58,18 @@ def get_analysis_date() -> str:
[
("text", "fg:green"),
("highlighted", "noinherit"),
]
],
),
).ask()
if not date:
console.print("\n[red]No date provided. Exiting...[/red]")
exit(1)
sys.exit(1)
return date.strip()
def select_analysts() -> List[AnalystType]:
def select_analysts() -> list[AnalystType]:
"""Select analysts using an interactive checkbox."""
choices = questionary.checkbox(
"Select Your [Analysts Team]:",
@ -82,13 +84,13 @@ def select_analysts() -> List[AnalystType]:
("selected", "fg:green noinherit"),
("highlighted", "noinherit"),
("pointer", "noinherit"),
]
],
),
).ask()
if not choices:
console.print("\n[red]No analysts selected. Exiting...[/red]")
exit(1)
sys.exit(1)
return choices
@ -114,13 +116,13 @@ def select_research_depth() -> int:
("selected", "fg:yellow noinherit"),
("highlighted", "fg:yellow noinherit"),
("pointer", "fg:yellow noinherit"),
]
],
),
).ask()
if choice is None:
console.print("\n[red]No research depth selected. Exiting...[/red]")
exit(1)
sys.exit(1)
return choice
@ -200,15 +202,15 @@ def select_shallow_thinking_agent(provider) -> str:
("selected", "fg:magenta noinherit"),
("highlighted", "fg:magenta noinherit"),
("pointer", "fg:magenta noinherit"),
]
],
),
).ask()
if choice is None:
console.print(
"\n[red]No shallow thinking llm engine selected. Exiting...[/red]"
"\n[red]No shallow thinking llm engine selected. Exiting...[/red]",
)
exit(1)
sys.exit(1)
return choice
@ -292,13 +294,13 @@ def select_deep_thinking_agent(provider) -> str:
("selected", "fg:magenta noinherit"),
("highlighted", "fg:magenta noinherit"),
("pointer", "fg:magenta noinherit"),
]
],
),
).ask()
if choice is None:
console.print("\n[red]No deep thinking llm engine selected. Exiting...[/red]")
exit(1)
sys.exit(1)
return choice
@ -326,15 +328,14 @@ def select_llm_provider() -> tuple[str, str]:
("selected", "fg:magenta noinherit"),
("highlighted", "fg:magenta noinherit"),
("pointer", "fg:magenta noinherit"),
]
],
),
).ask()
if choice is None:
console.print("\n[red]no OpenAI backend selected. Exiting...[/red]")
exit(1)
sys.exit(1)
display_name, url = choice
print(f"You selected: {display_name}\tURL: {url}")
return display_name, url

View File

@ -1,10 +1,7 @@
import os,sys # Multiple imports on one line (Ruff will fix)
from typing import List,Dict # Multiple imports on one line
def poorly_formatted_function(x,y,z): # Missing type hints
"""This function has formatting issues."""
result=x+y*z # Missing spaces around operators
unused_variable = 42 # Unused variable (Ruff will detect)
if result>100: # Missing spaces
print( "Result is large" ) # Extra spaces in parentheses
return result

View File

@ -23,11 +23,11 @@ def run_command(cmd, description=""):
print(f"{description or 'Command completed successfully'}")
return True
else:
print(f"❌ Command failed:")
print("❌ Command failed:")
print(result.stderr)
return False
except subprocess.TimeoutExpired:
print(f"⏱️ Command timed out")
print("⏱️ Command timed out")
return False
except Exception as e:
print(f"❌ Error running command: {e}")
@ -73,7 +73,7 @@ def main():
# Summary
print("\n" + "=" * 50)
print(f"📊 Test Setup Verification Results:")
print("📊 Test Setup Verification Results:")
print(f"✅ Successful: {success_count}/{total_tests}")
print(f"❌ Failed: {total_tests - success_count}/{total_tests}")

View File

@ -1,11 +1,10 @@
"""Pytest configuration and shared fixtures for TradingAgents tests."""
import os
import pytest
import tempfile
from unittest.mock import Mock, MagicMock
from datetime import date, datetime
from typing import Dict, Any
from unittest.mock import Mock
import pytest
from tradingagents.default_config import DEFAULT_CONFIG
@ -22,7 +21,7 @@ def sample_config():
"deep_think_llm": "gpt-4o-mini",
"quick_think_llm": "gpt-4o-mini",
"project_dir": "/tmp/test_tradingagents",
}
},
)
return config
@ -174,7 +173,7 @@ def mock_memory():
def pytest_configure(config):
"""Configure pytest with custom markers."""
config.addinivalue_line(
"markers", "integration: mark test as integration test (slow)"
"markers", "integration: mark test as integration test (slow)",
)
config.addinivalue_line("markers", "unit: mark test as unit test (fast)")
config.addinivalue_line("markers", "api: mark test as requiring API access")

View File

@ -2,14 +2,14 @@
import json
from datetime import datetime, timedelta
from typing import Dict, List, Any
from typing import Any
class SampleDataFactory:
"""Factory class for creating sample test data."""
@staticmethod
def create_market_data(ticker: str = "AAPL", days: int = 30) -> Dict[str, Any]:
def create_market_data(ticker: str = "AAPL", days: int = 30) -> dict[str, Any]:
"""Create sample market data for testing."""
base_date = datetime(2024, 5, 1)
data = {}
@ -36,8 +36,8 @@ class SampleDataFactory:
@staticmethod
def create_finnhub_news_data(
ticker: str = "AAPL", count: int = 10
) -> Dict[str, List[Dict[str, Any]]]:
ticker: str = "AAPL", count: int = 10,
) -> dict[str, list[dict[str, Any]]]:
"""Create sample FinnHub news data for testing."""
base_date = datetime(2024, 5, 10)
data = {}
@ -93,7 +93,7 @@ class SampleDataFactory:
@staticmethod
def create_insider_transactions_data(
ticker: str = "AAPL",
) -> Dict[str, List[Dict[str, Any]]]:
) -> dict[str, list[dict[str, Any]]]:
"""Create sample insider transactions data for testing."""
base_date = datetime(2024, 5, 5)
data = {}
@ -129,15 +129,15 @@ class SampleDataFactory:
"transactionValue": transaction["shares"] * transaction["price"],
"reportingName": transaction["person"],
"typeOfOwner": "officer",
}
},
]
return data
@staticmethod
def create_financial_statements_data(
ticker: str = "AAPL", period: str = "annual"
) -> Dict[str, List[Dict[str, Any]]]:
ticker: str = "AAPL", period: str = "annual",
) -> dict[str, list[dict[str, Any]]]:
"""Create sample financial statements data for testing."""
if period == "annual":
dates = ["2023-12-31", "2022-12-31", "2021-12-31"]
@ -174,7 +174,7 @@ class SampleDataFactory:
@staticmethod
def create_social_sentiment_data(
ticker: str = "AAPL",
) -> Dict[str, List[Dict[str, Any]]]:
) -> dict[str, list[dict[str, Any]]]:
"""Create sample social media sentiment data for testing."""
base_date = datetime(2024, 5, 8)
data = {}
@ -226,7 +226,7 @@ class SampleDataFactory:
"subreddit": "stocks" if j % 2 else "investing",
"upvotes": 10 + (j * 5),
"comments": 3 + j,
}
},
)
data[date_str] = daily_posts
@ -234,7 +234,7 @@ class SampleDataFactory:
return data
@staticmethod
def create_technical_indicators_data(ticker: str = "AAPL") -> Dict[str, Any]:
def create_technical_indicators_data(ticker: str = "AAPL") -> dict[str, Any]:
"""Create sample technical indicators data for testing."""
return {
"symbol": ticker,
@ -262,23 +262,23 @@ class SampleDataFactory:
}
@staticmethod
def create_complete_test_dataset(ticker: str = "AAPL") -> Dict[str, Dict[str, Any]]:
def create_complete_test_dataset(ticker: str = "AAPL") -> dict[str, dict[str, Any]]:
"""Create a complete dataset for comprehensive testing."""
return {
"market_data": SampleDataFactory.create_market_data(ticker),
"news_data": SampleDataFactory.create_finnhub_news_data(ticker),
"insider_transactions": SampleDataFactory.create_insider_transactions_data(
ticker
ticker,
),
"financial_annual": SampleDataFactory.create_financial_statements_data(
ticker, "annual"
ticker, "annual",
),
"financial_quarterly": SampleDataFactory.create_financial_statements_data(
ticker, "quarterly"
ticker, "quarterly",
),
"social_sentiment": SampleDataFactory.create_social_sentiment_data(ticker),
"technical_indicators": SampleDataFactory.create_technical_indicators_data(
ticker
ticker,
),
}
@ -343,7 +343,7 @@ def save_sample_data_to_files(base_path: str, ticker: str = "AAPL") -> None:
# Save quarterly data separately
quarterly_path = os.path.join(
finnhub_path, "fin_as_reported", f"{ticker}_quarterly_data_formatted.json"
finnhub_path, "fin_as_reported", f"{ticker}_quarterly_data_formatted.json",
)
with open(quarterly_path, "w") as f:
json.dump(dataset["financial_quarterly"], f, indent=2)

View File

@ -1,13 +1,11 @@
"""Integration tests for the full TradingAgents workflow."""
import pytest
import os
import tempfile
from unittest.mock import Mock, patch, MagicMock
from datetime import datetime
from unittest.mock import Mock, patch
import pytest
from tradingagents.graph.trading_graph import TradingAgentsGraph
from tradingagents.default_config import DEFAULT_CONFIG
from tradingagents.graph.trading_graph import TradingAgentsGraph
@pytest.mark.integration
@ -26,14 +24,14 @@ class TestFullWorkflowIntegration:
"deep_think_llm": "gpt-4o-mini",
"quick_think_llm": "gpt-4o-mini",
"project_dir": temp_data_dir,
}
},
)
return config
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
@patch("tradingagents.graph.trading_graph.Toolkit")
def test_end_to_end_trading_workflow(
self, mock_toolkit, mock_chat_openai, integration_config
self, mock_toolkit, mock_chat_openai, integration_config,
):
"""Test complete end-to-end trading workflow."""
# Setup mocks
@ -69,15 +67,14 @@ class TestFullWorkflowIntegration:
"company_of_interest": "AAPL",
"trade_date": "2024-05-10",
"messages": [],
}
},
)
trading_graph.propagator.get_graph_args = Mock(return_value={})
trading_graph.signal_processor.process_signal = Mock(return_value="BUY")
# Execute the full workflow
with patch("builtins.open", create=True):
with patch("json.dump"):
final_state, decision = trading_graph.propagate("AAPL", "2024-05-10")
with patch("builtins.open", create=True), patch("json.dump"):
final_state, decision = trading_graph.propagate("AAPL", "2024-05-10")
# Verify the workflow completed successfully
assert final_state is not None
@ -89,7 +86,7 @@ class TestFullWorkflowIntegration:
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
@patch("tradingagents.graph.trading_graph.Toolkit")
def test_multiple_analysts_integration(
self, mock_toolkit, mock_chat_openai, integration_config
self, mock_toolkit, mock_chat_openai, integration_config,
):
"""Test integration with different analyst combinations."""
analyst_combinations = [
@ -117,7 +114,7 @@ class TestFullWorkflowIntegration:
with patch("tradingagents.graph.trading_graph.set_config"):
# Test each analyst combination
trading_graph = TradingAgentsGraph(
selected_analysts=analysts, config=integration_config
selected_analysts=analysts, config=integration_config,
)
trading_graph.graph = mock_graph
@ -127,19 +124,18 @@ class TestFullWorkflowIntegration:
"company_of_interest": "TSLA",
"trade_date": "2024-05-15",
"messages": [],
}
},
)
trading_graph.propagator.get_graph_args = Mock(return_value={})
trading_graph.signal_processor.process_signal = Mock(
return_value="HOLD"
return_value="HOLD",
)
# Execute
with patch("builtins.open", create=True):
with patch("json.dump"):
final_state, decision = trading_graph.propagate(
"TSLA", "2024-05-15"
)
with patch("builtins.open", create=True), patch("json.dump"):
final_state, decision = trading_graph.propagate(
"TSLA", "2024-05-15",
)
# Verify
assert final_state is not None
@ -148,7 +144,7 @@ class TestFullWorkflowIntegration:
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
@patch("tradingagents.graph.trading_graph.Toolkit")
def test_memory_and_reflection_integration(
self, mock_toolkit, mock_chat_openai, integration_config
self, mock_toolkit, mock_chat_openai, integration_config,
):
"""Test integration of memory and reflection components."""
# Setup
@ -165,7 +161,7 @@ class TestFullWorkflowIntegration:
mock_graph.invoke.return_value = mock_final_state
with patch(
"tradingagents.graph.trading_graph.FinancialSituationMemory"
"tradingagents.graph.trading_graph.FinancialSituationMemory",
) as mock_memory:
mock_memory_instance = Mock()
mock_memory.return_value = mock_memory_instance
@ -180,11 +176,11 @@ class TestFullWorkflowIntegration:
"company_of_interest": "NVDA",
"trade_date": "2024-05-20",
"messages": [],
}
},
)
trading_graph.propagator.get_graph_args = Mock(return_value={})
trading_graph.signal_processor.process_signal = Mock(
return_value="SELL"
return_value="SELL",
)
# Mock reflection methods
@ -195,9 +191,8 @@ class TestFullWorkflowIntegration:
trading_graph.reflector.reflect_risk_manager = Mock()
# Execute workflow
with patch("builtins.open", create=True):
with patch("json.dump"):
final_state, decision = trading_graph.propagate("NVDA", "2024-05-20")
with patch("builtins.open", create=True), patch("json.dump"):
final_state, decision = trading_graph.propagate("NVDA", "2024-05-20")
# Test reflection and memory update
returns_losses = {"return": -0.03, "loss": -0.08}
@ -213,7 +208,7 @@ class TestFullWorkflowIntegration:
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
@patch("tradingagents.graph.trading_graph.Toolkit")
def test_debug_mode_integration(
self, mock_toolkit, mock_chat_openai, integration_config
self, mock_toolkit, mock_chat_openai, integration_config,
):
"""Test integration in debug mode."""
# Setup
@ -233,7 +228,7 @@ class TestFullWorkflowIntegration:
self._create_mock_final_state(), # Final chunk
]
for chunk in mock_chunks:
if "messages" in chunk and chunk["messages"]:
if chunk.get("messages"):
for msg in chunk["messages"]:
if hasattr(msg, "pretty_print"):
msg.pretty_print = Mock()
@ -245,7 +240,7 @@ class TestFullWorkflowIntegration:
with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"):
with patch("tradingagents.graph.trading_graph.set_config"):
trading_graph = TradingAgentsGraph(
debug=True, config=integration_config
debug=True, config=integration_config,
)
trading_graph.graph = mock_graph
@ -255,15 +250,14 @@ class TestFullWorkflowIntegration:
"company_of_interest": "AMZN",
"trade_date": "2024-05-25",
"messages": [],
}
},
)
trading_graph.propagator.get_graph_args = Mock(return_value={})
trading_graph.signal_processor.process_signal = Mock(return_value="BUY")
# Execute in debug mode
with patch("builtins.open", create=True):
with patch("json.dump"):
final_state, decision = trading_graph.propagate("AMZN", "2024-05-25")
with patch("builtins.open", create=True), patch("json.dump"):
final_state, decision = trading_graph.propagate("AMZN", "2024-05-25")
# Verify debug mode was used
mock_graph.stream.assert_called_once()
@ -271,7 +265,7 @@ class TestFullWorkflowIntegration:
assert decision == "BUY"
@pytest.mark.parametrize(
"ticker,date",
("ticker", "date"),
[
("AAPL", "2024-01-15"),
("TSLA", "2024-02-20"),
@ -282,7 +276,7 @@ class TestFullWorkflowIntegration:
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
@patch("tradingagents.graph.trading_graph.Toolkit")
def test_multiple_stocks_integration(
self, mock_toolkit, mock_chat_openai, ticker, date, integration_config
self, mock_toolkit, mock_chat_openai, ticker, date, integration_config,
):
"""Test integration with different stocks and dates."""
# Setup
@ -309,17 +303,16 @@ class TestFullWorkflowIntegration:
"company_of_interest": ticker,
"trade_date": date,
"messages": [],
}
},
)
trading_graph.propagator.get_graph_args = Mock(return_value={})
trading_graph.signal_processor.process_signal = Mock(
return_value="HOLD"
return_value="HOLD",
)
# Execute
with patch("builtins.open", create=True):
with patch("json.dump"):
final_state, decision = trading_graph.propagate(ticker, date)
with patch("builtins.open", create=True), patch("json.dump"):
final_state, decision = trading_graph.propagate(ticker, date)
# Verify
assert final_state["company_of_interest"] == ticker
@ -389,7 +382,7 @@ class TestPerformanceIntegration:
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
@patch("tradingagents.graph.trading_graph.Toolkit")
def test_multiple_consecutive_runs(
self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir
self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir,
):
"""Test multiple consecutive trading decisions."""
sample_config["project_dir"] = temp_data_dir
@ -455,9 +448,8 @@ class TestPerformanceIntegration:
mock_final_state["final_trade_decision"]
)
with patch("builtins.open", create=True):
with patch("json.dump"):
final_state, decision = trading_graph.propagate(ticker, date)
with patch("builtins.open", create=True), patch("json.dump"):
final_state, decision = trading_graph.propagate(ticker, date)
decisions.append(decision)

View File

@ -1,8 +1,9 @@
"""Unit tests for market analyst agent."""
from unittest.mock import Mock
import pytest
from unittest.mock import Mock, patch, MagicMock
from langchain_core.messages import HumanMessage, AIMessage
from langchain_core.messages import HumanMessage
from tradingagents.agents.analysts.market_analyst import create_market_analyst
@ -16,7 +17,7 @@ class TestMarketAnalyst:
assert callable(analyst_node)
def test_market_analyst_node_basic_execution(
self, mock_llm, mock_toolkit, sample_agent_state
self, mock_llm, mock_toolkit, sample_agent_state,
):
"""Test basic execution of market analyst node."""
# Setup
@ -38,7 +39,7 @@ class TestMarketAnalyst:
assert result["market_report"] == "Market analysis complete"
def test_market_analyst_uses_online_tools_when_configured(
self, mock_llm, mock_toolkit, sample_agent_state
self, mock_llm, mock_toolkit, sample_agent_state,
):
"""Test that analyst uses online tools when configured."""
# Setup
@ -54,7 +55,7 @@ class TestMarketAnalyst:
analyst_node = create_market_analyst(mock_llm, mock_toolkit)
# Execute
result = analyst_node(sample_agent_state)
analyst_node(sample_agent_state)
# Verify tools were bound correctly
mock_llm.bind_tools.assert_called_once()
@ -63,7 +64,7 @@ class TestMarketAnalyst:
assert "get_YFin_data_online" in str(tool_names) or len(bound_tools) == 2
def test_market_analyst_uses_offline_tools_when_configured(
self, mock_llm, mock_toolkit, sample_agent_state
self, mock_llm, mock_toolkit, sample_agent_state,
):
"""Test that analyst uses offline tools when configured."""
# Setup
@ -79,7 +80,7 @@ class TestMarketAnalyst:
analyst_node = create_market_analyst(mock_llm, mock_toolkit)
# Execute
result = analyst_node(sample_agent_state)
analyst_node(sample_agent_state)
# Verify tools were bound correctly
mock_llm.bind_tools.assert_called_once()
@ -87,7 +88,7 @@ class TestMarketAnalyst:
assert len(bound_tools) == 2 # Should have 2 offline tools
def test_market_analyst_processes_state_variables(
self, mock_llm, mock_toolkit, sample_agent_state
self, mock_llm, mock_toolkit, sample_agent_state,
):
"""Test that market analyst correctly processes state variables."""
# Setup
@ -111,7 +112,7 @@ class TestMarketAnalyst:
assert result["market_report"] == "Analysis for AAPL on 2024-05-10"
def test_market_analyst_handles_empty_tool_calls(
self, mock_llm, mock_toolkit, sample_agent_state
self, mock_llm, mock_toolkit, sample_agent_state,
):
"""Test handling when no tool calls are made."""
# Setup
@ -131,7 +132,7 @@ class TestMarketAnalyst:
assert result["messages"] == [mock_result]
def test_market_analyst_with_tool_calls(
self, mock_llm, mock_toolkit, sample_agent_state
self, mock_llm, mock_toolkit, sample_agent_state,
):
"""Test handling when tool calls are present."""
# Setup
@ -152,7 +153,7 @@ class TestMarketAnalyst:
@pytest.mark.parametrize("online_tools", [True, False])
def test_market_analyst_tool_configuration(
self, mock_llm, mock_toolkit, sample_agent_state, online_tools
self, mock_llm, mock_toolkit, sample_agent_state, online_tools,
):
"""Test tool configuration for both online and offline modes."""
# Setup
@ -194,15 +195,15 @@ class TestMarketAnalystIntegration:
mock_result = Mock()
mock_result.content = """
# Market Analysis for TSLA (2024-05-15)
## Technical Analysis
- RSI: 65 (slightly overbought)
- MACD: Bullish crossover
- 50-day SMA: Trending upward
## Volume Analysis
- Above average volume suggests strong interest
| Indicator | Value | Signal |
|-----------|-------|--------|
| RSI | 65 | Neutral |

View File

@ -1,10 +1,9 @@
"""Unit tests for FinnHub utilities."""
import pytest
import json
import os
import tempfile
from unittest.mock import patch, mock_open, Mock
import pytest
from tradingagents.dataflows.finnhub_utils import get_data_in_range
@ -191,7 +190,7 @@ class TestFinnhubUtils:
# Test without period
expected_path_no_period = os.path.join(
temp_data_dir, "finnhub_data", data_type, f"{ticker}_data_formatted.json"
temp_data_dir, "finnhub_data", data_type, f"{ticker}_data_formatted.json",
)
# Test with period
@ -238,7 +237,7 @@ class TestFinnhubUtils:
assert len(result2) == 1
@pytest.mark.parametrize(
"data_type,period",
("data_type", "period"),
[
("news_data", None),
("insider_trans", None),
@ -249,7 +248,7 @@ class TestFinnhubUtils:
],
)
def test_get_data_in_range_various_data_types(
self, temp_data_dir, data_type, period
self, temp_data_dir, data_type, period,
):
"""Test get_data_in_range with various data types."""
ticker = "TEST"

View File

@ -1,12 +1,10 @@
"""Unit tests for TradingAgentsGraph."""
from unittest.mock import Mock, mock_open, patch
import pytest
import os
from unittest.mock import Mock, patch, MagicMock
from pathlib import Path
from tradingagents.graph.trading_graph import TradingAgentsGraph
from tradingagents.default_config import DEFAULT_CONFIG
class TestTradingAgentsGraph:
@ -47,7 +45,7 @@ class TestTradingAgentsGraph:
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
@patch("tradingagents.graph.trading_graph.Toolkit")
def test_init_with_debug(
self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir
self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir,
):
"""Test initialization with debug mode enabled."""
sample_config["project_dir"] = temp_data_dir
@ -65,7 +63,7 @@ class TestTradingAgentsGraph:
@patch("tradingagents.graph.trading_graph.ChatAnthropic")
@patch("tradingagents.graph.trading_graph.Toolkit")
def test_init_with_anthropic(
self, mock_toolkit, mock_chat_anthropic, sample_config, temp_data_dir
self, mock_toolkit, mock_chat_anthropic, sample_config, temp_data_dir,
):
"""Test initialization with Anthropic LLM provider."""
sample_config["project_dir"] = temp_data_dir
@ -77,14 +75,14 @@ class TestTradingAgentsGraph:
with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"):
with patch("tradingagents.graph.trading_graph.set_config"):
graph = TradingAgentsGraph(config=sample_config)
TradingAgentsGraph(config=sample_config)
assert mock_chat_anthropic.call_count == 2
@patch("tradingagents.graph.trading_graph.ChatGoogleGenerativeAI")
@patch("tradingagents.graph.trading_graph.Toolkit")
def test_init_with_google(
self, mock_toolkit, mock_chat_google, sample_config, temp_data_dir
self, mock_toolkit, mock_chat_google, sample_config, temp_data_dir,
):
"""Test initialization with Google LLM provider."""
sample_config["project_dir"] = temp_data_dir
@ -96,13 +94,13 @@ class TestTradingAgentsGraph:
with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"):
with patch("tradingagents.graph.trading_graph.set_config"):
graph = TradingAgentsGraph(config=sample_config)
TradingAgentsGraph(config=sample_config)
assert mock_chat_google.call_count == 2
@patch("tradingagents.graph.trading_graph.Toolkit")
def test_init_unsupported_llm_provider(
self, mock_toolkit, sample_config, temp_data_dir
self, mock_toolkit, sample_config, temp_data_dir,
):
"""Test initialization with unsupported LLM provider raises error."""
sample_config["project_dir"] = temp_data_dir
@ -117,7 +115,7 @@ class TestTradingAgentsGraph:
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
@patch("tradingagents.graph.trading_graph.Toolkit")
def test_create_tool_nodes(
self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir
self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir,
):
"""Test creation of tool nodes."""
sample_config["project_dir"] = temp_data_dir
@ -145,7 +143,7 @@ class TestTradingAgentsGraph:
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
@patch("tradingagents.graph.trading_graph.Toolkit")
def test_propagate_basic(
self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir
self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir,
):
"""Test basic propagate functionality."""
sample_config["project_dir"] = temp_data_dir
@ -190,15 +188,14 @@ class TestTradingAgentsGraph:
# Mock the propagator and signal processor
graph.propagator.create_initial_state = Mock(
return_value={"test": "state"}
return_value={"test": "state"},
)
graph.propagator.get_graph_args = Mock(return_value={})
graph.signal_processor.process_signal = Mock(return_value="HOLD")
# Execute
with patch("builtins.open", create=True):
with patch("json.dump"):
final_state, decision = graph.propagate("AAPL", "2024-05-10")
with patch("builtins.open", create=True), patch("json.dump"):
final_state, decision = graph.propagate("AAPL", "2024-05-10")
# Verify
assert final_state == mock_final_state
@ -209,7 +206,7 @@ class TestTradingAgentsGraph:
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
@patch("tradingagents.graph.trading_graph.Toolkit")
def test_propagate_debug_mode(
self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir
self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir,
):
"""Test propagate in debug mode."""
sample_config["project_dir"] = temp_data_dir
@ -231,15 +228,14 @@ class TestTradingAgentsGraph:
# Mock other components
graph.propagator.create_initial_state = Mock(
return_value={"test": "state"}
return_value={"test": "state"},
)
graph.propagator.get_graph_args = Mock(return_value={})
graph.signal_processor.process_signal = Mock(return_value="BUY")
# Execute
with patch("builtins.open", create=True):
with patch("json.dump"):
final_state, decision = graph.propagate("TSLA", "2024-05-15")
with patch("builtins.open", create=True), patch("json.dump"):
final_state, decision = graph.propagate("TSLA", "2024-05-15")
# Verify debug mode was used
mock_graph.stream.assert_called_once()
@ -249,7 +245,7 @@ class TestTradingAgentsGraph:
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
@patch("tradingagents.graph.trading_graph.Toolkit")
def test_log_state(
self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir
self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir,
):
"""Test state logging functionality."""
sample_config["project_dir"] = temp_data_dir
@ -291,10 +287,9 @@ class TestTradingAgentsGraph:
}
# Mock file operations
with patch("pathlib.Path.mkdir"):
with patch("builtins.open", mock_open()) as mock_file:
with patch("json.dump") as mock_json_dump:
graph._log_state("2024-05-20", final_state)
with patch("pathlib.Path.mkdir"), patch("builtins.open", mock_open()):
with patch("json.dump"):
graph._log_state("2024-05-20", final_state)
# Verify logging occurred
assert "2024-05-20" in graph.log_states_dict
@ -305,7 +300,7 @@ class TestTradingAgentsGraph:
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
@patch("tradingagents.graph.trading_graph.Toolkit")
def test_reflect_and_remember(
self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir
self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir,
):
"""Test reflection and memory update functionality."""
sample_config["project_dir"] = temp_data_dir
@ -315,20 +310,19 @@ class TestTradingAgentsGraph:
mock_toolkit.return_value = mock_toolkit_instance
with patch(
"tradingagents.graph.trading_graph.FinancialSituationMemory"
) as mock_memory:
with patch("tradingagents.graph.trading_graph.set_config"):
graph = TradingAgentsGraph(config=sample_config)
"tradingagents.graph.trading_graph.FinancialSituationMemory",
), patch("tradingagents.graph.trading_graph.set_config"):
graph = TradingAgentsGraph(config=sample_config)
# Set up current state
graph.curr_state = {"test": "state"}
# Set up current state
graph.curr_state = {"test": "state"}
# Mock reflector methods
graph.reflector.reflect_bull_researcher = Mock()
graph.reflector.reflect_bear_researcher = Mock()
graph.reflector.reflect_trader = Mock()
graph.reflector.reflect_invest_judge = Mock()
graph.reflector.reflect_risk_manager = Mock()
# Mock reflector methods
graph.reflector.reflect_bull_researcher = Mock()
graph.reflector.reflect_bear_researcher = Mock()
graph.reflector.reflect_trader = Mock()
graph.reflector.reflect_invest_judge = Mock()
graph.reflector.reflect_risk_manager = Mock()
returns_losses = {"return": 0.05, "loss": -0.02}
@ -345,7 +339,7 @@ class TestTradingAgentsGraph:
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
@patch("tradingagents.graph.trading_graph.Toolkit")
def test_process_signal(
self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir
self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir,
):
"""Test signal processing functionality."""
sample_config["project_dir"] = temp_data_dir
@ -393,8 +387,8 @@ class TestTradingAgentsGraph:
with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"):
with patch("tradingagents.graph.trading_graph.set_config"):
graph = TradingAgentsGraph(
selected_analysts=selected_analysts, config=sample_config
TradingAgentsGraph(
selected_analysts=selected_analysts, config=sample_config,
)
# Verify graph was set up with selected analysts
@ -415,14 +409,14 @@ class TestTradingAgentsGraphErrorHandling:
# This should still work as the class should use defaults for missing keys
with patch("tradingagents.graph.trading_graph.set_config"):
with pytest.raises(
KeyError
KeyError,
): # Should fail when trying to access missing config keys
TradingAgentsGraph(config=invalid_config)
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
@patch("tradingagents.graph.trading_graph.Toolkit")
def test_directory_creation_failure(
self, mock_toolkit, mock_chat_openai, sample_config
self, mock_toolkit, mock_chat_openai, sample_config,
):
"""Test handling when directory creation fails."""
sample_config["project_dir"] = "/invalid/path/that/cannot/be/created"

View File

@ -1,40 +1,35 @@
from .utils.agent_utils import Toolkit, create_msg_delete
from .utils.agent_states import AgentState, InvestDebateState, RiskDebateState
from .utils.memory import FinancialSituationMemory
from .analysts.fundamentals_analyst import create_fundamentals_analyst
from .analysts.market_analyst import create_market_analyst
from .analysts.news_analyst import create_news_analyst
from .analysts.social_media_analyst import create_social_media_analyst
from .managers.research_manager import create_research_manager
from .managers.risk_manager import create_risk_manager
from .researchers.bear_researcher import create_bear_researcher
from .researchers.bull_researcher import create_bull_researcher
from .risk_mgmt.aggresive_debator import create_risky_debator
from .risk_mgmt.conservative_debator import create_safe_debator
from .risk_mgmt.neutral_debator import create_neutral_debator
from .managers.research_manager import create_research_manager
from .managers.risk_manager import create_risk_manager
from .trader.trader import create_trader
from .utils.agent_states import AgentState, InvestDebateState, RiskDebateState
from .utils.agent_utils import Toolkit, create_msg_delete
from .utils.memory import FinancialSituationMemory
__all__ = [
"FinancialSituationMemory",
"Toolkit",
"AgentState",
"create_msg_delete",
"FinancialSituationMemory",
"InvestDebateState",
"RiskDebateState",
"Toolkit",
"create_bear_researcher",
"create_bull_researcher",
"create_research_manager",
"create_fundamentals_analyst",
"create_market_analyst",
"create_msg_delete",
"create_neutral_debator",
"create_news_analyst",
"create_risky_debator",
"create_research_manager",
"create_risk_manager",
"create_risky_debator",
"create_safe_debator",
"create_social_media_analyst",
"create_trader",

View File

@ -5,7 +5,7 @@ def create_fundamentals_analyst(llm, toolkit):
def fundamentals_analyst_node(state):
current_date = state["trade_date"]
ticker = state["company_of_interest"]
company_name = state["company_of_interest"]
state["company_of_interest"]
if toolkit.config["online_tools"]:
tools = [toolkit.get_fundamentals_openai]
@ -20,7 +20,7 @@ def create_fundamentals_analyst(llm, toolkit):
system_message = (
"You are a researcher tasked with analyzing fundamental information over the past week about a company. Please write a comprehensive report of the company's fundamental information such as financial documents, company profile, basic company financials, company financial history, insider sentiment and insider transactions to gain a full view of the company's fundamental information to inform traders. Make sure to include as much detail as possible. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."
+ " Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read.",
" Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read.",
)
prompt = ChatPromptTemplate.from_messages(
@ -37,7 +37,7 @@ def create_fundamentals_analyst(llm, toolkit):
"For your reference, the current date is {current_date}. The company we want to look at is {ticker}",
),
MessagesPlaceholder(variable_name="messages"),
]
],
)
prompt = prompt.partial(system_message=system_message)

View File

@ -6,7 +6,7 @@ def create_market_analyst(llm, toolkit):
def market_analyst_node(state):
current_date = state["trade_date"]
ticker = state["company_of_interest"]
company_name = state["company_of_interest"]
state["company_of_interest"]
if toolkit.config["online_tools"]:
tools = [
@ -45,7 +45,7 @@ Volume-Based Indicators:
- vwma: VWMA: A moving average weighted by volume. Usage: Confirm trends by integrating price action with volume data. Tips: Watch for skewed results from volume spikes; use in combination with other volume analyses.
- Select indicators that provide diverse and complementary information. Avoid redundancy (e.g., do not select both rsi and stochrsi). Also briefly explain why they are suitable for the given market context. When you tool call, please use the exact name of the indicators provided above as they are defined parameters, otherwise your call will fail. Please make sure to call get_YFin_data first to retrieve the CSV that is needed to generate indicators. Write a very detailed and nuanced report of the trends you observe. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."""
+ """ Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read."""
""" Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read."""
)
prompt = ChatPromptTemplate.from_messages(
@ -62,7 +62,7 @@ Volume-Based Indicators:
"For your reference, the current date is {current_date}. The company we want to look at is {ticker}",
),
MessagesPlaceholder(variable_name="messages"),
]
],
)
prompt = prompt.partial(system_message=system_message)

View File

@ -17,7 +17,7 @@ def create_news_analyst(llm, toolkit):
system_message = (
"You are a news researcher tasked with analyzing recent news and trends over the past week. Please write a comprehensive report of the current state of the world that is relevant for trading and macroeconomics. Look at news from EODHD, and finnhub to be comprehensive. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."
+ """ Make sure to append a Makrdown table at the end of the report to organize key points in the report, organized and easy to read."""
""" Make sure to append a Makrdown table at the end of the report to organize key points in the report, organized and easy to read."""
)
prompt = ChatPromptTemplate.from_messages(
@ -34,7 +34,7 @@ def create_news_analyst(llm, toolkit):
"For your reference, the current date is {current_date}. We are looking at the company {ticker}",
),
MessagesPlaceholder(variable_name="messages"),
]
],
)
prompt = prompt.partial(system_message=system_message)

View File

@ -5,7 +5,7 @@ def create_social_media_analyst(llm, toolkit):
def social_media_analyst_node(state):
current_date = state["trade_date"]
ticker = state["company_of_interest"]
company_name = state["company_of_interest"]
state["company_of_interest"]
if toolkit.config["online_tools"]:
tools = [toolkit.get_stock_news_openai]
@ -16,7 +16,7 @@ def create_social_media_analyst(llm, toolkit):
system_message = (
"You are a social media and company specific news researcher/analyst tasked with analyzing social media posts, recent company news, and public sentiment for a specific company over the past week. You will be given a company's name your objective is to write a comprehensive long report detailing your analysis, insights, and implications for traders and investors on this company's current state after looking at social media and what people are saying about that company, analyzing sentiment data of what people feel each day about the company, and looking at recent company news. Try to look at all sources possible from social media to sentiment to news. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."
+ """ Make sure to append a Makrdown table at the end of the report to organize key points in the report, organized and easy to read.""",
""" Make sure to append a Makrdown table at the end of the report to organize key points in the report, organized and easy to read.""",
)
prompt = ChatPromptTemplate.from_messages(
@ -33,7 +33,7 @@ def create_social_media_analyst(llm, toolkit):
"For your reference, the current date is {current_date}. The current company we want to analyze is {ticker}",
),
MessagesPlaceholder(variable_name="messages"),
]
],
)
prompt = prompt.partial(system_message=system_message)

View File

@ -12,7 +12,7 @@ def create_research_manager(llm, memory):
past_memories = memory.get_memories(curr_situation, n_matches=2)
past_memory_str = ""
for i, rec in enumerate(past_memories, 1):
for _i, rec in enumerate(past_memories, 1):
past_memory_str += rec["recommendation"] + "\n\n"
prompt = f"""As the portfolio manager and debate facilitator, your role is to critically evaluate this round of debate and make a definitive decision: align with the bear analyst, the bull analyst, or choose Hold only if it is strongly justified based on the arguments presented.
@ -24,7 +24,7 @@ Additionally, develop a detailed investment plan for the trader. This should inc
Your Recommendation: A decisive stance supported by the most convincing arguments.
Rationale: An explanation of why these arguments lead to your conclusion.
Strategic Actions: Concrete steps for implementing the recommendation.
Take into account your past mistakes on similar situations. Use these insights to refine your decision-making and ensure you are learning and improving. Present your analysis conversationally, as if speaking naturally, without special formatting.
Take into account your past mistakes on similar situations. Use these insights to refine your decision-making and ensure you are learning and improving. Present your analysis conversationally, as if speaking naturally, without special formatting.
Here are your past reflections on mistakes:
\"{past_memory_str}\"

View File

@ -1,7 +1,7 @@
def create_risk_manager(llm, memory):
def risk_manager_node(state) -> dict:
company_name = state["company_of_interest"]
state["company_of_interest"]
history = state["risk_debate_state"]["history"]
risk_debate_state = state["risk_debate_state"]
@ -15,7 +15,7 @@ def create_risk_manager(llm, memory):
past_memories = memory.get_memories(curr_situation, n_matches=2)
past_memory_str = ""
for i, rec in enumerate(past_memories, 1):
for _i, rec in enumerate(past_memories, 1):
past_memory_str += rec["recommendation"] + "\n\n"
prompt = f"""As the Risk Management Judge and Debate Facilitator, your goal is to evaluate the debate between three risk analysts—Risky, Neutral, and Safe/Conservative—and determine the best course of action for the trader. Your decision must result in a clear recommendation: Buy, Sell, or Hold. Choose Hold only if strongly justified by specific arguments, not as a fallback when all sides seem valid. Strive for clarity and decisiveness.
@ -32,7 +32,7 @@ Deliverables:
---
**Analysts Debate History:**
**Analysts Debate History:**
{history}
---

View File

@ -14,7 +14,7 @@ def create_bear_researcher(llm, memory):
past_memories = memory.get_memories(curr_situation, n_matches=2)
past_memory_str = ""
for i, rec in enumerate(past_memories, 1):
for _i, rec in enumerate(past_memories, 1):
past_memory_str += rec["recommendation"] + "\n\n"
prompt = f"""You are a Bear Analyst making the case against investing in the stock. Your goal is to present a well-reasoned argument emphasizing risks, challenges, and negative indicators. Leverage the provided research and data to highlight potential downsides and counter bullish arguments effectively.

View File

@ -14,7 +14,7 @@ def create_bull_researcher(llm, memory):
past_memories = memory.get_memories(curr_situation, n_matches=2)
past_memory_str = ""
for i, rec in enumerate(past_memories, 1):
for _i, rec in enumerate(past_memories, 1):
past_memory_str += rec["recommendation"] + "\n\n"
prompt = f"""You are a Bull Analyst advocating for investing in the stock. Your task is to build a strong, evidence-based case emphasizing growth potential, competitive advantages, and positive market indicators. Leverage the provided research and data to address concerns and counter bearish arguments effectively.

View File

@ -41,7 +41,7 @@ Engage actively by addressing any specific concerns raised, refuting the weaknes
"current_risky_response": argument,
"current_safe_response": risk_debate_state.get("current_safe_response", ""),
"current_neutral_response": risk_debate_state.get(
"current_neutral_response", ""
"current_neutral_response", "",
),
"count": risk_debate_state["count"] + 1,
}

View File

@ -39,11 +39,11 @@ Engage by questioning their optimism and emphasizing the potential downsides the
"neutral_history": risk_debate_state.get("neutral_history", ""),
"latest_speaker": "Safe",
"current_risky_response": risk_debate_state.get(
"current_risky_response", ""
"current_risky_response", "",
),
"current_safe_response": argument,
"current_neutral_response": risk_debate_state.get(
"current_neutral_response", ""
"current_neutral_response", "",
),
"count": risk_debate_state["count"] + 1,
}

View File

@ -39,7 +39,7 @@ Engage actively by analyzing both sides critically, addressing weaknesses in the
"neutral_history": neutral_history + "\n" + argument,
"latest_speaker": "Neutral",
"current_risky_response": risk_debate_state.get(
"current_risky_response", ""
"current_risky_response", "",
),
"current_safe_response": risk_debate_state.get("current_safe_response", ""),
"current_neutral_response": argument,

View File

@ -15,7 +15,7 @@ def create_trader(llm, memory):
past_memory_str = ""
if past_memories:
for i, rec in enumerate(past_memories, 1):
for _i, rec in enumerate(past_memories, 1):
past_memory_str += rec["recommendation"] + "\n\n"
else:
past_memory_str = "No past memories found."

View File

@ -1,16 +1,18 @@
from typing import Annotated
from typing_extensions import TypedDict
from tradingagents.agents import *
from langgraph.graph import MessagesState
from typing_extensions import TypedDict
# Import specific agent classes as needed
# Researcher team state
class InvestDebateState(TypedDict):
bull_history: Annotated[
str, "Bullish Conversation history"
str, "Bullish Conversation history",
] # Bullish Conversation history
bear_history: Annotated[
str, "Bearish Conversation history"
str, "Bearish Conversation history",
] # Bullish Conversation history
history: Annotated[str, "Conversation history"] # Conversation history
current_response: Annotated[str, "Latest response"] # Last response
@ -21,24 +23,24 @@ class InvestDebateState(TypedDict):
# Risk management team state
class RiskDebateState(TypedDict):
risky_history: Annotated[
str, "Risky Agent's Conversation history"
str, "Risky Agent's Conversation history",
] # Conversation history
safe_history: Annotated[
str, "Safe Agent's Conversation history"
str, "Safe Agent's Conversation history",
] # Conversation history
neutral_history: Annotated[
str, "Neutral Agent's Conversation history"
str, "Neutral Agent's Conversation history",
] # Conversation history
history: Annotated[str, "Conversation history"] # Conversation history
latest_speaker: Annotated[str, "Analyst that spoke last"]
current_risky_response: Annotated[
str, "Latest response by the risky analyst"
str, "Latest response by the risky analyst",
] # Last response
current_safe_response: Annotated[
str, "Latest response by the safe analyst"
str, "Latest response by the safe analyst",
] # Last response
current_neutral_response: Annotated[
str, "Latest response by the neutral analyst"
str, "Latest response by the neutral analyst",
] # Last response
judge_decision: Annotated[str, "Judge's decision"]
count: Annotated[int, "Length of the current conversation"] # Conversation length
@ -54,13 +56,13 @@ class AgentState(MessagesState):
market_report: Annotated[str, "Report from the Market Analyst"]
sentiment_report: Annotated[str, "Report from the Social Media Analyst"]
news_report: Annotated[
str, "Report from the News Researcher of current world affairs"
str, "Report from the News Researcher of current world affairs",
]
fundamentals_report: Annotated[str, "Report from the Fundamentals Researcher"]
# researcher team discussion step
investment_debate_state: Annotated[
InvestDebateState, "Current state of the debate on if to invest or not"
InvestDebateState, "Current state of the debate on if to invest or not",
]
investment_plan: Annotated[str, "Plan generated by the Analyst"]
@ -68,6 +70,6 @@ class AgentState(MessagesState):
# risk management team discussion step
risk_debate_state: Annotated[
RiskDebateState, "Current state of the debate on evaluating risk"
RiskDebateState, "Current state of the debate on evaluating risk",
]
final_trade_decision: Annotated[str, "Final decision made by the Risk Analysts"]

View File

@ -1,9 +1,10 @@
from langchain_core.messages import HumanMessage
from typing import Annotated
from langchain_core.messages import RemoveMessage
from langchain_core.tools import tool
from datetime import datetime
import tradingagents.dataflows.interface as interface
from typing import Annotated
from langchain_core.messages import HumanMessage, RemoveMessage
from langchain_core.tools import tool
from tradingagents.dataflows import interface
from tradingagents.default_config import DEFAULT_CONFIG
@ -18,7 +19,7 @@ def create_msg_delete():
# Add a minimal placeholder message
placeholder = HumanMessage(content="Continue")
return {"messages": removal_operations + [placeholder]}
return {"messages": [*removal_operations, placeholder]}
return delete_messages
@ -53,9 +54,8 @@ class Toolkit:
str: A formatted dataframe containing the latest global news from Reddit in the specified time frame.
"""
global_news_result = interface.get_reddit_global_news(curr_date, 7, 5)
return interface.get_reddit_global_news(curr_date, 7, 5)
return global_news_result
@staticmethod
@tool
@ -83,11 +83,10 @@ class Toolkit:
start_date = datetime.strptime(start_date, "%Y-%m-%d")
look_back_days = (end_date - start_date).days
finnhub_news_result = interface.get_finnhub_news(
ticker, end_date_str, look_back_days
return interface.get_finnhub_news(
ticker, end_date_str, look_back_days,
)
return finnhub_news_result
@staticmethod
@tool
@ -107,9 +106,8 @@ class Toolkit:
str: A formatted dataframe containing the latest news about the company on the given date
"""
stock_news_results = interface.get_reddit_company_news(ticker, curr_date, 7, 5)
return interface.get_reddit_company_news(ticker, curr_date, 7, 5)
return stock_news_results
@staticmethod
@tool
@ -128,9 +126,8 @@ class Toolkit:
str: A formatted dataframe containing the stock price data for the specified ticker symbol in the specified date range.
"""
result_data = interface.get_YFin_data(symbol, start_date, end_date)
return interface.get_YFin_data(symbol, start_date, end_date)
return result_data
@staticmethod
@tool
@ -149,19 +146,18 @@ class Toolkit:
str: A formatted dataframe containing the stock price data for the specified ticker symbol in the specified date range.
"""
result_data = interface.get_YFin_data_online(symbol, start_date, end_date)
return interface.get_YFin_data_online(symbol, start_date, end_date)
return result_data
@staticmethod
@tool
def get_stockstats_indicators_report(
symbol: Annotated[str, "ticker symbol of the company"],
indicator: Annotated[
str, "technical indicator to get the analysis and report of"
str, "technical indicator to get the analysis and report of",
],
curr_date: Annotated[
str, "The current trading date you are trading on, YYYY-mm-dd"
str, "The current trading date you are trading on, YYYY-mm-dd",
],
look_back_days: Annotated[int, "how many days to look back"] = 30,
) -> str:
@ -176,21 +172,20 @@ class Toolkit:
str: A formatted dataframe containing the stock stats indicators for the specified ticker symbol and indicator.
"""
result_stockstats = interface.get_stock_stats_indicators_window(
symbol, indicator, curr_date, look_back_days, False
return interface.get_stock_stats_indicators_window(
symbol, indicator, curr_date, look_back_days, False,
)
return result_stockstats
@staticmethod
@tool
def get_stockstats_indicators_report_online(
symbol: Annotated[str, "ticker symbol of the company"],
indicator: Annotated[
str, "technical indicator to get the analysis and report of"
str, "technical indicator to get the analysis and report of",
],
curr_date: Annotated[
str, "The current trading date you are trading on, YYYY-mm-dd"
str, "The current trading date you are trading on, YYYY-mm-dd",
],
look_back_days: Annotated[int, "how many days to look back"] = 30,
) -> str:
@ -205,11 +200,10 @@ class Toolkit:
str: A formatted dataframe containing the stock stats indicators for the specified ticker symbol and indicator.
"""
result_stockstats = interface.get_stock_stats_indicators_window(
symbol, indicator, curr_date, look_back_days, True
return interface.get_stock_stats_indicators_window(
symbol, indicator, curr_date, look_back_days, True,
)
return result_stockstats
@staticmethod
@tool
@ -229,11 +223,10 @@ class Toolkit:
str: a report of the sentiment in the past 30 days starting at curr_date
"""
data_sentiment = interface.get_finnhub_company_insider_sentiment(
ticker, curr_date, 30
return interface.get_finnhub_company_insider_sentiment(
ticker, curr_date, 30,
)
return data_sentiment
@staticmethod
@tool
@ -253,11 +246,10 @@ class Toolkit:
str: a report of the company's insider transactions/trading information in the past 30 days
"""
data_trans = interface.get_finnhub_company_insider_transactions(
ticker, curr_date, 30
return interface.get_finnhub_company_insider_transactions(
ticker, curr_date, 30,
)
return data_trans
@staticmethod
@tool
@ -279,9 +271,8 @@ class Toolkit:
str: a report of the company's most recent balance sheet
"""
data_balance_sheet = interface.get_simfin_balance_sheet(ticker, freq, curr_date)
return interface.get_simfin_balance_sheet(ticker, freq, curr_date)
return data_balance_sheet
@staticmethod
@tool
@ -303,9 +294,8 @@ class Toolkit:
str: a report of the company's most recent cash flow statement
"""
data_cashflow = interface.get_simfin_cashflow(ticker, freq, curr_date)
return interface.get_simfin_cashflow(ticker, freq, curr_date)
return data_cashflow
@staticmethod
@tool
@ -327,11 +317,10 @@ class Toolkit:
str: a report of the company's most recent income statement
"""
data_income_stmt = interface.get_simfin_income_statements(
ticker, freq, curr_date
return interface.get_simfin_income_statements(
ticker, freq, curr_date,
)
return data_income_stmt
@staticmethod
@tool
@ -349,9 +338,8 @@ class Toolkit:
str: A formatted string containing the latest news from Google News based on the query and date range.
"""
google_news_results = interface.get_google_news(query, curr_date, 7)
return interface.get_google_news(query, curr_date, 7)
return google_news_results
@staticmethod
@tool
@ -368,9 +356,8 @@ class Toolkit:
str: A formatted string containing the latest news about the company on the given date.
"""
openai_news_results = interface.get_stock_news_openai(ticker, curr_date)
return interface.get_stock_news_openai(ticker, curr_date)
return openai_news_results
@staticmethod
@tool
@ -385,9 +372,8 @@ class Toolkit:
str: A formatted string containing the latest macroeconomic news on the given date.
"""
openai_news_results = interface.get_global_news_openai(curr_date)
return interface.get_global_news_openai(curr_date)
return openai_news_results
@staticmethod
@tool
@ -404,8 +390,7 @@ class Toolkit:
str: A formatted string containing the latest fundamental information about the company on the given date.
"""
openai_fundamentals_results = interface.get_fundamentals_openai(
ticker, curr_date
return interface.get_fundamentals_openai(
ticker, curr_date,
)
return openai_fundamentals_results

View File

@ -59,7 +59,7 @@ class FinancialSituationMemory:
"matched_situation": results["documents"][0][i],
"recommendation": results["metadatas"][0][i]["recommendation"],
"similarity_score": 1 - results["distances"][0][i],
}
},
)
return matched_results
@ -94,18 +94,15 @@ if __name__ == "__main__":
# Example query
current_situation = """
Market showing increased volatility in tech sector, with institutional investors
Market showing increased volatility in tech sector, with institutional investors
reducing positions and rising interest rates affecting growth stock valuations
"""
try:
recommendations = matcher.get_memories(current_situation, n_matches=2)
for i, rec in enumerate(recommendations, 1):
print(f"\nMatch {i}:")
print(f"Similarity Score: {rec['similarity_score']:.2f}")
print(f"Matched Situation: {rec['matched_situation']}")
print(f"Recommendation: {rec['recommendation']}")
for _i, _rec in enumerate(recommendations, 1):
pass
except Exception as e:
print(f"Error during recommendation: {str(e)}")
except Exception:
pass

View File

@ -5,6 +5,7 @@ Loads configuration from environment variables and .env file.
import os
from pathlib import Path
from dotenv import load_dotenv
# Load .env file from project root
@ -20,10 +21,10 @@ def get_config():
"project_dir": str(project_root / "tradingagents"),
"results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", "./results"),
"data_dir": os.getenv(
"TRADINGAGENTS_DATA_DIR", "/Users/yluo/Documents/Code/ScAI/FR1-data"
"TRADINGAGENTS_DATA_DIR", "/Users/yluo/Documents/Code/ScAI/FR1-data",
),
"data_cache_dir": str(
project_root / "tradingagents" / "dataflows" / "data_cache"
project_root / "tradingagents" / "dataflows" / "data_cache",
),
# LLM settings
"llm_provider": os.getenv("LLM_PROVIDER", "openai"),
@ -47,16 +48,17 @@ def get_config():
# Validate required API keys based on provider
if config["llm_provider"] == "openai" and not config["openai_api_key"]:
raise ValueError("OPENAI_API_KEY is required when using OpenAI provider")
elif config["llm_provider"] == "anthropic" and not config["anthropic_api_key"]:
raise ValueError("ANTHROPIC_API_KEY is required when using Anthropic provider")
elif config["llm_provider"] == "google" and not config["google_api_key"]:
raise ValueError("GOOGLE_API_KEY is required when using Google provider")
msg = "OPENAI_API_KEY is required when using OpenAI provider"
raise ValueError(msg)
if config["llm_provider"] == "anthropic" and not config["anthropic_api_key"]:
msg = "ANTHROPIC_API_KEY is required when using Anthropic provider"
raise ValueError(msg)
if config["llm_provider"] == "google" and not config["google_api_key"]:
msg = "GOOGLE_API_KEY is required when using Google provider"
raise ValueError(msg)
if not config["finnhub_api_key"]:
print(
"Warning: FINNHUB_API_KEY not set. Some financial data features may be limited."
)
pass
return config

View File

@ -1,17 +1,13 @@
from .finnhub_utils import get_data_in_range
from .googlenews_utils import getNewsData
from .yfin_utils import YFinanceUtils
from .reddit_utils import fetch_top_from_category
from .stockstats_utils import StockstatsUtils
from .interface import (
# News and sentiment functions
get_finnhub_news,
get_finnhub_company_insider_sentiment,
get_finnhub_company_insider_transactions,
# News and sentiment functions
get_finnhub_news,
get_google_news,
get_reddit_global_news,
get_reddit_company_news,
get_reddit_global_news,
# Financial statements functions
get_simfin_balance_sheet,
get_simfin_cashflow,
@ -19,19 +15,25 @@ from .interface import (
# Technical analysis functions
get_stock_stats_indicators_window,
get_stockstats_indicator,
get_YFin_data,
# Market data functions
get_YFin_data_window,
get_YFin_data,
)
from .reddit_utils import fetch_top_from_category
from .stockstats_utils import StockstatsUtils
from .yfin_utils import YFinanceUtils
__all__ = [
# News and sentiment functions
"get_finnhub_news",
"get_YFin_data",
# Market data functions
"get_YFin_data_window",
"get_finnhub_company_insider_sentiment",
"get_finnhub_company_insider_transactions",
# News and sentiment functions
"get_finnhub_news",
"get_google_news",
"get_reddit_global_news",
"get_reddit_company_news",
"get_reddit_global_news",
# Financial statements functions
"get_simfin_balance_sheet",
"get_simfin_cashflow",
@ -39,7 +41,10 @@ __all__ = [
# Technical analysis functions
"get_stock_stats_indicators_window",
"get_stockstats_indicator",
# Market data functions
"get_YFin_data_window",
"get_YFin_data",
# Utilities and classes
"get_data_in_range",
"getNewsData",
"YFinanceUtils",
"fetch_top_from_category",
"StockstatsUtils",
]

View File

@ -1,9 +1,9 @@
import tradingagents.default_config as default_config
from typing import Dict, Optional
from tradingagents import default_config
# Use default config but allow it to be overridden
_config: Optional[Dict] = None
DATA_DIR: Optional[str] = None
_config: dict | None = None
DATA_DIR: str | None = None
def initialize_config():
@ -14,7 +14,7 @@ def initialize_config():
DATA_DIR = _config["data_dir"]
def set_config(config: Dict):
def set_config(config: dict):
"""Update the configuration with custom values."""
global _config, DATA_DIR
if _config is None:
@ -23,7 +23,7 @@ def set_config(config: Dict):
DATA_DIR = _config["data_dir"]
def get_config() -> Dict:
def get_config() -> dict:
"""Get the current configuration."""
if _config is None:
initialize_config()

View File

@ -22,10 +22,10 @@ def get_data_in_range(ticker, start_date, end_date, data_type, data_dir, period=
)
else:
data_path = os.path.join(
data_dir, "finnhub_data", data_type, f"{ticker}_data_formatted.json"
data_dir, "finnhub_data", data_type, f"{ticker}_data_formatted.json",
)
data = open(data_path, "r")
data = open(data_path)
data = json.load(data)
# filter keys (date, str in format YYYY-MM-DD) by the date range (str, str in format YYYY-MM-DD)

View File

@ -1,20 +1,20 @@
import requests
from bs4 import BeautifulSoup
from datetime import datetime
import time
import random
import logging
import random
import time
from datetime import datetime
from urllib.parse import quote_plus
logger = logging.getLogger(__name__)
import requests
from bs4 import BeautifulSoup
from tenacity import (
retry,
retry_if_result,
stop_after_attempt,
wait_exponential,
retry_if_result,
)
logger = logging.getLogger(__name__)
def is_rate_limited(response):
"""Check if the response indicates we should back off (rate-limited or temporarily unavailable)."""
@ -35,8 +35,7 @@ def _add_jitter(retry_state):
def make_request(url, headers):
"""Make a request with retry logic for rate limiting"""
# The retry decorator already applies exponential backoff with jitter
response = requests.get(url, headers=headers, timeout=(5, 20))
return response
return requests.get(url, headers=headers, timeout=(5, 20))
def getNewsData(query, start_date, end_date):
@ -58,7 +57,7 @@ def getNewsData(query, start_date, end_date):
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
"AppleWebKit/537.36 (KHTML, like Gecko) "
"Chrome/101.0.4951.54 Safari/537.36"
)
),
}
news_results = []
@ -103,7 +102,7 @@ def getNewsData(query, start_date, end_date):
"source": (
source_el.get_text(strip=True) if source_el else ""
),
}
},
)
except Exception as e:
logger.warning("Error processing result: %s", e)
@ -120,7 +119,7 @@ def getNewsData(query, start_date, end_date):
page += 1
except Exception as e:
logger.error("Failed after multiple retries: %s", e)
logger.exception("Failed after multiple retries: %s", e)
break
return news_results

View File

@ -1,17 +1,18 @@
from typing import Annotated
from .reddit_utils import fetch_top_from_category
from .yfin_utils import *
from .stockstats_utils import *
from .googlenews_utils import *
from .finnhub_utils import get_data_in_range
from dateutil.relativedelta import relativedelta
from datetime import datetime
import os
from datetime import datetime
from typing import Annotated
import pandas as pd
from tqdm import tqdm
import yfinance as yf
from dateutil.relativedelta import relativedelta
from openai import OpenAI
from .config import get_config, DATA_DIR
from tqdm import tqdm
from .config import DATA_DIR, get_config
from .finnhub_utils import get_data_in_range
from .googlenews_utils import getNewsData
from .reddit_utils import fetch_top_from_category
from .stockstats_utils import StockstatsUtils
def get_finnhub_news(
@ -84,7 +85,7 @@ def get_finnhub_company_insider_sentiment(
result_str = ""
seen_dicts = []
for date, senti_list in data.items():
for senti_list in data.values():
for entry in senti_list:
if entry not in seen_dicts:
result_str += f"### {entry['year']}-{entry['month']}:\nChange: {entry['change']}\nMonthly Share Purchase Ratio: {entry['mspr']}\n\n"
@ -126,7 +127,7 @@ def get_finnhub_company_insider_transactions(
result_str = ""
seen_dicts = []
for date, senti_list in data.items():
for senti_list in data.values():
for entry in senti_list:
if entry not in seen_dicts:
result_str += f"### Filing Date: {entry['filingDate']}, {entry['name']}:\nChange:{entry['change']}\nShares: {entry['share']}\nTransaction Price: {entry['transactionPrice']}\nTransaction Code: {entry['transactionCode']}\n\n"
@ -170,7 +171,6 @@ def get_simfin_balance_sheet(
# Check if there are any available reports; if not, return a notification
if filtered_df.empty:
print("No balance sheet available before the given current date.")
return ""
# Get the most recent balance sheet by selecting the row with the latest Publish Date
@ -217,7 +217,6 @@ def get_simfin_cashflow(
# Check if there are any available reports; if not, return a notification
if filtered_df.empty:
print("No cash flow statement available before the given current date.")
return ""
# Get the most recent cash flow statement by selecting the row with the latest Publish Date
@ -264,7 +263,6 @@ def get_simfin_income_statements(
# Check if there are any available reports; if not, return a notification
if filtered_df.empty:
print("No income statement available before the given current date.")
return ""
# Get the most recent income statement by selecting the row with the latest Publish Date
@ -421,7 +419,7 @@ def get_stock_stats_indicators_window(
symbol: Annotated[str, "ticker symbol of the company"],
indicator: Annotated[str, "technical indicator to get the analysis and report of"],
curr_date: Annotated[
str, "The current trading date you are trading on, YYYY-mm-dd"
str, "The current trading date you are trading on, YYYY-mm-dd",
],
look_back_days: Annotated[int, "how many days to look back"],
online: Annotated[bool, "to fetch data online or offline"],
@ -501,8 +499,9 @@ def get_stock_stats_indicators_window(
}
if indicator not in best_ind_params:
msg = f"Indicator {indicator} is not supported. Please choose from: {list(best_ind_params.keys())}"
raise ValueError(
f"Indicator {indicator} is not supported. Please choose from: {list(best_ind_params.keys())}"
msg,
)
end_date = curr_date
@ -515,7 +514,7 @@ def get_stock_stats_indicators_window(
os.path.join(
DATA_DIR,
f"market_data/price_data/{symbol}-YFin-data-2015-01-01-2025-03-25.csv",
)
),
)
data["Date"] = pd.to_datetime(data["Date"], utc=True)
dates_in_df = data["Date"].astype(str).str[:10]
@ -525,7 +524,7 @@ def get_stock_stats_indicators_window(
# only do the trading dates
if curr_date.strftime("%Y-%m-%d") in dates_in_df.values:
indicator_value = get_stockstats_indicator(
symbol, indicator, curr_date.strftime("%Y-%m-%d"), online
symbol, indicator, curr_date.strftime("%Y-%m-%d"), online,
)
ind_string += f"{curr_date.strftime('%Y-%m-%d')}: {indicator_value}\n"
@ -536,28 +535,27 @@ def get_stock_stats_indicators_window(
ind_string = ""
while curr_date >= before:
indicator_value = get_stockstats_indicator(
symbol, indicator, curr_date.strftime("%Y-%m-%d"), online
symbol, indicator, curr_date.strftime("%Y-%m-%d"), online,
)
ind_string += f"{curr_date.strftime('%Y-%m-%d')}: {indicator_value}\n"
curr_date = curr_date - relativedelta(days=1)
result_str = (
return (
f"## {indicator} values from {before.strftime('%Y-%m-%d')} to {end_date}:\n\n"
+ ind_string
+ "\n\n"
+ best_ind_params.get(indicator, "No description available.")
)
return result_str
def get_stockstats_indicator(
symbol: Annotated[str, "ticker symbol of the company"],
indicator: Annotated[str, "technical indicator to get the analysis and report of"],
curr_date: Annotated[
str, "The current trading date you are trading on, YYYY-mm-dd"
str, "The current trading date you are trading on, YYYY-mm-dd",
],
online: Annotated[bool, "to fetch data online or offline"],
) -> str:
@ -573,10 +571,7 @@ def get_stockstats_indicator(
os.path.join(DATA_DIR, "market_data", "price_data"),
online=online,
)
except Exception as e:
print(
f"Error getting stockstats indicator data for indicator {indicator} on {curr_date}: {e}"
)
except Exception:
return ""
return str(indicator_value)
@ -597,7 +592,7 @@ def get_YFin_data_window(
os.path.join(
DATA_DIR,
f"market_data/price_data/{symbol}-YFin-data-2015-01-01-2025-03-25.csv",
)
),
)
# Extract just the date part for comparison
@ -613,7 +608,7 @@ def get_YFin_data_window(
# Set pandas display options to show the full DataFrame
with pd.option_context(
"display.max_rows", None, "display.max_columns", None, "display.width", None
"display.max_rows", None, "display.max_columns", None, "display.width", None,
):
df_string = filtered_data.to_string()
@ -675,12 +670,13 @@ def get_YFin_data(
os.path.join(
DATA_DIR,
f"market_data/price_data/{symbol}-YFin-data-2015-01-01-2025-03-25.csv",
)
),
)
if end_date > "2025-03-25":
msg = f"Get_YFin_Data: {end_date} is outside of the data range of 2015-01-01 to 2025-03-25"
raise Exception(
f"Get_YFin_Data: {end_date} is outside of the data range of 2015-01-01 to 2025-03-25"
msg,
)
# Extract just the date part for comparison
@ -695,9 +691,8 @@ def get_YFin_data(
filtered_data = filtered_data.drop("DateOnly", axis=1)
# remove the index from the dataframe
filtered_data = filtered_data.reset_index(drop=True)
return filtered_data.reset_index(drop=True)
return filtered_data
def get_stock_news_openai(ticker, curr_date):
@ -713,9 +708,9 @@ def get_stock_news_openai(ticker, curr_date):
{
"type": "input_text",
"text": f"Can you search Social Media for {ticker} from 7 days before {curr_date} to {curr_date}? Make sure you only get the data posted during that period.",
}
},
],
}
},
],
text={"format": {"type": "text"}},
reasoning={},
@ -724,7 +719,7 @@ def get_stock_news_openai(ticker, curr_date):
"type": "web_search_preview",
"user_location": {"type": "approximate"},
"search_context_size": "low",
}
},
],
temperature=1,
max_output_tokens=4096,
@ -748,9 +743,9 @@ def get_global_news_openai(curr_date):
{
"type": "input_text",
"text": f"Can you search global or macroeconomics news from 7 days before {curr_date} to {curr_date} that would be informative for trading purposes? Make sure you only get the data posted during that period.",
}
},
],
}
},
],
text={"format": {"type": "text"}},
reasoning={},
@ -759,7 +754,7 @@ def get_global_news_openai(curr_date):
"type": "web_search_preview",
"user_location": {"type": "approximate"},
"search_context_size": "low",
}
},
],
temperature=1,
max_output_tokens=4096,
@ -783,9 +778,9 @@ def get_fundamentals_openai(ticker, curr_date):
{
"type": "input_text",
"text": f"Can you search Fundamental for discussions on {ticker} during of the month before {curr_date} to the month of {curr_date}. Make sure you only get the data posted during that period. List as a table, with PE/PS/Cash flow/ etc",
}
},
],
}
},
],
text={"format": {"type": "text"}},
reasoning={},
@ -794,7 +789,7 @@ def get_fundamentals_openai(ticker, curr_date):
"type": "web_search_preview",
"user_location": {"type": "approximate"},
"search_context_size": "low",
}
},
],
temperature=1,
max_output_tokens=4096,

View File

@ -1,8 +1,8 @@
import json
from datetime import datetime
from typing import Annotated
import os
import re
from datetime import datetime
from typing import Annotated
ticker_to_company = {
"AAPL": "Apple",
@ -48,11 +48,11 @@ ticker_to_company = {
def fetch_top_from_category(
category: Annotated[
str, "Category to fetch top post from. Collection of subreddits."
str, "Category to fetch top post from. Collection of subreddits.",
],
date: Annotated[str, "Date to fetch top posts from."],
max_limit: Annotated[int, "Maximum number of posts to fetch."],
query: Annotated[str, "Optional query to search for in the subreddit."] = None,
query: Annotated[str | None, "Optional query to search for in the subreddit."] = None,
data_path: Annotated[
str,
"Path to the data folder. Default is 'reddit_data'.",
@ -63,12 +63,13 @@ def fetch_top_from_category(
all_content = []
if max_limit < len(os.listdir(os.path.join(base_path, category))):
msg = "REDDIT FETCHING ERROR: max limit is less than the number of files in the category. Will not be able to fetch any posts"
raise ValueError(
"REDDIT FETCHING ERROR: max limit is less than the number of files in the category. Will not be able to fetch any posts"
msg,
)
limit_per_subreddit = max_limit // len(
os.listdir(os.path.join(base_path, category))
os.listdir(os.path.join(base_path, category)),
)
for data_file in os.listdir(os.path.join(base_path, category)):
@ -79,7 +80,7 @@ def fetch_top_from_category(
all_content_curr_subreddit = []
with open(os.path.join(base_path, category, data_file), "rb") as f:
for i, line in enumerate(f):
for _i, line in enumerate(f):
# skip empty lines
if not line.strip():
continue
@ -88,7 +89,7 @@ def fetch_top_from_category(
# select only lines that are from the date
post_date = datetime.utcfromtimestamp(
parsed_line["created_utc"]
parsed_line["created_utc"],
).strftime("%Y-%m-%d")
if post_date != date:
continue
@ -106,7 +107,7 @@ def fetch_top_from_category(
found = False
for term in search_terms:
if re.search(
term, parsed_line["title"], re.IGNORECASE
term, parsed_line["title"], re.IGNORECASE,
) or re.search(term, parsed_line["selftext"], re.IGNORECASE):
found = True
break

View File

@ -1,8 +1,10 @@
import os
from typing import Annotated
import pandas as pd
import yfinance as yf
from stockstats import wrap
from typing import Annotated
import os
from .config import get_config
@ -11,10 +13,10 @@ class StockstatsUtils:
def get_stock_stats(
symbol: Annotated[str, "ticker symbol for the company"],
indicator: Annotated[
str, "quantitative indicators based off of the stock data for the company"
str, "quantitative indicators based off of the stock data for the company",
],
curr_date: Annotated[
str, "curr date for retrieving stock price data, YYYY-mm-dd"
str, "curr date for retrieving stock price data, YYYY-mm-dd",
],
data_dir: Annotated[
str,
@ -34,11 +36,12 @@ class StockstatsUtils:
os.path.join(
data_dir,
f"{symbol}-YFin-data-2015-01-01-2025-03-25.csv",
)
),
)
df = wrap(data)
except FileNotFoundError:
raise Exception("Stockstats fail: Yahoo Finance data not fetched yet!")
msg = "Stockstats fail: Yahoo Finance data not fetched yet!"
raise Exception(msg)
else:
# Get today's date as YYYY-mm-dd to add to cache
today_date = pd.Timestamp.today()
@ -81,7 +84,5 @@ class StockstatsUtils:
matching_rows = df[df["Date"].str.startswith(curr_date)]
if not matching_rows.empty:
indicator_value = matching_rows[indicator].values[0]
return indicator_value
else:
return "N/A: Not a trading day (weekend or holiday)"
return matching_rows[indicator].values[0]
return "N/A: Not a trading day (weekend or holiday)"

View File

@ -1,14 +1,14 @@
import pandas as pd
from datetime import date, timedelta, datetime
from datetime import date, datetime, timedelta
from typing import Annotated
import pandas as pd
SavePathType = Annotated[str, "File path to save data. If None, data is not saved."]
def save_output(data: pd.DataFrame, tag: str, save_path: SavePathType = None) -> None:
if save_path:
data.to_csv(save_path)
print(f"{tag} saved to {save_path}")
def get_current_date():
@ -32,7 +32,5 @@ def get_next_weekday(date):
if date.weekday() >= 5:
days_to_add = 7 - date.weekday()
next_weekday = date + timedelta(days=days_to_add)
return next_weekday
else:
return date
return date + timedelta(days=days_to_add)
return date

View File

@ -1,10 +1,12 @@
# gets data/stats
import yfinance as yf
from typing import Annotated, Callable, Any, Optional
from pandas import DataFrame
import pandas as pd
from collections.abc import Callable
from functools import wraps
from typing import Annotated, Any
import pandas as pd
import yfinance as yf
from pandas import DataFrame
from .utils import SavePathType, decorate_all_methods
@ -24,38 +26,36 @@ def init_ticker(func: Callable) -> Callable:
class YFinanceUtils:
def get_stock_data(
symbol: Annotated[str, "ticker symbol"],
self: Annotated[str, "ticker symbol"],
start_date: Annotated[
str, "start date for retrieving stock price data, YYYY-mm-dd"
str, "start date for retrieving stock price data, YYYY-mm-dd",
],
end_date: Annotated[
str, "end date for retrieving stock price data, YYYY-mm-dd"
str, "end date for retrieving stock price data, YYYY-mm-dd",
],
save_path: SavePathType = None,
) -> DataFrame:
"""retrieve stock price data for designated ticker symbol"""
ticker = symbol
ticker = self
# add one day to the end_date so that the data range is inclusive
end_date = pd.to_datetime(end_date) + pd.DateOffset(days=1)
end_date = end_date.strftime("%Y-%m-%d")
stock_data = ticker.history(start=start_date, end=end_date)
return ticker.history(start=start_date, end=end_date)
# save_output(stock_data, f"Stock data for {ticker.ticker}", save_path)
return stock_data
def get_stock_info(
symbol: Annotated[str, "ticker symbol"],
self: Annotated[str, "ticker symbol"],
) -> dict:
"""Fetches and returns latest stock information."""
ticker = symbol
stock_info = ticker.info
return stock_info
ticker = self
return ticker.info
def get_company_info(
symbol: Annotated[str, "ticker symbol"],
save_path: Optional[str] = None,
self: Annotated[str, "ticker symbol"],
save_path: str | None = None,
) -> DataFrame:
"""Fetches and returns company information as a DataFrame."""
ticker = symbol
ticker = self
info = ticker.info
company_info = {
"Company Name": info.get("shortName", "N/A"),
@ -67,42 +67,37 @@ class YFinanceUtils:
company_info_df = DataFrame([company_info])
if save_path:
company_info_df.to_csv(save_path)
print(f"Company info for {ticker.ticker} saved to {save_path}")
return company_info_df
def get_stock_dividends(
symbol: Annotated[str, "ticker symbol"],
save_path: Optional[str] = None,
self: Annotated[str, "ticker symbol"],
save_path: str | None = None,
) -> DataFrame:
"""Fetches and returns the latest dividends data as a DataFrame."""
ticker = symbol
ticker = self
dividends = ticker.dividends
if save_path:
dividends.to_csv(save_path)
print(f"Dividends for {ticker.ticker} saved to {save_path}")
return dividends
def get_income_stmt(symbol: Annotated[str, "ticker symbol"]) -> DataFrame:
def get_income_stmt(self: Annotated[str, "ticker symbol"]) -> DataFrame:
"""Fetches and returns the latest income statement of the company as a DataFrame."""
ticker = symbol
income_stmt = ticker.financials
return income_stmt
ticker = self
return ticker.financials
def get_balance_sheet(symbol: Annotated[str, "ticker symbol"]) -> DataFrame:
def get_balance_sheet(self: Annotated[str, "ticker symbol"]) -> DataFrame:
"""Fetches and returns the latest balance sheet of the company as a DataFrame."""
ticker = symbol
balance_sheet = ticker.balance_sheet
return balance_sheet
ticker = self
return ticker.balance_sheet
def get_cash_flow(symbol: Annotated[str, "ticker symbol"]) -> DataFrame:
def get_cash_flow(self: Annotated[str, "ticker symbol"]) -> DataFrame:
"""Fetches and returns the latest cash flow statement of the company as a DataFrame."""
ticker = symbol
cash_flow = ticker.cashflow
return cash_flow
ticker = self
return ticker.cashflow
def get_analyst_recommendations(symbol: Annotated[str, "ticker symbol"]) -> tuple:
def get_analyst_recommendations(self: Annotated[str, "ticker symbol"]) -> tuple:
"""Fetches the latest analyst recommendations and returns the most common recommendation and its count."""
ticker = symbol
ticker = self
recommendations = ticker.recommendations
if recommendations.empty:
return None, 0 # No recommendations available

View File

@ -1,17 +1,17 @@
# TradingAgents/graph/__init__.py
from .trading_graph import TradingAgentsGraph
from .conditional_logic import ConditionalLogic
from .setup import GraphSetup
from .propagation import Propagator
from .reflection import Reflector
from .setup import GraphSetup
from .signal_processing import SignalProcessor
from .trading_graph import TradingAgentsGraph
__all__ = [
"TradingAgentsGraph",
"ConditionalLogic",
"GraphSetup",
"Propagator",
"Reflector",
"SignalProcessor",
"TradingAgentsGraph",
]

View File

@ -1,6 +1,7 @@
# TradingAgents/graph/propagation.py
from typing import Dict, Any
from typing import Any
from tradingagents.agents.utils.agent_states import (
InvestDebateState,
RiskDebateState,
@ -15,15 +16,15 @@ class Propagator:
self.max_recur_limit = max_recur_limit
def create_initial_state(
self, company_name: str, trade_date: str
) -> Dict[str, Any]:
self, company_name: str, trade_date: str,
) -> dict[str, Any]:
"""Create the initial state for the agent graph."""
return {
"messages": [("human", company_name)],
"company_of_interest": company_name,
"trade_date": str(trade_date),
"investment_debate_state": InvestDebateState(
{"history": "", "current_response": "", "count": 0}
{"history": "", "current_response": "", "count": 0},
),
"risk_debate_state": RiskDebateState(
{
@ -32,7 +33,7 @@ class Propagator:
"current_safe_response": "",
"current_neutral_response": "",
"count": 0,
}
},
),
"market_report": "",
"fundamentals_report": "",
@ -40,7 +41,7 @@ class Propagator:
"news_report": "",
}
def get_graph_args(self) -> Dict[str, Any]:
def get_graph_args(self) -> dict[str, Any]:
"""Get arguments for the graph invocation."""
return {
"stream_mode": "values",

View File

@ -1,6 +1,7 @@
# TradingAgents/graph/reflection.py
from typing import Dict, Any
from typing import Any
from langchain_openai import ChatOpenAI
@ -15,7 +16,7 @@ class Reflector:
def _get_reflection_prompt(self) -> str:
"""Get the system prompt for reflection."""
return """
You are an expert financial analyst tasked with reviewing trading decisions/analysis and providing a comprehensive, step-by-step analysis.
You are an expert financial analyst tasked with reviewing trading decisions/analysis and providing a comprehensive, step-by-step analysis.
Your goal is to deliver detailed insights into investment decisions and highlight opportunities for improvement, adhering strictly to the following guidelines:
1. Reasoning:
@ -25,7 +26,7 @@ Your goal is to deliver detailed insights into investment decisions and highligh
- Technical indicators.
- Technical signals.
- Price movement analysis.
- Overall market data analysis
- Overall market data analysis
- News analysis.
- Social media and sentiment analysis.
- Fundamental data analysis.
@ -46,7 +47,7 @@ Your goal is to deliver detailed insights into investment decisions and highligh
Adhere strictly to these instructions, and ensure your output is detailed, accurate, and actionable. You will also be given objective descriptions of the market from a price movements, technical indicator, news, and sentiment perspective to provide more context for your analysis.
"""
def _extract_current_situation(self, current_state: Dict[str, Any]) -> str:
def _extract_current_situation(self, current_state: dict[str, Any]) -> str:
"""Extract the current market situation from the state."""
curr_market_report = current_state["market_report"]
curr_sentiment_report = current_state["sentiment_report"]
@ -56,7 +57,7 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur
return f"{curr_market_report}\n\n{curr_sentiment_report}\n\n{curr_news_report}\n\n{curr_fundamentals_report}"
def _reflect_on_component(
self, component_type: str, report: str, situation: str, returns_losses
self, component_type: str, report: str, situation: str, returns_losses,
) -> str:
"""Generate reflection for a component."""
messages = [
@ -67,8 +68,7 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur
),
]
result = self.quick_thinking_llm.invoke(messages).content
return result
return self.quick_thinking_llm.invoke(messages).content
def reflect_bull_researcher(self, current_state, returns_losses, bull_memory):
"""Reflect on bull researcher's analysis and update memory."""
@ -76,7 +76,7 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur
bull_debate_history = current_state["investment_debate_state"]["bull_history"]
result = self._reflect_on_component(
"BULL", bull_debate_history, situation, returns_losses
"BULL", bull_debate_history, situation, returns_losses,
)
bull_memory.add_situations([(situation, result)])
@ -86,7 +86,7 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur
bear_debate_history = current_state["investment_debate_state"]["bear_history"]
result = self._reflect_on_component(
"BEAR", bear_debate_history, situation, returns_losses
"BEAR", bear_debate_history, situation, returns_losses,
)
bear_memory.add_situations([(situation, result)])
@ -96,7 +96,7 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur
trader_decision = current_state["trader_investment_plan"]
result = self._reflect_on_component(
"TRADER", trader_decision, situation, returns_losses
"TRADER", trader_decision, situation, returns_losses,
)
trader_memory.add_situations([(situation, result)])
@ -106,7 +106,7 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur
judge_decision = current_state["investment_debate_state"]["judge_decision"]
result = self._reflect_on_component(
"INVEST JUDGE", judge_decision, situation, returns_losses
"INVEST JUDGE", judge_decision, situation, returns_losses,
)
invest_judge_memory.add_situations([(situation, result)])
@ -116,6 +116,6 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur
judge_decision = current_state["risk_debate_state"]["judge_decision"]
result = self._reflect_on_component(
"RISK JUDGE", judge_decision, situation, returns_losses
"RISK JUDGE", judge_decision, situation, returns_losses,
)
risk_manager_memory.add_situations([(situation, result)])

View File

@ -1,26 +1,26 @@
# TradingAgents/graph/setup.py
from typing import Dict
from langchain_openai import ChatOpenAI
from langgraph.graph import END, StateGraph, START
from langgraph.graph import END, START, StateGraph
from langgraph.prebuilt import ToolNode
from tradingagents.agents import (
create_market_analyst,
create_social_media_analyst,
create_news_analyst,
create_fundamentals_analyst,
create_bull_researcher,
create_bear_researcher,
create_research_manager,
create_trader,
create_risky_debator,
create_neutral_debator,
create_safe_debator,
create_risk_manager,
create_msg_delete,
AgentState,
Toolkit,
create_bear_researcher,
create_bull_researcher,
create_fundamentals_analyst,
create_market_analyst,
create_msg_delete,
create_neutral_debator,
create_news_analyst,
create_research_manager,
create_risk_manager,
create_risky_debator,
create_safe_debator,
create_social_media_analyst,
create_trader,
)
from .conditional_logic import ConditionalLogic
@ -34,7 +34,7 @@ class GraphSetup:
quick_thinking_llm: ChatOpenAI,
deep_thinking_llm: ChatOpenAI,
toolkit: Toolkit,
tool_nodes: Dict[str, ToolNode],
tool_nodes: dict[str, ToolNode],
bull_memory,
bear_memory,
trader_memory,
@ -55,7 +55,7 @@ class GraphSetup:
self.conditional_logic = conditional_logic
def setup_graph(
self, selected_analysts=["market", "social", "news", "fundamentals"]
self, selected_analysts=None,
):
"""Set up and compile the agent workflow graph.
@ -66,8 +66,11 @@ class GraphSetup:
- "news": News analyst
- "fundamentals": Fundamentals analyst
"""
if selected_analysts is None:
selected_analysts = ["market", "social", "news", "fundamentals"]
if len(selected_analysts) == 0:
raise ValueError("Trading Agents Graph Setup Error: no analysts selected!")
msg = "Trading Agents Graph Setup Error: no analysts selected!"
raise ValueError(msg)
# Create analyst nodes
analyst_nodes = {}
@ -76,41 +79,41 @@ class GraphSetup:
if "market" in selected_analysts:
analyst_nodes["market"] = create_market_analyst(
self.quick_thinking_llm, self.toolkit
self.quick_thinking_llm, self.toolkit,
)
delete_nodes["market"] = create_msg_delete()
tool_nodes["market"] = self.tool_nodes["market"]
if "social" in selected_analysts:
analyst_nodes["social"] = create_social_media_analyst(
self.quick_thinking_llm, self.toolkit
self.quick_thinking_llm, self.toolkit,
)
delete_nodes["social"] = create_msg_delete()
tool_nodes["social"] = self.tool_nodes["social"]
if "news" in selected_analysts:
analyst_nodes["news"] = create_news_analyst(
self.quick_thinking_llm, self.toolkit
self.quick_thinking_llm, self.toolkit,
)
delete_nodes["news"] = create_msg_delete()
tool_nodes["news"] = self.tool_nodes["news"]
if "fundamentals" in selected_analysts:
analyst_nodes["fundamentals"] = create_fundamentals_analyst(
self.quick_thinking_llm, self.toolkit
self.quick_thinking_llm, self.toolkit,
)
delete_nodes["fundamentals"] = create_msg_delete()
tool_nodes["fundamentals"] = self.tool_nodes["fundamentals"]
# Create researcher and manager nodes
bull_researcher_node = create_bull_researcher(
self.quick_thinking_llm, self.bull_memory
self.quick_thinking_llm, self.bull_memory,
)
bear_researcher_node = create_bear_researcher(
self.quick_thinking_llm, self.bear_memory
self.quick_thinking_llm, self.bear_memory,
)
research_manager_node = create_research_manager(
self.deep_thinking_llm, self.invest_judge_memory
self.deep_thinking_llm, self.invest_judge_memory,
)
trader_node = create_trader(self.quick_thinking_llm, self.trader_memory)
@ -119,7 +122,7 @@ class GraphSetup:
neutral_analyst = create_neutral_debator(self.quick_thinking_llm)
safe_analyst = create_safe_debator(self.quick_thinking_llm)
risk_manager_node = create_risk_manager(
self.deep_thinking_llm, self.risk_manager_memory
self.deep_thinking_llm, self.risk_manager_memory,
)
# Create workflow
@ -129,7 +132,7 @@ class GraphSetup:
for analyst_type, node in analyst_nodes.items():
workflow.add_node(f"{analyst_type.capitalize()} Analyst", node)
workflow.add_node(
f"Msg Clear {analyst_type.capitalize()}", delete_nodes[analyst_type]
f"Msg Clear {analyst_type.capitalize()}", delete_nodes[analyst_type],
)
workflow.add_node(f"tools_{analyst_type}", tool_nodes[analyst_type])

View File

@ -1,25 +1,24 @@
# TradingAgents/graph/trading_graph.py
import json
import os
from pathlib import Path
import json
from typing import Dict, Any
from typing import Any
from langchain_openai import ChatOpenAI
from langchain_anthropic import ChatAnthropic
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_openai import ChatOpenAI
from langgraph.prebuilt import ToolNode
from tradingagents.agents import Toolkit
from tradingagents.default_config import DEFAULT_CONFIG
from tradingagents.agents.utils.memory import FinancialSituationMemory
from tradingagents.dataflows.interface import set_config
from tradingagents.default_config import DEFAULT_CONFIG
from .conditional_logic import ConditionalLogic
from .setup import GraphSetup
from .propagation import Propagator
from .reflection import Reflector
from .setup import GraphSetup
from .signal_processing import SignalProcessor
@ -28,9 +27,9 @@ class TradingAgentsGraph:
def __init__(
self,
selected_analysts=["market", "social", "news", "fundamentals"],
selected_analysts=None,
debug=False,
config: Dict[str, Any] = None,
config: dict[str, Any] | None = None,
):
"""Initialize the trading agents graph and components.
@ -39,6 +38,8 @@ class TradingAgentsGraph:
debug: Whether to run in debug mode
config: Configuration dictionary. If None, uses default config
"""
if selected_analysts is None:
selected_analysts = ["market", "social", "news", "fundamentals"]
self.debug = debug
self.config = config or DEFAULT_CONFIG
@ -58,7 +59,7 @@ class TradingAgentsGraph:
or self.config["llm_provider"] == "openrouter"
):
self.deep_thinking_llm = ChatOpenAI(
model=self.config["deep_think_llm"], base_url=self.config["backend_url"]
model=self.config["deep_think_llm"], base_url=self.config["backend_url"],
)
self.quick_thinking_llm = ChatOpenAI(
model=self.config["quick_think_llm"],
@ -66,7 +67,7 @@ class TradingAgentsGraph:
)
elif self.config["llm_provider"].lower() == "anthropic":
self.deep_thinking_llm = ChatAnthropic(
model=self.config["deep_think_llm"], base_url=self.config["backend_url"]
model=self.config["deep_think_llm"], base_url=self.config["backend_url"],
)
self.quick_thinking_llm = ChatAnthropic(
model=self.config["quick_think_llm"],
@ -74,13 +75,14 @@ class TradingAgentsGraph:
)
elif self.config["llm_provider"].lower() == "google":
self.deep_thinking_llm = ChatGoogleGenerativeAI(
model=self.config["deep_think_llm"]
model=self.config["deep_think_llm"],
)
self.quick_thinking_llm = ChatGoogleGenerativeAI(
model=self.config["quick_think_llm"]
model=self.config["quick_think_llm"],
)
else:
raise ValueError(f"Unsupported LLM provider: {self.config['llm_provider']}")
msg = f"Unsupported LLM provider: {self.config['llm_provider']}"
raise ValueError(msg)
self.toolkit = Toolkit(config=self.config)
@ -89,10 +91,10 @@ class TradingAgentsGraph:
self.bear_memory = FinancialSituationMemory("bear_memory", self.config)
self.trader_memory = FinancialSituationMemory("trader_memory", self.config)
self.invest_judge_memory = FinancialSituationMemory(
"invest_judge_memory", self.config
"invest_judge_memory", self.config,
)
self.risk_manager_memory = FinancialSituationMemory(
"risk_manager_memory", self.config
"risk_manager_memory", self.config,
)
# Create tool nodes
@ -125,7 +127,7 @@ class TradingAgentsGraph:
# Set up the graph
self.graph = self.graph_setup.setup_graph(selected_analysts)
def _create_tool_nodes(self) -> Dict[str, ToolNode]:
def _create_tool_nodes(self) -> dict[str, ToolNode]:
"""Create tool nodes for different data sources."""
return {
"market": ToolNode(
@ -136,7 +138,7 @@ class TradingAgentsGraph:
# offline tools
self.toolkit.get_YFin_data,
self.toolkit.get_stockstats_indicators_report,
]
],
),
"social": ToolNode(
[
@ -144,7 +146,7 @@ class TradingAgentsGraph:
self.toolkit.get_stock_news_openai,
# offline tools
self.toolkit.get_reddit_stock_info,
]
],
),
"news": ToolNode(
[
@ -154,7 +156,7 @@ class TradingAgentsGraph:
# offline tools
self.toolkit.get_finnhub_news,
self.toolkit.get_reddit_news,
]
],
),
"fundamentals": ToolNode(
[
@ -166,7 +168,7 @@ class TradingAgentsGraph:
self.toolkit.get_simfin_balance_sheet,
self.toolkit.get_simfin_cashflow,
self.toolkit.get_simfin_income_stmt,
]
],
),
}
@ -177,7 +179,7 @@ class TradingAgentsGraph:
# Initialize state
init_agent_state = self.propagator.create_initial_state(
company_name, trade_date
company_name, trade_date,
)
args = self.propagator.get_graph_args()
@ -250,19 +252,19 @@ class TradingAgentsGraph:
def reflect_and_remember(self, returns_losses):
"""Reflect on decisions and update memory based on returns."""
self.reflector.reflect_bull_researcher(
self.curr_state, returns_losses, self.bull_memory
self.curr_state, returns_losses, self.bull_memory,
)
self.reflector.reflect_bear_researcher(
self.curr_state, returns_losses, self.bear_memory
self.curr_state, returns_losses, self.bear_memory,
)
self.reflector.reflect_trader(
self.curr_state, returns_losses, self.trader_memory
self.curr_state, returns_losses, self.trader_memory,
)
self.reflector.reflect_invest_judge(
self.curr_state, returns_losses, self.invest_judge_memory
self.curr_state, returns_losses, self.invest_judge_memory,
)
self.reflector.reflect_risk_manager(
self.curr_state, returns_losses, self.risk_manager_memory
self.curr_state, returns_losses, self.risk_manager_memory,
)
def process_signal(self, full_signal):