Fix remaining ruff linting errors
- Fixed all F-type errors (undefined names, unused imports) - Applied automatic fixes for code style issues - Ensured CI/CD pipeline passes all checks
This commit is contained in:
parent
4361ed19e4
commit
6f3981412b
251
cli/main.py
251
cli/main.py
|
|
@ -1,31 +1,32 @@
|
||||||
import datetime
|
import datetime
|
||||||
import typer
|
|
||||||
from pathlib import Path
|
|
||||||
from functools import wraps
|
|
||||||
from rich.console import Console
|
|
||||||
from rich.panel import Panel
|
|
||||||
from rich.spinner import Spinner
|
|
||||||
from rich.live import Live
|
|
||||||
from rich.columns import Columns
|
|
||||||
from rich.markdown import Markdown
|
|
||||||
from rich.layout import Layout
|
|
||||||
from rich.text import Text
|
|
||||||
from rich.table import Table
|
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
from functools import wraps
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import typer
|
||||||
from rich import box
|
from rich import box
|
||||||
from rich.align import Align
|
from rich.align import Align
|
||||||
|
from rich.columns import Columns
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.layout import Layout
|
||||||
|
from rich.live import Live
|
||||||
|
from rich.markdown import Markdown
|
||||||
|
from rich.panel import Panel
|
||||||
|
from rich.spinner import Spinner
|
||||||
|
from rich.table import Table
|
||||||
|
from rich.text import Text
|
||||||
|
|
||||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
|
||||||
from tradingagents.default_config import DEFAULT_CONFIG
|
|
||||||
from cli.utils import (
|
from cli.utils import (
|
||||||
get_ticker,
|
|
||||||
get_analysis_date,
|
get_analysis_date,
|
||||||
|
get_ticker,
|
||||||
select_analysts,
|
select_analysts,
|
||||||
select_research_depth,
|
|
||||||
select_shallow_thinking_agent,
|
|
||||||
select_deep_thinking_agent,
|
select_deep_thinking_agent,
|
||||||
select_llm_provider,
|
select_llm_provider,
|
||||||
|
select_research_depth,
|
||||||
|
select_shallow_thinking_agent,
|
||||||
)
|
)
|
||||||
|
from tradingagents.default_config import DEFAULT_CONFIG
|
||||||
|
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||||
|
|
||||||
console = Console()
|
console = Console()
|
||||||
|
|
||||||
|
|
@ -136,19 +137,20 @@ class MessageBuffer:
|
||||||
report_parts.append("## Analyst Team Reports")
|
report_parts.append("## Analyst Team Reports")
|
||||||
if self.report_sections["market_report"]:
|
if self.report_sections["market_report"]:
|
||||||
report_parts.append(
|
report_parts.append(
|
||||||
f"### Market Analysis\n{self.report_sections['market_report']}"
|
f"### Market Analysis\n{self.report_sections['market_report']}",
|
||||||
)
|
)
|
||||||
if self.report_sections["sentiment_report"]:
|
if self.report_sections["sentiment_report"]:
|
||||||
report_parts.append(
|
report_parts.append(
|
||||||
f"### Social Sentiment\n{self.report_sections['sentiment_report']}"
|
f"### Social Sentiment\n{self.report_sections['sentiment_report']}",
|
||||||
)
|
)
|
||||||
if self.report_sections["news_report"]:
|
if self.report_sections["news_report"]:
|
||||||
report_parts.append(
|
report_parts.append(
|
||||||
f"### News Analysis\n{self.report_sections['news_report']}"
|
f"### News Analysis\n{self.report_sections['news_report']}",
|
||||||
)
|
)
|
||||||
if self.report_sections["fundamentals_report"]:
|
if self.report_sections["fundamentals_report"]:
|
||||||
|
fundamentals = self.report_sections['fundamentals_report']
|
||||||
report_parts.append(
|
report_parts.append(
|
||||||
f"### Fundamentals Analysis\n{self.report_sections['fundamentals_report']}"
|
f"### Fundamentals Analysis\n{fundamentals}",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Research Team Reports
|
# Research Team Reports
|
||||||
|
|
@ -180,10 +182,10 @@ def create_layout():
|
||||||
Layout(name="footer", size=3),
|
Layout(name="footer", size=3),
|
||||||
)
|
)
|
||||||
layout["main"].split_column(
|
layout["main"].split_column(
|
||||||
Layout(name="upper", ratio=3), Layout(name="analysis", ratio=5)
|
Layout(name="upper", ratio=3), Layout(name="analysis", ratio=5),
|
||||||
)
|
)
|
||||||
layout["upper"].split_row(
|
layout["upper"].split_row(
|
||||||
Layout(name="progress", ratio=2), Layout(name="messages", ratio=3)
|
Layout(name="progress", ratio=2), Layout(name="messages", ratio=3),
|
||||||
)
|
)
|
||||||
return layout
|
return layout
|
||||||
|
|
||||||
|
|
@ -198,7 +200,7 @@ def update_display(layout, spinner_text=None):
|
||||||
border_style="green",
|
border_style="green",
|
||||||
padding=(1, 2),
|
padding=(1, 2),
|
||||||
expand=True,
|
expand=True,
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Progress panel showing agent status
|
# Progress panel showing agent status
|
||||||
|
|
@ -235,7 +237,7 @@ def update_display(layout, spinner_text=None):
|
||||||
status = message_buffer.agent_status[first_agent]
|
status = message_buffer.agent_status[first_agent]
|
||||||
if status == "in_progress":
|
if status == "in_progress":
|
||||||
spinner = Spinner(
|
spinner = Spinner(
|
||||||
"dots", text="[blue]in_progress[/blue]", style="bold cyan"
|
"dots", text="[blue]in_progress[/blue]", style="bold cyan",
|
||||||
)
|
)
|
||||||
status_cell = spinner
|
status_cell = spinner
|
||||||
else:
|
else:
|
||||||
|
|
@ -252,7 +254,7 @@ def update_display(layout, spinner_text=None):
|
||||||
status = message_buffer.agent_status[agent]
|
status = message_buffer.agent_status[agent]
|
||||||
if status == "in_progress":
|
if status == "in_progress":
|
||||||
spinner = Spinner(
|
spinner = Spinner(
|
||||||
"dots", text="[blue]in_progress[/blue]", style="bold cyan"
|
"dots", text="[blue]in_progress[/blue]", style="bold cyan",
|
||||||
)
|
)
|
||||||
status_cell = spinner
|
status_cell = spinner
|
||||||
else:
|
else:
|
||||||
|
|
@ -268,7 +270,7 @@ def update_display(layout, spinner_text=None):
|
||||||
progress_table.add_row("─" * 20, "─" * 20, "─" * 20, style="dim")
|
progress_table.add_row("─" * 20, "─" * 20, "─" * 20, style="dim")
|
||||||
|
|
||||||
layout["progress"].update(
|
layout["progress"].update(
|
||||||
Panel(progress_table, title="Progress", border_style="cyan", padding=(1, 2))
|
Panel(progress_table, title="Progress", border_style="cyan", padding=(1, 2)),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Messages panel showing recent messages and tool calls
|
# Messages panel showing recent messages and tool calls
|
||||||
|
|
@ -284,7 +286,7 @@ def update_display(layout, spinner_text=None):
|
||||||
messages_table.add_column("Time", style="cyan", width=8, justify="center")
|
messages_table.add_column("Time", style="cyan", width=8, justify="center")
|
||||||
messages_table.add_column("Type", style="green", width=10, justify="center")
|
messages_table.add_column("Type", style="green", width=10, justify="center")
|
||||||
messages_table.add_column(
|
messages_table.add_column(
|
||||||
"Content", style="white", no_wrap=False, ratio=1
|
"Content", style="white", no_wrap=False, ratio=1,
|
||||||
) # Make content column expand
|
) # Make content column expand
|
||||||
|
|
||||||
# Combine tool calls and messages
|
# Combine tool calls and messages
|
||||||
|
|
@ -352,7 +354,7 @@ def update_display(layout, spinner_text=None):
|
||||||
title="Messages & Tools",
|
title="Messages & Tools",
|
||||||
border_style="blue",
|
border_style="blue",
|
||||||
padding=(1, 2),
|
padding=(1, 2),
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Analysis panel showing current report
|
# Analysis panel showing current report
|
||||||
|
|
@ -363,7 +365,7 @@ def update_display(layout, spinner_text=None):
|
||||||
title="Current Report",
|
title="Current Report",
|
||||||
border_style="green",
|
border_style="green",
|
||||||
padding=(1, 2),
|
padding=(1, 2),
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
layout["analysis"].update(
|
layout["analysis"].update(
|
||||||
|
|
@ -372,7 +374,7 @@ def update_display(layout, spinner_text=None):
|
||||||
title="Current Report",
|
title="Current Report",
|
||||||
border_style="green",
|
border_style="green",
|
||||||
padding=(1, 2),
|
padding=(1, 2),
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Footer with statistics
|
# Footer with statistics
|
||||||
|
|
@ -386,9 +388,12 @@ def update_display(layout, spinner_text=None):
|
||||||
|
|
||||||
stats_table = Table(show_header=False, box=None, padding=(0, 2), expand=True)
|
stats_table = Table(show_header=False, box=None, padding=(0, 2), expand=True)
|
||||||
stats_table.add_column("Stats", justify="center")
|
stats_table.add_column("Stats", justify="center")
|
||||||
stats_table.add_row(
|
stats_text = (
|
||||||
f"Tool Calls: {tool_calls_count} | LLM Calls: {llm_calls_count} | Generated Reports: {reports_count}"
|
f"Tool Calls: {tool_calls_count} | "
|
||||||
|
f"LLM Calls: {llm_calls_count} | "
|
||||||
|
f"Generated Reports: {reports_count}"
|
||||||
)
|
)
|
||||||
|
stats_table.add_row(stats_text)
|
||||||
|
|
||||||
layout["footer"].update(Panel(stats_table, border_style="grey50"))
|
layout["footer"].update(Panel(stats_table, border_style="grey50"))
|
||||||
|
|
||||||
|
|
@ -396,14 +401,20 @@ def update_display(layout, spinner_text=None):
|
||||||
def get_user_selections():
|
def get_user_selections():
|
||||||
"""Get all user selections before starting the analysis display."""
|
"""Get all user selections before starting the analysis display."""
|
||||||
# Display ASCII art welcome message
|
# Display ASCII art welcome message
|
||||||
with open("./cli/static/welcome.txt", "r") as f:
|
with open("./cli/static/welcome.txt") as f:
|
||||||
welcome_ascii = f.read()
|
welcome_ascii = f.read()
|
||||||
|
|
||||||
# Create welcome box content
|
# Create welcome box content
|
||||||
welcome_content = f"{welcome_ascii}\n"
|
welcome_content = f"{welcome_ascii}\n"
|
||||||
welcome_content += "[bold green]TradingAgents: Multi-Agents LLM Financial Trading Framework - CLI[/bold green]\n\n"
|
welcome_content += (
|
||||||
|
"[bold green]TradingAgents: "
|
||||||
|
"Multi-Agents LLM Financial Trading Framework - CLI[/bold green]\n\n"
|
||||||
|
)
|
||||||
welcome_content += "[bold]Workflow Steps:[/bold]\n"
|
welcome_content += "[bold]Workflow Steps:[/bold]\n"
|
||||||
welcome_content += "I. Analyst Team → II. Research Team → III. Trader → IV. Risk Management → V. Portfolio Management\n\n"
|
welcome_content += (
|
||||||
|
"I. Analyst Team → II. Research Team → III. Trader → "
|
||||||
|
"IV. Risk Management → V. Portfolio Management\n\n"
|
||||||
|
)
|
||||||
welcome_content += (
|
welcome_content += (
|
||||||
"[dim]Built by [Tauric Research](https://github.com/TauricResearch)[/dim]"
|
"[dim]Built by [Tauric Research](https://github.com/TauricResearch)[/dim]"
|
||||||
)
|
)
|
||||||
|
|
@ -430,8 +441,8 @@ def get_user_selections():
|
||||||
# Step 1: Ticker symbol
|
# Step 1: Ticker symbol
|
||||||
console.print(
|
console.print(
|
||||||
create_question_box(
|
create_question_box(
|
||||||
"Step 1: Ticker Symbol", "Enter the ticker symbol to analyze", "SPY"
|
"Step 1: Ticker Symbol", "Enter the ticker symbol to analyze", "SPY",
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
selected_ticker = get_ticker()
|
selected_ticker = get_ticker()
|
||||||
|
|
||||||
|
|
@ -442,40 +453,40 @@ def get_user_selections():
|
||||||
"Step 2: Analysis Date",
|
"Step 2: Analysis Date",
|
||||||
"Enter the analysis date (YYYY-MM-DD)",
|
"Enter the analysis date (YYYY-MM-DD)",
|
||||||
default_date,
|
default_date,
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
analysis_date = get_analysis_date()
|
analysis_date = get_analysis_date()
|
||||||
|
|
||||||
# Step 3: Select analysts
|
# Step 3: Select analysts
|
||||||
console.print(
|
console.print(
|
||||||
create_question_box(
|
create_question_box(
|
||||||
"Step 3: Analysts Team", "Select your LLM analyst agents for the analysis"
|
"Step 3: Analysts Team", "Select your LLM analyst agents for the analysis",
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
selected_analysts = select_analysts()
|
selected_analysts = select_analysts()
|
||||||
console.print(
|
console.print(
|
||||||
f"[green]Selected analysts:[/green] {', '.join(analyst.value for analyst in selected_analysts)}"
|
f"[green]Selected analysts:[/green] {', '.join(analyst.value for analyst in selected_analysts)}",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Step 4: Research depth
|
# Step 4: Research depth
|
||||||
console.print(
|
console.print(
|
||||||
create_question_box(
|
create_question_box(
|
||||||
"Step 4: Research Depth", "Select your research depth level"
|
"Step 4: Research Depth", "Select your research depth level",
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
selected_research_depth = select_research_depth()
|
selected_research_depth = select_research_depth()
|
||||||
|
|
||||||
# Step 5: OpenAI backend
|
# Step 5: OpenAI backend
|
||||||
console.print(
|
console.print(
|
||||||
create_question_box("Step 5: OpenAI backend", "Select which service to talk to")
|
create_question_box("Step 5: OpenAI backend", "Select which service to talk to"),
|
||||||
)
|
)
|
||||||
selected_llm_provider, backend_url = select_llm_provider()
|
selected_llm_provider, backend_url = select_llm_provider()
|
||||||
|
|
||||||
# Step 6: Thinking agents
|
# Step 6: Thinking agents
|
||||||
console.print(
|
console.print(
|
||||||
create_question_box(
|
create_question_box(
|
||||||
"Step 6: Thinking Agents", "Select your thinking agents for analysis"
|
"Step 6: Thinking Agents", "Select your thinking agents for analysis",
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
selected_shallow_thinker = select_shallow_thinking_agent(selected_llm_provider)
|
selected_shallow_thinker = select_shallow_thinking_agent(selected_llm_provider)
|
||||||
selected_deep_thinker = select_deep_thinking_agent(selected_llm_provider)
|
selected_deep_thinker = select_deep_thinking_agent(selected_llm_provider)
|
||||||
|
|
@ -492,28 +503,7 @@ def get_user_selections():
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_ticker():
|
# Functions get_ticker and get_analysis_date are imported from cli.utils
|
||||||
"""Get ticker symbol from user input."""
|
|
||||||
return typer.prompt("", default="SPY")
|
|
||||||
|
|
||||||
|
|
||||||
def get_analysis_date():
|
|
||||||
"""Get the analysis date from user input."""
|
|
||||||
while True:
|
|
||||||
date_str = typer.prompt(
|
|
||||||
"", default=datetime.datetime.now().strftime("%Y-%m-%d")
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
# Validate date format and ensure it's not in the future
|
|
||||||
analysis_date = datetime.datetime.strptime(date_str, "%Y-%m-%d")
|
|
||||||
if analysis_date.date() > datetime.datetime.now().date():
|
|
||||||
console.print("[red]Error: Analysis date cannot be in the future[/red]")
|
|
||||||
continue
|
|
||||||
return date_str
|
|
||||||
except ValueError:
|
|
||||||
console.print(
|
|
||||||
"[red]Error: Invalid date format. Please use YYYY-MM-DD[/red]"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def display_complete_report(final_state):
|
def display_complete_report(final_state):
|
||||||
|
|
@ -531,7 +521,7 @@ def display_complete_report(final_state):
|
||||||
title="Market Analyst",
|
title="Market Analyst",
|
||||||
border_style="blue",
|
border_style="blue",
|
||||||
padding=(1, 2),
|
padding=(1, 2),
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Social Analyst Report
|
# Social Analyst Report
|
||||||
|
|
@ -542,7 +532,7 @@ def display_complete_report(final_state):
|
||||||
title="Social Analyst",
|
title="Social Analyst",
|
||||||
border_style="blue",
|
border_style="blue",
|
||||||
padding=(1, 2),
|
padding=(1, 2),
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# News Analyst Report
|
# News Analyst Report
|
||||||
|
|
@ -553,7 +543,7 @@ def display_complete_report(final_state):
|
||||||
title="News Analyst",
|
title="News Analyst",
|
||||||
border_style="blue",
|
border_style="blue",
|
||||||
padding=(1, 2),
|
padding=(1, 2),
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Fundamentals Analyst Report
|
# Fundamentals Analyst Report
|
||||||
|
|
@ -564,7 +554,7 @@ def display_complete_report(final_state):
|
||||||
title="Fundamentals Analyst",
|
title="Fundamentals Analyst",
|
||||||
border_style="blue",
|
border_style="blue",
|
||||||
padding=(1, 2),
|
padding=(1, 2),
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
if analyst_reports:
|
if analyst_reports:
|
||||||
|
|
@ -574,7 +564,7 @@ def display_complete_report(final_state):
|
||||||
title="I. Analyst Team Reports",
|
title="I. Analyst Team Reports",
|
||||||
border_style="cyan",
|
border_style="cyan",
|
||||||
padding=(1, 2),
|
padding=(1, 2),
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# II. Research Team Reports
|
# II. Research Team Reports
|
||||||
|
|
@ -590,7 +580,7 @@ def display_complete_report(final_state):
|
||||||
title="Bull Researcher",
|
title="Bull Researcher",
|
||||||
border_style="blue",
|
border_style="blue",
|
||||||
padding=(1, 2),
|
padding=(1, 2),
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Bear Researcher Analysis
|
# Bear Researcher Analysis
|
||||||
|
|
@ -601,7 +591,7 @@ def display_complete_report(final_state):
|
||||||
title="Bear Researcher",
|
title="Bear Researcher",
|
||||||
border_style="blue",
|
border_style="blue",
|
||||||
padding=(1, 2),
|
padding=(1, 2),
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Research Manager Decision
|
# Research Manager Decision
|
||||||
|
|
@ -612,7 +602,7 @@ def display_complete_report(final_state):
|
||||||
title="Research Manager",
|
title="Research Manager",
|
||||||
border_style="blue",
|
border_style="blue",
|
||||||
padding=(1, 2),
|
padding=(1, 2),
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
if research_reports:
|
if research_reports:
|
||||||
|
|
@ -622,7 +612,7 @@ def display_complete_report(final_state):
|
||||||
title="II. Research Team Decision",
|
title="II. Research Team Decision",
|
||||||
border_style="magenta",
|
border_style="magenta",
|
||||||
padding=(1, 2),
|
padding=(1, 2),
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# III. Trading Team Reports
|
# III. Trading Team Reports
|
||||||
|
|
@ -638,7 +628,7 @@ def display_complete_report(final_state):
|
||||||
title="III. Trading Team Plan",
|
title="III. Trading Team Plan",
|
||||||
border_style="yellow",
|
border_style="yellow",
|
||||||
padding=(1, 2),
|
padding=(1, 2),
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# IV. Risk Management Team Reports
|
# IV. Risk Management Team Reports
|
||||||
|
|
@ -654,7 +644,7 @@ def display_complete_report(final_state):
|
||||||
title="Aggressive Analyst",
|
title="Aggressive Analyst",
|
||||||
border_style="blue",
|
border_style="blue",
|
||||||
padding=(1, 2),
|
padding=(1, 2),
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Conservative (Safe) Analyst Analysis
|
# Conservative (Safe) Analyst Analysis
|
||||||
|
|
@ -665,7 +655,7 @@ def display_complete_report(final_state):
|
||||||
title="Conservative Analyst",
|
title="Conservative Analyst",
|
||||||
border_style="blue",
|
border_style="blue",
|
||||||
padding=(1, 2),
|
padding=(1, 2),
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Neutral Analyst Analysis
|
# Neutral Analyst Analysis
|
||||||
|
|
@ -676,7 +666,7 @@ def display_complete_report(final_state):
|
||||||
title="Neutral Analyst",
|
title="Neutral Analyst",
|
||||||
border_style="blue",
|
border_style="blue",
|
||||||
padding=(1, 2),
|
padding=(1, 2),
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
if risk_reports:
|
if risk_reports:
|
||||||
|
|
@ -686,7 +676,7 @@ def display_complete_report(final_state):
|
||||||
title="IV. Risk Management Team Decision",
|
title="IV. Risk Management Team Decision",
|
||||||
border_style="red",
|
border_style="red",
|
||||||
padding=(1, 2),
|
padding=(1, 2),
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# V. Portfolio Manager Decision
|
# V. Portfolio Manager Decision
|
||||||
|
|
@ -702,7 +692,7 @@ def display_complete_report(final_state):
|
||||||
title="V. Portfolio Manager Decision",
|
title="V. Portfolio Manager Decision",
|
||||||
border_style="green",
|
border_style="green",
|
||||||
padding=(1, 2),
|
padding=(1, 2),
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -717,7 +707,7 @@ def extract_content_string(content):
|
||||||
"""Extract string content from various message formats."""
|
"""Extract string content from various message formats."""
|
||||||
if isinstance(content, str):
|
if isinstance(content, str):
|
||||||
return content
|
return content
|
||||||
elif isinstance(content, list):
|
if isinstance(content, list):
|
||||||
# Handle Anthropic's list format
|
# Handle Anthropic's list format
|
||||||
text_parts = []
|
text_parts = []
|
||||||
for item in content:
|
for item in content:
|
||||||
|
|
@ -729,8 +719,7 @@ def extract_content_string(content):
|
||||||
else:
|
else:
|
||||||
text_parts.append(str(item))
|
text_parts.append(str(item))
|
||||||
return " ".join(text_parts)
|
return " ".join(text_parts)
|
||||||
else:
|
return str(content)
|
||||||
return str(content)
|
|
||||||
|
|
||||||
|
|
||||||
def run_analysis():
|
def run_analysis():
|
||||||
|
|
@ -748,7 +737,7 @@ def run_analysis():
|
||||||
|
|
||||||
# Initialize the graph
|
# Initialize the graph
|
||||||
graph = TradingAgentsGraph(
|
graph = TradingAgentsGraph(
|
||||||
[analyst.value for analyst in selections["analysts"]], config=config, debug=True
|
[analyst.value for analyst in selections["analysts"]], config=config, debug=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create result directory
|
# Create result directory
|
||||||
|
|
@ -807,23 +796,23 @@ def run_analysis():
|
||||||
|
|
||||||
message_buffer.add_message = save_message_decorator(message_buffer, "add_message")
|
message_buffer.add_message = save_message_decorator(message_buffer, "add_message")
|
||||||
message_buffer.add_tool_call = save_tool_call_decorator(
|
message_buffer.add_tool_call = save_tool_call_decorator(
|
||||||
message_buffer, "add_tool_call"
|
message_buffer, "add_tool_call",
|
||||||
)
|
)
|
||||||
message_buffer.update_report_section = save_report_section_decorator(
|
message_buffer.update_report_section = save_report_section_decorator(
|
||||||
message_buffer, "update_report_section"
|
message_buffer, "update_report_section",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Now start the display layout
|
# Now start the display layout
|
||||||
layout = create_layout()
|
layout = create_layout()
|
||||||
|
|
||||||
with Live(layout, refresh_per_second=4) as live:
|
with Live(layout, refresh_per_second=4):
|
||||||
# Initial display
|
# Initial display
|
||||||
update_display(layout)
|
update_display(layout)
|
||||||
|
|
||||||
# Add initial messages
|
# Add initial messages
|
||||||
message_buffer.add_message("System", f"Selected ticker: {selections['ticker']}")
|
message_buffer.add_message("System", f"Selected ticker: {selections['ticker']}")
|
||||||
message_buffer.add_message(
|
message_buffer.add_message(
|
||||||
"System", f"Analysis date: {selections['analysis_date']}"
|
"System", f"Analysis date: {selections['analysis_date']}",
|
||||||
)
|
)
|
||||||
message_buffer.add_message(
|
message_buffer.add_message(
|
||||||
"System",
|
"System",
|
||||||
|
|
@ -854,7 +843,7 @@ def run_analysis():
|
||||||
|
|
||||||
# Initialize state and get graph args
|
# Initialize state and get graph args
|
||||||
init_agent_state = graph.propagator.create_initial_state(
|
init_agent_state = graph.propagator.create_initial_state(
|
||||||
selections["ticker"], selections["analysis_date"]
|
selections["ticker"], selections["analysis_date"],
|
||||||
)
|
)
|
||||||
args = graph.propagator.get_graph_args()
|
args = graph.propagator.get_graph_args()
|
||||||
|
|
||||||
|
|
@ -868,7 +857,7 @@ def run_analysis():
|
||||||
# Extract message content and type
|
# Extract message content and type
|
||||||
if hasattr(last_message, "content"):
|
if hasattr(last_message, "content"):
|
||||||
content = extract_content_string(
|
content = extract_content_string(
|
||||||
last_message.content
|
last_message.content,
|
||||||
) # Use the helper function
|
) # Use the helper function
|
||||||
msg_type = "Reasoning"
|
msg_type = "Reasoning"
|
||||||
else:
|
else:
|
||||||
|
|
@ -884,65 +873,64 @@ def run_analysis():
|
||||||
# Handle both dictionary and object tool calls
|
# Handle both dictionary and object tool calls
|
||||||
if isinstance(tool_call, dict):
|
if isinstance(tool_call, dict):
|
||||||
message_buffer.add_tool_call(
|
message_buffer.add_tool_call(
|
||||||
tool_call["name"], tool_call["args"]
|
tool_call["name"], tool_call["args"],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
message_buffer.add_tool_call(tool_call.name, tool_call.args)
|
message_buffer.add_tool_call(tool_call.name, tool_call.args)
|
||||||
|
|
||||||
# Update reports and agent status based on chunk content
|
# Update reports and agent status based on chunk content
|
||||||
# Analyst Team Reports
|
# Analyst Team Reports
|
||||||
if "market_report" in chunk and chunk["market_report"]:
|
if chunk.get("market_report"):
|
||||||
message_buffer.update_report_section(
|
message_buffer.update_report_section(
|
||||||
"market_report", chunk["market_report"]
|
"market_report", chunk["market_report"],
|
||||||
)
|
)
|
||||||
message_buffer.update_agent_status("Market Analyst", "completed")
|
message_buffer.update_agent_status("Market Analyst", "completed")
|
||||||
# Set next analyst to in_progress
|
# Set next analyst to in_progress
|
||||||
if "social" in selections["analysts"]:
|
if "social" in selections["analysts"]:
|
||||||
message_buffer.update_agent_status(
|
message_buffer.update_agent_status(
|
||||||
"Social Analyst", "in_progress"
|
"Social Analyst", "in_progress",
|
||||||
)
|
)
|
||||||
|
|
||||||
if "sentiment_report" in chunk and chunk["sentiment_report"]:
|
if chunk.get("sentiment_report"):
|
||||||
message_buffer.update_report_section(
|
message_buffer.update_report_section(
|
||||||
"sentiment_report", chunk["sentiment_report"]
|
"sentiment_report", chunk["sentiment_report"],
|
||||||
)
|
)
|
||||||
message_buffer.update_agent_status("Social Analyst", "completed")
|
message_buffer.update_agent_status("Social Analyst", "completed")
|
||||||
# Set next analyst to in_progress
|
# Set next analyst to in_progress
|
||||||
if "news" in selections["analysts"]:
|
if "news" in selections["analysts"]:
|
||||||
message_buffer.update_agent_status(
|
message_buffer.update_agent_status(
|
||||||
"News Analyst", "in_progress"
|
"News Analyst", "in_progress",
|
||||||
)
|
)
|
||||||
|
|
||||||
if "news_report" in chunk and chunk["news_report"]:
|
if chunk.get("news_report"):
|
||||||
message_buffer.update_report_section(
|
message_buffer.update_report_section(
|
||||||
"news_report", chunk["news_report"]
|
"news_report", chunk["news_report"],
|
||||||
)
|
)
|
||||||
message_buffer.update_agent_status("News Analyst", "completed")
|
message_buffer.update_agent_status("News Analyst", "completed")
|
||||||
# Set next analyst to in_progress
|
# Set next analyst to in_progress
|
||||||
if "fundamentals" in selections["analysts"]:
|
if "fundamentals" in selections["analysts"]:
|
||||||
message_buffer.update_agent_status(
|
message_buffer.update_agent_status(
|
||||||
"Fundamentals Analyst", "in_progress"
|
"Fundamentals Analyst", "in_progress",
|
||||||
)
|
)
|
||||||
|
|
||||||
if "fundamentals_report" in chunk and chunk["fundamentals_report"]:
|
if chunk.get("fundamentals_report"):
|
||||||
message_buffer.update_report_section(
|
message_buffer.update_report_section(
|
||||||
"fundamentals_report", chunk["fundamentals_report"]
|
"fundamentals_report", chunk["fundamentals_report"],
|
||||||
)
|
)
|
||||||
message_buffer.update_agent_status(
|
message_buffer.update_agent_status(
|
||||||
"Fundamentals Analyst", "completed"
|
"Fundamentals Analyst", "completed",
|
||||||
)
|
)
|
||||||
# Set all research team members to in_progress
|
# Set all research team members to in_progress
|
||||||
update_research_team_status("in_progress")
|
update_research_team_status("in_progress")
|
||||||
|
|
||||||
# Research Team - Handle Investment Debate State
|
# Research Team - Handle Investment Debate State
|
||||||
if (
|
if (
|
||||||
"investment_debate_state" in chunk
|
chunk.get("investment_debate_state")
|
||||||
and chunk["investment_debate_state"]
|
|
||||||
):
|
):
|
||||||
debate_state = chunk["investment_debate_state"]
|
debate_state = chunk["investment_debate_state"]
|
||||||
|
|
||||||
# Update Bull Researcher status and report
|
# Update Bull Researcher status and report
|
||||||
if "bull_history" in debate_state and debate_state["bull_history"]:
|
if debate_state.get("bull_history"):
|
||||||
# Keep all research team members in progress
|
# Keep all research team members in progress
|
||||||
update_research_team_status("in_progress")
|
update_research_team_status("in_progress")
|
||||||
# Extract latest bull response
|
# Extract latest bull response
|
||||||
|
|
@ -957,7 +945,7 @@ def run_analysis():
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update Bear Researcher status and report
|
# Update Bear Researcher status and report
|
||||||
if "bear_history" in debate_state and debate_state["bear_history"]:
|
if debate_state.get("bear_history"):
|
||||||
# Keep all research team members in progress
|
# Keep all research team members in progress
|
||||||
update_research_team_status("in_progress")
|
update_research_team_status("in_progress")
|
||||||
# Extract latest bear response
|
# Extract latest bear response
|
||||||
|
|
@ -973,8 +961,7 @@ def run_analysis():
|
||||||
|
|
||||||
# Update Research Manager status and final decision
|
# Update Research Manager status and final decision
|
||||||
if (
|
if (
|
||||||
"judge_decision" in debate_state
|
debate_state.get("judge_decision")
|
||||||
and debate_state["judge_decision"]
|
|
||||||
):
|
):
|
||||||
# Keep all research team members in progress until final decision
|
# Keep all research team members in progress until final decision
|
||||||
update_research_team_status("in_progress")
|
update_research_team_status("in_progress")
|
||||||
|
|
@ -991,31 +978,29 @@ def run_analysis():
|
||||||
update_research_team_status("completed")
|
update_research_team_status("completed")
|
||||||
# Set first risk analyst to in_progress
|
# Set first risk analyst to in_progress
|
||||||
message_buffer.update_agent_status(
|
message_buffer.update_agent_status(
|
||||||
"Risky Analyst", "in_progress"
|
"Risky Analyst", "in_progress",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Trading Team
|
# Trading Team
|
||||||
if (
|
if (
|
||||||
"trader_investment_plan" in chunk
|
chunk.get("trader_investment_plan")
|
||||||
and chunk["trader_investment_plan"]
|
|
||||||
):
|
):
|
||||||
message_buffer.update_report_section(
|
message_buffer.update_report_section(
|
||||||
"trader_investment_plan", chunk["trader_investment_plan"]
|
"trader_investment_plan", chunk["trader_investment_plan"],
|
||||||
)
|
)
|
||||||
# Set first risk analyst to in_progress
|
# Set first risk analyst to in_progress
|
||||||
message_buffer.update_agent_status("Risky Analyst", "in_progress")
|
message_buffer.update_agent_status("Risky Analyst", "in_progress")
|
||||||
|
|
||||||
# Risk Management Team - Handle Risk Debate State
|
# Risk Management Team - Handle Risk Debate State
|
||||||
if "risk_debate_state" in chunk and chunk["risk_debate_state"]:
|
if chunk.get("risk_debate_state"):
|
||||||
risk_state = chunk["risk_debate_state"]
|
risk_state = chunk["risk_debate_state"]
|
||||||
|
|
||||||
# Update Risky Analyst status and report
|
# Update Risky Analyst status and report
|
||||||
if (
|
if (
|
||||||
"current_risky_response" in risk_state
|
risk_state.get("current_risky_response")
|
||||||
and risk_state["current_risky_response"]
|
|
||||||
):
|
):
|
||||||
message_buffer.update_agent_status(
|
message_buffer.update_agent_status(
|
||||||
"Risky Analyst", "in_progress"
|
"Risky Analyst", "in_progress",
|
||||||
)
|
)
|
||||||
message_buffer.add_message(
|
message_buffer.add_message(
|
||||||
"Reasoning",
|
"Reasoning",
|
||||||
|
|
@ -1029,11 +1014,10 @@ def run_analysis():
|
||||||
|
|
||||||
# Update Safe Analyst status and report
|
# Update Safe Analyst status and report
|
||||||
if (
|
if (
|
||||||
"current_safe_response" in risk_state
|
risk_state.get("current_safe_response")
|
||||||
and risk_state["current_safe_response"]
|
|
||||||
):
|
):
|
||||||
message_buffer.update_agent_status(
|
message_buffer.update_agent_status(
|
||||||
"Safe Analyst", "in_progress"
|
"Safe Analyst", "in_progress",
|
||||||
)
|
)
|
||||||
message_buffer.add_message(
|
message_buffer.add_message(
|
||||||
"Reasoning",
|
"Reasoning",
|
||||||
|
|
@ -1047,11 +1031,10 @@ def run_analysis():
|
||||||
|
|
||||||
# Update Neutral Analyst status and report
|
# Update Neutral Analyst status and report
|
||||||
if (
|
if (
|
||||||
"current_neutral_response" in risk_state
|
risk_state.get("current_neutral_response")
|
||||||
and risk_state["current_neutral_response"]
|
|
||||||
):
|
):
|
||||||
message_buffer.update_agent_status(
|
message_buffer.update_agent_status(
|
||||||
"Neutral Analyst", "in_progress"
|
"Neutral Analyst", "in_progress",
|
||||||
)
|
)
|
||||||
message_buffer.add_message(
|
message_buffer.add_message(
|
||||||
"Reasoning",
|
"Reasoning",
|
||||||
|
|
@ -1064,9 +1047,9 @@ def run_analysis():
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update Portfolio Manager status and final decision
|
# Update Portfolio Manager status and final decision
|
||||||
if "judge_decision" in risk_state and risk_state["judge_decision"]:
|
if risk_state.get("judge_decision"):
|
||||||
message_buffer.update_agent_status(
|
message_buffer.update_agent_status(
|
||||||
"Portfolio Manager", "in_progress"
|
"Portfolio Manager", "in_progress",
|
||||||
)
|
)
|
||||||
message_buffer.add_message(
|
message_buffer.add_message(
|
||||||
"Reasoning",
|
"Reasoning",
|
||||||
|
|
@ -1081,10 +1064,10 @@ def run_analysis():
|
||||||
message_buffer.update_agent_status("Risky Analyst", "completed")
|
message_buffer.update_agent_status("Risky Analyst", "completed")
|
||||||
message_buffer.update_agent_status("Safe Analyst", "completed")
|
message_buffer.update_agent_status("Safe Analyst", "completed")
|
||||||
message_buffer.update_agent_status(
|
message_buffer.update_agent_status(
|
||||||
"Neutral Analyst", "completed"
|
"Neutral Analyst", "completed",
|
||||||
)
|
)
|
||||||
message_buffer.update_agent_status(
|
message_buffer.update_agent_status(
|
||||||
"Portfolio Manager", "completed"
|
"Portfolio Manager", "completed",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update the display
|
# Update the display
|
||||||
|
|
@ -1094,18 +1077,18 @@ def run_analysis():
|
||||||
|
|
||||||
# Get final state and decision
|
# Get final state and decision
|
||||||
final_state = trace[-1]
|
final_state = trace[-1]
|
||||||
decision = graph.process_signal(final_state["final_trade_decision"])
|
graph.process_signal(final_state["final_trade_decision"])
|
||||||
|
|
||||||
# Update all agent statuses to completed
|
# Update all agent statuses to completed
|
||||||
for agent in message_buffer.agent_status:
|
for agent in message_buffer.agent_status:
|
||||||
message_buffer.update_agent_status(agent, "completed")
|
message_buffer.update_agent_status(agent, "completed")
|
||||||
|
|
||||||
message_buffer.add_message(
|
message_buffer.add_message(
|
||||||
"Analysis", f"Completed analysis for {selections['analysis_date']}"
|
"Analysis", f"Completed analysis for {selections['analysis_date']}",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update final report sections
|
# Update final report sections
|
||||||
for section in message_buffer.report_sections.keys():
|
for section in message_buffer.report_sections:
|
||||||
if section in final_state:
|
if section in final_state:
|
||||||
message_buffer.update_report_section(section, final_state[section])
|
message_buffer.update_report_section(section, final_state[section])
|
||||||
|
|
||||||
|
|
|
||||||
37
cli/utils.py
37
cli/utils.py
|
|
@ -1,5 +1,7 @@
|
||||||
|
|
||||||
|
import sys
|
||||||
|
|
||||||
import questionary
|
import questionary
|
||||||
from typing import List
|
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
|
|
||||||
from cli.models import AnalystType
|
from cli.models import AnalystType
|
||||||
|
|
@ -23,13 +25,13 @@ def get_ticker() -> str:
|
||||||
[
|
[
|
||||||
("text", "fg:green"),
|
("text", "fg:green"),
|
||||||
("highlighted", "noinherit"),
|
("highlighted", "noinherit"),
|
||||||
]
|
],
|
||||||
),
|
),
|
||||||
).ask()
|
).ask()
|
||||||
|
|
||||||
if not ticker:
|
if not ticker:
|
||||||
console.print("\n[red]No ticker symbol provided. Exiting...[/red]")
|
console.print("\n[red]No ticker symbol provided. Exiting...[/red]")
|
||||||
exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
return ticker.strip().upper()
|
return ticker.strip().upper()
|
||||||
|
|
||||||
|
|
@ -56,18 +58,18 @@ def get_analysis_date() -> str:
|
||||||
[
|
[
|
||||||
("text", "fg:green"),
|
("text", "fg:green"),
|
||||||
("highlighted", "noinherit"),
|
("highlighted", "noinherit"),
|
||||||
]
|
],
|
||||||
),
|
),
|
||||||
).ask()
|
).ask()
|
||||||
|
|
||||||
if not date:
|
if not date:
|
||||||
console.print("\n[red]No date provided. Exiting...[/red]")
|
console.print("\n[red]No date provided. Exiting...[/red]")
|
||||||
exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
return date.strip()
|
return date.strip()
|
||||||
|
|
||||||
|
|
||||||
def select_analysts() -> List[AnalystType]:
|
def select_analysts() -> list[AnalystType]:
|
||||||
"""Select analysts using an interactive checkbox."""
|
"""Select analysts using an interactive checkbox."""
|
||||||
choices = questionary.checkbox(
|
choices = questionary.checkbox(
|
||||||
"Select Your [Analysts Team]:",
|
"Select Your [Analysts Team]:",
|
||||||
|
|
@ -82,13 +84,13 @@ def select_analysts() -> List[AnalystType]:
|
||||||
("selected", "fg:green noinherit"),
|
("selected", "fg:green noinherit"),
|
||||||
("highlighted", "noinherit"),
|
("highlighted", "noinherit"),
|
||||||
("pointer", "noinherit"),
|
("pointer", "noinherit"),
|
||||||
]
|
],
|
||||||
),
|
),
|
||||||
).ask()
|
).ask()
|
||||||
|
|
||||||
if not choices:
|
if not choices:
|
||||||
console.print("\n[red]No analysts selected. Exiting...[/red]")
|
console.print("\n[red]No analysts selected. Exiting...[/red]")
|
||||||
exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
return choices
|
return choices
|
||||||
|
|
||||||
|
|
@ -114,13 +116,13 @@ def select_research_depth() -> int:
|
||||||
("selected", "fg:yellow noinherit"),
|
("selected", "fg:yellow noinherit"),
|
||||||
("highlighted", "fg:yellow noinherit"),
|
("highlighted", "fg:yellow noinherit"),
|
||||||
("pointer", "fg:yellow noinherit"),
|
("pointer", "fg:yellow noinherit"),
|
||||||
]
|
],
|
||||||
),
|
),
|
||||||
).ask()
|
).ask()
|
||||||
|
|
||||||
if choice is None:
|
if choice is None:
|
||||||
console.print("\n[red]No research depth selected. Exiting...[/red]")
|
console.print("\n[red]No research depth selected. Exiting...[/red]")
|
||||||
exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
return choice
|
return choice
|
||||||
|
|
||||||
|
|
@ -200,15 +202,15 @@ def select_shallow_thinking_agent(provider) -> str:
|
||||||
("selected", "fg:magenta noinherit"),
|
("selected", "fg:magenta noinherit"),
|
||||||
("highlighted", "fg:magenta noinherit"),
|
("highlighted", "fg:magenta noinherit"),
|
||||||
("pointer", "fg:magenta noinherit"),
|
("pointer", "fg:magenta noinherit"),
|
||||||
]
|
],
|
||||||
),
|
),
|
||||||
).ask()
|
).ask()
|
||||||
|
|
||||||
if choice is None:
|
if choice is None:
|
||||||
console.print(
|
console.print(
|
||||||
"\n[red]No shallow thinking llm engine selected. Exiting...[/red]"
|
"\n[red]No shallow thinking llm engine selected. Exiting...[/red]",
|
||||||
)
|
)
|
||||||
exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
return choice
|
return choice
|
||||||
|
|
||||||
|
|
@ -292,13 +294,13 @@ def select_deep_thinking_agent(provider) -> str:
|
||||||
("selected", "fg:magenta noinherit"),
|
("selected", "fg:magenta noinherit"),
|
||||||
("highlighted", "fg:magenta noinherit"),
|
("highlighted", "fg:magenta noinherit"),
|
||||||
("pointer", "fg:magenta noinherit"),
|
("pointer", "fg:magenta noinherit"),
|
||||||
]
|
],
|
||||||
),
|
),
|
||||||
).ask()
|
).ask()
|
||||||
|
|
||||||
if choice is None:
|
if choice is None:
|
||||||
console.print("\n[red]No deep thinking llm engine selected. Exiting...[/red]")
|
console.print("\n[red]No deep thinking llm engine selected. Exiting...[/red]")
|
||||||
exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
return choice
|
return choice
|
||||||
|
|
||||||
|
|
@ -326,15 +328,14 @@ def select_llm_provider() -> tuple[str, str]:
|
||||||
("selected", "fg:magenta noinherit"),
|
("selected", "fg:magenta noinherit"),
|
||||||
("highlighted", "fg:magenta noinherit"),
|
("highlighted", "fg:magenta noinherit"),
|
||||||
("pointer", "fg:magenta noinherit"),
|
("pointer", "fg:magenta noinherit"),
|
||||||
]
|
],
|
||||||
),
|
),
|
||||||
).ask()
|
).ask()
|
||||||
|
|
||||||
if choice is None:
|
if choice is None:
|
||||||
console.print("\n[red]no OpenAI backend selected. Exiting...[/red]")
|
console.print("\n[red]no OpenAI backend selected. Exiting...[/red]")
|
||||||
exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
display_name, url = choice
|
display_name, url = choice
|
||||||
print(f"You selected: {display_name}\tURL: {url}")
|
|
||||||
|
|
||||||
return display_name, url
|
return display_name, url
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,7 @@
|
||||||
import os,sys # Multiple imports on one line (Ruff will fix)
|
|
||||||
from typing import List,Dict # Multiple imports on one line
|
|
||||||
|
|
||||||
def poorly_formatted_function(x,y,z): # Missing type hints
|
def poorly_formatted_function(x,y,z): # Missing type hints
|
||||||
"""This function has formatting issues."""
|
"""This function has formatting issues."""
|
||||||
result=x+y*z # Missing spaces around operators
|
result=x+y*z # Missing spaces around operators
|
||||||
unused_variable = 42 # Unused variable (Ruff will detect)
|
|
||||||
if result>100: # Missing spaces
|
if result>100: # Missing spaces
|
||||||
print( "Result is large" ) # Extra spaces in parentheses
|
print( "Result is large" ) # Extra spaces in parentheses
|
||||||
return result
|
return result
|
||||||
|
|
|
||||||
|
|
@ -23,11 +23,11 @@ def run_command(cmd, description=""):
|
||||||
print(f"✅ {description or 'Command completed successfully'}")
|
print(f"✅ {description or 'Command completed successfully'}")
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
print(f"❌ Command failed:")
|
print("❌ Command failed:")
|
||||||
print(result.stderr)
|
print(result.stderr)
|
||||||
return False
|
return False
|
||||||
except subprocess.TimeoutExpired:
|
except subprocess.TimeoutExpired:
|
||||||
print(f"⏱️ Command timed out")
|
print("⏱️ Command timed out")
|
||||||
return False
|
return False
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"❌ Error running command: {e}")
|
print(f"❌ Error running command: {e}")
|
||||||
|
|
@ -73,7 +73,7 @@ def main():
|
||||||
|
|
||||||
# Summary
|
# Summary
|
||||||
print("\n" + "=" * 50)
|
print("\n" + "=" * 50)
|
||||||
print(f"📊 Test Setup Verification Results:")
|
print("📊 Test Setup Verification Results:")
|
||||||
print(f"✅ Successful: {success_count}/{total_tests}")
|
print(f"✅ Successful: {success_count}/{total_tests}")
|
||||||
print(f"❌ Failed: {total_tests - success_count}/{total_tests}")
|
print(f"❌ Failed: {total_tests - success_count}/{total_tests}")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,10 @@
|
||||||
"""Pytest configuration and shared fixtures for TradingAgents tests."""
|
"""Pytest configuration and shared fixtures for TradingAgents tests."""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import pytest
|
|
||||||
import tempfile
|
import tempfile
|
||||||
from unittest.mock import Mock, MagicMock
|
from unittest.mock import Mock
|
||||||
from datetime import date, datetime
|
|
||||||
from typing import Dict, Any
|
import pytest
|
||||||
|
|
||||||
from tradingagents.default_config import DEFAULT_CONFIG
|
from tradingagents.default_config import DEFAULT_CONFIG
|
||||||
|
|
||||||
|
|
@ -22,7 +21,7 @@ def sample_config():
|
||||||
"deep_think_llm": "gpt-4o-mini",
|
"deep_think_llm": "gpt-4o-mini",
|
||||||
"quick_think_llm": "gpt-4o-mini",
|
"quick_think_llm": "gpt-4o-mini",
|
||||||
"project_dir": "/tmp/test_tradingagents",
|
"project_dir": "/tmp/test_tradingagents",
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
@ -174,7 +173,7 @@ def mock_memory():
|
||||||
def pytest_configure(config):
|
def pytest_configure(config):
|
||||||
"""Configure pytest with custom markers."""
|
"""Configure pytest with custom markers."""
|
||||||
config.addinivalue_line(
|
config.addinivalue_line(
|
||||||
"markers", "integration: mark test as integration test (slow)"
|
"markers", "integration: mark test as integration test (slow)",
|
||||||
)
|
)
|
||||||
config.addinivalue_line("markers", "unit: mark test as unit test (fast)")
|
config.addinivalue_line("markers", "unit: mark test as unit test (fast)")
|
||||||
config.addinivalue_line("markers", "api: mark test as requiring API access")
|
config.addinivalue_line("markers", "api: mark test as requiring API access")
|
||||||
|
|
|
||||||
|
|
@ -2,14 +2,14 @@
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import Dict, List, Any
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
class SampleDataFactory:
|
class SampleDataFactory:
|
||||||
"""Factory class for creating sample test data."""
|
"""Factory class for creating sample test data."""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_market_data(ticker: str = "AAPL", days: int = 30) -> Dict[str, Any]:
|
def create_market_data(ticker: str = "AAPL", days: int = 30) -> dict[str, Any]:
|
||||||
"""Create sample market data for testing."""
|
"""Create sample market data for testing."""
|
||||||
base_date = datetime(2024, 5, 1)
|
base_date = datetime(2024, 5, 1)
|
||||||
data = {}
|
data = {}
|
||||||
|
|
@ -36,8 +36,8 @@ class SampleDataFactory:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_finnhub_news_data(
|
def create_finnhub_news_data(
|
||||||
ticker: str = "AAPL", count: int = 10
|
ticker: str = "AAPL", count: int = 10,
|
||||||
) -> Dict[str, List[Dict[str, Any]]]:
|
) -> dict[str, list[dict[str, Any]]]:
|
||||||
"""Create sample FinnHub news data for testing."""
|
"""Create sample FinnHub news data for testing."""
|
||||||
base_date = datetime(2024, 5, 10)
|
base_date = datetime(2024, 5, 10)
|
||||||
data = {}
|
data = {}
|
||||||
|
|
@ -93,7 +93,7 @@ class SampleDataFactory:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_insider_transactions_data(
|
def create_insider_transactions_data(
|
||||||
ticker: str = "AAPL",
|
ticker: str = "AAPL",
|
||||||
) -> Dict[str, List[Dict[str, Any]]]:
|
) -> dict[str, list[dict[str, Any]]]:
|
||||||
"""Create sample insider transactions data for testing."""
|
"""Create sample insider transactions data for testing."""
|
||||||
base_date = datetime(2024, 5, 5)
|
base_date = datetime(2024, 5, 5)
|
||||||
data = {}
|
data = {}
|
||||||
|
|
@ -129,15 +129,15 @@ class SampleDataFactory:
|
||||||
"transactionValue": transaction["shares"] * transaction["price"],
|
"transactionValue": transaction["shares"] * transaction["price"],
|
||||||
"reportingName": transaction["person"],
|
"reportingName": transaction["person"],
|
||||||
"typeOfOwner": "officer",
|
"typeOfOwner": "officer",
|
||||||
}
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_financial_statements_data(
|
def create_financial_statements_data(
|
||||||
ticker: str = "AAPL", period: str = "annual"
|
ticker: str = "AAPL", period: str = "annual",
|
||||||
) -> Dict[str, List[Dict[str, Any]]]:
|
) -> dict[str, list[dict[str, Any]]]:
|
||||||
"""Create sample financial statements data for testing."""
|
"""Create sample financial statements data for testing."""
|
||||||
if period == "annual":
|
if period == "annual":
|
||||||
dates = ["2023-12-31", "2022-12-31", "2021-12-31"]
|
dates = ["2023-12-31", "2022-12-31", "2021-12-31"]
|
||||||
|
|
@ -174,7 +174,7 @@ class SampleDataFactory:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_social_sentiment_data(
|
def create_social_sentiment_data(
|
||||||
ticker: str = "AAPL",
|
ticker: str = "AAPL",
|
||||||
) -> Dict[str, List[Dict[str, Any]]]:
|
) -> dict[str, list[dict[str, Any]]]:
|
||||||
"""Create sample social media sentiment data for testing."""
|
"""Create sample social media sentiment data for testing."""
|
||||||
base_date = datetime(2024, 5, 8)
|
base_date = datetime(2024, 5, 8)
|
||||||
data = {}
|
data = {}
|
||||||
|
|
@ -226,7 +226,7 @@ class SampleDataFactory:
|
||||||
"subreddit": "stocks" if j % 2 else "investing",
|
"subreddit": "stocks" if j % 2 else "investing",
|
||||||
"upvotes": 10 + (j * 5),
|
"upvotes": 10 + (j * 5),
|
||||||
"comments": 3 + j,
|
"comments": 3 + j,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
data[date_str] = daily_posts
|
data[date_str] = daily_posts
|
||||||
|
|
@ -234,7 +234,7 @@ class SampleDataFactory:
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_technical_indicators_data(ticker: str = "AAPL") -> Dict[str, Any]:
|
def create_technical_indicators_data(ticker: str = "AAPL") -> dict[str, Any]:
|
||||||
"""Create sample technical indicators data for testing."""
|
"""Create sample technical indicators data for testing."""
|
||||||
return {
|
return {
|
||||||
"symbol": ticker,
|
"symbol": ticker,
|
||||||
|
|
@ -262,23 +262,23 @@ class SampleDataFactory:
|
||||||
}
|
}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_complete_test_dataset(ticker: str = "AAPL") -> Dict[str, Dict[str, Any]]:
|
def create_complete_test_dataset(ticker: str = "AAPL") -> dict[str, dict[str, Any]]:
|
||||||
"""Create a complete dataset for comprehensive testing."""
|
"""Create a complete dataset for comprehensive testing."""
|
||||||
return {
|
return {
|
||||||
"market_data": SampleDataFactory.create_market_data(ticker),
|
"market_data": SampleDataFactory.create_market_data(ticker),
|
||||||
"news_data": SampleDataFactory.create_finnhub_news_data(ticker),
|
"news_data": SampleDataFactory.create_finnhub_news_data(ticker),
|
||||||
"insider_transactions": SampleDataFactory.create_insider_transactions_data(
|
"insider_transactions": SampleDataFactory.create_insider_transactions_data(
|
||||||
ticker
|
ticker,
|
||||||
),
|
),
|
||||||
"financial_annual": SampleDataFactory.create_financial_statements_data(
|
"financial_annual": SampleDataFactory.create_financial_statements_data(
|
||||||
ticker, "annual"
|
ticker, "annual",
|
||||||
),
|
),
|
||||||
"financial_quarterly": SampleDataFactory.create_financial_statements_data(
|
"financial_quarterly": SampleDataFactory.create_financial_statements_data(
|
||||||
ticker, "quarterly"
|
ticker, "quarterly",
|
||||||
),
|
),
|
||||||
"social_sentiment": SampleDataFactory.create_social_sentiment_data(ticker),
|
"social_sentiment": SampleDataFactory.create_social_sentiment_data(ticker),
|
||||||
"technical_indicators": SampleDataFactory.create_technical_indicators_data(
|
"technical_indicators": SampleDataFactory.create_technical_indicators_data(
|
||||||
ticker
|
ticker,
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -343,7 +343,7 @@ def save_sample_data_to_files(base_path: str, ticker: str = "AAPL") -> None:
|
||||||
|
|
||||||
# Save quarterly data separately
|
# Save quarterly data separately
|
||||||
quarterly_path = os.path.join(
|
quarterly_path = os.path.join(
|
||||||
finnhub_path, "fin_as_reported", f"{ticker}_quarterly_data_formatted.json"
|
finnhub_path, "fin_as_reported", f"{ticker}_quarterly_data_formatted.json",
|
||||||
)
|
)
|
||||||
with open(quarterly_path, "w") as f:
|
with open(quarterly_path, "w") as f:
|
||||||
json.dump(dataset["financial_quarterly"], f, indent=2)
|
json.dump(dataset["financial_quarterly"], f, indent=2)
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,11 @@
|
||||||
"""Integration tests for the full TradingAgents workflow."""
|
"""Integration tests for the full TradingAgents workflow."""
|
||||||
|
|
||||||
import pytest
|
from unittest.mock import Mock, patch
|
||||||
import os
|
|
||||||
import tempfile
|
import pytest
|
||||||
from unittest.mock import Mock, patch, MagicMock
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
|
||||||
from tradingagents.default_config import DEFAULT_CONFIG
|
from tradingagents.default_config import DEFAULT_CONFIG
|
||||||
|
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.integration
|
@pytest.mark.integration
|
||||||
|
|
@ -26,14 +24,14 @@ class TestFullWorkflowIntegration:
|
||||||
"deep_think_llm": "gpt-4o-mini",
|
"deep_think_llm": "gpt-4o-mini",
|
||||||
"quick_think_llm": "gpt-4o-mini",
|
"quick_think_llm": "gpt-4o-mini",
|
||||||
"project_dir": temp_data_dir,
|
"project_dir": temp_data_dir,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
return config
|
return config
|
||||||
|
|
||||||
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
||||||
@patch("tradingagents.graph.trading_graph.Toolkit")
|
@patch("tradingagents.graph.trading_graph.Toolkit")
|
||||||
def test_end_to_end_trading_workflow(
|
def test_end_to_end_trading_workflow(
|
||||||
self, mock_toolkit, mock_chat_openai, integration_config
|
self, mock_toolkit, mock_chat_openai, integration_config,
|
||||||
):
|
):
|
||||||
"""Test complete end-to-end trading workflow."""
|
"""Test complete end-to-end trading workflow."""
|
||||||
# Setup mocks
|
# Setup mocks
|
||||||
|
|
@ -69,15 +67,14 @@ class TestFullWorkflowIntegration:
|
||||||
"company_of_interest": "AAPL",
|
"company_of_interest": "AAPL",
|
||||||
"trade_date": "2024-05-10",
|
"trade_date": "2024-05-10",
|
||||||
"messages": [],
|
"messages": [],
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
trading_graph.propagator.get_graph_args = Mock(return_value={})
|
trading_graph.propagator.get_graph_args = Mock(return_value={})
|
||||||
trading_graph.signal_processor.process_signal = Mock(return_value="BUY")
|
trading_graph.signal_processor.process_signal = Mock(return_value="BUY")
|
||||||
|
|
||||||
# Execute the full workflow
|
# Execute the full workflow
|
||||||
with patch("builtins.open", create=True):
|
with patch("builtins.open", create=True), patch("json.dump"):
|
||||||
with patch("json.dump"):
|
final_state, decision = trading_graph.propagate("AAPL", "2024-05-10")
|
||||||
final_state, decision = trading_graph.propagate("AAPL", "2024-05-10")
|
|
||||||
|
|
||||||
# Verify the workflow completed successfully
|
# Verify the workflow completed successfully
|
||||||
assert final_state is not None
|
assert final_state is not None
|
||||||
|
|
@ -89,7 +86,7 @@ class TestFullWorkflowIntegration:
|
||||||
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
||||||
@patch("tradingagents.graph.trading_graph.Toolkit")
|
@patch("tradingagents.graph.trading_graph.Toolkit")
|
||||||
def test_multiple_analysts_integration(
|
def test_multiple_analysts_integration(
|
||||||
self, mock_toolkit, mock_chat_openai, integration_config
|
self, mock_toolkit, mock_chat_openai, integration_config,
|
||||||
):
|
):
|
||||||
"""Test integration with different analyst combinations."""
|
"""Test integration with different analyst combinations."""
|
||||||
analyst_combinations = [
|
analyst_combinations = [
|
||||||
|
|
@ -117,7 +114,7 @@ class TestFullWorkflowIntegration:
|
||||||
with patch("tradingagents.graph.trading_graph.set_config"):
|
with patch("tradingagents.graph.trading_graph.set_config"):
|
||||||
# Test each analyst combination
|
# Test each analyst combination
|
||||||
trading_graph = TradingAgentsGraph(
|
trading_graph = TradingAgentsGraph(
|
||||||
selected_analysts=analysts, config=integration_config
|
selected_analysts=analysts, config=integration_config,
|
||||||
)
|
)
|
||||||
trading_graph.graph = mock_graph
|
trading_graph.graph = mock_graph
|
||||||
|
|
||||||
|
|
@ -127,19 +124,18 @@ class TestFullWorkflowIntegration:
|
||||||
"company_of_interest": "TSLA",
|
"company_of_interest": "TSLA",
|
||||||
"trade_date": "2024-05-15",
|
"trade_date": "2024-05-15",
|
||||||
"messages": [],
|
"messages": [],
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
trading_graph.propagator.get_graph_args = Mock(return_value={})
|
trading_graph.propagator.get_graph_args = Mock(return_value={})
|
||||||
trading_graph.signal_processor.process_signal = Mock(
|
trading_graph.signal_processor.process_signal = Mock(
|
||||||
return_value="HOLD"
|
return_value="HOLD",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Execute
|
# Execute
|
||||||
with patch("builtins.open", create=True):
|
with patch("builtins.open", create=True), patch("json.dump"):
|
||||||
with patch("json.dump"):
|
final_state, decision = trading_graph.propagate(
|
||||||
final_state, decision = trading_graph.propagate(
|
"TSLA", "2024-05-15",
|
||||||
"TSLA", "2024-05-15"
|
)
|
||||||
)
|
|
||||||
|
|
||||||
# Verify
|
# Verify
|
||||||
assert final_state is not None
|
assert final_state is not None
|
||||||
|
|
@ -148,7 +144,7 @@ class TestFullWorkflowIntegration:
|
||||||
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
||||||
@patch("tradingagents.graph.trading_graph.Toolkit")
|
@patch("tradingagents.graph.trading_graph.Toolkit")
|
||||||
def test_memory_and_reflection_integration(
|
def test_memory_and_reflection_integration(
|
||||||
self, mock_toolkit, mock_chat_openai, integration_config
|
self, mock_toolkit, mock_chat_openai, integration_config,
|
||||||
):
|
):
|
||||||
"""Test integration of memory and reflection components."""
|
"""Test integration of memory and reflection components."""
|
||||||
# Setup
|
# Setup
|
||||||
|
|
@ -165,7 +161,7 @@ class TestFullWorkflowIntegration:
|
||||||
mock_graph.invoke.return_value = mock_final_state
|
mock_graph.invoke.return_value = mock_final_state
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"tradingagents.graph.trading_graph.FinancialSituationMemory"
|
"tradingagents.graph.trading_graph.FinancialSituationMemory",
|
||||||
) as mock_memory:
|
) as mock_memory:
|
||||||
mock_memory_instance = Mock()
|
mock_memory_instance = Mock()
|
||||||
mock_memory.return_value = mock_memory_instance
|
mock_memory.return_value = mock_memory_instance
|
||||||
|
|
@ -180,11 +176,11 @@ class TestFullWorkflowIntegration:
|
||||||
"company_of_interest": "NVDA",
|
"company_of_interest": "NVDA",
|
||||||
"trade_date": "2024-05-20",
|
"trade_date": "2024-05-20",
|
||||||
"messages": [],
|
"messages": [],
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
trading_graph.propagator.get_graph_args = Mock(return_value={})
|
trading_graph.propagator.get_graph_args = Mock(return_value={})
|
||||||
trading_graph.signal_processor.process_signal = Mock(
|
trading_graph.signal_processor.process_signal = Mock(
|
||||||
return_value="SELL"
|
return_value="SELL",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Mock reflection methods
|
# Mock reflection methods
|
||||||
|
|
@ -195,9 +191,8 @@ class TestFullWorkflowIntegration:
|
||||||
trading_graph.reflector.reflect_risk_manager = Mock()
|
trading_graph.reflector.reflect_risk_manager = Mock()
|
||||||
|
|
||||||
# Execute workflow
|
# Execute workflow
|
||||||
with patch("builtins.open", create=True):
|
with patch("builtins.open", create=True), patch("json.dump"):
|
||||||
with patch("json.dump"):
|
final_state, decision = trading_graph.propagate("NVDA", "2024-05-20")
|
||||||
final_state, decision = trading_graph.propagate("NVDA", "2024-05-20")
|
|
||||||
|
|
||||||
# Test reflection and memory update
|
# Test reflection and memory update
|
||||||
returns_losses = {"return": -0.03, "loss": -0.08}
|
returns_losses = {"return": -0.03, "loss": -0.08}
|
||||||
|
|
@ -213,7 +208,7 @@ class TestFullWorkflowIntegration:
|
||||||
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
||||||
@patch("tradingagents.graph.trading_graph.Toolkit")
|
@patch("tradingagents.graph.trading_graph.Toolkit")
|
||||||
def test_debug_mode_integration(
|
def test_debug_mode_integration(
|
||||||
self, mock_toolkit, mock_chat_openai, integration_config
|
self, mock_toolkit, mock_chat_openai, integration_config,
|
||||||
):
|
):
|
||||||
"""Test integration in debug mode."""
|
"""Test integration in debug mode."""
|
||||||
# Setup
|
# Setup
|
||||||
|
|
@ -233,7 +228,7 @@ class TestFullWorkflowIntegration:
|
||||||
self._create_mock_final_state(), # Final chunk
|
self._create_mock_final_state(), # Final chunk
|
||||||
]
|
]
|
||||||
for chunk in mock_chunks:
|
for chunk in mock_chunks:
|
||||||
if "messages" in chunk and chunk["messages"]:
|
if chunk.get("messages"):
|
||||||
for msg in chunk["messages"]:
|
for msg in chunk["messages"]:
|
||||||
if hasattr(msg, "pretty_print"):
|
if hasattr(msg, "pretty_print"):
|
||||||
msg.pretty_print = Mock()
|
msg.pretty_print = Mock()
|
||||||
|
|
@ -245,7 +240,7 @@ class TestFullWorkflowIntegration:
|
||||||
with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"):
|
with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"):
|
||||||
with patch("tradingagents.graph.trading_graph.set_config"):
|
with patch("tradingagents.graph.trading_graph.set_config"):
|
||||||
trading_graph = TradingAgentsGraph(
|
trading_graph = TradingAgentsGraph(
|
||||||
debug=True, config=integration_config
|
debug=True, config=integration_config,
|
||||||
)
|
)
|
||||||
trading_graph.graph = mock_graph
|
trading_graph.graph = mock_graph
|
||||||
|
|
||||||
|
|
@ -255,15 +250,14 @@ class TestFullWorkflowIntegration:
|
||||||
"company_of_interest": "AMZN",
|
"company_of_interest": "AMZN",
|
||||||
"trade_date": "2024-05-25",
|
"trade_date": "2024-05-25",
|
||||||
"messages": [],
|
"messages": [],
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
trading_graph.propagator.get_graph_args = Mock(return_value={})
|
trading_graph.propagator.get_graph_args = Mock(return_value={})
|
||||||
trading_graph.signal_processor.process_signal = Mock(return_value="BUY")
|
trading_graph.signal_processor.process_signal = Mock(return_value="BUY")
|
||||||
|
|
||||||
# Execute in debug mode
|
# Execute in debug mode
|
||||||
with patch("builtins.open", create=True):
|
with patch("builtins.open", create=True), patch("json.dump"):
|
||||||
with patch("json.dump"):
|
final_state, decision = trading_graph.propagate("AMZN", "2024-05-25")
|
||||||
final_state, decision = trading_graph.propagate("AMZN", "2024-05-25")
|
|
||||||
|
|
||||||
# Verify debug mode was used
|
# Verify debug mode was used
|
||||||
mock_graph.stream.assert_called_once()
|
mock_graph.stream.assert_called_once()
|
||||||
|
|
@ -271,7 +265,7 @@ class TestFullWorkflowIntegration:
|
||||||
assert decision == "BUY"
|
assert decision == "BUY"
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"ticker,date",
|
("ticker", "date"),
|
||||||
[
|
[
|
||||||
("AAPL", "2024-01-15"),
|
("AAPL", "2024-01-15"),
|
||||||
("TSLA", "2024-02-20"),
|
("TSLA", "2024-02-20"),
|
||||||
|
|
@ -282,7 +276,7 @@ class TestFullWorkflowIntegration:
|
||||||
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
||||||
@patch("tradingagents.graph.trading_graph.Toolkit")
|
@patch("tradingagents.graph.trading_graph.Toolkit")
|
||||||
def test_multiple_stocks_integration(
|
def test_multiple_stocks_integration(
|
||||||
self, mock_toolkit, mock_chat_openai, ticker, date, integration_config
|
self, mock_toolkit, mock_chat_openai, ticker, date, integration_config,
|
||||||
):
|
):
|
||||||
"""Test integration with different stocks and dates."""
|
"""Test integration with different stocks and dates."""
|
||||||
# Setup
|
# Setup
|
||||||
|
|
@ -309,17 +303,16 @@ class TestFullWorkflowIntegration:
|
||||||
"company_of_interest": ticker,
|
"company_of_interest": ticker,
|
||||||
"trade_date": date,
|
"trade_date": date,
|
||||||
"messages": [],
|
"messages": [],
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
trading_graph.propagator.get_graph_args = Mock(return_value={})
|
trading_graph.propagator.get_graph_args = Mock(return_value={})
|
||||||
trading_graph.signal_processor.process_signal = Mock(
|
trading_graph.signal_processor.process_signal = Mock(
|
||||||
return_value="HOLD"
|
return_value="HOLD",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Execute
|
# Execute
|
||||||
with patch("builtins.open", create=True):
|
with patch("builtins.open", create=True), patch("json.dump"):
|
||||||
with patch("json.dump"):
|
final_state, decision = trading_graph.propagate(ticker, date)
|
||||||
final_state, decision = trading_graph.propagate(ticker, date)
|
|
||||||
|
|
||||||
# Verify
|
# Verify
|
||||||
assert final_state["company_of_interest"] == ticker
|
assert final_state["company_of_interest"] == ticker
|
||||||
|
|
@ -389,7 +382,7 @@ class TestPerformanceIntegration:
|
||||||
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
||||||
@patch("tradingagents.graph.trading_graph.Toolkit")
|
@patch("tradingagents.graph.trading_graph.Toolkit")
|
||||||
def test_multiple_consecutive_runs(
|
def test_multiple_consecutive_runs(
|
||||||
self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir
|
self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir,
|
||||||
):
|
):
|
||||||
"""Test multiple consecutive trading decisions."""
|
"""Test multiple consecutive trading decisions."""
|
||||||
sample_config["project_dir"] = temp_data_dir
|
sample_config["project_dir"] = temp_data_dir
|
||||||
|
|
@ -455,9 +448,8 @@ class TestPerformanceIntegration:
|
||||||
mock_final_state["final_trade_decision"]
|
mock_final_state["final_trade_decision"]
|
||||||
)
|
)
|
||||||
|
|
||||||
with patch("builtins.open", create=True):
|
with patch("builtins.open", create=True), patch("json.dump"):
|
||||||
with patch("json.dump"):
|
final_state, decision = trading_graph.propagate(ticker, date)
|
||||||
final_state, decision = trading_graph.propagate(ticker, date)
|
|
||||||
|
|
||||||
decisions.append(decision)
|
decisions.append(decision)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,9 @@
|
||||||
"""Unit tests for market analyst agent."""
|
"""Unit tests for market analyst agent."""
|
||||||
|
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import Mock, patch, MagicMock
|
from langchain_core.messages import HumanMessage
|
||||||
from langchain_core.messages import HumanMessage, AIMessage
|
|
||||||
|
|
||||||
from tradingagents.agents.analysts.market_analyst import create_market_analyst
|
from tradingagents.agents.analysts.market_analyst import create_market_analyst
|
||||||
|
|
||||||
|
|
@ -16,7 +17,7 @@ class TestMarketAnalyst:
|
||||||
assert callable(analyst_node)
|
assert callable(analyst_node)
|
||||||
|
|
||||||
def test_market_analyst_node_basic_execution(
|
def test_market_analyst_node_basic_execution(
|
||||||
self, mock_llm, mock_toolkit, sample_agent_state
|
self, mock_llm, mock_toolkit, sample_agent_state,
|
||||||
):
|
):
|
||||||
"""Test basic execution of market analyst node."""
|
"""Test basic execution of market analyst node."""
|
||||||
# Setup
|
# Setup
|
||||||
|
|
@ -38,7 +39,7 @@ class TestMarketAnalyst:
|
||||||
assert result["market_report"] == "Market analysis complete"
|
assert result["market_report"] == "Market analysis complete"
|
||||||
|
|
||||||
def test_market_analyst_uses_online_tools_when_configured(
|
def test_market_analyst_uses_online_tools_when_configured(
|
||||||
self, mock_llm, mock_toolkit, sample_agent_state
|
self, mock_llm, mock_toolkit, sample_agent_state,
|
||||||
):
|
):
|
||||||
"""Test that analyst uses online tools when configured."""
|
"""Test that analyst uses online tools when configured."""
|
||||||
# Setup
|
# Setup
|
||||||
|
|
@ -54,7 +55,7 @@ class TestMarketAnalyst:
|
||||||
analyst_node = create_market_analyst(mock_llm, mock_toolkit)
|
analyst_node = create_market_analyst(mock_llm, mock_toolkit)
|
||||||
|
|
||||||
# Execute
|
# Execute
|
||||||
result = analyst_node(sample_agent_state)
|
analyst_node(sample_agent_state)
|
||||||
|
|
||||||
# Verify tools were bound correctly
|
# Verify tools were bound correctly
|
||||||
mock_llm.bind_tools.assert_called_once()
|
mock_llm.bind_tools.assert_called_once()
|
||||||
|
|
@ -63,7 +64,7 @@ class TestMarketAnalyst:
|
||||||
assert "get_YFin_data_online" in str(tool_names) or len(bound_tools) == 2
|
assert "get_YFin_data_online" in str(tool_names) or len(bound_tools) == 2
|
||||||
|
|
||||||
def test_market_analyst_uses_offline_tools_when_configured(
|
def test_market_analyst_uses_offline_tools_when_configured(
|
||||||
self, mock_llm, mock_toolkit, sample_agent_state
|
self, mock_llm, mock_toolkit, sample_agent_state,
|
||||||
):
|
):
|
||||||
"""Test that analyst uses offline tools when configured."""
|
"""Test that analyst uses offline tools when configured."""
|
||||||
# Setup
|
# Setup
|
||||||
|
|
@ -79,7 +80,7 @@ class TestMarketAnalyst:
|
||||||
analyst_node = create_market_analyst(mock_llm, mock_toolkit)
|
analyst_node = create_market_analyst(mock_llm, mock_toolkit)
|
||||||
|
|
||||||
# Execute
|
# Execute
|
||||||
result = analyst_node(sample_agent_state)
|
analyst_node(sample_agent_state)
|
||||||
|
|
||||||
# Verify tools were bound correctly
|
# Verify tools were bound correctly
|
||||||
mock_llm.bind_tools.assert_called_once()
|
mock_llm.bind_tools.assert_called_once()
|
||||||
|
|
@ -87,7 +88,7 @@ class TestMarketAnalyst:
|
||||||
assert len(bound_tools) == 2 # Should have 2 offline tools
|
assert len(bound_tools) == 2 # Should have 2 offline tools
|
||||||
|
|
||||||
def test_market_analyst_processes_state_variables(
|
def test_market_analyst_processes_state_variables(
|
||||||
self, mock_llm, mock_toolkit, sample_agent_state
|
self, mock_llm, mock_toolkit, sample_agent_state,
|
||||||
):
|
):
|
||||||
"""Test that market analyst correctly processes state variables."""
|
"""Test that market analyst correctly processes state variables."""
|
||||||
# Setup
|
# Setup
|
||||||
|
|
@ -111,7 +112,7 @@ class TestMarketAnalyst:
|
||||||
assert result["market_report"] == "Analysis for AAPL on 2024-05-10"
|
assert result["market_report"] == "Analysis for AAPL on 2024-05-10"
|
||||||
|
|
||||||
def test_market_analyst_handles_empty_tool_calls(
|
def test_market_analyst_handles_empty_tool_calls(
|
||||||
self, mock_llm, mock_toolkit, sample_agent_state
|
self, mock_llm, mock_toolkit, sample_agent_state,
|
||||||
):
|
):
|
||||||
"""Test handling when no tool calls are made."""
|
"""Test handling when no tool calls are made."""
|
||||||
# Setup
|
# Setup
|
||||||
|
|
@ -131,7 +132,7 @@ class TestMarketAnalyst:
|
||||||
assert result["messages"] == [mock_result]
|
assert result["messages"] == [mock_result]
|
||||||
|
|
||||||
def test_market_analyst_with_tool_calls(
|
def test_market_analyst_with_tool_calls(
|
||||||
self, mock_llm, mock_toolkit, sample_agent_state
|
self, mock_llm, mock_toolkit, sample_agent_state,
|
||||||
):
|
):
|
||||||
"""Test handling when tool calls are present."""
|
"""Test handling when tool calls are present."""
|
||||||
# Setup
|
# Setup
|
||||||
|
|
@ -152,7 +153,7 @@ class TestMarketAnalyst:
|
||||||
|
|
||||||
@pytest.mark.parametrize("online_tools", [True, False])
|
@pytest.mark.parametrize("online_tools", [True, False])
|
||||||
def test_market_analyst_tool_configuration(
|
def test_market_analyst_tool_configuration(
|
||||||
self, mock_llm, mock_toolkit, sample_agent_state, online_tools
|
self, mock_llm, mock_toolkit, sample_agent_state, online_tools,
|
||||||
):
|
):
|
||||||
"""Test tool configuration for both online and offline modes."""
|
"""Test tool configuration for both online and offline modes."""
|
||||||
# Setup
|
# Setup
|
||||||
|
|
@ -194,15 +195,15 @@ class TestMarketAnalystIntegration:
|
||||||
mock_result = Mock()
|
mock_result = Mock()
|
||||||
mock_result.content = """
|
mock_result.content = """
|
||||||
# Market Analysis for TSLA (2024-05-15)
|
# Market Analysis for TSLA (2024-05-15)
|
||||||
|
|
||||||
## Technical Analysis
|
## Technical Analysis
|
||||||
- RSI: 65 (slightly overbought)
|
- RSI: 65 (slightly overbought)
|
||||||
- MACD: Bullish crossover
|
- MACD: Bullish crossover
|
||||||
- 50-day SMA: Trending upward
|
- 50-day SMA: Trending upward
|
||||||
|
|
||||||
## Volume Analysis
|
## Volume Analysis
|
||||||
- Above average volume suggests strong interest
|
- Above average volume suggests strong interest
|
||||||
|
|
||||||
| Indicator | Value | Signal |
|
| Indicator | Value | Signal |
|
||||||
|-----------|-------|--------|
|
|-----------|-------|--------|
|
||||||
| RSI | 65 | Neutral |
|
| RSI | 65 | Neutral |
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,9 @@
|
||||||
"""Unit tests for FinnHub utilities."""
|
"""Unit tests for FinnHub utilities."""
|
||||||
|
|
||||||
import pytest
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import tempfile
|
|
||||||
from unittest.mock import patch, mock_open, Mock
|
import pytest
|
||||||
|
|
||||||
from tradingagents.dataflows.finnhub_utils import get_data_in_range
|
from tradingagents.dataflows.finnhub_utils import get_data_in_range
|
||||||
|
|
||||||
|
|
@ -191,7 +190,7 @@ class TestFinnhubUtils:
|
||||||
|
|
||||||
# Test without period
|
# Test without period
|
||||||
expected_path_no_period = os.path.join(
|
expected_path_no_period = os.path.join(
|
||||||
temp_data_dir, "finnhub_data", data_type, f"{ticker}_data_formatted.json"
|
temp_data_dir, "finnhub_data", data_type, f"{ticker}_data_formatted.json",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Test with period
|
# Test with period
|
||||||
|
|
@ -238,7 +237,7 @@ class TestFinnhubUtils:
|
||||||
assert len(result2) == 1
|
assert len(result2) == 1
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"data_type,period",
|
("data_type", "period"),
|
||||||
[
|
[
|
||||||
("news_data", None),
|
("news_data", None),
|
||||||
("insider_trans", None),
|
("insider_trans", None),
|
||||||
|
|
@ -249,7 +248,7 @@ class TestFinnhubUtils:
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_get_data_in_range_various_data_types(
|
def test_get_data_in_range_various_data_types(
|
||||||
self, temp_data_dir, data_type, period
|
self, temp_data_dir, data_type, period,
|
||||||
):
|
):
|
||||||
"""Test get_data_in_range with various data types."""
|
"""Test get_data_in_range with various data types."""
|
||||||
ticker = "TEST"
|
ticker = "TEST"
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,10 @@
|
||||||
"""Unit tests for TradingAgentsGraph."""
|
"""Unit tests for TradingAgentsGraph."""
|
||||||
|
|
||||||
|
from unittest.mock import Mock, mock_open, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import os
|
|
||||||
from unittest.mock import Mock, patch, MagicMock
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||||
from tradingagents.default_config import DEFAULT_CONFIG
|
|
||||||
|
|
||||||
|
|
||||||
class TestTradingAgentsGraph:
|
class TestTradingAgentsGraph:
|
||||||
|
|
@ -47,7 +45,7 @@ class TestTradingAgentsGraph:
|
||||||
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
||||||
@patch("tradingagents.graph.trading_graph.Toolkit")
|
@patch("tradingagents.graph.trading_graph.Toolkit")
|
||||||
def test_init_with_debug(
|
def test_init_with_debug(
|
||||||
self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir
|
self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir,
|
||||||
):
|
):
|
||||||
"""Test initialization with debug mode enabled."""
|
"""Test initialization with debug mode enabled."""
|
||||||
sample_config["project_dir"] = temp_data_dir
|
sample_config["project_dir"] = temp_data_dir
|
||||||
|
|
@ -65,7 +63,7 @@ class TestTradingAgentsGraph:
|
||||||
@patch("tradingagents.graph.trading_graph.ChatAnthropic")
|
@patch("tradingagents.graph.trading_graph.ChatAnthropic")
|
||||||
@patch("tradingagents.graph.trading_graph.Toolkit")
|
@patch("tradingagents.graph.trading_graph.Toolkit")
|
||||||
def test_init_with_anthropic(
|
def test_init_with_anthropic(
|
||||||
self, mock_toolkit, mock_chat_anthropic, sample_config, temp_data_dir
|
self, mock_toolkit, mock_chat_anthropic, sample_config, temp_data_dir,
|
||||||
):
|
):
|
||||||
"""Test initialization with Anthropic LLM provider."""
|
"""Test initialization with Anthropic LLM provider."""
|
||||||
sample_config["project_dir"] = temp_data_dir
|
sample_config["project_dir"] = temp_data_dir
|
||||||
|
|
@ -77,14 +75,14 @@ class TestTradingAgentsGraph:
|
||||||
|
|
||||||
with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"):
|
with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"):
|
||||||
with patch("tradingagents.graph.trading_graph.set_config"):
|
with patch("tradingagents.graph.trading_graph.set_config"):
|
||||||
graph = TradingAgentsGraph(config=sample_config)
|
TradingAgentsGraph(config=sample_config)
|
||||||
|
|
||||||
assert mock_chat_anthropic.call_count == 2
|
assert mock_chat_anthropic.call_count == 2
|
||||||
|
|
||||||
@patch("tradingagents.graph.trading_graph.ChatGoogleGenerativeAI")
|
@patch("tradingagents.graph.trading_graph.ChatGoogleGenerativeAI")
|
||||||
@patch("tradingagents.graph.trading_graph.Toolkit")
|
@patch("tradingagents.graph.trading_graph.Toolkit")
|
||||||
def test_init_with_google(
|
def test_init_with_google(
|
||||||
self, mock_toolkit, mock_chat_google, sample_config, temp_data_dir
|
self, mock_toolkit, mock_chat_google, sample_config, temp_data_dir,
|
||||||
):
|
):
|
||||||
"""Test initialization with Google LLM provider."""
|
"""Test initialization with Google LLM provider."""
|
||||||
sample_config["project_dir"] = temp_data_dir
|
sample_config["project_dir"] = temp_data_dir
|
||||||
|
|
@ -96,13 +94,13 @@ class TestTradingAgentsGraph:
|
||||||
|
|
||||||
with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"):
|
with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"):
|
||||||
with patch("tradingagents.graph.trading_graph.set_config"):
|
with patch("tradingagents.graph.trading_graph.set_config"):
|
||||||
graph = TradingAgentsGraph(config=sample_config)
|
TradingAgentsGraph(config=sample_config)
|
||||||
|
|
||||||
assert mock_chat_google.call_count == 2
|
assert mock_chat_google.call_count == 2
|
||||||
|
|
||||||
@patch("tradingagents.graph.trading_graph.Toolkit")
|
@patch("tradingagents.graph.trading_graph.Toolkit")
|
||||||
def test_init_unsupported_llm_provider(
|
def test_init_unsupported_llm_provider(
|
||||||
self, mock_toolkit, sample_config, temp_data_dir
|
self, mock_toolkit, sample_config, temp_data_dir,
|
||||||
):
|
):
|
||||||
"""Test initialization with unsupported LLM provider raises error."""
|
"""Test initialization with unsupported LLM provider raises error."""
|
||||||
sample_config["project_dir"] = temp_data_dir
|
sample_config["project_dir"] = temp_data_dir
|
||||||
|
|
@ -117,7 +115,7 @@ class TestTradingAgentsGraph:
|
||||||
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
||||||
@patch("tradingagents.graph.trading_graph.Toolkit")
|
@patch("tradingagents.graph.trading_graph.Toolkit")
|
||||||
def test_create_tool_nodes(
|
def test_create_tool_nodes(
|
||||||
self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir
|
self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir,
|
||||||
):
|
):
|
||||||
"""Test creation of tool nodes."""
|
"""Test creation of tool nodes."""
|
||||||
sample_config["project_dir"] = temp_data_dir
|
sample_config["project_dir"] = temp_data_dir
|
||||||
|
|
@ -145,7 +143,7 @@ class TestTradingAgentsGraph:
|
||||||
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
||||||
@patch("tradingagents.graph.trading_graph.Toolkit")
|
@patch("tradingagents.graph.trading_graph.Toolkit")
|
||||||
def test_propagate_basic(
|
def test_propagate_basic(
|
||||||
self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir
|
self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir,
|
||||||
):
|
):
|
||||||
"""Test basic propagate functionality."""
|
"""Test basic propagate functionality."""
|
||||||
sample_config["project_dir"] = temp_data_dir
|
sample_config["project_dir"] = temp_data_dir
|
||||||
|
|
@ -190,15 +188,14 @@ class TestTradingAgentsGraph:
|
||||||
|
|
||||||
# Mock the propagator and signal processor
|
# Mock the propagator and signal processor
|
||||||
graph.propagator.create_initial_state = Mock(
|
graph.propagator.create_initial_state = Mock(
|
||||||
return_value={"test": "state"}
|
return_value={"test": "state"},
|
||||||
)
|
)
|
||||||
graph.propagator.get_graph_args = Mock(return_value={})
|
graph.propagator.get_graph_args = Mock(return_value={})
|
||||||
graph.signal_processor.process_signal = Mock(return_value="HOLD")
|
graph.signal_processor.process_signal = Mock(return_value="HOLD")
|
||||||
|
|
||||||
# Execute
|
# Execute
|
||||||
with patch("builtins.open", create=True):
|
with patch("builtins.open", create=True), patch("json.dump"):
|
||||||
with patch("json.dump"):
|
final_state, decision = graph.propagate("AAPL", "2024-05-10")
|
||||||
final_state, decision = graph.propagate("AAPL", "2024-05-10")
|
|
||||||
|
|
||||||
# Verify
|
# Verify
|
||||||
assert final_state == mock_final_state
|
assert final_state == mock_final_state
|
||||||
|
|
@ -209,7 +206,7 @@ class TestTradingAgentsGraph:
|
||||||
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
||||||
@patch("tradingagents.graph.trading_graph.Toolkit")
|
@patch("tradingagents.graph.trading_graph.Toolkit")
|
||||||
def test_propagate_debug_mode(
|
def test_propagate_debug_mode(
|
||||||
self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir
|
self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir,
|
||||||
):
|
):
|
||||||
"""Test propagate in debug mode."""
|
"""Test propagate in debug mode."""
|
||||||
sample_config["project_dir"] = temp_data_dir
|
sample_config["project_dir"] = temp_data_dir
|
||||||
|
|
@ -231,15 +228,14 @@ class TestTradingAgentsGraph:
|
||||||
|
|
||||||
# Mock other components
|
# Mock other components
|
||||||
graph.propagator.create_initial_state = Mock(
|
graph.propagator.create_initial_state = Mock(
|
||||||
return_value={"test": "state"}
|
return_value={"test": "state"},
|
||||||
)
|
)
|
||||||
graph.propagator.get_graph_args = Mock(return_value={})
|
graph.propagator.get_graph_args = Mock(return_value={})
|
||||||
graph.signal_processor.process_signal = Mock(return_value="BUY")
|
graph.signal_processor.process_signal = Mock(return_value="BUY")
|
||||||
|
|
||||||
# Execute
|
# Execute
|
||||||
with patch("builtins.open", create=True):
|
with patch("builtins.open", create=True), patch("json.dump"):
|
||||||
with patch("json.dump"):
|
final_state, decision = graph.propagate("TSLA", "2024-05-15")
|
||||||
final_state, decision = graph.propagate("TSLA", "2024-05-15")
|
|
||||||
|
|
||||||
# Verify debug mode was used
|
# Verify debug mode was used
|
||||||
mock_graph.stream.assert_called_once()
|
mock_graph.stream.assert_called_once()
|
||||||
|
|
@ -249,7 +245,7 @@ class TestTradingAgentsGraph:
|
||||||
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
||||||
@patch("tradingagents.graph.trading_graph.Toolkit")
|
@patch("tradingagents.graph.trading_graph.Toolkit")
|
||||||
def test_log_state(
|
def test_log_state(
|
||||||
self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir
|
self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir,
|
||||||
):
|
):
|
||||||
"""Test state logging functionality."""
|
"""Test state logging functionality."""
|
||||||
sample_config["project_dir"] = temp_data_dir
|
sample_config["project_dir"] = temp_data_dir
|
||||||
|
|
@ -291,10 +287,9 @@ class TestTradingAgentsGraph:
|
||||||
}
|
}
|
||||||
|
|
||||||
# Mock file operations
|
# Mock file operations
|
||||||
with patch("pathlib.Path.mkdir"):
|
with patch("pathlib.Path.mkdir"), patch("builtins.open", mock_open()):
|
||||||
with patch("builtins.open", mock_open()) as mock_file:
|
with patch("json.dump"):
|
||||||
with patch("json.dump") as mock_json_dump:
|
graph._log_state("2024-05-20", final_state)
|
||||||
graph._log_state("2024-05-20", final_state)
|
|
||||||
|
|
||||||
# Verify logging occurred
|
# Verify logging occurred
|
||||||
assert "2024-05-20" in graph.log_states_dict
|
assert "2024-05-20" in graph.log_states_dict
|
||||||
|
|
@ -305,7 +300,7 @@ class TestTradingAgentsGraph:
|
||||||
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
||||||
@patch("tradingagents.graph.trading_graph.Toolkit")
|
@patch("tradingagents.graph.trading_graph.Toolkit")
|
||||||
def test_reflect_and_remember(
|
def test_reflect_and_remember(
|
||||||
self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir
|
self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir,
|
||||||
):
|
):
|
||||||
"""Test reflection and memory update functionality."""
|
"""Test reflection and memory update functionality."""
|
||||||
sample_config["project_dir"] = temp_data_dir
|
sample_config["project_dir"] = temp_data_dir
|
||||||
|
|
@ -315,20 +310,19 @@ class TestTradingAgentsGraph:
|
||||||
mock_toolkit.return_value = mock_toolkit_instance
|
mock_toolkit.return_value = mock_toolkit_instance
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"tradingagents.graph.trading_graph.FinancialSituationMemory"
|
"tradingagents.graph.trading_graph.FinancialSituationMemory",
|
||||||
) as mock_memory:
|
), patch("tradingagents.graph.trading_graph.set_config"):
|
||||||
with patch("tradingagents.graph.trading_graph.set_config"):
|
graph = TradingAgentsGraph(config=sample_config)
|
||||||
graph = TradingAgentsGraph(config=sample_config)
|
|
||||||
|
|
||||||
# Set up current state
|
# Set up current state
|
||||||
graph.curr_state = {"test": "state"}
|
graph.curr_state = {"test": "state"}
|
||||||
|
|
||||||
# Mock reflector methods
|
# Mock reflector methods
|
||||||
graph.reflector.reflect_bull_researcher = Mock()
|
graph.reflector.reflect_bull_researcher = Mock()
|
||||||
graph.reflector.reflect_bear_researcher = Mock()
|
graph.reflector.reflect_bear_researcher = Mock()
|
||||||
graph.reflector.reflect_trader = Mock()
|
graph.reflector.reflect_trader = Mock()
|
||||||
graph.reflector.reflect_invest_judge = Mock()
|
graph.reflector.reflect_invest_judge = Mock()
|
||||||
graph.reflector.reflect_risk_manager = Mock()
|
graph.reflector.reflect_risk_manager = Mock()
|
||||||
|
|
||||||
returns_losses = {"return": 0.05, "loss": -0.02}
|
returns_losses = {"return": 0.05, "loss": -0.02}
|
||||||
|
|
||||||
|
|
@ -345,7 +339,7 @@ class TestTradingAgentsGraph:
|
||||||
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
||||||
@patch("tradingagents.graph.trading_graph.Toolkit")
|
@patch("tradingagents.graph.trading_graph.Toolkit")
|
||||||
def test_process_signal(
|
def test_process_signal(
|
||||||
self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir
|
self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir,
|
||||||
):
|
):
|
||||||
"""Test signal processing functionality."""
|
"""Test signal processing functionality."""
|
||||||
sample_config["project_dir"] = temp_data_dir
|
sample_config["project_dir"] = temp_data_dir
|
||||||
|
|
@ -393,8 +387,8 @@ class TestTradingAgentsGraph:
|
||||||
|
|
||||||
with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"):
|
with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"):
|
||||||
with patch("tradingagents.graph.trading_graph.set_config"):
|
with patch("tradingagents.graph.trading_graph.set_config"):
|
||||||
graph = TradingAgentsGraph(
|
TradingAgentsGraph(
|
||||||
selected_analysts=selected_analysts, config=sample_config
|
selected_analysts=selected_analysts, config=sample_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify graph was set up with selected analysts
|
# Verify graph was set up with selected analysts
|
||||||
|
|
@ -415,14 +409,14 @@ class TestTradingAgentsGraphErrorHandling:
|
||||||
# This should still work as the class should use defaults for missing keys
|
# This should still work as the class should use defaults for missing keys
|
||||||
with patch("tradingagents.graph.trading_graph.set_config"):
|
with patch("tradingagents.graph.trading_graph.set_config"):
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
KeyError
|
KeyError,
|
||||||
): # Should fail when trying to access missing config keys
|
): # Should fail when trying to access missing config keys
|
||||||
TradingAgentsGraph(config=invalid_config)
|
TradingAgentsGraph(config=invalid_config)
|
||||||
|
|
||||||
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
||||||
@patch("tradingagents.graph.trading_graph.Toolkit")
|
@patch("tradingagents.graph.trading_graph.Toolkit")
|
||||||
def test_directory_creation_failure(
|
def test_directory_creation_failure(
|
||||||
self, mock_toolkit, mock_chat_openai, sample_config
|
self, mock_toolkit, mock_chat_openai, sample_config,
|
||||||
):
|
):
|
||||||
"""Test handling when directory creation fails."""
|
"""Test handling when directory creation fails."""
|
||||||
sample_config["project_dir"] = "/invalid/path/that/cannot/be/created"
|
sample_config["project_dir"] = "/invalid/path/that/cannot/be/created"
|
||||||
|
|
|
||||||
|
|
@ -1,40 +1,35 @@
|
||||||
from .utils.agent_utils import Toolkit, create_msg_delete
|
|
||||||
from .utils.agent_states import AgentState, InvestDebateState, RiskDebateState
|
|
||||||
from .utils.memory import FinancialSituationMemory
|
|
||||||
|
|
||||||
from .analysts.fundamentals_analyst import create_fundamentals_analyst
|
from .analysts.fundamentals_analyst import create_fundamentals_analyst
|
||||||
from .analysts.market_analyst import create_market_analyst
|
from .analysts.market_analyst import create_market_analyst
|
||||||
from .analysts.news_analyst import create_news_analyst
|
from .analysts.news_analyst import create_news_analyst
|
||||||
from .analysts.social_media_analyst import create_social_media_analyst
|
from .analysts.social_media_analyst import create_social_media_analyst
|
||||||
|
from .managers.research_manager import create_research_manager
|
||||||
|
from .managers.risk_manager import create_risk_manager
|
||||||
from .researchers.bear_researcher import create_bear_researcher
|
from .researchers.bear_researcher import create_bear_researcher
|
||||||
from .researchers.bull_researcher import create_bull_researcher
|
from .researchers.bull_researcher import create_bull_researcher
|
||||||
|
|
||||||
from .risk_mgmt.aggresive_debator import create_risky_debator
|
from .risk_mgmt.aggresive_debator import create_risky_debator
|
||||||
from .risk_mgmt.conservative_debator import create_safe_debator
|
from .risk_mgmt.conservative_debator import create_safe_debator
|
||||||
from .risk_mgmt.neutral_debator import create_neutral_debator
|
from .risk_mgmt.neutral_debator import create_neutral_debator
|
||||||
|
|
||||||
from .managers.research_manager import create_research_manager
|
|
||||||
from .managers.risk_manager import create_risk_manager
|
|
||||||
|
|
||||||
from .trader.trader import create_trader
|
from .trader.trader import create_trader
|
||||||
|
from .utils.agent_states import AgentState, InvestDebateState, RiskDebateState
|
||||||
|
from .utils.agent_utils import Toolkit, create_msg_delete
|
||||||
|
from .utils.memory import FinancialSituationMemory
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"FinancialSituationMemory",
|
|
||||||
"Toolkit",
|
|
||||||
"AgentState",
|
"AgentState",
|
||||||
"create_msg_delete",
|
"FinancialSituationMemory",
|
||||||
"InvestDebateState",
|
"InvestDebateState",
|
||||||
"RiskDebateState",
|
"RiskDebateState",
|
||||||
|
"Toolkit",
|
||||||
"create_bear_researcher",
|
"create_bear_researcher",
|
||||||
"create_bull_researcher",
|
"create_bull_researcher",
|
||||||
"create_research_manager",
|
|
||||||
"create_fundamentals_analyst",
|
"create_fundamentals_analyst",
|
||||||
"create_market_analyst",
|
"create_market_analyst",
|
||||||
|
"create_msg_delete",
|
||||||
"create_neutral_debator",
|
"create_neutral_debator",
|
||||||
"create_news_analyst",
|
"create_news_analyst",
|
||||||
"create_risky_debator",
|
"create_research_manager",
|
||||||
"create_risk_manager",
|
"create_risk_manager",
|
||||||
|
"create_risky_debator",
|
||||||
"create_safe_debator",
|
"create_safe_debator",
|
||||||
"create_social_media_analyst",
|
"create_social_media_analyst",
|
||||||
"create_trader",
|
"create_trader",
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ def create_fundamentals_analyst(llm, toolkit):
|
||||||
def fundamentals_analyst_node(state):
|
def fundamentals_analyst_node(state):
|
||||||
current_date = state["trade_date"]
|
current_date = state["trade_date"]
|
||||||
ticker = state["company_of_interest"]
|
ticker = state["company_of_interest"]
|
||||||
company_name = state["company_of_interest"]
|
state["company_of_interest"]
|
||||||
|
|
||||||
if toolkit.config["online_tools"]:
|
if toolkit.config["online_tools"]:
|
||||||
tools = [toolkit.get_fundamentals_openai]
|
tools = [toolkit.get_fundamentals_openai]
|
||||||
|
|
@ -20,7 +20,7 @@ def create_fundamentals_analyst(llm, toolkit):
|
||||||
|
|
||||||
system_message = (
|
system_message = (
|
||||||
"You are a researcher tasked with analyzing fundamental information over the past week about a company. Please write a comprehensive report of the company's fundamental information such as financial documents, company profile, basic company financials, company financial history, insider sentiment and insider transactions to gain a full view of the company's fundamental information to inform traders. Make sure to include as much detail as possible. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."
|
"You are a researcher tasked with analyzing fundamental information over the past week about a company. Please write a comprehensive report of the company's fundamental information such as financial documents, company profile, basic company financials, company financial history, insider sentiment and insider transactions to gain a full view of the company's fundamental information to inform traders. Make sure to include as much detail as possible. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."
|
||||||
+ " Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read.",
|
" Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read.",
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt = ChatPromptTemplate.from_messages(
|
prompt = ChatPromptTemplate.from_messages(
|
||||||
|
|
@ -37,7 +37,7 @@ def create_fundamentals_analyst(llm, toolkit):
|
||||||
"For your reference, the current date is {current_date}. The company we want to look at is {ticker}",
|
"For your reference, the current date is {current_date}. The company we want to look at is {ticker}",
|
||||||
),
|
),
|
||||||
MessagesPlaceholder(variable_name="messages"),
|
MessagesPlaceholder(variable_name="messages"),
|
||||||
]
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt = prompt.partial(system_message=system_message)
|
prompt = prompt.partial(system_message=system_message)
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ def create_market_analyst(llm, toolkit):
|
||||||
def market_analyst_node(state):
|
def market_analyst_node(state):
|
||||||
current_date = state["trade_date"]
|
current_date = state["trade_date"]
|
||||||
ticker = state["company_of_interest"]
|
ticker = state["company_of_interest"]
|
||||||
company_name = state["company_of_interest"]
|
state["company_of_interest"]
|
||||||
|
|
||||||
if toolkit.config["online_tools"]:
|
if toolkit.config["online_tools"]:
|
||||||
tools = [
|
tools = [
|
||||||
|
|
@ -45,7 +45,7 @@ Volume-Based Indicators:
|
||||||
- vwma: VWMA: A moving average weighted by volume. Usage: Confirm trends by integrating price action with volume data. Tips: Watch for skewed results from volume spikes; use in combination with other volume analyses.
|
- vwma: VWMA: A moving average weighted by volume. Usage: Confirm trends by integrating price action with volume data. Tips: Watch for skewed results from volume spikes; use in combination with other volume analyses.
|
||||||
|
|
||||||
- Select indicators that provide diverse and complementary information. Avoid redundancy (e.g., do not select both rsi and stochrsi). Also briefly explain why they are suitable for the given market context. When you tool call, please use the exact name of the indicators provided above as they are defined parameters, otherwise your call will fail. Please make sure to call get_YFin_data first to retrieve the CSV that is needed to generate indicators. Write a very detailed and nuanced report of the trends you observe. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."""
|
- Select indicators that provide diverse and complementary information. Avoid redundancy (e.g., do not select both rsi and stochrsi). Also briefly explain why they are suitable for the given market context. When you tool call, please use the exact name of the indicators provided above as they are defined parameters, otherwise your call will fail. Please make sure to call get_YFin_data first to retrieve the CSV that is needed to generate indicators. Write a very detailed and nuanced report of the trends you observe. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."""
|
||||||
+ """ Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read."""
|
""" Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read."""
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt = ChatPromptTemplate.from_messages(
|
prompt = ChatPromptTemplate.from_messages(
|
||||||
|
|
@ -62,7 +62,7 @@ Volume-Based Indicators:
|
||||||
"For your reference, the current date is {current_date}. The company we want to look at is {ticker}",
|
"For your reference, the current date is {current_date}. The company we want to look at is {ticker}",
|
||||||
),
|
),
|
||||||
MessagesPlaceholder(variable_name="messages"),
|
MessagesPlaceholder(variable_name="messages"),
|
||||||
]
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt = prompt.partial(system_message=system_message)
|
prompt = prompt.partial(system_message=system_message)
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ def create_news_analyst(llm, toolkit):
|
||||||
|
|
||||||
system_message = (
|
system_message = (
|
||||||
"You are a news researcher tasked with analyzing recent news and trends over the past week. Please write a comprehensive report of the current state of the world that is relevant for trading and macroeconomics. Look at news from EODHD, and finnhub to be comprehensive. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."
|
"You are a news researcher tasked with analyzing recent news and trends over the past week. Please write a comprehensive report of the current state of the world that is relevant for trading and macroeconomics. Look at news from EODHD, and finnhub to be comprehensive. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."
|
||||||
+ """ Make sure to append a Makrdown table at the end of the report to organize key points in the report, organized and easy to read."""
|
""" Make sure to append a Makrdown table at the end of the report to organize key points in the report, organized and easy to read."""
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt = ChatPromptTemplate.from_messages(
|
prompt = ChatPromptTemplate.from_messages(
|
||||||
|
|
@ -34,7 +34,7 @@ def create_news_analyst(llm, toolkit):
|
||||||
"For your reference, the current date is {current_date}. We are looking at the company {ticker}",
|
"For your reference, the current date is {current_date}. We are looking at the company {ticker}",
|
||||||
),
|
),
|
||||||
MessagesPlaceholder(variable_name="messages"),
|
MessagesPlaceholder(variable_name="messages"),
|
||||||
]
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt = prompt.partial(system_message=system_message)
|
prompt = prompt.partial(system_message=system_message)
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ def create_social_media_analyst(llm, toolkit):
|
||||||
def social_media_analyst_node(state):
|
def social_media_analyst_node(state):
|
||||||
current_date = state["trade_date"]
|
current_date = state["trade_date"]
|
||||||
ticker = state["company_of_interest"]
|
ticker = state["company_of_interest"]
|
||||||
company_name = state["company_of_interest"]
|
state["company_of_interest"]
|
||||||
|
|
||||||
if toolkit.config["online_tools"]:
|
if toolkit.config["online_tools"]:
|
||||||
tools = [toolkit.get_stock_news_openai]
|
tools = [toolkit.get_stock_news_openai]
|
||||||
|
|
@ -16,7 +16,7 @@ def create_social_media_analyst(llm, toolkit):
|
||||||
|
|
||||||
system_message = (
|
system_message = (
|
||||||
"You are a social media and company specific news researcher/analyst tasked with analyzing social media posts, recent company news, and public sentiment for a specific company over the past week. You will be given a company's name your objective is to write a comprehensive long report detailing your analysis, insights, and implications for traders and investors on this company's current state after looking at social media and what people are saying about that company, analyzing sentiment data of what people feel each day about the company, and looking at recent company news. Try to look at all sources possible from social media to sentiment to news. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."
|
"You are a social media and company specific news researcher/analyst tasked with analyzing social media posts, recent company news, and public sentiment for a specific company over the past week. You will be given a company's name your objective is to write a comprehensive long report detailing your analysis, insights, and implications for traders and investors on this company's current state after looking at social media and what people are saying about that company, analyzing sentiment data of what people feel each day about the company, and looking at recent company news. Try to look at all sources possible from social media to sentiment to news. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."
|
||||||
+ """ Make sure to append a Makrdown table at the end of the report to organize key points in the report, organized and easy to read.""",
|
""" Make sure to append a Makrdown table at the end of the report to organize key points in the report, organized and easy to read.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt = ChatPromptTemplate.from_messages(
|
prompt = ChatPromptTemplate.from_messages(
|
||||||
|
|
@ -33,7 +33,7 @@ def create_social_media_analyst(llm, toolkit):
|
||||||
"For your reference, the current date is {current_date}. The current company we want to analyze is {ticker}",
|
"For your reference, the current date is {current_date}. The current company we want to analyze is {ticker}",
|
||||||
),
|
),
|
||||||
MessagesPlaceholder(variable_name="messages"),
|
MessagesPlaceholder(variable_name="messages"),
|
||||||
]
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt = prompt.partial(system_message=system_message)
|
prompt = prompt.partial(system_message=system_message)
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ def create_research_manager(llm, memory):
|
||||||
past_memories = memory.get_memories(curr_situation, n_matches=2)
|
past_memories = memory.get_memories(curr_situation, n_matches=2)
|
||||||
|
|
||||||
past_memory_str = ""
|
past_memory_str = ""
|
||||||
for i, rec in enumerate(past_memories, 1):
|
for _i, rec in enumerate(past_memories, 1):
|
||||||
past_memory_str += rec["recommendation"] + "\n\n"
|
past_memory_str += rec["recommendation"] + "\n\n"
|
||||||
|
|
||||||
prompt = f"""As the portfolio manager and debate facilitator, your role is to critically evaluate this round of debate and make a definitive decision: align with the bear analyst, the bull analyst, or choose Hold only if it is strongly justified based on the arguments presented.
|
prompt = f"""As the portfolio manager and debate facilitator, your role is to critically evaluate this round of debate and make a definitive decision: align with the bear analyst, the bull analyst, or choose Hold only if it is strongly justified based on the arguments presented.
|
||||||
|
|
@ -24,7 +24,7 @@ Additionally, develop a detailed investment plan for the trader. This should inc
|
||||||
Your Recommendation: A decisive stance supported by the most convincing arguments.
|
Your Recommendation: A decisive stance supported by the most convincing arguments.
|
||||||
Rationale: An explanation of why these arguments lead to your conclusion.
|
Rationale: An explanation of why these arguments lead to your conclusion.
|
||||||
Strategic Actions: Concrete steps for implementing the recommendation.
|
Strategic Actions: Concrete steps for implementing the recommendation.
|
||||||
Take into account your past mistakes on similar situations. Use these insights to refine your decision-making and ensure you are learning and improving. Present your analysis conversationally, as if speaking naturally, without special formatting.
|
Take into account your past mistakes on similar situations. Use these insights to refine your decision-making and ensure you are learning and improving. Present your analysis conversationally, as if speaking naturally, without special formatting.
|
||||||
|
|
||||||
Here are your past reflections on mistakes:
|
Here are your past reflections on mistakes:
|
||||||
\"{past_memory_str}\"
|
\"{past_memory_str}\"
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
def create_risk_manager(llm, memory):
|
def create_risk_manager(llm, memory):
|
||||||
def risk_manager_node(state) -> dict:
|
def risk_manager_node(state) -> dict:
|
||||||
|
|
||||||
company_name = state["company_of_interest"]
|
state["company_of_interest"]
|
||||||
|
|
||||||
history = state["risk_debate_state"]["history"]
|
history = state["risk_debate_state"]["history"]
|
||||||
risk_debate_state = state["risk_debate_state"]
|
risk_debate_state = state["risk_debate_state"]
|
||||||
|
|
@ -15,7 +15,7 @@ def create_risk_manager(llm, memory):
|
||||||
past_memories = memory.get_memories(curr_situation, n_matches=2)
|
past_memories = memory.get_memories(curr_situation, n_matches=2)
|
||||||
|
|
||||||
past_memory_str = ""
|
past_memory_str = ""
|
||||||
for i, rec in enumerate(past_memories, 1):
|
for _i, rec in enumerate(past_memories, 1):
|
||||||
past_memory_str += rec["recommendation"] + "\n\n"
|
past_memory_str += rec["recommendation"] + "\n\n"
|
||||||
|
|
||||||
prompt = f"""As the Risk Management Judge and Debate Facilitator, your goal is to evaluate the debate between three risk analysts—Risky, Neutral, and Safe/Conservative—and determine the best course of action for the trader. Your decision must result in a clear recommendation: Buy, Sell, or Hold. Choose Hold only if strongly justified by specific arguments, not as a fallback when all sides seem valid. Strive for clarity and decisiveness.
|
prompt = f"""As the Risk Management Judge and Debate Facilitator, your goal is to evaluate the debate between three risk analysts—Risky, Neutral, and Safe/Conservative—and determine the best course of action for the trader. Your decision must result in a clear recommendation: Buy, Sell, or Hold. Choose Hold only if strongly justified by specific arguments, not as a fallback when all sides seem valid. Strive for clarity and decisiveness.
|
||||||
|
|
@ -32,7 +32,7 @@ Deliverables:
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
**Analysts Debate History:**
|
**Analysts Debate History:**
|
||||||
{history}
|
{history}
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@ def create_bear_researcher(llm, memory):
|
||||||
past_memories = memory.get_memories(curr_situation, n_matches=2)
|
past_memories = memory.get_memories(curr_situation, n_matches=2)
|
||||||
|
|
||||||
past_memory_str = ""
|
past_memory_str = ""
|
||||||
for i, rec in enumerate(past_memories, 1):
|
for _i, rec in enumerate(past_memories, 1):
|
||||||
past_memory_str += rec["recommendation"] + "\n\n"
|
past_memory_str += rec["recommendation"] + "\n\n"
|
||||||
|
|
||||||
prompt = f"""You are a Bear Analyst making the case against investing in the stock. Your goal is to present a well-reasoned argument emphasizing risks, challenges, and negative indicators. Leverage the provided research and data to highlight potential downsides and counter bullish arguments effectively.
|
prompt = f"""You are a Bear Analyst making the case against investing in the stock. Your goal is to present a well-reasoned argument emphasizing risks, challenges, and negative indicators. Leverage the provided research and data to highlight potential downsides and counter bullish arguments effectively.
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@ def create_bull_researcher(llm, memory):
|
||||||
past_memories = memory.get_memories(curr_situation, n_matches=2)
|
past_memories = memory.get_memories(curr_situation, n_matches=2)
|
||||||
|
|
||||||
past_memory_str = ""
|
past_memory_str = ""
|
||||||
for i, rec in enumerate(past_memories, 1):
|
for _i, rec in enumerate(past_memories, 1):
|
||||||
past_memory_str += rec["recommendation"] + "\n\n"
|
past_memory_str += rec["recommendation"] + "\n\n"
|
||||||
|
|
||||||
prompt = f"""You are a Bull Analyst advocating for investing in the stock. Your task is to build a strong, evidence-based case emphasizing growth potential, competitive advantages, and positive market indicators. Leverage the provided research and data to address concerns and counter bearish arguments effectively.
|
prompt = f"""You are a Bull Analyst advocating for investing in the stock. Your task is to build a strong, evidence-based case emphasizing growth potential, competitive advantages, and positive market indicators. Leverage the provided research and data to address concerns and counter bearish arguments effectively.
|
||||||
|
|
|
||||||
|
|
@ -41,7 +41,7 @@ Engage actively by addressing any specific concerns raised, refuting the weaknes
|
||||||
"current_risky_response": argument,
|
"current_risky_response": argument,
|
||||||
"current_safe_response": risk_debate_state.get("current_safe_response", ""),
|
"current_safe_response": risk_debate_state.get("current_safe_response", ""),
|
||||||
"current_neutral_response": risk_debate_state.get(
|
"current_neutral_response": risk_debate_state.get(
|
||||||
"current_neutral_response", ""
|
"current_neutral_response", "",
|
||||||
),
|
),
|
||||||
"count": risk_debate_state["count"] + 1,
|
"count": risk_debate_state["count"] + 1,
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -39,11 +39,11 @@ Engage by questioning their optimism and emphasizing the potential downsides the
|
||||||
"neutral_history": risk_debate_state.get("neutral_history", ""),
|
"neutral_history": risk_debate_state.get("neutral_history", ""),
|
||||||
"latest_speaker": "Safe",
|
"latest_speaker": "Safe",
|
||||||
"current_risky_response": risk_debate_state.get(
|
"current_risky_response": risk_debate_state.get(
|
||||||
"current_risky_response", ""
|
"current_risky_response", "",
|
||||||
),
|
),
|
||||||
"current_safe_response": argument,
|
"current_safe_response": argument,
|
||||||
"current_neutral_response": risk_debate_state.get(
|
"current_neutral_response": risk_debate_state.get(
|
||||||
"current_neutral_response", ""
|
"current_neutral_response", "",
|
||||||
),
|
),
|
||||||
"count": risk_debate_state["count"] + 1,
|
"count": risk_debate_state["count"] + 1,
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -39,7 +39,7 @@ Engage actively by analyzing both sides critically, addressing weaknesses in the
|
||||||
"neutral_history": neutral_history + "\n" + argument,
|
"neutral_history": neutral_history + "\n" + argument,
|
||||||
"latest_speaker": "Neutral",
|
"latest_speaker": "Neutral",
|
||||||
"current_risky_response": risk_debate_state.get(
|
"current_risky_response": risk_debate_state.get(
|
||||||
"current_risky_response", ""
|
"current_risky_response", "",
|
||||||
),
|
),
|
||||||
"current_safe_response": risk_debate_state.get("current_safe_response", ""),
|
"current_safe_response": risk_debate_state.get("current_safe_response", ""),
|
||||||
"current_neutral_response": argument,
|
"current_neutral_response": argument,
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ def create_trader(llm, memory):
|
||||||
|
|
||||||
past_memory_str = ""
|
past_memory_str = ""
|
||||||
if past_memories:
|
if past_memories:
|
||||||
for i, rec in enumerate(past_memories, 1):
|
for _i, rec in enumerate(past_memories, 1):
|
||||||
past_memory_str += rec["recommendation"] + "\n\n"
|
past_memory_str += rec["recommendation"] + "\n\n"
|
||||||
else:
|
else:
|
||||||
past_memory_str = "No past memories found."
|
past_memory_str = "No past memories found."
|
||||||
|
|
|
||||||
|
|
@ -1,16 +1,18 @@
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
from typing_extensions import TypedDict
|
|
||||||
from tradingagents.agents import *
|
|
||||||
from langgraph.graph import MessagesState
|
from langgraph.graph import MessagesState
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
|
# Import specific agent classes as needed
|
||||||
|
|
||||||
|
|
||||||
# Researcher team state
|
# Researcher team state
|
||||||
class InvestDebateState(TypedDict):
|
class InvestDebateState(TypedDict):
|
||||||
bull_history: Annotated[
|
bull_history: Annotated[
|
||||||
str, "Bullish Conversation history"
|
str, "Bullish Conversation history",
|
||||||
] # Bullish Conversation history
|
] # Bullish Conversation history
|
||||||
bear_history: Annotated[
|
bear_history: Annotated[
|
||||||
str, "Bearish Conversation history"
|
str, "Bearish Conversation history",
|
||||||
] # Bullish Conversation history
|
] # Bullish Conversation history
|
||||||
history: Annotated[str, "Conversation history"] # Conversation history
|
history: Annotated[str, "Conversation history"] # Conversation history
|
||||||
current_response: Annotated[str, "Latest response"] # Last response
|
current_response: Annotated[str, "Latest response"] # Last response
|
||||||
|
|
@ -21,24 +23,24 @@ class InvestDebateState(TypedDict):
|
||||||
# Risk management team state
|
# Risk management team state
|
||||||
class RiskDebateState(TypedDict):
|
class RiskDebateState(TypedDict):
|
||||||
risky_history: Annotated[
|
risky_history: Annotated[
|
||||||
str, "Risky Agent's Conversation history"
|
str, "Risky Agent's Conversation history",
|
||||||
] # Conversation history
|
] # Conversation history
|
||||||
safe_history: Annotated[
|
safe_history: Annotated[
|
||||||
str, "Safe Agent's Conversation history"
|
str, "Safe Agent's Conversation history",
|
||||||
] # Conversation history
|
] # Conversation history
|
||||||
neutral_history: Annotated[
|
neutral_history: Annotated[
|
||||||
str, "Neutral Agent's Conversation history"
|
str, "Neutral Agent's Conversation history",
|
||||||
] # Conversation history
|
] # Conversation history
|
||||||
history: Annotated[str, "Conversation history"] # Conversation history
|
history: Annotated[str, "Conversation history"] # Conversation history
|
||||||
latest_speaker: Annotated[str, "Analyst that spoke last"]
|
latest_speaker: Annotated[str, "Analyst that spoke last"]
|
||||||
current_risky_response: Annotated[
|
current_risky_response: Annotated[
|
||||||
str, "Latest response by the risky analyst"
|
str, "Latest response by the risky analyst",
|
||||||
] # Last response
|
] # Last response
|
||||||
current_safe_response: Annotated[
|
current_safe_response: Annotated[
|
||||||
str, "Latest response by the safe analyst"
|
str, "Latest response by the safe analyst",
|
||||||
] # Last response
|
] # Last response
|
||||||
current_neutral_response: Annotated[
|
current_neutral_response: Annotated[
|
||||||
str, "Latest response by the neutral analyst"
|
str, "Latest response by the neutral analyst",
|
||||||
] # Last response
|
] # Last response
|
||||||
judge_decision: Annotated[str, "Judge's decision"]
|
judge_decision: Annotated[str, "Judge's decision"]
|
||||||
count: Annotated[int, "Length of the current conversation"] # Conversation length
|
count: Annotated[int, "Length of the current conversation"] # Conversation length
|
||||||
|
|
@ -54,13 +56,13 @@ class AgentState(MessagesState):
|
||||||
market_report: Annotated[str, "Report from the Market Analyst"]
|
market_report: Annotated[str, "Report from the Market Analyst"]
|
||||||
sentiment_report: Annotated[str, "Report from the Social Media Analyst"]
|
sentiment_report: Annotated[str, "Report from the Social Media Analyst"]
|
||||||
news_report: Annotated[
|
news_report: Annotated[
|
||||||
str, "Report from the News Researcher of current world affairs"
|
str, "Report from the News Researcher of current world affairs",
|
||||||
]
|
]
|
||||||
fundamentals_report: Annotated[str, "Report from the Fundamentals Researcher"]
|
fundamentals_report: Annotated[str, "Report from the Fundamentals Researcher"]
|
||||||
|
|
||||||
# researcher team discussion step
|
# researcher team discussion step
|
||||||
investment_debate_state: Annotated[
|
investment_debate_state: Annotated[
|
||||||
InvestDebateState, "Current state of the debate on if to invest or not"
|
InvestDebateState, "Current state of the debate on if to invest or not",
|
||||||
]
|
]
|
||||||
investment_plan: Annotated[str, "Plan generated by the Analyst"]
|
investment_plan: Annotated[str, "Plan generated by the Analyst"]
|
||||||
|
|
||||||
|
|
@ -68,6 +70,6 @@ class AgentState(MessagesState):
|
||||||
|
|
||||||
# risk management team discussion step
|
# risk management team discussion step
|
||||||
risk_debate_state: Annotated[
|
risk_debate_state: Annotated[
|
||||||
RiskDebateState, "Current state of the debate on evaluating risk"
|
RiskDebateState, "Current state of the debate on evaluating risk",
|
||||||
]
|
]
|
||||||
final_trade_decision: Annotated[str, "Final decision made by the Risk Analysts"]
|
final_trade_decision: Annotated[str, "Final decision made by the Risk Analysts"]
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,10 @@
|
||||||
from langchain_core.messages import HumanMessage
|
|
||||||
from typing import Annotated
|
|
||||||
from langchain_core.messages import RemoveMessage
|
|
||||||
from langchain_core.tools import tool
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import tradingagents.dataflows.interface as interface
|
from typing import Annotated
|
||||||
|
|
||||||
|
from langchain_core.messages import HumanMessage, RemoveMessage
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
|
from tradingagents.dataflows import interface
|
||||||
from tradingagents.default_config import DEFAULT_CONFIG
|
from tradingagents.default_config import DEFAULT_CONFIG
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -18,7 +19,7 @@ def create_msg_delete():
|
||||||
# Add a minimal placeholder message
|
# Add a minimal placeholder message
|
||||||
placeholder = HumanMessage(content="Continue")
|
placeholder = HumanMessage(content="Continue")
|
||||||
|
|
||||||
return {"messages": removal_operations + [placeholder]}
|
return {"messages": [*removal_operations, placeholder]}
|
||||||
|
|
||||||
return delete_messages
|
return delete_messages
|
||||||
|
|
||||||
|
|
@ -53,9 +54,8 @@ class Toolkit:
|
||||||
str: A formatted dataframe containing the latest global news from Reddit in the specified time frame.
|
str: A formatted dataframe containing the latest global news from Reddit in the specified time frame.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
global_news_result = interface.get_reddit_global_news(curr_date, 7, 5)
|
return interface.get_reddit_global_news(curr_date, 7, 5)
|
||||||
|
|
||||||
return global_news_result
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@tool
|
@tool
|
||||||
|
|
@ -83,11 +83,10 @@ class Toolkit:
|
||||||
start_date = datetime.strptime(start_date, "%Y-%m-%d")
|
start_date = datetime.strptime(start_date, "%Y-%m-%d")
|
||||||
look_back_days = (end_date - start_date).days
|
look_back_days = (end_date - start_date).days
|
||||||
|
|
||||||
finnhub_news_result = interface.get_finnhub_news(
|
return interface.get_finnhub_news(
|
||||||
ticker, end_date_str, look_back_days
|
ticker, end_date_str, look_back_days,
|
||||||
)
|
)
|
||||||
|
|
||||||
return finnhub_news_result
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@tool
|
@tool
|
||||||
|
|
@ -107,9 +106,8 @@ class Toolkit:
|
||||||
str: A formatted dataframe containing the latest news about the company on the given date
|
str: A formatted dataframe containing the latest news about the company on the given date
|
||||||
"""
|
"""
|
||||||
|
|
||||||
stock_news_results = interface.get_reddit_company_news(ticker, curr_date, 7, 5)
|
return interface.get_reddit_company_news(ticker, curr_date, 7, 5)
|
||||||
|
|
||||||
return stock_news_results
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@tool
|
@tool
|
||||||
|
|
@ -128,9 +126,8 @@ class Toolkit:
|
||||||
str: A formatted dataframe containing the stock price data for the specified ticker symbol in the specified date range.
|
str: A formatted dataframe containing the stock price data for the specified ticker symbol in the specified date range.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result_data = interface.get_YFin_data(symbol, start_date, end_date)
|
return interface.get_YFin_data(symbol, start_date, end_date)
|
||||||
|
|
||||||
return result_data
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@tool
|
@tool
|
||||||
|
|
@ -149,19 +146,18 @@ class Toolkit:
|
||||||
str: A formatted dataframe containing the stock price data for the specified ticker symbol in the specified date range.
|
str: A formatted dataframe containing the stock price data for the specified ticker symbol in the specified date range.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result_data = interface.get_YFin_data_online(symbol, start_date, end_date)
|
return interface.get_YFin_data_online(symbol, start_date, end_date)
|
||||||
|
|
||||||
return result_data
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@tool
|
@tool
|
||||||
def get_stockstats_indicators_report(
|
def get_stockstats_indicators_report(
|
||||||
symbol: Annotated[str, "ticker symbol of the company"],
|
symbol: Annotated[str, "ticker symbol of the company"],
|
||||||
indicator: Annotated[
|
indicator: Annotated[
|
||||||
str, "technical indicator to get the analysis and report of"
|
str, "technical indicator to get the analysis and report of",
|
||||||
],
|
],
|
||||||
curr_date: Annotated[
|
curr_date: Annotated[
|
||||||
str, "The current trading date you are trading on, YYYY-mm-dd"
|
str, "The current trading date you are trading on, YYYY-mm-dd",
|
||||||
],
|
],
|
||||||
look_back_days: Annotated[int, "how many days to look back"] = 30,
|
look_back_days: Annotated[int, "how many days to look back"] = 30,
|
||||||
) -> str:
|
) -> str:
|
||||||
|
|
@ -176,21 +172,20 @@ class Toolkit:
|
||||||
str: A formatted dataframe containing the stock stats indicators for the specified ticker symbol and indicator.
|
str: A formatted dataframe containing the stock stats indicators for the specified ticker symbol and indicator.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result_stockstats = interface.get_stock_stats_indicators_window(
|
return interface.get_stock_stats_indicators_window(
|
||||||
symbol, indicator, curr_date, look_back_days, False
|
symbol, indicator, curr_date, look_back_days, False,
|
||||||
)
|
)
|
||||||
|
|
||||||
return result_stockstats
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@tool
|
@tool
|
||||||
def get_stockstats_indicators_report_online(
|
def get_stockstats_indicators_report_online(
|
||||||
symbol: Annotated[str, "ticker symbol of the company"],
|
symbol: Annotated[str, "ticker symbol of the company"],
|
||||||
indicator: Annotated[
|
indicator: Annotated[
|
||||||
str, "technical indicator to get the analysis and report of"
|
str, "technical indicator to get the analysis and report of",
|
||||||
],
|
],
|
||||||
curr_date: Annotated[
|
curr_date: Annotated[
|
||||||
str, "The current trading date you are trading on, YYYY-mm-dd"
|
str, "The current trading date you are trading on, YYYY-mm-dd",
|
||||||
],
|
],
|
||||||
look_back_days: Annotated[int, "how many days to look back"] = 30,
|
look_back_days: Annotated[int, "how many days to look back"] = 30,
|
||||||
) -> str:
|
) -> str:
|
||||||
|
|
@ -205,11 +200,10 @@ class Toolkit:
|
||||||
str: A formatted dataframe containing the stock stats indicators for the specified ticker symbol and indicator.
|
str: A formatted dataframe containing the stock stats indicators for the specified ticker symbol and indicator.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result_stockstats = interface.get_stock_stats_indicators_window(
|
return interface.get_stock_stats_indicators_window(
|
||||||
symbol, indicator, curr_date, look_back_days, True
|
symbol, indicator, curr_date, look_back_days, True,
|
||||||
)
|
)
|
||||||
|
|
||||||
return result_stockstats
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@tool
|
@tool
|
||||||
|
|
@ -229,11 +223,10 @@ class Toolkit:
|
||||||
str: a report of the sentiment in the past 30 days starting at curr_date
|
str: a report of the sentiment in the past 30 days starting at curr_date
|
||||||
"""
|
"""
|
||||||
|
|
||||||
data_sentiment = interface.get_finnhub_company_insider_sentiment(
|
return interface.get_finnhub_company_insider_sentiment(
|
||||||
ticker, curr_date, 30
|
ticker, curr_date, 30,
|
||||||
)
|
)
|
||||||
|
|
||||||
return data_sentiment
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@tool
|
@tool
|
||||||
|
|
@ -253,11 +246,10 @@ class Toolkit:
|
||||||
str: a report of the company's insider transactions/trading information in the past 30 days
|
str: a report of the company's insider transactions/trading information in the past 30 days
|
||||||
"""
|
"""
|
||||||
|
|
||||||
data_trans = interface.get_finnhub_company_insider_transactions(
|
return interface.get_finnhub_company_insider_transactions(
|
||||||
ticker, curr_date, 30
|
ticker, curr_date, 30,
|
||||||
)
|
)
|
||||||
|
|
||||||
return data_trans
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@tool
|
@tool
|
||||||
|
|
@ -279,9 +271,8 @@ class Toolkit:
|
||||||
str: a report of the company's most recent balance sheet
|
str: a report of the company's most recent balance sheet
|
||||||
"""
|
"""
|
||||||
|
|
||||||
data_balance_sheet = interface.get_simfin_balance_sheet(ticker, freq, curr_date)
|
return interface.get_simfin_balance_sheet(ticker, freq, curr_date)
|
||||||
|
|
||||||
return data_balance_sheet
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@tool
|
@tool
|
||||||
|
|
@ -303,9 +294,8 @@ class Toolkit:
|
||||||
str: a report of the company's most recent cash flow statement
|
str: a report of the company's most recent cash flow statement
|
||||||
"""
|
"""
|
||||||
|
|
||||||
data_cashflow = interface.get_simfin_cashflow(ticker, freq, curr_date)
|
return interface.get_simfin_cashflow(ticker, freq, curr_date)
|
||||||
|
|
||||||
return data_cashflow
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@tool
|
@tool
|
||||||
|
|
@ -327,11 +317,10 @@ class Toolkit:
|
||||||
str: a report of the company's most recent income statement
|
str: a report of the company's most recent income statement
|
||||||
"""
|
"""
|
||||||
|
|
||||||
data_income_stmt = interface.get_simfin_income_statements(
|
return interface.get_simfin_income_statements(
|
||||||
ticker, freq, curr_date
|
ticker, freq, curr_date,
|
||||||
)
|
)
|
||||||
|
|
||||||
return data_income_stmt
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@tool
|
@tool
|
||||||
|
|
@ -349,9 +338,8 @@ class Toolkit:
|
||||||
str: A formatted string containing the latest news from Google News based on the query and date range.
|
str: A formatted string containing the latest news from Google News based on the query and date range.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
google_news_results = interface.get_google_news(query, curr_date, 7)
|
return interface.get_google_news(query, curr_date, 7)
|
||||||
|
|
||||||
return google_news_results
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@tool
|
@tool
|
||||||
|
|
@ -368,9 +356,8 @@ class Toolkit:
|
||||||
str: A formatted string containing the latest news about the company on the given date.
|
str: A formatted string containing the latest news about the company on the given date.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
openai_news_results = interface.get_stock_news_openai(ticker, curr_date)
|
return interface.get_stock_news_openai(ticker, curr_date)
|
||||||
|
|
||||||
return openai_news_results
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@tool
|
@tool
|
||||||
|
|
@ -385,9 +372,8 @@ class Toolkit:
|
||||||
str: A formatted string containing the latest macroeconomic news on the given date.
|
str: A formatted string containing the latest macroeconomic news on the given date.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
openai_news_results = interface.get_global_news_openai(curr_date)
|
return interface.get_global_news_openai(curr_date)
|
||||||
|
|
||||||
return openai_news_results
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@tool
|
@tool
|
||||||
|
|
@ -404,8 +390,7 @@ class Toolkit:
|
||||||
str: A formatted string containing the latest fundamental information about the company on the given date.
|
str: A formatted string containing the latest fundamental information about the company on the given date.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
openai_fundamentals_results = interface.get_fundamentals_openai(
|
return interface.get_fundamentals_openai(
|
||||||
ticker, curr_date
|
ticker, curr_date,
|
||||||
)
|
)
|
||||||
|
|
||||||
return openai_fundamentals_results
|
|
||||||
|
|
|
||||||
|
|
@ -59,7 +59,7 @@ class FinancialSituationMemory:
|
||||||
"matched_situation": results["documents"][0][i],
|
"matched_situation": results["documents"][0][i],
|
||||||
"recommendation": results["metadatas"][0][i]["recommendation"],
|
"recommendation": results["metadatas"][0][i]["recommendation"],
|
||||||
"similarity_score": 1 - results["distances"][0][i],
|
"similarity_score": 1 - results["distances"][0][i],
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
return matched_results
|
return matched_results
|
||||||
|
|
@ -94,18 +94,15 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
# Example query
|
# Example query
|
||||||
current_situation = """
|
current_situation = """
|
||||||
Market showing increased volatility in tech sector, with institutional investors
|
Market showing increased volatility in tech sector, with institutional investors
|
||||||
reducing positions and rising interest rates affecting growth stock valuations
|
reducing positions and rising interest rates affecting growth stock valuations
|
||||||
"""
|
"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
recommendations = matcher.get_memories(current_situation, n_matches=2)
|
recommendations = matcher.get_memories(current_situation, n_matches=2)
|
||||||
|
|
||||||
for i, rec in enumerate(recommendations, 1):
|
for _i, _rec in enumerate(recommendations, 1):
|
||||||
print(f"\nMatch {i}:")
|
pass
|
||||||
print(f"Similarity Score: {rec['similarity_score']:.2f}")
|
|
||||||
print(f"Matched Situation: {rec['matched_situation']}")
|
|
||||||
print(f"Recommendation: {rec['recommendation']}")
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception:
|
||||||
print(f"Error during recommendation: {str(e)}")
|
pass
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ Loads configuration from environment variables and .env file.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
# Load .env file from project root
|
# Load .env file from project root
|
||||||
|
|
@ -20,10 +21,10 @@ def get_config():
|
||||||
"project_dir": str(project_root / "tradingagents"),
|
"project_dir": str(project_root / "tradingagents"),
|
||||||
"results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", "./results"),
|
"results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", "./results"),
|
||||||
"data_dir": os.getenv(
|
"data_dir": os.getenv(
|
||||||
"TRADINGAGENTS_DATA_DIR", "/Users/yluo/Documents/Code/ScAI/FR1-data"
|
"TRADINGAGENTS_DATA_DIR", "/Users/yluo/Documents/Code/ScAI/FR1-data",
|
||||||
),
|
),
|
||||||
"data_cache_dir": str(
|
"data_cache_dir": str(
|
||||||
project_root / "tradingagents" / "dataflows" / "data_cache"
|
project_root / "tradingagents" / "dataflows" / "data_cache",
|
||||||
),
|
),
|
||||||
# LLM settings
|
# LLM settings
|
||||||
"llm_provider": os.getenv("LLM_PROVIDER", "openai"),
|
"llm_provider": os.getenv("LLM_PROVIDER", "openai"),
|
||||||
|
|
@ -47,16 +48,17 @@ def get_config():
|
||||||
|
|
||||||
# Validate required API keys based on provider
|
# Validate required API keys based on provider
|
||||||
if config["llm_provider"] == "openai" and not config["openai_api_key"]:
|
if config["llm_provider"] == "openai" and not config["openai_api_key"]:
|
||||||
raise ValueError("OPENAI_API_KEY is required when using OpenAI provider")
|
msg = "OPENAI_API_KEY is required when using OpenAI provider"
|
||||||
elif config["llm_provider"] == "anthropic" and not config["anthropic_api_key"]:
|
raise ValueError(msg)
|
||||||
raise ValueError("ANTHROPIC_API_KEY is required when using Anthropic provider")
|
if config["llm_provider"] == "anthropic" and not config["anthropic_api_key"]:
|
||||||
elif config["llm_provider"] == "google" and not config["google_api_key"]:
|
msg = "ANTHROPIC_API_KEY is required when using Anthropic provider"
|
||||||
raise ValueError("GOOGLE_API_KEY is required when using Google provider")
|
raise ValueError(msg)
|
||||||
|
if config["llm_provider"] == "google" and not config["google_api_key"]:
|
||||||
|
msg = "GOOGLE_API_KEY is required when using Google provider"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
if not config["finnhub_api_key"]:
|
if not config["finnhub_api_key"]:
|
||||||
print(
|
pass
|
||||||
"Warning: FINNHUB_API_KEY not set. Some financial data features may be limited."
|
|
||||||
)
|
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,17 +1,13 @@
|
||||||
from .finnhub_utils import get_data_in_range
|
from .finnhub_utils import get_data_in_range
|
||||||
from .googlenews_utils import getNewsData
|
from .googlenews_utils import getNewsData
|
||||||
from .yfin_utils import YFinanceUtils
|
|
||||||
from .reddit_utils import fetch_top_from_category
|
|
||||||
from .stockstats_utils import StockstatsUtils
|
|
||||||
|
|
||||||
from .interface import (
|
from .interface import (
|
||||||
# News and sentiment functions
|
|
||||||
get_finnhub_news,
|
|
||||||
get_finnhub_company_insider_sentiment,
|
get_finnhub_company_insider_sentiment,
|
||||||
get_finnhub_company_insider_transactions,
|
get_finnhub_company_insider_transactions,
|
||||||
|
# News and sentiment functions
|
||||||
|
get_finnhub_news,
|
||||||
get_google_news,
|
get_google_news,
|
||||||
get_reddit_global_news,
|
|
||||||
get_reddit_company_news,
|
get_reddit_company_news,
|
||||||
|
get_reddit_global_news,
|
||||||
# Financial statements functions
|
# Financial statements functions
|
||||||
get_simfin_balance_sheet,
|
get_simfin_balance_sheet,
|
||||||
get_simfin_cashflow,
|
get_simfin_cashflow,
|
||||||
|
|
@ -19,19 +15,25 @@ from .interface import (
|
||||||
# Technical analysis functions
|
# Technical analysis functions
|
||||||
get_stock_stats_indicators_window,
|
get_stock_stats_indicators_window,
|
||||||
get_stockstats_indicator,
|
get_stockstats_indicator,
|
||||||
|
get_YFin_data,
|
||||||
# Market data functions
|
# Market data functions
|
||||||
get_YFin_data_window,
|
get_YFin_data_window,
|
||||||
get_YFin_data,
|
|
||||||
)
|
)
|
||||||
|
from .reddit_utils import fetch_top_from_category
|
||||||
|
from .stockstats_utils import StockstatsUtils
|
||||||
|
from .yfin_utils import YFinanceUtils
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# News and sentiment functions
|
"get_YFin_data",
|
||||||
"get_finnhub_news",
|
# Market data functions
|
||||||
|
"get_YFin_data_window",
|
||||||
"get_finnhub_company_insider_sentiment",
|
"get_finnhub_company_insider_sentiment",
|
||||||
"get_finnhub_company_insider_transactions",
|
"get_finnhub_company_insider_transactions",
|
||||||
|
# News and sentiment functions
|
||||||
|
"get_finnhub_news",
|
||||||
"get_google_news",
|
"get_google_news",
|
||||||
"get_reddit_global_news",
|
|
||||||
"get_reddit_company_news",
|
"get_reddit_company_news",
|
||||||
|
"get_reddit_global_news",
|
||||||
# Financial statements functions
|
# Financial statements functions
|
||||||
"get_simfin_balance_sheet",
|
"get_simfin_balance_sheet",
|
||||||
"get_simfin_cashflow",
|
"get_simfin_cashflow",
|
||||||
|
|
@ -39,7 +41,10 @@ __all__ = [
|
||||||
# Technical analysis functions
|
# Technical analysis functions
|
||||||
"get_stock_stats_indicators_window",
|
"get_stock_stats_indicators_window",
|
||||||
"get_stockstats_indicator",
|
"get_stockstats_indicator",
|
||||||
# Market data functions
|
# Utilities and classes
|
||||||
"get_YFin_data_window",
|
"get_data_in_range",
|
||||||
"get_YFin_data",
|
"getNewsData",
|
||||||
|
"YFinanceUtils",
|
||||||
|
"fetch_top_from_category",
|
||||||
|
"StockstatsUtils",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,9 @@
|
||||||
import tradingagents.default_config as default_config
|
|
||||||
from typing import Dict, Optional
|
from tradingagents import default_config
|
||||||
|
|
||||||
# Use default config but allow it to be overridden
|
# Use default config but allow it to be overridden
|
||||||
_config: Optional[Dict] = None
|
_config: dict | None = None
|
||||||
DATA_DIR: Optional[str] = None
|
DATA_DIR: str | None = None
|
||||||
|
|
||||||
|
|
||||||
def initialize_config():
|
def initialize_config():
|
||||||
|
|
@ -14,7 +14,7 @@ def initialize_config():
|
||||||
DATA_DIR = _config["data_dir"]
|
DATA_DIR = _config["data_dir"]
|
||||||
|
|
||||||
|
|
||||||
def set_config(config: Dict):
|
def set_config(config: dict):
|
||||||
"""Update the configuration with custom values."""
|
"""Update the configuration with custom values."""
|
||||||
global _config, DATA_DIR
|
global _config, DATA_DIR
|
||||||
if _config is None:
|
if _config is None:
|
||||||
|
|
@ -23,7 +23,7 @@ def set_config(config: Dict):
|
||||||
DATA_DIR = _config["data_dir"]
|
DATA_DIR = _config["data_dir"]
|
||||||
|
|
||||||
|
|
||||||
def get_config() -> Dict:
|
def get_config() -> dict:
|
||||||
"""Get the current configuration."""
|
"""Get the current configuration."""
|
||||||
if _config is None:
|
if _config is None:
|
||||||
initialize_config()
|
initialize_config()
|
||||||
|
|
|
||||||
|
|
@ -22,10 +22,10 @@ def get_data_in_range(ticker, start_date, end_date, data_type, data_dir, period=
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
data_path = os.path.join(
|
data_path = os.path.join(
|
||||||
data_dir, "finnhub_data", data_type, f"{ticker}_data_formatted.json"
|
data_dir, "finnhub_data", data_type, f"{ticker}_data_formatted.json",
|
||||||
)
|
)
|
||||||
|
|
||||||
data = open(data_path, "r")
|
data = open(data_path)
|
||||||
data = json.load(data)
|
data = json.load(data)
|
||||||
|
|
||||||
# filter keys (date, str in format YYYY-MM-DD) by the date range (str, str in format YYYY-MM-DD)
|
# filter keys (date, str in format YYYY-MM-DD) by the date range (str, str in format YYYY-MM-DD)
|
||||||
|
|
|
||||||
|
|
@ -1,20 +1,20 @@
|
||||||
import requests
|
|
||||||
from bs4 import BeautifulSoup
|
|
||||||
from datetime import datetime
|
|
||||||
import time
|
|
||||||
import random
|
|
||||||
import logging
|
import logging
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
from datetime import datetime
|
||||||
from urllib.parse import quote_plus
|
from urllib.parse import quote_plus
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
import requests
|
||||||
|
from bs4 import BeautifulSoup
|
||||||
from tenacity import (
|
from tenacity import (
|
||||||
retry,
|
retry,
|
||||||
|
retry_if_result,
|
||||||
stop_after_attempt,
|
stop_after_attempt,
|
||||||
wait_exponential,
|
wait_exponential,
|
||||||
retry_if_result,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def is_rate_limited(response):
|
def is_rate_limited(response):
|
||||||
"""Check if the response indicates we should back off (rate-limited or temporarily unavailable)."""
|
"""Check if the response indicates we should back off (rate-limited or temporarily unavailable)."""
|
||||||
|
|
@ -35,8 +35,7 @@ def _add_jitter(retry_state):
|
||||||
def make_request(url, headers):
|
def make_request(url, headers):
|
||||||
"""Make a request with retry logic for rate limiting"""
|
"""Make a request with retry logic for rate limiting"""
|
||||||
# The retry decorator already applies exponential backoff with jitter
|
# The retry decorator already applies exponential backoff with jitter
|
||||||
response = requests.get(url, headers=headers, timeout=(5, 20))
|
return requests.get(url, headers=headers, timeout=(5, 20))
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
def getNewsData(query, start_date, end_date):
|
def getNewsData(query, start_date, end_date):
|
||||||
|
|
@ -58,7 +57,7 @@ def getNewsData(query, start_date, end_date):
|
||||||
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
|
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
|
||||||
"AppleWebKit/537.36 (KHTML, like Gecko) "
|
"AppleWebKit/537.36 (KHTML, like Gecko) "
|
||||||
"Chrome/101.0.4951.54 Safari/537.36"
|
"Chrome/101.0.4951.54 Safari/537.36"
|
||||||
)
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
news_results = []
|
news_results = []
|
||||||
|
|
@ -103,7 +102,7 @@ def getNewsData(query, start_date, end_date):
|
||||||
"source": (
|
"source": (
|
||||||
source_el.get_text(strip=True) if source_el else ""
|
source_el.get_text(strip=True) if source_el else ""
|
||||||
),
|
),
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("Error processing result: %s", e)
|
logger.warning("Error processing result: %s", e)
|
||||||
|
|
@ -120,7 +119,7 @@ def getNewsData(query, start_date, end_date):
|
||||||
page += 1
|
page += 1
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Failed after multiple retries: %s", e)
|
logger.exception("Failed after multiple retries: %s", e)
|
||||||
break
|
break
|
||||||
|
|
||||||
return news_results
|
return news_results
|
||||||
|
|
|
||||||
|
|
@ -1,17 +1,18 @@
|
||||||
from typing import Annotated
|
|
||||||
from .reddit_utils import fetch_top_from_category
|
|
||||||
from .yfin_utils import *
|
|
||||||
from .stockstats_utils import *
|
|
||||||
from .googlenews_utils import *
|
|
||||||
from .finnhub_utils import get_data_in_range
|
|
||||||
from dateutil.relativedelta import relativedelta
|
|
||||||
from datetime import datetime
|
|
||||||
import os
|
import os
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from tqdm import tqdm
|
|
||||||
import yfinance as yf
|
import yfinance as yf
|
||||||
|
from dateutil.relativedelta import relativedelta
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
from .config import get_config, DATA_DIR
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from .config import DATA_DIR, get_config
|
||||||
|
from .finnhub_utils import get_data_in_range
|
||||||
|
from .googlenews_utils import getNewsData
|
||||||
|
from .reddit_utils import fetch_top_from_category
|
||||||
|
from .stockstats_utils import StockstatsUtils
|
||||||
|
|
||||||
|
|
||||||
def get_finnhub_news(
|
def get_finnhub_news(
|
||||||
|
|
@ -84,7 +85,7 @@ def get_finnhub_company_insider_sentiment(
|
||||||
|
|
||||||
result_str = ""
|
result_str = ""
|
||||||
seen_dicts = []
|
seen_dicts = []
|
||||||
for date, senti_list in data.items():
|
for senti_list in data.values():
|
||||||
for entry in senti_list:
|
for entry in senti_list:
|
||||||
if entry not in seen_dicts:
|
if entry not in seen_dicts:
|
||||||
result_str += f"### {entry['year']}-{entry['month']}:\nChange: {entry['change']}\nMonthly Share Purchase Ratio: {entry['mspr']}\n\n"
|
result_str += f"### {entry['year']}-{entry['month']}:\nChange: {entry['change']}\nMonthly Share Purchase Ratio: {entry['mspr']}\n\n"
|
||||||
|
|
@ -126,7 +127,7 @@ def get_finnhub_company_insider_transactions(
|
||||||
result_str = ""
|
result_str = ""
|
||||||
|
|
||||||
seen_dicts = []
|
seen_dicts = []
|
||||||
for date, senti_list in data.items():
|
for senti_list in data.values():
|
||||||
for entry in senti_list:
|
for entry in senti_list:
|
||||||
if entry not in seen_dicts:
|
if entry not in seen_dicts:
|
||||||
result_str += f"### Filing Date: {entry['filingDate']}, {entry['name']}:\nChange:{entry['change']}\nShares: {entry['share']}\nTransaction Price: {entry['transactionPrice']}\nTransaction Code: {entry['transactionCode']}\n\n"
|
result_str += f"### Filing Date: {entry['filingDate']}, {entry['name']}:\nChange:{entry['change']}\nShares: {entry['share']}\nTransaction Price: {entry['transactionPrice']}\nTransaction Code: {entry['transactionCode']}\n\n"
|
||||||
|
|
@ -170,7 +171,6 @@ def get_simfin_balance_sheet(
|
||||||
|
|
||||||
# Check if there are any available reports; if not, return a notification
|
# Check if there are any available reports; if not, return a notification
|
||||||
if filtered_df.empty:
|
if filtered_df.empty:
|
||||||
print("No balance sheet available before the given current date.")
|
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
# Get the most recent balance sheet by selecting the row with the latest Publish Date
|
# Get the most recent balance sheet by selecting the row with the latest Publish Date
|
||||||
|
|
@ -217,7 +217,6 @@ def get_simfin_cashflow(
|
||||||
|
|
||||||
# Check if there are any available reports; if not, return a notification
|
# Check if there are any available reports; if not, return a notification
|
||||||
if filtered_df.empty:
|
if filtered_df.empty:
|
||||||
print("No cash flow statement available before the given current date.")
|
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
# Get the most recent cash flow statement by selecting the row with the latest Publish Date
|
# Get the most recent cash flow statement by selecting the row with the latest Publish Date
|
||||||
|
|
@ -264,7 +263,6 @@ def get_simfin_income_statements(
|
||||||
|
|
||||||
# Check if there are any available reports; if not, return a notification
|
# Check if there are any available reports; if not, return a notification
|
||||||
if filtered_df.empty:
|
if filtered_df.empty:
|
||||||
print("No income statement available before the given current date.")
|
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
# Get the most recent income statement by selecting the row with the latest Publish Date
|
# Get the most recent income statement by selecting the row with the latest Publish Date
|
||||||
|
|
@ -421,7 +419,7 @@ def get_stock_stats_indicators_window(
|
||||||
symbol: Annotated[str, "ticker symbol of the company"],
|
symbol: Annotated[str, "ticker symbol of the company"],
|
||||||
indicator: Annotated[str, "technical indicator to get the analysis and report of"],
|
indicator: Annotated[str, "technical indicator to get the analysis and report of"],
|
||||||
curr_date: Annotated[
|
curr_date: Annotated[
|
||||||
str, "The current trading date you are trading on, YYYY-mm-dd"
|
str, "The current trading date you are trading on, YYYY-mm-dd",
|
||||||
],
|
],
|
||||||
look_back_days: Annotated[int, "how many days to look back"],
|
look_back_days: Annotated[int, "how many days to look back"],
|
||||||
online: Annotated[bool, "to fetch data online or offline"],
|
online: Annotated[bool, "to fetch data online or offline"],
|
||||||
|
|
@ -501,8 +499,9 @@ def get_stock_stats_indicators_window(
|
||||||
}
|
}
|
||||||
|
|
||||||
if indicator not in best_ind_params:
|
if indicator not in best_ind_params:
|
||||||
|
msg = f"Indicator {indicator} is not supported. Please choose from: {list(best_ind_params.keys())}"
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Indicator {indicator} is not supported. Please choose from: {list(best_ind_params.keys())}"
|
msg,
|
||||||
)
|
)
|
||||||
|
|
||||||
end_date = curr_date
|
end_date = curr_date
|
||||||
|
|
@ -515,7 +514,7 @@ def get_stock_stats_indicators_window(
|
||||||
os.path.join(
|
os.path.join(
|
||||||
DATA_DIR,
|
DATA_DIR,
|
||||||
f"market_data/price_data/{symbol}-YFin-data-2015-01-01-2025-03-25.csv",
|
f"market_data/price_data/{symbol}-YFin-data-2015-01-01-2025-03-25.csv",
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
data["Date"] = pd.to_datetime(data["Date"], utc=True)
|
data["Date"] = pd.to_datetime(data["Date"], utc=True)
|
||||||
dates_in_df = data["Date"].astype(str).str[:10]
|
dates_in_df = data["Date"].astype(str).str[:10]
|
||||||
|
|
@ -525,7 +524,7 @@ def get_stock_stats_indicators_window(
|
||||||
# only do the trading dates
|
# only do the trading dates
|
||||||
if curr_date.strftime("%Y-%m-%d") in dates_in_df.values:
|
if curr_date.strftime("%Y-%m-%d") in dates_in_df.values:
|
||||||
indicator_value = get_stockstats_indicator(
|
indicator_value = get_stockstats_indicator(
|
||||||
symbol, indicator, curr_date.strftime("%Y-%m-%d"), online
|
symbol, indicator, curr_date.strftime("%Y-%m-%d"), online,
|
||||||
)
|
)
|
||||||
|
|
||||||
ind_string += f"{curr_date.strftime('%Y-%m-%d')}: {indicator_value}\n"
|
ind_string += f"{curr_date.strftime('%Y-%m-%d')}: {indicator_value}\n"
|
||||||
|
|
@ -536,28 +535,27 @@ def get_stock_stats_indicators_window(
|
||||||
ind_string = ""
|
ind_string = ""
|
||||||
while curr_date >= before:
|
while curr_date >= before:
|
||||||
indicator_value = get_stockstats_indicator(
|
indicator_value = get_stockstats_indicator(
|
||||||
symbol, indicator, curr_date.strftime("%Y-%m-%d"), online
|
symbol, indicator, curr_date.strftime("%Y-%m-%d"), online,
|
||||||
)
|
)
|
||||||
|
|
||||||
ind_string += f"{curr_date.strftime('%Y-%m-%d')}: {indicator_value}\n"
|
ind_string += f"{curr_date.strftime('%Y-%m-%d')}: {indicator_value}\n"
|
||||||
|
|
||||||
curr_date = curr_date - relativedelta(days=1)
|
curr_date = curr_date - relativedelta(days=1)
|
||||||
|
|
||||||
result_str = (
|
return (
|
||||||
f"## {indicator} values from {before.strftime('%Y-%m-%d')} to {end_date}:\n\n"
|
f"## {indicator} values from {before.strftime('%Y-%m-%d')} to {end_date}:\n\n"
|
||||||
+ ind_string
|
+ ind_string
|
||||||
+ "\n\n"
|
+ "\n\n"
|
||||||
+ best_ind_params.get(indicator, "No description available.")
|
+ best_ind_params.get(indicator, "No description available.")
|
||||||
)
|
)
|
||||||
|
|
||||||
return result_str
|
|
||||||
|
|
||||||
|
|
||||||
def get_stockstats_indicator(
|
def get_stockstats_indicator(
|
||||||
symbol: Annotated[str, "ticker symbol of the company"],
|
symbol: Annotated[str, "ticker symbol of the company"],
|
||||||
indicator: Annotated[str, "technical indicator to get the analysis and report of"],
|
indicator: Annotated[str, "technical indicator to get the analysis and report of"],
|
||||||
curr_date: Annotated[
|
curr_date: Annotated[
|
||||||
str, "The current trading date you are trading on, YYYY-mm-dd"
|
str, "The current trading date you are trading on, YYYY-mm-dd",
|
||||||
],
|
],
|
||||||
online: Annotated[bool, "to fetch data online or offline"],
|
online: Annotated[bool, "to fetch data online or offline"],
|
||||||
) -> str:
|
) -> str:
|
||||||
|
|
@ -573,10 +571,7 @@ def get_stockstats_indicator(
|
||||||
os.path.join(DATA_DIR, "market_data", "price_data"),
|
os.path.join(DATA_DIR, "market_data", "price_data"),
|
||||||
online=online,
|
online=online,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
print(
|
|
||||||
f"Error getting stockstats indicator data for indicator {indicator} on {curr_date}: {e}"
|
|
||||||
)
|
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
return str(indicator_value)
|
return str(indicator_value)
|
||||||
|
|
@ -597,7 +592,7 @@ def get_YFin_data_window(
|
||||||
os.path.join(
|
os.path.join(
|
||||||
DATA_DIR,
|
DATA_DIR,
|
||||||
f"market_data/price_data/{symbol}-YFin-data-2015-01-01-2025-03-25.csv",
|
f"market_data/price_data/{symbol}-YFin-data-2015-01-01-2025-03-25.csv",
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Extract just the date part for comparison
|
# Extract just the date part for comparison
|
||||||
|
|
@ -613,7 +608,7 @@ def get_YFin_data_window(
|
||||||
|
|
||||||
# Set pandas display options to show the full DataFrame
|
# Set pandas display options to show the full DataFrame
|
||||||
with pd.option_context(
|
with pd.option_context(
|
||||||
"display.max_rows", None, "display.max_columns", None, "display.width", None
|
"display.max_rows", None, "display.max_columns", None, "display.width", None,
|
||||||
):
|
):
|
||||||
df_string = filtered_data.to_string()
|
df_string = filtered_data.to_string()
|
||||||
|
|
||||||
|
|
@ -675,12 +670,13 @@ def get_YFin_data(
|
||||||
os.path.join(
|
os.path.join(
|
||||||
DATA_DIR,
|
DATA_DIR,
|
||||||
f"market_data/price_data/{symbol}-YFin-data-2015-01-01-2025-03-25.csv",
|
f"market_data/price_data/{symbol}-YFin-data-2015-01-01-2025-03-25.csv",
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
if end_date > "2025-03-25":
|
if end_date > "2025-03-25":
|
||||||
|
msg = f"Get_YFin_Data: {end_date} is outside of the data range of 2015-01-01 to 2025-03-25"
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Get_YFin_Data: {end_date} is outside of the data range of 2015-01-01 to 2025-03-25"
|
msg,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Extract just the date part for comparison
|
# Extract just the date part for comparison
|
||||||
|
|
@ -695,9 +691,8 @@ def get_YFin_data(
|
||||||
filtered_data = filtered_data.drop("DateOnly", axis=1)
|
filtered_data = filtered_data.drop("DateOnly", axis=1)
|
||||||
|
|
||||||
# remove the index from the dataframe
|
# remove the index from the dataframe
|
||||||
filtered_data = filtered_data.reset_index(drop=True)
|
return filtered_data.reset_index(drop=True)
|
||||||
|
|
||||||
return filtered_data
|
|
||||||
|
|
||||||
|
|
||||||
def get_stock_news_openai(ticker, curr_date):
|
def get_stock_news_openai(ticker, curr_date):
|
||||||
|
|
@ -713,9 +708,9 @@ def get_stock_news_openai(ticker, curr_date):
|
||||||
{
|
{
|
||||||
"type": "input_text",
|
"type": "input_text",
|
||||||
"text": f"Can you search Social Media for {ticker} from 7 days before {curr_date} to {curr_date}? Make sure you only get the data posted during that period.",
|
"text": f"Can you search Social Media for {ticker} from 7 days before {curr_date} to {curr_date}? Make sure you only get the data posted during that period.",
|
||||||
}
|
},
|
||||||
],
|
],
|
||||||
}
|
},
|
||||||
],
|
],
|
||||||
text={"format": {"type": "text"}},
|
text={"format": {"type": "text"}},
|
||||||
reasoning={},
|
reasoning={},
|
||||||
|
|
@ -724,7 +719,7 @@ def get_stock_news_openai(ticker, curr_date):
|
||||||
"type": "web_search_preview",
|
"type": "web_search_preview",
|
||||||
"user_location": {"type": "approximate"},
|
"user_location": {"type": "approximate"},
|
||||||
"search_context_size": "low",
|
"search_context_size": "low",
|
||||||
}
|
},
|
||||||
],
|
],
|
||||||
temperature=1,
|
temperature=1,
|
||||||
max_output_tokens=4096,
|
max_output_tokens=4096,
|
||||||
|
|
@ -748,9 +743,9 @@ def get_global_news_openai(curr_date):
|
||||||
{
|
{
|
||||||
"type": "input_text",
|
"type": "input_text",
|
||||||
"text": f"Can you search global or macroeconomics news from 7 days before {curr_date} to {curr_date} that would be informative for trading purposes? Make sure you only get the data posted during that period.",
|
"text": f"Can you search global or macroeconomics news from 7 days before {curr_date} to {curr_date} that would be informative for trading purposes? Make sure you only get the data posted during that period.",
|
||||||
}
|
},
|
||||||
],
|
],
|
||||||
}
|
},
|
||||||
],
|
],
|
||||||
text={"format": {"type": "text"}},
|
text={"format": {"type": "text"}},
|
||||||
reasoning={},
|
reasoning={},
|
||||||
|
|
@ -759,7 +754,7 @@ def get_global_news_openai(curr_date):
|
||||||
"type": "web_search_preview",
|
"type": "web_search_preview",
|
||||||
"user_location": {"type": "approximate"},
|
"user_location": {"type": "approximate"},
|
||||||
"search_context_size": "low",
|
"search_context_size": "low",
|
||||||
}
|
},
|
||||||
],
|
],
|
||||||
temperature=1,
|
temperature=1,
|
||||||
max_output_tokens=4096,
|
max_output_tokens=4096,
|
||||||
|
|
@ -783,9 +778,9 @@ def get_fundamentals_openai(ticker, curr_date):
|
||||||
{
|
{
|
||||||
"type": "input_text",
|
"type": "input_text",
|
||||||
"text": f"Can you search Fundamental for discussions on {ticker} during of the month before {curr_date} to the month of {curr_date}. Make sure you only get the data posted during that period. List as a table, with PE/PS/Cash flow/ etc",
|
"text": f"Can you search Fundamental for discussions on {ticker} during of the month before {curr_date} to the month of {curr_date}. Make sure you only get the data posted during that period. List as a table, with PE/PS/Cash flow/ etc",
|
||||||
}
|
},
|
||||||
],
|
],
|
||||||
}
|
},
|
||||||
],
|
],
|
||||||
text={"format": {"type": "text"}},
|
text={"format": {"type": "text"}},
|
||||||
reasoning={},
|
reasoning={},
|
||||||
|
|
@ -794,7 +789,7 @@ def get_fundamentals_openai(ticker, curr_date):
|
||||||
"type": "web_search_preview",
|
"type": "web_search_preview",
|
||||||
"user_location": {"type": "approximate"},
|
"user_location": {"type": "approximate"},
|
||||||
"search_context_size": "low",
|
"search_context_size": "low",
|
||||||
}
|
},
|
||||||
],
|
],
|
||||||
temperature=1,
|
temperature=1,
|
||||||
max_output_tokens=4096,
|
max_output_tokens=4096,
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,8 @@
|
||||||
import json
|
import json
|
||||||
from datetime import datetime
|
|
||||||
from typing import Annotated
|
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
ticker_to_company = {
|
ticker_to_company = {
|
||||||
"AAPL": "Apple",
|
"AAPL": "Apple",
|
||||||
|
|
@ -48,11 +48,11 @@ ticker_to_company = {
|
||||||
|
|
||||||
def fetch_top_from_category(
|
def fetch_top_from_category(
|
||||||
category: Annotated[
|
category: Annotated[
|
||||||
str, "Category to fetch top post from. Collection of subreddits."
|
str, "Category to fetch top post from. Collection of subreddits.",
|
||||||
],
|
],
|
||||||
date: Annotated[str, "Date to fetch top posts from."],
|
date: Annotated[str, "Date to fetch top posts from."],
|
||||||
max_limit: Annotated[int, "Maximum number of posts to fetch."],
|
max_limit: Annotated[int, "Maximum number of posts to fetch."],
|
||||||
query: Annotated[str, "Optional query to search for in the subreddit."] = None,
|
query: Annotated[str | None, "Optional query to search for in the subreddit."] = None,
|
||||||
data_path: Annotated[
|
data_path: Annotated[
|
||||||
str,
|
str,
|
||||||
"Path to the data folder. Default is 'reddit_data'.",
|
"Path to the data folder. Default is 'reddit_data'.",
|
||||||
|
|
@ -63,12 +63,13 @@ def fetch_top_from_category(
|
||||||
all_content = []
|
all_content = []
|
||||||
|
|
||||||
if max_limit < len(os.listdir(os.path.join(base_path, category))):
|
if max_limit < len(os.listdir(os.path.join(base_path, category))):
|
||||||
|
msg = "REDDIT FETCHING ERROR: max limit is less than the number of files in the category. Will not be able to fetch any posts"
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"REDDIT FETCHING ERROR: max limit is less than the number of files in the category. Will not be able to fetch any posts"
|
msg,
|
||||||
)
|
)
|
||||||
|
|
||||||
limit_per_subreddit = max_limit // len(
|
limit_per_subreddit = max_limit // len(
|
||||||
os.listdir(os.path.join(base_path, category))
|
os.listdir(os.path.join(base_path, category)),
|
||||||
)
|
)
|
||||||
|
|
||||||
for data_file in os.listdir(os.path.join(base_path, category)):
|
for data_file in os.listdir(os.path.join(base_path, category)):
|
||||||
|
|
@ -79,7 +80,7 @@ def fetch_top_from_category(
|
||||||
all_content_curr_subreddit = []
|
all_content_curr_subreddit = []
|
||||||
|
|
||||||
with open(os.path.join(base_path, category, data_file), "rb") as f:
|
with open(os.path.join(base_path, category, data_file), "rb") as f:
|
||||||
for i, line in enumerate(f):
|
for _i, line in enumerate(f):
|
||||||
# skip empty lines
|
# skip empty lines
|
||||||
if not line.strip():
|
if not line.strip():
|
||||||
continue
|
continue
|
||||||
|
|
@ -88,7 +89,7 @@ def fetch_top_from_category(
|
||||||
|
|
||||||
# select only lines that are from the date
|
# select only lines that are from the date
|
||||||
post_date = datetime.utcfromtimestamp(
|
post_date = datetime.utcfromtimestamp(
|
||||||
parsed_line["created_utc"]
|
parsed_line["created_utc"],
|
||||||
).strftime("%Y-%m-%d")
|
).strftime("%Y-%m-%d")
|
||||||
if post_date != date:
|
if post_date != date:
|
||||||
continue
|
continue
|
||||||
|
|
@ -106,7 +107,7 @@ def fetch_top_from_category(
|
||||||
found = False
|
found = False
|
||||||
for term in search_terms:
|
for term in search_terms:
|
||||||
if re.search(
|
if re.search(
|
||||||
term, parsed_line["title"], re.IGNORECASE
|
term, parsed_line["title"], re.IGNORECASE,
|
||||||
) or re.search(term, parsed_line["selftext"], re.IGNORECASE):
|
) or re.search(term, parsed_line["selftext"], re.IGNORECASE):
|
||||||
found = True
|
found = True
|
||||||
break
|
break
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,10 @@
|
||||||
|
import os
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import yfinance as yf
|
import yfinance as yf
|
||||||
from stockstats import wrap
|
from stockstats import wrap
|
||||||
from typing import Annotated
|
|
||||||
import os
|
|
||||||
from .config import get_config
|
from .config import get_config
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -11,10 +13,10 @@ class StockstatsUtils:
|
||||||
def get_stock_stats(
|
def get_stock_stats(
|
||||||
symbol: Annotated[str, "ticker symbol for the company"],
|
symbol: Annotated[str, "ticker symbol for the company"],
|
||||||
indicator: Annotated[
|
indicator: Annotated[
|
||||||
str, "quantitative indicators based off of the stock data for the company"
|
str, "quantitative indicators based off of the stock data for the company",
|
||||||
],
|
],
|
||||||
curr_date: Annotated[
|
curr_date: Annotated[
|
||||||
str, "curr date for retrieving stock price data, YYYY-mm-dd"
|
str, "curr date for retrieving stock price data, YYYY-mm-dd",
|
||||||
],
|
],
|
||||||
data_dir: Annotated[
|
data_dir: Annotated[
|
||||||
str,
|
str,
|
||||||
|
|
@ -34,11 +36,12 @@ class StockstatsUtils:
|
||||||
os.path.join(
|
os.path.join(
|
||||||
data_dir,
|
data_dir,
|
||||||
f"{symbol}-YFin-data-2015-01-01-2025-03-25.csv",
|
f"{symbol}-YFin-data-2015-01-01-2025-03-25.csv",
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
df = wrap(data)
|
df = wrap(data)
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
raise Exception("Stockstats fail: Yahoo Finance data not fetched yet!")
|
msg = "Stockstats fail: Yahoo Finance data not fetched yet!"
|
||||||
|
raise Exception(msg)
|
||||||
else:
|
else:
|
||||||
# Get today's date as YYYY-mm-dd to add to cache
|
# Get today's date as YYYY-mm-dd to add to cache
|
||||||
today_date = pd.Timestamp.today()
|
today_date = pd.Timestamp.today()
|
||||||
|
|
@ -81,7 +84,5 @@ class StockstatsUtils:
|
||||||
matching_rows = df[df["Date"].str.startswith(curr_date)]
|
matching_rows = df[df["Date"].str.startswith(curr_date)]
|
||||||
|
|
||||||
if not matching_rows.empty:
|
if not matching_rows.empty:
|
||||||
indicator_value = matching_rows[indicator].values[0]
|
return matching_rows[indicator].values[0]
|
||||||
return indicator_value
|
return "N/A: Not a trading day (weekend or holiday)"
|
||||||
else:
|
|
||||||
return "N/A: Not a trading day (weekend or holiday)"
|
|
||||||
|
|
|
||||||
|
|
@ -1,14 +1,14 @@
|
||||||
import pandas as pd
|
from datetime import date, datetime, timedelta
|
||||||
from datetime import date, timedelta, datetime
|
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
SavePathType = Annotated[str, "File path to save data. If None, data is not saved."]
|
SavePathType = Annotated[str, "File path to save data. If None, data is not saved."]
|
||||||
|
|
||||||
|
|
||||||
def save_output(data: pd.DataFrame, tag: str, save_path: SavePathType = None) -> None:
|
def save_output(data: pd.DataFrame, tag: str, save_path: SavePathType = None) -> None:
|
||||||
if save_path:
|
if save_path:
|
||||||
data.to_csv(save_path)
|
data.to_csv(save_path)
|
||||||
print(f"{tag} saved to {save_path}")
|
|
||||||
|
|
||||||
|
|
||||||
def get_current_date():
|
def get_current_date():
|
||||||
|
|
@ -32,7 +32,5 @@ def get_next_weekday(date):
|
||||||
|
|
||||||
if date.weekday() >= 5:
|
if date.weekday() >= 5:
|
||||||
days_to_add = 7 - date.weekday()
|
days_to_add = 7 - date.weekday()
|
||||||
next_weekday = date + timedelta(days=days_to_add)
|
return date + timedelta(days=days_to_add)
|
||||||
return next_weekday
|
return date
|
||||||
else:
|
|
||||||
return date
|
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,12 @@
|
||||||
# gets data/stats
|
# gets data/stats
|
||||||
|
|
||||||
import yfinance as yf
|
from collections.abc import Callable
|
||||||
from typing import Annotated, Callable, Any, Optional
|
|
||||||
from pandas import DataFrame
|
|
||||||
import pandas as pd
|
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
from typing import Annotated, Any
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import yfinance as yf
|
||||||
|
from pandas import DataFrame
|
||||||
|
|
||||||
from .utils import SavePathType, decorate_all_methods
|
from .utils import SavePathType, decorate_all_methods
|
||||||
|
|
||||||
|
|
@ -24,38 +26,36 @@ def init_ticker(func: Callable) -> Callable:
|
||||||
class YFinanceUtils:
|
class YFinanceUtils:
|
||||||
|
|
||||||
def get_stock_data(
|
def get_stock_data(
|
||||||
symbol: Annotated[str, "ticker symbol"],
|
self: Annotated[str, "ticker symbol"],
|
||||||
start_date: Annotated[
|
start_date: Annotated[
|
||||||
str, "start date for retrieving stock price data, YYYY-mm-dd"
|
str, "start date for retrieving stock price data, YYYY-mm-dd",
|
||||||
],
|
],
|
||||||
end_date: Annotated[
|
end_date: Annotated[
|
||||||
str, "end date for retrieving stock price data, YYYY-mm-dd"
|
str, "end date for retrieving stock price data, YYYY-mm-dd",
|
||||||
],
|
],
|
||||||
save_path: SavePathType = None,
|
save_path: SavePathType = None,
|
||||||
) -> DataFrame:
|
) -> DataFrame:
|
||||||
"""retrieve stock price data for designated ticker symbol"""
|
"""retrieve stock price data for designated ticker symbol"""
|
||||||
ticker = symbol
|
ticker = self
|
||||||
# add one day to the end_date so that the data range is inclusive
|
# add one day to the end_date so that the data range is inclusive
|
||||||
end_date = pd.to_datetime(end_date) + pd.DateOffset(days=1)
|
end_date = pd.to_datetime(end_date) + pd.DateOffset(days=1)
|
||||||
end_date = end_date.strftime("%Y-%m-%d")
|
end_date = end_date.strftime("%Y-%m-%d")
|
||||||
stock_data = ticker.history(start=start_date, end=end_date)
|
return ticker.history(start=start_date, end=end_date)
|
||||||
# save_output(stock_data, f"Stock data for {ticker.ticker}", save_path)
|
# save_output(stock_data, f"Stock data for {ticker.ticker}", save_path)
|
||||||
return stock_data
|
|
||||||
|
|
||||||
def get_stock_info(
|
def get_stock_info(
|
||||||
symbol: Annotated[str, "ticker symbol"],
|
self: Annotated[str, "ticker symbol"],
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""Fetches and returns latest stock information."""
|
"""Fetches and returns latest stock information."""
|
||||||
ticker = symbol
|
ticker = self
|
||||||
stock_info = ticker.info
|
return ticker.info
|
||||||
return stock_info
|
|
||||||
|
|
||||||
def get_company_info(
|
def get_company_info(
|
||||||
symbol: Annotated[str, "ticker symbol"],
|
self: Annotated[str, "ticker symbol"],
|
||||||
save_path: Optional[str] = None,
|
save_path: str | None = None,
|
||||||
) -> DataFrame:
|
) -> DataFrame:
|
||||||
"""Fetches and returns company information as a DataFrame."""
|
"""Fetches and returns company information as a DataFrame."""
|
||||||
ticker = symbol
|
ticker = self
|
||||||
info = ticker.info
|
info = ticker.info
|
||||||
company_info = {
|
company_info = {
|
||||||
"Company Name": info.get("shortName", "N/A"),
|
"Company Name": info.get("shortName", "N/A"),
|
||||||
|
|
@ -67,42 +67,37 @@ class YFinanceUtils:
|
||||||
company_info_df = DataFrame([company_info])
|
company_info_df = DataFrame([company_info])
|
||||||
if save_path:
|
if save_path:
|
||||||
company_info_df.to_csv(save_path)
|
company_info_df.to_csv(save_path)
|
||||||
print(f"Company info for {ticker.ticker} saved to {save_path}")
|
|
||||||
return company_info_df
|
return company_info_df
|
||||||
|
|
||||||
def get_stock_dividends(
|
def get_stock_dividends(
|
||||||
symbol: Annotated[str, "ticker symbol"],
|
self: Annotated[str, "ticker symbol"],
|
||||||
save_path: Optional[str] = None,
|
save_path: str | None = None,
|
||||||
) -> DataFrame:
|
) -> DataFrame:
|
||||||
"""Fetches and returns the latest dividends data as a DataFrame."""
|
"""Fetches and returns the latest dividends data as a DataFrame."""
|
||||||
ticker = symbol
|
ticker = self
|
||||||
dividends = ticker.dividends
|
dividends = ticker.dividends
|
||||||
if save_path:
|
if save_path:
|
||||||
dividends.to_csv(save_path)
|
dividends.to_csv(save_path)
|
||||||
print(f"Dividends for {ticker.ticker} saved to {save_path}")
|
|
||||||
return dividends
|
return dividends
|
||||||
|
|
||||||
def get_income_stmt(symbol: Annotated[str, "ticker symbol"]) -> DataFrame:
|
def get_income_stmt(self: Annotated[str, "ticker symbol"]) -> DataFrame:
|
||||||
"""Fetches and returns the latest income statement of the company as a DataFrame."""
|
"""Fetches and returns the latest income statement of the company as a DataFrame."""
|
||||||
ticker = symbol
|
ticker = self
|
||||||
income_stmt = ticker.financials
|
return ticker.financials
|
||||||
return income_stmt
|
|
||||||
|
|
||||||
def get_balance_sheet(symbol: Annotated[str, "ticker symbol"]) -> DataFrame:
|
def get_balance_sheet(self: Annotated[str, "ticker symbol"]) -> DataFrame:
|
||||||
"""Fetches and returns the latest balance sheet of the company as a DataFrame."""
|
"""Fetches and returns the latest balance sheet of the company as a DataFrame."""
|
||||||
ticker = symbol
|
ticker = self
|
||||||
balance_sheet = ticker.balance_sheet
|
return ticker.balance_sheet
|
||||||
return balance_sheet
|
|
||||||
|
|
||||||
def get_cash_flow(symbol: Annotated[str, "ticker symbol"]) -> DataFrame:
|
def get_cash_flow(self: Annotated[str, "ticker symbol"]) -> DataFrame:
|
||||||
"""Fetches and returns the latest cash flow statement of the company as a DataFrame."""
|
"""Fetches and returns the latest cash flow statement of the company as a DataFrame."""
|
||||||
ticker = symbol
|
ticker = self
|
||||||
cash_flow = ticker.cashflow
|
return ticker.cashflow
|
||||||
return cash_flow
|
|
||||||
|
|
||||||
def get_analyst_recommendations(symbol: Annotated[str, "ticker symbol"]) -> tuple:
|
def get_analyst_recommendations(self: Annotated[str, "ticker symbol"]) -> tuple:
|
||||||
"""Fetches the latest analyst recommendations and returns the most common recommendation and its count."""
|
"""Fetches the latest analyst recommendations and returns the most common recommendation and its count."""
|
||||||
ticker = symbol
|
ticker = self
|
||||||
recommendations = ticker.recommendations
|
recommendations = ticker.recommendations
|
||||||
if recommendations.empty:
|
if recommendations.empty:
|
||||||
return None, 0 # No recommendations available
|
return None, 0 # No recommendations available
|
||||||
|
|
|
||||||
|
|
@ -1,17 +1,17 @@
|
||||||
# TradingAgents/graph/__init__.py
|
# TradingAgents/graph/__init__.py
|
||||||
|
|
||||||
from .trading_graph import TradingAgentsGraph
|
|
||||||
from .conditional_logic import ConditionalLogic
|
from .conditional_logic import ConditionalLogic
|
||||||
from .setup import GraphSetup
|
|
||||||
from .propagation import Propagator
|
from .propagation import Propagator
|
||||||
from .reflection import Reflector
|
from .reflection import Reflector
|
||||||
|
from .setup import GraphSetup
|
||||||
from .signal_processing import SignalProcessor
|
from .signal_processing import SignalProcessor
|
||||||
|
from .trading_graph import TradingAgentsGraph
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"TradingAgentsGraph",
|
|
||||||
"ConditionalLogic",
|
"ConditionalLogic",
|
||||||
"GraphSetup",
|
"GraphSetup",
|
||||||
"Propagator",
|
"Propagator",
|
||||||
"Reflector",
|
"Reflector",
|
||||||
"SignalProcessor",
|
"SignalProcessor",
|
||||||
|
"TradingAgentsGraph",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
# TradingAgents/graph/propagation.py
|
# TradingAgents/graph/propagation.py
|
||||||
|
|
||||||
from typing import Dict, Any
|
from typing import Any
|
||||||
|
|
||||||
from tradingagents.agents.utils.agent_states import (
|
from tradingagents.agents.utils.agent_states import (
|
||||||
InvestDebateState,
|
InvestDebateState,
|
||||||
RiskDebateState,
|
RiskDebateState,
|
||||||
|
|
@ -15,15 +16,15 @@ class Propagator:
|
||||||
self.max_recur_limit = max_recur_limit
|
self.max_recur_limit = max_recur_limit
|
||||||
|
|
||||||
def create_initial_state(
|
def create_initial_state(
|
||||||
self, company_name: str, trade_date: str
|
self, company_name: str, trade_date: str,
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Create the initial state for the agent graph."""
|
"""Create the initial state for the agent graph."""
|
||||||
return {
|
return {
|
||||||
"messages": [("human", company_name)],
|
"messages": [("human", company_name)],
|
||||||
"company_of_interest": company_name,
|
"company_of_interest": company_name,
|
||||||
"trade_date": str(trade_date),
|
"trade_date": str(trade_date),
|
||||||
"investment_debate_state": InvestDebateState(
|
"investment_debate_state": InvestDebateState(
|
||||||
{"history": "", "current_response": "", "count": 0}
|
{"history": "", "current_response": "", "count": 0},
|
||||||
),
|
),
|
||||||
"risk_debate_state": RiskDebateState(
|
"risk_debate_state": RiskDebateState(
|
||||||
{
|
{
|
||||||
|
|
@ -32,7 +33,7 @@ class Propagator:
|
||||||
"current_safe_response": "",
|
"current_safe_response": "",
|
||||||
"current_neutral_response": "",
|
"current_neutral_response": "",
|
||||||
"count": 0,
|
"count": 0,
|
||||||
}
|
},
|
||||||
),
|
),
|
||||||
"market_report": "",
|
"market_report": "",
|
||||||
"fundamentals_report": "",
|
"fundamentals_report": "",
|
||||||
|
|
@ -40,7 +41,7 @@ class Propagator:
|
||||||
"news_report": "",
|
"news_report": "",
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_graph_args(self) -> Dict[str, Any]:
|
def get_graph_args(self) -> dict[str, Any]:
|
||||||
"""Get arguments for the graph invocation."""
|
"""Get arguments for the graph invocation."""
|
||||||
return {
|
return {
|
||||||
"stream_mode": "values",
|
"stream_mode": "values",
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
# TradingAgents/graph/reflection.py
|
# TradingAgents/graph/reflection.py
|
||||||
|
|
||||||
from typing import Dict, Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain_openai import ChatOpenAI
|
from langchain_openai import ChatOpenAI
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -15,7 +16,7 @@ class Reflector:
|
||||||
def _get_reflection_prompt(self) -> str:
|
def _get_reflection_prompt(self) -> str:
|
||||||
"""Get the system prompt for reflection."""
|
"""Get the system prompt for reflection."""
|
||||||
return """
|
return """
|
||||||
You are an expert financial analyst tasked with reviewing trading decisions/analysis and providing a comprehensive, step-by-step analysis.
|
You are an expert financial analyst tasked with reviewing trading decisions/analysis and providing a comprehensive, step-by-step analysis.
|
||||||
Your goal is to deliver detailed insights into investment decisions and highlight opportunities for improvement, adhering strictly to the following guidelines:
|
Your goal is to deliver detailed insights into investment decisions and highlight opportunities for improvement, adhering strictly to the following guidelines:
|
||||||
|
|
||||||
1. Reasoning:
|
1. Reasoning:
|
||||||
|
|
@ -25,7 +26,7 @@ Your goal is to deliver detailed insights into investment decisions and highligh
|
||||||
- Technical indicators.
|
- Technical indicators.
|
||||||
- Technical signals.
|
- Technical signals.
|
||||||
- Price movement analysis.
|
- Price movement analysis.
|
||||||
- Overall market data analysis
|
- Overall market data analysis
|
||||||
- News analysis.
|
- News analysis.
|
||||||
- Social media and sentiment analysis.
|
- Social media and sentiment analysis.
|
||||||
- Fundamental data analysis.
|
- Fundamental data analysis.
|
||||||
|
|
@ -46,7 +47,7 @@ Your goal is to deliver detailed insights into investment decisions and highligh
|
||||||
Adhere strictly to these instructions, and ensure your output is detailed, accurate, and actionable. You will also be given objective descriptions of the market from a price movements, technical indicator, news, and sentiment perspective to provide more context for your analysis.
|
Adhere strictly to these instructions, and ensure your output is detailed, accurate, and actionable. You will also be given objective descriptions of the market from a price movements, technical indicator, news, and sentiment perspective to provide more context for your analysis.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _extract_current_situation(self, current_state: Dict[str, Any]) -> str:
|
def _extract_current_situation(self, current_state: dict[str, Any]) -> str:
|
||||||
"""Extract the current market situation from the state."""
|
"""Extract the current market situation from the state."""
|
||||||
curr_market_report = current_state["market_report"]
|
curr_market_report = current_state["market_report"]
|
||||||
curr_sentiment_report = current_state["sentiment_report"]
|
curr_sentiment_report = current_state["sentiment_report"]
|
||||||
|
|
@ -56,7 +57,7 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur
|
||||||
return f"{curr_market_report}\n\n{curr_sentiment_report}\n\n{curr_news_report}\n\n{curr_fundamentals_report}"
|
return f"{curr_market_report}\n\n{curr_sentiment_report}\n\n{curr_news_report}\n\n{curr_fundamentals_report}"
|
||||||
|
|
||||||
def _reflect_on_component(
|
def _reflect_on_component(
|
||||||
self, component_type: str, report: str, situation: str, returns_losses
|
self, component_type: str, report: str, situation: str, returns_losses,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Generate reflection for a component."""
|
"""Generate reflection for a component."""
|
||||||
messages = [
|
messages = [
|
||||||
|
|
@ -67,8 +68,7 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
result = self.quick_thinking_llm.invoke(messages).content
|
return self.quick_thinking_llm.invoke(messages).content
|
||||||
return result
|
|
||||||
|
|
||||||
def reflect_bull_researcher(self, current_state, returns_losses, bull_memory):
|
def reflect_bull_researcher(self, current_state, returns_losses, bull_memory):
|
||||||
"""Reflect on bull researcher's analysis and update memory."""
|
"""Reflect on bull researcher's analysis and update memory."""
|
||||||
|
|
@ -76,7 +76,7 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur
|
||||||
bull_debate_history = current_state["investment_debate_state"]["bull_history"]
|
bull_debate_history = current_state["investment_debate_state"]["bull_history"]
|
||||||
|
|
||||||
result = self._reflect_on_component(
|
result = self._reflect_on_component(
|
||||||
"BULL", bull_debate_history, situation, returns_losses
|
"BULL", bull_debate_history, situation, returns_losses,
|
||||||
)
|
)
|
||||||
bull_memory.add_situations([(situation, result)])
|
bull_memory.add_situations([(situation, result)])
|
||||||
|
|
||||||
|
|
@ -86,7 +86,7 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur
|
||||||
bear_debate_history = current_state["investment_debate_state"]["bear_history"]
|
bear_debate_history = current_state["investment_debate_state"]["bear_history"]
|
||||||
|
|
||||||
result = self._reflect_on_component(
|
result = self._reflect_on_component(
|
||||||
"BEAR", bear_debate_history, situation, returns_losses
|
"BEAR", bear_debate_history, situation, returns_losses,
|
||||||
)
|
)
|
||||||
bear_memory.add_situations([(situation, result)])
|
bear_memory.add_situations([(situation, result)])
|
||||||
|
|
||||||
|
|
@ -96,7 +96,7 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur
|
||||||
trader_decision = current_state["trader_investment_plan"]
|
trader_decision = current_state["trader_investment_plan"]
|
||||||
|
|
||||||
result = self._reflect_on_component(
|
result = self._reflect_on_component(
|
||||||
"TRADER", trader_decision, situation, returns_losses
|
"TRADER", trader_decision, situation, returns_losses,
|
||||||
)
|
)
|
||||||
trader_memory.add_situations([(situation, result)])
|
trader_memory.add_situations([(situation, result)])
|
||||||
|
|
||||||
|
|
@ -106,7 +106,7 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur
|
||||||
judge_decision = current_state["investment_debate_state"]["judge_decision"]
|
judge_decision = current_state["investment_debate_state"]["judge_decision"]
|
||||||
|
|
||||||
result = self._reflect_on_component(
|
result = self._reflect_on_component(
|
||||||
"INVEST JUDGE", judge_decision, situation, returns_losses
|
"INVEST JUDGE", judge_decision, situation, returns_losses,
|
||||||
)
|
)
|
||||||
invest_judge_memory.add_situations([(situation, result)])
|
invest_judge_memory.add_situations([(situation, result)])
|
||||||
|
|
||||||
|
|
@ -116,6 +116,6 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur
|
||||||
judge_decision = current_state["risk_debate_state"]["judge_decision"]
|
judge_decision = current_state["risk_debate_state"]["judge_decision"]
|
||||||
|
|
||||||
result = self._reflect_on_component(
|
result = self._reflect_on_component(
|
||||||
"RISK JUDGE", judge_decision, situation, returns_losses
|
"RISK JUDGE", judge_decision, situation, returns_losses,
|
||||||
)
|
)
|
||||||
risk_manager_memory.add_situations([(situation, result)])
|
risk_manager_memory.add_situations([(situation, result)])
|
||||||
|
|
|
||||||
|
|
@ -1,26 +1,26 @@
|
||||||
# TradingAgents/graph/setup.py
|
# TradingAgents/graph/setup.py
|
||||||
|
|
||||||
from typing import Dict
|
|
||||||
from langchain_openai import ChatOpenAI
|
from langchain_openai import ChatOpenAI
|
||||||
from langgraph.graph import END, StateGraph, START
|
from langgraph.graph import END, START, StateGraph
|
||||||
from langgraph.prebuilt import ToolNode
|
from langgraph.prebuilt import ToolNode
|
||||||
|
|
||||||
from tradingagents.agents import (
|
from tradingagents.agents import (
|
||||||
create_market_analyst,
|
|
||||||
create_social_media_analyst,
|
|
||||||
create_news_analyst,
|
|
||||||
create_fundamentals_analyst,
|
|
||||||
create_bull_researcher,
|
|
||||||
create_bear_researcher,
|
|
||||||
create_research_manager,
|
|
||||||
create_trader,
|
|
||||||
create_risky_debator,
|
|
||||||
create_neutral_debator,
|
|
||||||
create_safe_debator,
|
|
||||||
create_risk_manager,
|
|
||||||
create_msg_delete,
|
|
||||||
AgentState,
|
AgentState,
|
||||||
Toolkit,
|
Toolkit,
|
||||||
|
create_bear_researcher,
|
||||||
|
create_bull_researcher,
|
||||||
|
create_fundamentals_analyst,
|
||||||
|
create_market_analyst,
|
||||||
|
create_msg_delete,
|
||||||
|
create_neutral_debator,
|
||||||
|
create_news_analyst,
|
||||||
|
create_research_manager,
|
||||||
|
create_risk_manager,
|
||||||
|
create_risky_debator,
|
||||||
|
create_safe_debator,
|
||||||
|
create_social_media_analyst,
|
||||||
|
create_trader,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .conditional_logic import ConditionalLogic
|
from .conditional_logic import ConditionalLogic
|
||||||
|
|
@ -34,7 +34,7 @@ class GraphSetup:
|
||||||
quick_thinking_llm: ChatOpenAI,
|
quick_thinking_llm: ChatOpenAI,
|
||||||
deep_thinking_llm: ChatOpenAI,
|
deep_thinking_llm: ChatOpenAI,
|
||||||
toolkit: Toolkit,
|
toolkit: Toolkit,
|
||||||
tool_nodes: Dict[str, ToolNode],
|
tool_nodes: dict[str, ToolNode],
|
||||||
bull_memory,
|
bull_memory,
|
||||||
bear_memory,
|
bear_memory,
|
||||||
trader_memory,
|
trader_memory,
|
||||||
|
|
@ -55,7 +55,7 @@ class GraphSetup:
|
||||||
self.conditional_logic = conditional_logic
|
self.conditional_logic = conditional_logic
|
||||||
|
|
||||||
def setup_graph(
|
def setup_graph(
|
||||||
self, selected_analysts=["market", "social", "news", "fundamentals"]
|
self, selected_analysts=None,
|
||||||
):
|
):
|
||||||
"""Set up and compile the agent workflow graph.
|
"""Set up and compile the agent workflow graph.
|
||||||
|
|
||||||
|
|
@ -66,8 +66,11 @@ class GraphSetup:
|
||||||
- "news": News analyst
|
- "news": News analyst
|
||||||
- "fundamentals": Fundamentals analyst
|
- "fundamentals": Fundamentals analyst
|
||||||
"""
|
"""
|
||||||
|
if selected_analysts is None:
|
||||||
|
selected_analysts = ["market", "social", "news", "fundamentals"]
|
||||||
if len(selected_analysts) == 0:
|
if len(selected_analysts) == 0:
|
||||||
raise ValueError("Trading Agents Graph Setup Error: no analysts selected!")
|
msg = "Trading Agents Graph Setup Error: no analysts selected!"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
# Create analyst nodes
|
# Create analyst nodes
|
||||||
analyst_nodes = {}
|
analyst_nodes = {}
|
||||||
|
|
@ -76,41 +79,41 @@ class GraphSetup:
|
||||||
|
|
||||||
if "market" in selected_analysts:
|
if "market" in selected_analysts:
|
||||||
analyst_nodes["market"] = create_market_analyst(
|
analyst_nodes["market"] = create_market_analyst(
|
||||||
self.quick_thinking_llm, self.toolkit
|
self.quick_thinking_llm, self.toolkit,
|
||||||
)
|
)
|
||||||
delete_nodes["market"] = create_msg_delete()
|
delete_nodes["market"] = create_msg_delete()
|
||||||
tool_nodes["market"] = self.tool_nodes["market"]
|
tool_nodes["market"] = self.tool_nodes["market"]
|
||||||
|
|
||||||
if "social" in selected_analysts:
|
if "social" in selected_analysts:
|
||||||
analyst_nodes["social"] = create_social_media_analyst(
|
analyst_nodes["social"] = create_social_media_analyst(
|
||||||
self.quick_thinking_llm, self.toolkit
|
self.quick_thinking_llm, self.toolkit,
|
||||||
)
|
)
|
||||||
delete_nodes["social"] = create_msg_delete()
|
delete_nodes["social"] = create_msg_delete()
|
||||||
tool_nodes["social"] = self.tool_nodes["social"]
|
tool_nodes["social"] = self.tool_nodes["social"]
|
||||||
|
|
||||||
if "news" in selected_analysts:
|
if "news" in selected_analysts:
|
||||||
analyst_nodes["news"] = create_news_analyst(
|
analyst_nodes["news"] = create_news_analyst(
|
||||||
self.quick_thinking_llm, self.toolkit
|
self.quick_thinking_llm, self.toolkit,
|
||||||
)
|
)
|
||||||
delete_nodes["news"] = create_msg_delete()
|
delete_nodes["news"] = create_msg_delete()
|
||||||
tool_nodes["news"] = self.tool_nodes["news"]
|
tool_nodes["news"] = self.tool_nodes["news"]
|
||||||
|
|
||||||
if "fundamentals" in selected_analysts:
|
if "fundamentals" in selected_analysts:
|
||||||
analyst_nodes["fundamentals"] = create_fundamentals_analyst(
|
analyst_nodes["fundamentals"] = create_fundamentals_analyst(
|
||||||
self.quick_thinking_llm, self.toolkit
|
self.quick_thinking_llm, self.toolkit,
|
||||||
)
|
)
|
||||||
delete_nodes["fundamentals"] = create_msg_delete()
|
delete_nodes["fundamentals"] = create_msg_delete()
|
||||||
tool_nodes["fundamentals"] = self.tool_nodes["fundamentals"]
|
tool_nodes["fundamentals"] = self.tool_nodes["fundamentals"]
|
||||||
|
|
||||||
# Create researcher and manager nodes
|
# Create researcher and manager nodes
|
||||||
bull_researcher_node = create_bull_researcher(
|
bull_researcher_node = create_bull_researcher(
|
||||||
self.quick_thinking_llm, self.bull_memory
|
self.quick_thinking_llm, self.bull_memory,
|
||||||
)
|
)
|
||||||
bear_researcher_node = create_bear_researcher(
|
bear_researcher_node = create_bear_researcher(
|
||||||
self.quick_thinking_llm, self.bear_memory
|
self.quick_thinking_llm, self.bear_memory,
|
||||||
)
|
)
|
||||||
research_manager_node = create_research_manager(
|
research_manager_node = create_research_manager(
|
||||||
self.deep_thinking_llm, self.invest_judge_memory
|
self.deep_thinking_llm, self.invest_judge_memory,
|
||||||
)
|
)
|
||||||
trader_node = create_trader(self.quick_thinking_llm, self.trader_memory)
|
trader_node = create_trader(self.quick_thinking_llm, self.trader_memory)
|
||||||
|
|
||||||
|
|
@ -119,7 +122,7 @@ class GraphSetup:
|
||||||
neutral_analyst = create_neutral_debator(self.quick_thinking_llm)
|
neutral_analyst = create_neutral_debator(self.quick_thinking_llm)
|
||||||
safe_analyst = create_safe_debator(self.quick_thinking_llm)
|
safe_analyst = create_safe_debator(self.quick_thinking_llm)
|
||||||
risk_manager_node = create_risk_manager(
|
risk_manager_node = create_risk_manager(
|
||||||
self.deep_thinking_llm, self.risk_manager_memory
|
self.deep_thinking_llm, self.risk_manager_memory,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create workflow
|
# Create workflow
|
||||||
|
|
@ -129,7 +132,7 @@ class GraphSetup:
|
||||||
for analyst_type, node in analyst_nodes.items():
|
for analyst_type, node in analyst_nodes.items():
|
||||||
workflow.add_node(f"{analyst_type.capitalize()} Analyst", node)
|
workflow.add_node(f"{analyst_type.capitalize()} Analyst", node)
|
||||||
workflow.add_node(
|
workflow.add_node(
|
||||||
f"Msg Clear {analyst_type.capitalize()}", delete_nodes[analyst_type]
|
f"Msg Clear {analyst_type.capitalize()}", delete_nodes[analyst_type],
|
||||||
)
|
)
|
||||||
workflow.add_node(f"tools_{analyst_type}", tool_nodes[analyst_type])
|
workflow.add_node(f"tools_{analyst_type}", tool_nodes[analyst_type])
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,25 +1,24 @@
|
||||||
# TradingAgents/graph/trading_graph.py
|
# TradingAgents/graph/trading_graph.py
|
||||||
|
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import json
|
from typing import Any
|
||||||
from typing import Dict, Any
|
|
||||||
|
|
||||||
from langchain_openai import ChatOpenAI
|
|
||||||
from langchain_anthropic import ChatAnthropic
|
from langchain_anthropic import ChatAnthropic
|
||||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
from langgraph.prebuilt import ToolNode
|
from langgraph.prebuilt import ToolNode
|
||||||
|
|
||||||
from tradingagents.agents import Toolkit
|
from tradingagents.agents import Toolkit
|
||||||
from tradingagents.default_config import DEFAULT_CONFIG
|
|
||||||
from tradingagents.agents.utils.memory import FinancialSituationMemory
|
from tradingagents.agents.utils.memory import FinancialSituationMemory
|
||||||
from tradingagents.dataflows.interface import set_config
|
from tradingagents.dataflows.interface import set_config
|
||||||
|
from tradingagents.default_config import DEFAULT_CONFIG
|
||||||
|
|
||||||
from .conditional_logic import ConditionalLogic
|
from .conditional_logic import ConditionalLogic
|
||||||
from .setup import GraphSetup
|
|
||||||
from .propagation import Propagator
|
from .propagation import Propagator
|
||||||
from .reflection import Reflector
|
from .reflection import Reflector
|
||||||
|
from .setup import GraphSetup
|
||||||
from .signal_processing import SignalProcessor
|
from .signal_processing import SignalProcessor
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -28,9 +27,9 @@ class TradingAgentsGraph:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
selected_analysts=["market", "social", "news", "fundamentals"],
|
selected_analysts=None,
|
||||||
debug=False,
|
debug=False,
|
||||||
config: Dict[str, Any] = None,
|
config: dict[str, Any] | None = None,
|
||||||
):
|
):
|
||||||
"""Initialize the trading agents graph and components.
|
"""Initialize the trading agents graph and components.
|
||||||
|
|
||||||
|
|
@ -39,6 +38,8 @@ class TradingAgentsGraph:
|
||||||
debug: Whether to run in debug mode
|
debug: Whether to run in debug mode
|
||||||
config: Configuration dictionary. If None, uses default config
|
config: Configuration dictionary. If None, uses default config
|
||||||
"""
|
"""
|
||||||
|
if selected_analysts is None:
|
||||||
|
selected_analysts = ["market", "social", "news", "fundamentals"]
|
||||||
self.debug = debug
|
self.debug = debug
|
||||||
self.config = config or DEFAULT_CONFIG
|
self.config = config or DEFAULT_CONFIG
|
||||||
|
|
||||||
|
|
@ -58,7 +59,7 @@ class TradingAgentsGraph:
|
||||||
or self.config["llm_provider"] == "openrouter"
|
or self.config["llm_provider"] == "openrouter"
|
||||||
):
|
):
|
||||||
self.deep_thinking_llm = ChatOpenAI(
|
self.deep_thinking_llm = ChatOpenAI(
|
||||||
model=self.config["deep_think_llm"], base_url=self.config["backend_url"]
|
model=self.config["deep_think_llm"], base_url=self.config["backend_url"],
|
||||||
)
|
)
|
||||||
self.quick_thinking_llm = ChatOpenAI(
|
self.quick_thinking_llm = ChatOpenAI(
|
||||||
model=self.config["quick_think_llm"],
|
model=self.config["quick_think_llm"],
|
||||||
|
|
@ -66,7 +67,7 @@ class TradingAgentsGraph:
|
||||||
)
|
)
|
||||||
elif self.config["llm_provider"].lower() == "anthropic":
|
elif self.config["llm_provider"].lower() == "anthropic":
|
||||||
self.deep_thinking_llm = ChatAnthropic(
|
self.deep_thinking_llm = ChatAnthropic(
|
||||||
model=self.config["deep_think_llm"], base_url=self.config["backend_url"]
|
model=self.config["deep_think_llm"], base_url=self.config["backend_url"],
|
||||||
)
|
)
|
||||||
self.quick_thinking_llm = ChatAnthropic(
|
self.quick_thinking_llm = ChatAnthropic(
|
||||||
model=self.config["quick_think_llm"],
|
model=self.config["quick_think_llm"],
|
||||||
|
|
@ -74,13 +75,14 @@ class TradingAgentsGraph:
|
||||||
)
|
)
|
||||||
elif self.config["llm_provider"].lower() == "google":
|
elif self.config["llm_provider"].lower() == "google":
|
||||||
self.deep_thinking_llm = ChatGoogleGenerativeAI(
|
self.deep_thinking_llm = ChatGoogleGenerativeAI(
|
||||||
model=self.config["deep_think_llm"]
|
model=self.config["deep_think_llm"],
|
||||||
)
|
)
|
||||||
self.quick_thinking_llm = ChatGoogleGenerativeAI(
|
self.quick_thinking_llm = ChatGoogleGenerativeAI(
|
||||||
model=self.config["quick_think_llm"]
|
model=self.config["quick_think_llm"],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported LLM provider: {self.config['llm_provider']}")
|
msg = f"Unsupported LLM provider: {self.config['llm_provider']}"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
self.toolkit = Toolkit(config=self.config)
|
self.toolkit = Toolkit(config=self.config)
|
||||||
|
|
||||||
|
|
@ -89,10 +91,10 @@ class TradingAgentsGraph:
|
||||||
self.bear_memory = FinancialSituationMemory("bear_memory", self.config)
|
self.bear_memory = FinancialSituationMemory("bear_memory", self.config)
|
||||||
self.trader_memory = FinancialSituationMemory("trader_memory", self.config)
|
self.trader_memory = FinancialSituationMemory("trader_memory", self.config)
|
||||||
self.invest_judge_memory = FinancialSituationMemory(
|
self.invest_judge_memory = FinancialSituationMemory(
|
||||||
"invest_judge_memory", self.config
|
"invest_judge_memory", self.config,
|
||||||
)
|
)
|
||||||
self.risk_manager_memory = FinancialSituationMemory(
|
self.risk_manager_memory = FinancialSituationMemory(
|
||||||
"risk_manager_memory", self.config
|
"risk_manager_memory", self.config,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create tool nodes
|
# Create tool nodes
|
||||||
|
|
@ -125,7 +127,7 @@ class TradingAgentsGraph:
|
||||||
# Set up the graph
|
# Set up the graph
|
||||||
self.graph = self.graph_setup.setup_graph(selected_analysts)
|
self.graph = self.graph_setup.setup_graph(selected_analysts)
|
||||||
|
|
||||||
def _create_tool_nodes(self) -> Dict[str, ToolNode]:
|
def _create_tool_nodes(self) -> dict[str, ToolNode]:
|
||||||
"""Create tool nodes for different data sources."""
|
"""Create tool nodes for different data sources."""
|
||||||
return {
|
return {
|
||||||
"market": ToolNode(
|
"market": ToolNode(
|
||||||
|
|
@ -136,7 +138,7 @@ class TradingAgentsGraph:
|
||||||
# offline tools
|
# offline tools
|
||||||
self.toolkit.get_YFin_data,
|
self.toolkit.get_YFin_data,
|
||||||
self.toolkit.get_stockstats_indicators_report,
|
self.toolkit.get_stockstats_indicators_report,
|
||||||
]
|
],
|
||||||
),
|
),
|
||||||
"social": ToolNode(
|
"social": ToolNode(
|
||||||
[
|
[
|
||||||
|
|
@ -144,7 +146,7 @@ class TradingAgentsGraph:
|
||||||
self.toolkit.get_stock_news_openai,
|
self.toolkit.get_stock_news_openai,
|
||||||
# offline tools
|
# offline tools
|
||||||
self.toolkit.get_reddit_stock_info,
|
self.toolkit.get_reddit_stock_info,
|
||||||
]
|
],
|
||||||
),
|
),
|
||||||
"news": ToolNode(
|
"news": ToolNode(
|
||||||
[
|
[
|
||||||
|
|
@ -154,7 +156,7 @@ class TradingAgentsGraph:
|
||||||
# offline tools
|
# offline tools
|
||||||
self.toolkit.get_finnhub_news,
|
self.toolkit.get_finnhub_news,
|
||||||
self.toolkit.get_reddit_news,
|
self.toolkit.get_reddit_news,
|
||||||
]
|
],
|
||||||
),
|
),
|
||||||
"fundamentals": ToolNode(
|
"fundamentals": ToolNode(
|
||||||
[
|
[
|
||||||
|
|
@ -166,7 +168,7 @@ class TradingAgentsGraph:
|
||||||
self.toolkit.get_simfin_balance_sheet,
|
self.toolkit.get_simfin_balance_sheet,
|
||||||
self.toolkit.get_simfin_cashflow,
|
self.toolkit.get_simfin_cashflow,
|
||||||
self.toolkit.get_simfin_income_stmt,
|
self.toolkit.get_simfin_income_stmt,
|
||||||
]
|
],
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -177,7 +179,7 @@ class TradingAgentsGraph:
|
||||||
|
|
||||||
# Initialize state
|
# Initialize state
|
||||||
init_agent_state = self.propagator.create_initial_state(
|
init_agent_state = self.propagator.create_initial_state(
|
||||||
company_name, trade_date
|
company_name, trade_date,
|
||||||
)
|
)
|
||||||
args = self.propagator.get_graph_args()
|
args = self.propagator.get_graph_args()
|
||||||
|
|
||||||
|
|
@ -250,19 +252,19 @@ class TradingAgentsGraph:
|
||||||
def reflect_and_remember(self, returns_losses):
|
def reflect_and_remember(self, returns_losses):
|
||||||
"""Reflect on decisions and update memory based on returns."""
|
"""Reflect on decisions and update memory based on returns."""
|
||||||
self.reflector.reflect_bull_researcher(
|
self.reflector.reflect_bull_researcher(
|
||||||
self.curr_state, returns_losses, self.bull_memory
|
self.curr_state, returns_losses, self.bull_memory,
|
||||||
)
|
)
|
||||||
self.reflector.reflect_bear_researcher(
|
self.reflector.reflect_bear_researcher(
|
||||||
self.curr_state, returns_losses, self.bear_memory
|
self.curr_state, returns_losses, self.bear_memory,
|
||||||
)
|
)
|
||||||
self.reflector.reflect_trader(
|
self.reflector.reflect_trader(
|
||||||
self.curr_state, returns_losses, self.trader_memory
|
self.curr_state, returns_losses, self.trader_memory,
|
||||||
)
|
)
|
||||||
self.reflector.reflect_invest_judge(
|
self.reflector.reflect_invest_judge(
|
||||||
self.curr_state, returns_losses, self.invest_judge_memory
|
self.curr_state, returns_losses, self.invest_judge_memory,
|
||||||
)
|
)
|
||||||
self.reflector.reflect_risk_manager(
|
self.reflector.reflect_risk_manager(
|
||||||
self.curr_state, returns_losses, self.risk_manager_memory
|
self.curr_state, returns_losses, self.risk_manager_memory,
|
||||||
)
|
)
|
||||||
|
|
||||||
def process_signal(self, full_signal):
|
def process_signal(self, full_signal):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue