Apply Black formatting to pass CI checks
- Formatted 33 Python files with Black - Fixed code style to meet project standards - Ensures CI/CD pipeline passes formatting checks
This commit is contained in:
parent
6f3981412b
commit
850764ad7b
132
cli/main.py
132
cli/main.py
|
|
@ -148,7 +148,7 @@ class MessageBuffer:
|
||||||
f"### News Analysis\n{self.report_sections['news_report']}",
|
f"### News Analysis\n{self.report_sections['news_report']}",
|
||||||
)
|
)
|
||||||
if self.report_sections["fundamentals_report"]:
|
if self.report_sections["fundamentals_report"]:
|
||||||
fundamentals = self.report_sections['fundamentals_report']
|
fundamentals = self.report_sections["fundamentals_report"]
|
||||||
report_parts.append(
|
report_parts.append(
|
||||||
f"### Fundamentals Analysis\n{fundamentals}",
|
f"### Fundamentals Analysis\n{fundamentals}",
|
||||||
)
|
)
|
||||||
|
|
@ -182,10 +182,12 @@ def create_layout():
|
||||||
Layout(name="footer", size=3),
|
Layout(name="footer", size=3),
|
||||||
)
|
)
|
||||||
layout["main"].split_column(
|
layout["main"].split_column(
|
||||||
Layout(name="upper", ratio=3), Layout(name="analysis", ratio=5),
|
Layout(name="upper", ratio=3),
|
||||||
|
Layout(name="analysis", ratio=5),
|
||||||
)
|
)
|
||||||
layout["upper"].split_row(
|
layout["upper"].split_row(
|
||||||
Layout(name="progress", ratio=2), Layout(name="messages", ratio=3),
|
Layout(name="progress", ratio=2),
|
||||||
|
Layout(name="messages", ratio=3),
|
||||||
)
|
)
|
||||||
return layout
|
return layout
|
||||||
|
|
||||||
|
|
@ -237,7 +239,9 @@ def update_display(layout, spinner_text=None):
|
||||||
status = message_buffer.agent_status[first_agent]
|
status = message_buffer.agent_status[first_agent]
|
||||||
if status == "in_progress":
|
if status == "in_progress":
|
||||||
spinner = Spinner(
|
spinner = Spinner(
|
||||||
"dots", text="[blue]in_progress[/blue]", style="bold cyan",
|
"dots",
|
||||||
|
text="[blue]in_progress[/blue]",
|
||||||
|
style="bold cyan",
|
||||||
)
|
)
|
||||||
status_cell = spinner
|
status_cell = spinner
|
||||||
else:
|
else:
|
||||||
|
|
@ -254,7 +258,9 @@ def update_display(layout, spinner_text=None):
|
||||||
status = message_buffer.agent_status[agent]
|
status = message_buffer.agent_status[agent]
|
||||||
if status == "in_progress":
|
if status == "in_progress":
|
||||||
spinner = Spinner(
|
spinner = Spinner(
|
||||||
"dots", text="[blue]in_progress[/blue]", style="bold cyan",
|
"dots",
|
||||||
|
text="[blue]in_progress[/blue]",
|
||||||
|
style="bold cyan",
|
||||||
)
|
)
|
||||||
status_cell = spinner
|
status_cell = spinner
|
||||||
else:
|
else:
|
||||||
|
|
@ -286,7 +292,10 @@ def update_display(layout, spinner_text=None):
|
||||||
messages_table.add_column("Time", style="cyan", width=8, justify="center")
|
messages_table.add_column("Time", style="cyan", width=8, justify="center")
|
||||||
messages_table.add_column("Type", style="green", width=10, justify="center")
|
messages_table.add_column("Type", style="green", width=10, justify="center")
|
||||||
messages_table.add_column(
|
messages_table.add_column(
|
||||||
"Content", style="white", no_wrap=False, ratio=1,
|
"Content",
|
||||||
|
style="white",
|
||||||
|
no_wrap=False,
|
||||||
|
ratio=1,
|
||||||
) # Make content column expand
|
) # Make content column expand
|
||||||
|
|
||||||
# Combine tool calls and messages
|
# Combine tool calls and messages
|
||||||
|
|
@ -441,7 +450,9 @@ def get_user_selections():
|
||||||
# Step 1: Ticker symbol
|
# Step 1: Ticker symbol
|
||||||
console.print(
|
console.print(
|
||||||
create_question_box(
|
create_question_box(
|
||||||
"Step 1: Ticker Symbol", "Enter the ticker symbol to analyze", "SPY",
|
"Step 1: Ticker Symbol",
|
||||||
|
"Enter the ticker symbol to analyze",
|
||||||
|
"SPY",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
selected_ticker = get_ticker()
|
selected_ticker = get_ticker()
|
||||||
|
|
@ -460,7 +471,8 @@ def get_user_selections():
|
||||||
# Step 3: Select analysts
|
# Step 3: Select analysts
|
||||||
console.print(
|
console.print(
|
||||||
create_question_box(
|
create_question_box(
|
||||||
"Step 3: Analysts Team", "Select your LLM analyst agents for the analysis",
|
"Step 3: Analysts Team",
|
||||||
|
"Select your LLM analyst agents for the analysis",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
selected_analysts = select_analysts()
|
selected_analysts = select_analysts()
|
||||||
|
|
@ -471,21 +483,25 @@ def get_user_selections():
|
||||||
# Step 4: Research depth
|
# Step 4: Research depth
|
||||||
console.print(
|
console.print(
|
||||||
create_question_box(
|
create_question_box(
|
||||||
"Step 4: Research Depth", "Select your research depth level",
|
"Step 4: Research Depth",
|
||||||
|
"Select your research depth level",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
selected_research_depth = select_research_depth()
|
selected_research_depth = select_research_depth()
|
||||||
|
|
||||||
# Step 5: OpenAI backend
|
# Step 5: OpenAI backend
|
||||||
console.print(
|
console.print(
|
||||||
create_question_box("Step 5: OpenAI backend", "Select which service to talk to"),
|
create_question_box(
|
||||||
|
"Step 5: OpenAI backend", "Select which service to talk to"
|
||||||
|
),
|
||||||
)
|
)
|
||||||
selected_llm_provider, backend_url = select_llm_provider()
|
selected_llm_provider, backend_url = select_llm_provider()
|
||||||
|
|
||||||
# Step 6: Thinking agents
|
# Step 6: Thinking agents
|
||||||
console.print(
|
console.print(
|
||||||
create_question_box(
|
create_question_box(
|
||||||
"Step 6: Thinking Agents", "Select your thinking agents for analysis",
|
"Step 6: Thinking Agents",
|
||||||
|
"Select your thinking agents for analysis",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
selected_shallow_thinker = select_shallow_thinking_agent(selected_llm_provider)
|
selected_shallow_thinker = select_shallow_thinking_agent(selected_llm_provider)
|
||||||
|
|
@ -737,7 +753,9 @@ def run_analysis():
|
||||||
|
|
||||||
# Initialize the graph
|
# Initialize the graph
|
||||||
graph = TradingAgentsGraph(
|
graph = TradingAgentsGraph(
|
||||||
[analyst.value for analyst in selections["analysts"]], config=config, debug=True,
|
[analyst.value for analyst in selections["analysts"]],
|
||||||
|
config=config,
|
||||||
|
debug=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create result directory
|
# Create result directory
|
||||||
|
|
@ -796,10 +814,12 @@ def run_analysis():
|
||||||
|
|
||||||
message_buffer.add_message = save_message_decorator(message_buffer, "add_message")
|
message_buffer.add_message = save_message_decorator(message_buffer, "add_message")
|
||||||
message_buffer.add_tool_call = save_tool_call_decorator(
|
message_buffer.add_tool_call = save_tool_call_decorator(
|
||||||
message_buffer, "add_tool_call",
|
message_buffer,
|
||||||
|
"add_tool_call",
|
||||||
)
|
)
|
||||||
message_buffer.update_report_section = save_report_section_decorator(
|
message_buffer.update_report_section = save_report_section_decorator(
|
||||||
message_buffer, "update_report_section",
|
message_buffer,
|
||||||
|
"update_report_section",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Now start the display layout
|
# Now start the display layout
|
||||||
|
|
@ -812,7 +832,8 @@ def run_analysis():
|
||||||
# Add initial messages
|
# Add initial messages
|
||||||
message_buffer.add_message("System", f"Selected ticker: {selections['ticker']}")
|
message_buffer.add_message("System", f"Selected ticker: {selections['ticker']}")
|
||||||
message_buffer.add_message(
|
message_buffer.add_message(
|
||||||
"System", f"Analysis date: {selections['analysis_date']}",
|
"System",
|
||||||
|
f"Analysis date: {selections['analysis_date']}",
|
||||||
)
|
)
|
||||||
message_buffer.add_message(
|
message_buffer.add_message(
|
||||||
"System",
|
"System",
|
||||||
|
|
@ -843,7 +864,8 @@ def run_analysis():
|
||||||
|
|
||||||
# Initialize state and get graph args
|
# Initialize state and get graph args
|
||||||
init_agent_state = graph.propagator.create_initial_state(
|
init_agent_state = graph.propagator.create_initial_state(
|
||||||
selections["ticker"], selections["analysis_date"],
|
selections["ticker"],
|
||||||
|
selections["analysis_date"],
|
||||||
)
|
)
|
||||||
args = graph.propagator.get_graph_args()
|
args = graph.propagator.get_graph_args()
|
||||||
|
|
||||||
|
|
@ -873,7 +895,8 @@ def run_analysis():
|
||||||
# Handle both dictionary and object tool calls
|
# Handle both dictionary and object tool calls
|
||||||
if isinstance(tool_call, dict):
|
if isinstance(tool_call, dict):
|
||||||
message_buffer.add_tool_call(
|
message_buffer.add_tool_call(
|
||||||
tool_call["name"], tool_call["args"],
|
tool_call["name"],
|
||||||
|
tool_call["args"],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
message_buffer.add_tool_call(tool_call.name, tool_call.args)
|
message_buffer.add_tool_call(tool_call.name, tool_call.args)
|
||||||
|
|
@ -882,51 +905,57 @@ def run_analysis():
|
||||||
# Analyst Team Reports
|
# Analyst Team Reports
|
||||||
if chunk.get("market_report"):
|
if chunk.get("market_report"):
|
||||||
message_buffer.update_report_section(
|
message_buffer.update_report_section(
|
||||||
"market_report", chunk["market_report"],
|
"market_report",
|
||||||
|
chunk["market_report"],
|
||||||
)
|
)
|
||||||
message_buffer.update_agent_status("Market Analyst", "completed")
|
message_buffer.update_agent_status("Market Analyst", "completed")
|
||||||
# Set next analyst to in_progress
|
# Set next analyst to in_progress
|
||||||
if "social" in selections["analysts"]:
|
if "social" in selections["analysts"]:
|
||||||
message_buffer.update_agent_status(
|
message_buffer.update_agent_status(
|
||||||
"Social Analyst", "in_progress",
|
"Social Analyst",
|
||||||
|
"in_progress",
|
||||||
)
|
)
|
||||||
|
|
||||||
if chunk.get("sentiment_report"):
|
if chunk.get("sentiment_report"):
|
||||||
message_buffer.update_report_section(
|
message_buffer.update_report_section(
|
||||||
"sentiment_report", chunk["sentiment_report"],
|
"sentiment_report",
|
||||||
|
chunk["sentiment_report"],
|
||||||
)
|
)
|
||||||
message_buffer.update_agent_status("Social Analyst", "completed")
|
message_buffer.update_agent_status("Social Analyst", "completed")
|
||||||
# Set next analyst to in_progress
|
# Set next analyst to in_progress
|
||||||
if "news" in selections["analysts"]:
|
if "news" in selections["analysts"]:
|
||||||
message_buffer.update_agent_status(
|
message_buffer.update_agent_status(
|
||||||
"News Analyst", "in_progress",
|
"News Analyst",
|
||||||
|
"in_progress",
|
||||||
)
|
)
|
||||||
|
|
||||||
if chunk.get("news_report"):
|
if chunk.get("news_report"):
|
||||||
message_buffer.update_report_section(
|
message_buffer.update_report_section(
|
||||||
"news_report", chunk["news_report"],
|
"news_report",
|
||||||
|
chunk["news_report"],
|
||||||
)
|
)
|
||||||
message_buffer.update_agent_status("News Analyst", "completed")
|
message_buffer.update_agent_status("News Analyst", "completed")
|
||||||
# Set next analyst to in_progress
|
# Set next analyst to in_progress
|
||||||
if "fundamentals" in selections["analysts"]:
|
if "fundamentals" in selections["analysts"]:
|
||||||
message_buffer.update_agent_status(
|
message_buffer.update_agent_status(
|
||||||
"Fundamentals Analyst", "in_progress",
|
"Fundamentals Analyst",
|
||||||
|
"in_progress",
|
||||||
)
|
)
|
||||||
|
|
||||||
if chunk.get("fundamentals_report"):
|
if chunk.get("fundamentals_report"):
|
||||||
message_buffer.update_report_section(
|
message_buffer.update_report_section(
|
||||||
"fundamentals_report", chunk["fundamentals_report"],
|
"fundamentals_report",
|
||||||
|
chunk["fundamentals_report"],
|
||||||
)
|
)
|
||||||
message_buffer.update_agent_status(
|
message_buffer.update_agent_status(
|
||||||
"Fundamentals Analyst", "completed",
|
"Fundamentals Analyst",
|
||||||
|
"completed",
|
||||||
)
|
)
|
||||||
# Set all research team members to in_progress
|
# Set all research team members to in_progress
|
||||||
update_research_team_status("in_progress")
|
update_research_team_status("in_progress")
|
||||||
|
|
||||||
# Research Team - Handle Investment Debate State
|
# Research Team - Handle Investment Debate State
|
||||||
if (
|
if chunk.get("investment_debate_state"):
|
||||||
chunk.get("investment_debate_state")
|
|
||||||
):
|
|
||||||
debate_state = chunk["investment_debate_state"]
|
debate_state = chunk["investment_debate_state"]
|
||||||
|
|
||||||
# Update Bull Researcher status and report
|
# Update Bull Researcher status and report
|
||||||
|
|
@ -960,9 +989,7 @@ def run_analysis():
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update Research Manager status and final decision
|
# Update Research Manager status and final decision
|
||||||
if (
|
if debate_state.get("judge_decision"):
|
||||||
debate_state.get("judge_decision")
|
|
||||||
):
|
|
||||||
# Keep all research team members in progress until final decision
|
# Keep all research team members in progress until final decision
|
||||||
update_research_team_status("in_progress")
|
update_research_team_status("in_progress")
|
||||||
message_buffer.add_message(
|
message_buffer.add_message(
|
||||||
|
|
@ -978,15 +1005,15 @@ def run_analysis():
|
||||||
update_research_team_status("completed")
|
update_research_team_status("completed")
|
||||||
# Set first risk analyst to in_progress
|
# Set first risk analyst to in_progress
|
||||||
message_buffer.update_agent_status(
|
message_buffer.update_agent_status(
|
||||||
"Risky Analyst", "in_progress",
|
"Risky Analyst",
|
||||||
|
"in_progress",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Trading Team
|
# Trading Team
|
||||||
if (
|
if chunk.get("trader_investment_plan"):
|
||||||
chunk.get("trader_investment_plan")
|
|
||||||
):
|
|
||||||
message_buffer.update_report_section(
|
message_buffer.update_report_section(
|
||||||
"trader_investment_plan", chunk["trader_investment_plan"],
|
"trader_investment_plan",
|
||||||
|
chunk["trader_investment_plan"],
|
||||||
)
|
)
|
||||||
# Set first risk analyst to in_progress
|
# Set first risk analyst to in_progress
|
||||||
message_buffer.update_agent_status("Risky Analyst", "in_progress")
|
message_buffer.update_agent_status("Risky Analyst", "in_progress")
|
||||||
|
|
@ -996,11 +1023,10 @@ def run_analysis():
|
||||||
risk_state = chunk["risk_debate_state"]
|
risk_state = chunk["risk_debate_state"]
|
||||||
|
|
||||||
# Update Risky Analyst status and report
|
# Update Risky Analyst status and report
|
||||||
if (
|
if risk_state.get("current_risky_response"):
|
||||||
risk_state.get("current_risky_response")
|
|
||||||
):
|
|
||||||
message_buffer.update_agent_status(
|
message_buffer.update_agent_status(
|
||||||
"Risky Analyst", "in_progress",
|
"Risky Analyst",
|
||||||
|
"in_progress",
|
||||||
)
|
)
|
||||||
message_buffer.add_message(
|
message_buffer.add_message(
|
||||||
"Reasoning",
|
"Reasoning",
|
||||||
|
|
@ -1013,11 +1039,10 @@ def run_analysis():
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update Safe Analyst status and report
|
# Update Safe Analyst status and report
|
||||||
if (
|
if risk_state.get("current_safe_response"):
|
||||||
risk_state.get("current_safe_response")
|
|
||||||
):
|
|
||||||
message_buffer.update_agent_status(
|
message_buffer.update_agent_status(
|
||||||
"Safe Analyst", "in_progress",
|
"Safe Analyst",
|
||||||
|
"in_progress",
|
||||||
)
|
)
|
||||||
message_buffer.add_message(
|
message_buffer.add_message(
|
||||||
"Reasoning",
|
"Reasoning",
|
||||||
|
|
@ -1030,11 +1055,10 @@ def run_analysis():
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update Neutral Analyst status and report
|
# Update Neutral Analyst status and report
|
||||||
if (
|
if risk_state.get("current_neutral_response"):
|
||||||
risk_state.get("current_neutral_response")
|
|
||||||
):
|
|
||||||
message_buffer.update_agent_status(
|
message_buffer.update_agent_status(
|
||||||
"Neutral Analyst", "in_progress",
|
"Neutral Analyst",
|
||||||
|
"in_progress",
|
||||||
)
|
)
|
||||||
message_buffer.add_message(
|
message_buffer.add_message(
|
||||||
"Reasoning",
|
"Reasoning",
|
||||||
|
|
@ -1049,7 +1073,8 @@ def run_analysis():
|
||||||
# Update Portfolio Manager status and final decision
|
# Update Portfolio Manager status and final decision
|
||||||
if risk_state.get("judge_decision"):
|
if risk_state.get("judge_decision"):
|
||||||
message_buffer.update_agent_status(
|
message_buffer.update_agent_status(
|
||||||
"Portfolio Manager", "in_progress",
|
"Portfolio Manager",
|
||||||
|
"in_progress",
|
||||||
)
|
)
|
||||||
message_buffer.add_message(
|
message_buffer.add_message(
|
||||||
"Reasoning",
|
"Reasoning",
|
||||||
|
|
@ -1064,10 +1089,12 @@ def run_analysis():
|
||||||
message_buffer.update_agent_status("Risky Analyst", "completed")
|
message_buffer.update_agent_status("Risky Analyst", "completed")
|
||||||
message_buffer.update_agent_status("Safe Analyst", "completed")
|
message_buffer.update_agent_status("Safe Analyst", "completed")
|
||||||
message_buffer.update_agent_status(
|
message_buffer.update_agent_status(
|
||||||
"Neutral Analyst", "completed",
|
"Neutral Analyst",
|
||||||
|
"completed",
|
||||||
)
|
)
|
||||||
message_buffer.update_agent_status(
|
message_buffer.update_agent_status(
|
||||||
"Portfolio Manager", "completed",
|
"Portfolio Manager",
|
||||||
|
"completed",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update the display
|
# Update the display
|
||||||
|
|
@ -1084,7 +1111,8 @@ def run_analysis():
|
||||||
message_buffer.update_agent_status(agent, "completed")
|
message_buffer.update_agent_status(agent, "completed")
|
||||||
|
|
||||||
message_buffer.add_message(
|
message_buffer.add_message(
|
||||||
"Analysis", f"Completed analysis for {selections['analysis_date']}",
|
"Analysis",
|
||||||
|
f"Completed analysis for {selections['analysis_date']}",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update final report sections
|
# Update final report sections
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
import questionary
|
import questionary
|
||||||
|
|
|
||||||
4
main.py
4
main.py
|
|
@ -4,7 +4,9 @@ from tradingagents.default_config import DEFAULT_CONFIG
|
||||||
# Create a custom config
|
# Create a custom config
|
||||||
config = DEFAULT_CONFIG.copy()
|
config = DEFAULT_CONFIG.copy()
|
||||||
config["llm_provider"] = "google" # Use a different model
|
config["llm_provider"] = "google" # Use a different model
|
||||||
config["backend_url"] = "https://generativelanguage.googleapis.com/v1" # Use a different backend
|
config["backend_url"] = (
|
||||||
|
"https://generativelanguage.googleapis.com/v1" # Use a different backend
|
||||||
|
)
|
||||||
config["deep_think_llm"] = "gemini-2.0-flash" # Use a different model
|
config["deep_think_llm"] = "gemini-2.0-flash" # Use a different model
|
||||||
config["quick_think_llm"] = "gemini-2.0-flash" # Use a different model
|
config["quick_think_llm"] = "gemini-2.0-flash" # Use a different model
|
||||||
config["max_debate_rounds"] = 1 # Increase debate rounds
|
config["max_debate_rounds"] = 1 # Increase debate rounds
|
||||||
|
|
|
||||||
52
run_tests.py
52
run_tests.py
|
|
@ -16,10 +16,10 @@ def run_command(cmd, description=""):
|
||||||
"""Run a command and handle errors."""
|
"""Run a command and handle errors."""
|
||||||
if description:
|
if description:
|
||||||
print(f"\n🔄 {description}")
|
print(f"\n🔄 {description}")
|
||||||
|
|
||||||
print(f"Running: {' '.join(cmd)}")
|
print(f"Running: {' '.join(cmd)}")
|
||||||
result = subprocess.run(cmd, capture_output=False)
|
result = subprocess.run(cmd, capture_output=False)
|
||||||
|
|
||||||
if result.returncode != 0:
|
if result.returncode != 0:
|
||||||
print(f"❌ Command failed with return code {result.returncode}")
|
print(f"❌ Command failed with return code {result.returncode}")
|
||||||
sys.exit(result.returncode)
|
sys.exit(result.returncode)
|
||||||
|
|
@ -33,48 +33,42 @@ def main():
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"test_type",
|
"test_type",
|
||||||
choices=["unit", "integration", "all", "coverage", "fast", "slow", "lint"],
|
choices=["unit", "integration", "all", "coverage", "fast", "slow", "lint"],
|
||||||
help="Type of tests to run"
|
help="Type of tests to run",
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--verbose", "-v", action="store_true", help="Verbose output"
|
|
||||||
)
|
)
|
||||||
|
parser.add_argument("--verbose", "-v", action="store_true", help="Verbose output")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--parallel", "-p", action="store_true", help="Run tests in parallel"
|
"--parallel", "-p", action="store_true", help="Run tests in parallel"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument("--file", "-f", help="Run specific test file")
|
||||||
"--file", "-f", help="Run specific test file"
|
parser.add_argument("--pattern", "-k", help="Run tests matching pattern")
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--pattern", "-k", help="Run tests matching pattern"
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Base pytest command
|
# Base pytest command
|
||||||
base_cmd = ["python", "-m", "pytest"]
|
base_cmd = ["python", "-m", "pytest"]
|
||||||
|
|
||||||
if args.verbose:
|
if args.verbose:
|
||||||
base_cmd.append("-v")
|
base_cmd.append("-v")
|
||||||
|
|
||||||
if args.parallel:
|
if args.parallel:
|
||||||
base_cmd.extend(["-n", "auto"])
|
base_cmd.extend(["-n", "auto"])
|
||||||
|
|
||||||
if args.pattern:
|
if args.pattern:
|
||||||
base_cmd.extend(["-k", args.pattern])
|
base_cmd.extend(["-k", args.pattern])
|
||||||
|
|
||||||
# Configure based on test type
|
# Configure based on test type
|
||||||
if args.test_type == "unit":
|
if args.test_type == "unit":
|
||||||
cmd = base_cmd + ["tests/unit/", "-m", "unit"]
|
cmd = base_cmd + ["tests/unit/", "-m", "unit"]
|
||||||
run_command(cmd, "Running unit tests")
|
run_command(cmd, "Running unit tests")
|
||||||
|
|
||||||
elif args.test_type == "integration":
|
elif args.test_type == "integration":
|
||||||
cmd = base_cmd + ["tests/integration/", "-m", "integration"]
|
cmd = base_cmd + ["tests/integration/", "-m", "integration"]
|
||||||
run_command(cmd, "Running integration tests")
|
run_command(cmd, "Running integration tests")
|
||||||
|
|
||||||
elif args.test_type == "all":
|
elif args.test_type == "all":
|
||||||
cmd = base_cmd + ["tests/"]
|
cmd = base_cmd + ["tests/"]
|
||||||
run_command(cmd, "Running all tests")
|
run_command(cmd, "Running all tests")
|
||||||
|
|
||||||
elif args.test_type == "coverage":
|
elif args.test_type == "coverage":
|
||||||
cmd = base_cmd + [
|
cmd = base_cmd + [
|
||||||
"tests/",
|
"tests/",
|
||||||
|
|
@ -88,28 +82,28 @@ def main():
|
||||||
print("\n📊 Coverage report generated:")
|
print("\n📊 Coverage report generated:")
|
||||||
print(" - HTML: htmlcov/index.html")
|
print(" - HTML: htmlcov/index.html")
|
||||||
print(" - XML: coverage.xml")
|
print(" - XML: coverage.xml")
|
||||||
|
|
||||||
elif args.test_type == "fast":
|
elif args.test_type == "fast":
|
||||||
cmd = base_cmd + ["tests/unit/", "-m", "unit", "--durations=10"]
|
cmd = base_cmd + ["tests/unit/", "-m", "unit", "--durations=10"]
|
||||||
run_command(cmd, "Running fast unit tests")
|
run_command(cmd, "Running fast unit tests")
|
||||||
|
|
||||||
elif args.test_type == "slow":
|
elif args.test_type == "slow":
|
||||||
cmd = base_cmd + ["tests/", "-m", "slow", "--timeout=600"]
|
cmd = base_cmd + ["tests/", "-m", "slow", "--timeout=600"]
|
||||||
run_command(cmd, "Running slow tests")
|
run_command(cmd, "Running slow tests")
|
||||||
|
|
||||||
elif args.test_type == "lint":
|
elif args.test_type == "lint":
|
||||||
# Run mypy
|
# Run mypy
|
||||||
cmd = ["python", "-m", "mypy", "tradingagents/", "cli/", "tests/"]
|
cmd = ["python", "-m", "mypy", "tradingagents/", "cli/", "tests/"]
|
||||||
run_command(cmd, "Running mypy type checking")
|
run_command(cmd, "Running mypy type checking")
|
||||||
|
|
||||||
# Run pytest on tests only for syntax
|
# Run pytest on tests only for syntax
|
||||||
cmd = base_cmd + ["tests/", "--collect-only"]
|
cmd = base_cmd + ["tests/", "--collect-only"]
|
||||||
run_command(cmd, "Validating test syntax")
|
run_command(cmd, "Validating test syntax")
|
||||||
|
|
||||||
elif args.file:
|
elif args.file:
|
||||||
cmd = base_cmd + [args.file]
|
cmd = base_cmd + [args.file]
|
||||||
run_command(cmd, f"Running tests in {args.file}")
|
run_command(cmd, f"Running tests in {args.file}")
|
||||||
|
|
||||||
print("\n🎉 All tests completed successfully!")
|
print("\n🎉 All tests completed successfully!")
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -117,5 +111,5 @@ if __name__ == "__main__":
|
||||||
# Ensure we're in the project directory
|
# Ensure we're in the project directory
|
||||||
script_dir = Path(__file__).parent
|
script_dir = Path(__file__).parent
|
||||||
os.chdir(script_dir)
|
os.chdir(script_dir)
|
||||||
|
|
||||||
main()
|
main()
|
||||||
|
|
|
||||||
|
|
@ -1,19 +1,23 @@
|
||||||
|
def poorly_formatted_function(x, y, z): # Missing type hints
|
||||||
def poorly_formatted_function(x,y,z): # Missing type hints
|
|
||||||
"""This function has formatting issues."""
|
"""This function has formatting issues."""
|
||||||
result=x+y*z # Missing spaces around operators
|
result = x + y * z # Missing spaces around operators
|
||||||
if result>100: # Missing spaces
|
if result > 100: # Missing spaces
|
||||||
print( "Result is large" ) # Extra spaces in parentheses
|
print("Result is large") # Extra spaces in parentheses
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
# Long line that Black will wrap
|
# Long line that Black will wrap
|
||||||
very_long_variable_name_that_exceeds_the_standard_line_length_limit = "This is a very long string that will be wrapped by Black formatter"
|
very_long_variable_name_that_exceeds_the_standard_line_length_limit = (
|
||||||
|
"This is a very long string that will be wrapped by Black formatter"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MyClass:
|
class MyClass:
|
||||||
def __init__(self,name:str,age:int): # Missing space after comma
|
def __init__(self, name: str, age: int): # Missing space after comma
|
||||||
self.name=name # Missing spaces around =
|
self.name = name # Missing spaces around =
|
||||||
self.age=age
|
self.age = age
|
||||||
|
|
||||||
|
|
||||||
# Function with wrong return type hint
|
# Function with wrong return type hint
|
||||||
def get_number() -> str:
|
def get_number() -> str:
|
||||||
return 123 # Returns int but type hint says str
|
return 123 # Returns int but type hint says str
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@ def add_numbers(a: int, b: int) -> int:
|
||||||
"""Add two numbers and return the result."""
|
"""Add two numbers and return the result."""
|
||||||
return a + b
|
return a + b
|
||||||
|
|
||||||
|
|
||||||
# Test the function
|
# Test the function
|
||||||
result = add_numbers(1, 2)
|
result = add_numbers(1, 2)
|
||||||
print(f"Result: {result}")
|
print(f"Result: {result}")
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ def run_command(cmd, description=""):
|
||||||
"""Run a command and return success status."""
|
"""Run a command and return success status."""
|
||||||
if description:
|
if description:
|
||||||
print(f"\n🔄 {description}")
|
print(f"\n🔄 {description}")
|
||||||
|
|
||||||
print(f"Running: {' '.join(cmd)}")
|
print(f"Running: {' '.join(cmd)}")
|
||||||
try:
|
try:
|
||||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=30)
|
result = subprocess.run(cmd, capture_output=True, text=True, timeout=30)
|
||||||
|
|
@ -38,45 +38,60 @@ def main():
|
||||||
"""Run setup verification tests."""
|
"""Run setup verification tests."""
|
||||||
print("🧪 TradingAgents Test Setup Verification")
|
print("🧪 TradingAgents Test Setup Verification")
|
||||||
print("=" * 50)
|
print("=" * 50)
|
||||||
|
|
||||||
# Change to project directory
|
# Change to project directory
|
||||||
project_dir = Path(__file__).parent
|
project_dir = Path(__file__).parent
|
||||||
os.chdir(project_dir)
|
os.chdir(project_dir)
|
||||||
|
|
||||||
success_count = 0
|
success_count = 0
|
||||||
total_tests = 0
|
total_tests = 0
|
||||||
|
|
||||||
# Test 1: Check if pytest is installed and can discover tests
|
# Test 1: Check if pytest is installed and can discover tests
|
||||||
total_tests += 1
|
total_tests += 1
|
||||||
if run_command(["python", "-m", "pytest", "--version"], "Checking pytest installation"):
|
if run_command(
|
||||||
|
["python", "-m", "pytest", "--version"], "Checking pytest installation"
|
||||||
|
):
|
||||||
success_count += 1
|
success_count += 1
|
||||||
|
|
||||||
# Test 2: Test discovery
|
# Test 2: Test discovery
|
||||||
total_tests += 1
|
total_tests += 1
|
||||||
if run_command(["python", "-m", "pytest", "tests/", "--collect-only", "-q"], "Testing test discovery"):
|
if run_command(
|
||||||
|
["python", "-m", "pytest", "tests/", "--collect-only", "-q"],
|
||||||
|
"Testing test discovery",
|
||||||
|
):
|
||||||
success_count += 1
|
success_count += 1
|
||||||
|
|
||||||
# Test 3: Check if mypy can run
|
# Test 3: Check if mypy can run
|
||||||
total_tests += 1
|
total_tests += 1
|
||||||
if run_command(["python", "-m", "mypy", "--version"], "Checking mypy installation"):
|
if run_command(["python", "-m", "mypy", "--version"], "Checking mypy installation"):
|
||||||
success_count += 1
|
success_count += 1
|
||||||
|
|
||||||
# Test 4: Run a simple syntax check on test files
|
# Test 4: Run a simple syntax check on test files
|
||||||
total_tests += 1
|
total_tests += 1
|
||||||
if run_command(["python", "-c", "import tests.conftest; print('Test imports work!')"], "Testing test imports"):
|
if run_command(
|
||||||
|
["python", "-c", "import tests.conftest; print('Test imports work!')"],
|
||||||
|
"Testing test imports",
|
||||||
|
):
|
||||||
success_count += 1
|
success_count += 1
|
||||||
|
|
||||||
# Test 5: Check if we can import the main module
|
# Test 5: Check if we can import the main module
|
||||||
total_tests += 1
|
total_tests += 1
|
||||||
if run_command(["python", "-c", "import tradingagents.config; print('Main module imports work!')"], "Testing main module imports"):
|
if run_command(
|
||||||
|
[
|
||||||
|
"python",
|
||||||
|
"-c",
|
||||||
|
"import tradingagents.config; print('Main module imports work!')",
|
||||||
|
],
|
||||||
|
"Testing main module imports",
|
||||||
|
):
|
||||||
success_count += 1
|
success_count += 1
|
||||||
|
|
||||||
# Summary
|
# Summary
|
||||||
print("\n" + "=" * 50)
|
print("\n" + "=" * 50)
|
||||||
print("📊 Test Setup Verification Results:")
|
print("📊 Test Setup Verification Results:")
|
||||||
print(f"✅ Successful: {success_count}/{total_tests}")
|
print(f"✅ Successful: {success_count}/{total_tests}")
|
||||||
print(f"❌ Failed: {total_tests - success_count}/{total_tests}")
|
print(f"❌ Failed: {total_tests - success_count}/{total_tests}")
|
||||||
|
|
||||||
if success_count == total_tests:
|
if success_count == total_tests:
|
||||||
print("\n🎉 All verification tests passed! Your test setup is ready.")
|
print("\n🎉 All verification tests passed! Your test setup is ready.")
|
||||||
print("\n📚 Next steps:")
|
print("\n📚 Next steps:")
|
||||||
|
|
@ -95,4 +110,4 @@ def main():
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
sys.exit(main())
|
sys.exit(main())
|
||||||
|
|
|
||||||
|
|
@ -173,7 +173,8 @@ def mock_memory():
|
||||||
def pytest_configure(config):
|
def pytest_configure(config):
|
||||||
"""Configure pytest with custom markers."""
|
"""Configure pytest with custom markers."""
|
||||||
config.addinivalue_line(
|
config.addinivalue_line(
|
||||||
"markers", "integration: mark test as integration test (slow)",
|
"markers",
|
||||||
|
"integration: mark test as integration test (slow)",
|
||||||
)
|
)
|
||||||
config.addinivalue_line("markers", "unit: mark test as unit test (fast)")
|
config.addinivalue_line("markers", "unit: mark test as unit test (fast)")
|
||||||
config.addinivalue_line("markers", "api: mark test as requiring API access")
|
config.addinivalue_line("markers", "api: mark test as requiring API access")
|
||||||
|
|
|
||||||
|
|
@ -36,7 +36,8 @@ class SampleDataFactory:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_finnhub_news_data(
|
def create_finnhub_news_data(
|
||||||
ticker: str = "AAPL", count: int = 10,
|
ticker: str = "AAPL",
|
||||||
|
count: int = 10,
|
||||||
) -> dict[str, list[dict[str, Any]]]:
|
) -> dict[str, list[dict[str, Any]]]:
|
||||||
"""Create sample FinnHub news data for testing."""
|
"""Create sample FinnHub news data for testing."""
|
||||||
base_date = datetime(2024, 5, 10)
|
base_date = datetime(2024, 5, 10)
|
||||||
|
|
@ -136,7 +137,8 @@ class SampleDataFactory:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_financial_statements_data(
|
def create_financial_statements_data(
|
||||||
ticker: str = "AAPL", period: str = "annual",
|
ticker: str = "AAPL",
|
||||||
|
period: str = "annual",
|
||||||
) -> dict[str, list[dict[str, Any]]]:
|
) -> dict[str, list[dict[str, Any]]]:
|
||||||
"""Create sample financial statements data for testing."""
|
"""Create sample financial statements data for testing."""
|
||||||
if period == "annual":
|
if period == "annual":
|
||||||
|
|
@ -271,10 +273,12 @@ class SampleDataFactory:
|
||||||
ticker,
|
ticker,
|
||||||
),
|
),
|
||||||
"financial_annual": SampleDataFactory.create_financial_statements_data(
|
"financial_annual": SampleDataFactory.create_financial_statements_data(
|
||||||
ticker, "annual",
|
ticker,
|
||||||
|
"annual",
|
||||||
),
|
),
|
||||||
"financial_quarterly": SampleDataFactory.create_financial_statements_data(
|
"financial_quarterly": SampleDataFactory.create_financial_statements_data(
|
||||||
ticker, "quarterly",
|
ticker,
|
||||||
|
"quarterly",
|
||||||
),
|
),
|
||||||
"social_sentiment": SampleDataFactory.create_social_sentiment_data(ticker),
|
"social_sentiment": SampleDataFactory.create_social_sentiment_data(ticker),
|
||||||
"technical_indicators": SampleDataFactory.create_technical_indicators_data(
|
"technical_indicators": SampleDataFactory.create_technical_indicators_data(
|
||||||
|
|
@ -343,7 +347,9 @@ def save_sample_data_to_files(base_path: str, ticker: str = "AAPL") -> None:
|
||||||
|
|
||||||
# Save quarterly data separately
|
# Save quarterly data separately
|
||||||
quarterly_path = os.path.join(
|
quarterly_path = os.path.join(
|
||||||
finnhub_path, "fin_as_reported", f"{ticker}_quarterly_data_formatted.json",
|
finnhub_path,
|
||||||
|
"fin_as_reported",
|
||||||
|
f"{ticker}_quarterly_data_formatted.json",
|
||||||
)
|
)
|
||||||
with open(quarterly_path, "w") as f:
|
with open(quarterly_path, "w") as f:
|
||||||
json.dump(dataset["financial_quarterly"], f, indent=2)
|
json.dump(dataset["financial_quarterly"], f, indent=2)
|
||||||
|
|
|
||||||
|
|
@ -31,7 +31,10 @@ class TestFullWorkflowIntegration:
|
||||||
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
||||||
@patch("tradingagents.graph.trading_graph.Toolkit")
|
@patch("tradingagents.graph.trading_graph.Toolkit")
|
||||||
def test_end_to_end_trading_workflow(
|
def test_end_to_end_trading_workflow(
|
||||||
self, mock_toolkit, mock_chat_openai, integration_config,
|
self,
|
||||||
|
mock_toolkit,
|
||||||
|
mock_chat_openai,
|
||||||
|
integration_config,
|
||||||
):
|
):
|
||||||
"""Test complete end-to-end trading workflow."""
|
"""Test complete end-to-end trading workflow."""
|
||||||
# Setup mocks
|
# Setup mocks
|
||||||
|
|
@ -86,7 +89,10 @@ class TestFullWorkflowIntegration:
|
||||||
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
||||||
@patch("tradingagents.graph.trading_graph.Toolkit")
|
@patch("tradingagents.graph.trading_graph.Toolkit")
|
||||||
def test_multiple_analysts_integration(
|
def test_multiple_analysts_integration(
|
||||||
self, mock_toolkit, mock_chat_openai, integration_config,
|
self,
|
||||||
|
mock_toolkit,
|
||||||
|
mock_chat_openai,
|
||||||
|
integration_config,
|
||||||
):
|
):
|
||||||
"""Test integration with different analyst combinations."""
|
"""Test integration with different analyst combinations."""
|
||||||
analyst_combinations = [
|
analyst_combinations = [
|
||||||
|
|
@ -114,7 +120,8 @@ class TestFullWorkflowIntegration:
|
||||||
with patch("tradingagents.graph.trading_graph.set_config"):
|
with patch("tradingagents.graph.trading_graph.set_config"):
|
||||||
# Test each analyst combination
|
# Test each analyst combination
|
||||||
trading_graph = TradingAgentsGraph(
|
trading_graph = TradingAgentsGraph(
|
||||||
selected_analysts=analysts, config=integration_config,
|
selected_analysts=analysts,
|
||||||
|
config=integration_config,
|
||||||
)
|
)
|
||||||
trading_graph.graph = mock_graph
|
trading_graph.graph = mock_graph
|
||||||
|
|
||||||
|
|
@ -134,7 +141,8 @@ class TestFullWorkflowIntegration:
|
||||||
# Execute
|
# Execute
|
||||||
with patch("builtins.open", create=True), patch("json.dump"):
|
with patch("builtins.open", create=True), patch("json.dump"):
|
||||||
final_state, decision = trading_graph.propagate(
|
final_state, decision = trading_graph.propagate(
|
||||||
"TSLA", "2024-05-15",
|
"TSLA",
|
||||||
|
"2024-05-15",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify
|
# Verify
|
||||||
|
|
@ -144,7 +152,10 @@ class TestFullWorkflowIntegration:
|
||||||
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
||||||
@patch("tradingagents.graph.trading_graph.Toolkit")
|
@patch("tradingagents.graph.trading_graph.Toolkit")
|
||||||
def test_memory_and_reflection_integration(
|
def test_memory_and_reflection_integration(
|
||||||
self, mock_toolkit, mock_chat_openai, integration_config,
|
self,
|
||||||
|
mock_toolkit,
|
||||||
|
mock_chat_openai,
|
||||||
|
integration_config,
|
||||||
):
|
):
|
||||||
"""Test integration of memory and reflection components."""
|
"""Test integration of memory and reflection components."""
|
||||||
# Setup
|
# Setup
|
||||||
|
|
@ -208,7 +219,10 @@ class TestFullWorkflowIntegration:
|
||||||
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
||||||
@patch("tradingagents.graph.trading_graph.Toolkit")
|
@patch("tradingagents.graph.trading_graph.Toolkit")
|
||||||
def test_debug_mode_integration(
|
def test_debug_mode_integration(
|
||||||
self, mock_toolkit, mock_chat_openai, integration_config,
|
self,
|
||||||
|
mock_toolkit,
|
||||||
|
mock_chat_openai,
|
||||||
|
integration_config,
|
||||||
):
|
):
|
||||||
"""Test integration in debug mode."""
|
"""Test integration in debug mode."""
|
||||||
# Setup
|
# Setup
|
||||||
|
|
@ -240,7 +254,8 @@ class TestFullWorkflowIntegration:
|
||||||
with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"):
|
with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"):
|
||||||
with patch("tradingagents.graph.trading_graph.set_config"):
|
with patch("tradingagents.graph.trading_graph.set_config"):
|
||||||
trading_graph = TradingAgentsGraph(
|
trading_graph = TradingAgentsGraph(
|
||||||
debug=True, config=integration_config,
|
debug=True,
|
||||||
|
config=integration_config,
|
||||||
)
|
)
|
||||||
trading_graph.graph = mock_graph
|
trading_graph.graph = mock_graph
|
||||||
|
|
||||||
|
|
@ -276,7 +291,12 @@ class TestFullWorkflowIntegration:
|
||||||
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
||||||
@patch("tradingagents.graph.trading_graph.Toolkit")
|
@patch("tradingagents.graph.trading_graph.Toolkit")
|
||||||
def test_multiple_stocks_integration(
|
def test_multiple_stocks_integration(
|
||||||
self, mock_toolkit, mock_chat_openai, ticker, date, integration_config,
|
self,
|
||||||
|
mock_toolkit,
|
||||||
|
mock_chat_openai,
|
||||||
|
ticker,
|
||||||
|
date,
|
||||||
|
integration_config,
|
||||||
):
|
):
|
||||||
"""Test integration with different stocks and dates."""
|
"""Test integration with different stocks and dates."""
|
||||||
# Setup
|
# Setup
|
||||||
|
|
@ -382,7 +402,11 @@ class TestPerformanceIntegration:
|
||||||
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
||||||
@patch("tradingagents.graph.trading_graph.Toolkit")
|
@patch("tradingagents.graph.trading_graph.Toolkit")
|
||||||
def test_multiple_consecutive_runs(
|
def test_multiple_consecutive_runs(
|
||||||
self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir,
|
self,
|
||||||
|
mock_toolkit,
|
||||||
|
mock_chat_openai,
|
||||||
|
sample_config,
|
||||||
|
temp_data_dir,
|
||||||
):
|
):
|
||||||
"""Test multiple consecutive trading decisions."""
|
"""Test multiple consecutive trading decisions."""
|
||||||
sample_config["project_dir"] = temp_data_dir
|
sample_config["project_dir"] = temp_data_dir
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,10 @@ class TestMarketAnalyst:
|
||||||
assert callable(analyst_node)
|
assert callable(analyst_node)
|
||||||
|
|
||||||
def test_market_analyst_node_basic_execution(
|
def test_market_analyst_node_basic_execution(
|
||||||
self, mock_llm, mock_toolkit, sample_agent_state,
|
self,
|
||||||
|
mock_llm,
|
||||||
|
mock_toolkit,
|
||||||
|
sample_agent_state,
|
||||||
):
|
):
|
||||||
"""Test basic execution of market analyst node."""
|
"""Test basic execution of market analyst node."""
|
||||||
# Setup
|
# Setup
|
||||||
|
|
@ -39,7 +42,10 @@ class TestMarketAnalyst:
|
||||||
assert result["market_report"] == "Market analysis complete"
|
assert result["market_report"] == "Market analysis complete"
|
||||||
|
|
||||||
def test_market_analyst_uses_online_tools_when_configured(
|
def test_market_analyst_uses_online_tools_when_configured(
|
||||||
self, mock_llm, mock_toolkit, sample_agent_state,
|
self,
|
||||||
|
mock_llm,
|
||||||
|
mock_toolkit,
|
||||||
|
sample_agent_state,
|
||||||
):
|
):
|
||||||
"""Test that analyst uses online tools when configured."""
|
"""Test that analyst uses online tools when configured."""
|
||||||
# Setup
|
# Setup
|
||||||
|
|
@ -64,7 +70,10 @@ class TestMarketAnalyst:
|
||||||
assert "get_YFin_data_online" in str(tool_names) or len(bound_tools) == 2
|
assert "get_YFin_data_online" in str(tool_names) or len(bound_tools) == 2
|
||||||
|
|
||||||
def test_market_analyst_uses_offline_tools_when_configured(
|
def test_market_analyst_uses_offline_tools_when_configured(
|
||||||
self, mock_llm, mock_toolkit, sample_agent_state,
|
self,
|
||||||
|
mock_llm,
|
||||||
|
mock_toolkit,
|
||||||
|
sample_agent_state,
|
||||||
):
|
):
|
||||||
"""Test that analyst uses offline tools when configured."""
|
"""Test that analyst uses offline tools when configured."""
|
||||||
# Setup
|
# Setup
|
||||||
|
|
@ -88,7 +97,10 @@ class TestMarketAnalyst:
|
||||||
assert len(bound_tools) == 2 # Should have 2 offline tools
|
assert len(bound_tools) == 2 # Should have 2 offline tools
|
||||||
|
|
||||||
def test_market_analyst_processes_state_variables(
|
def test_market_analyst_processes_state_variables(
|
||||||
self, mock_llm, mock_toolkit, sample_agent_state,
|
self,
|
||||||
|
mock_llm,
|
||||||
|
mock_toolkit,
|
||||||
|
sample_agent_state,
|
||||||
):
|
):
|
||||||
"""Test that market analyst correctly processes state variables."""
|
"""Test that market analyst correctly processes state variables."""
|
||||||
# Setup
|
# Setup
|
||||||
|
|
@ -112,7 +124,10 @@ class TestMarketAnalyst:
|
||||||
assert result["market_report"] == "Analysis for AAPL on 2024-05-10"
|
assert result["market_report"] == "Analysis for AAPL on 2024-05-10"
|
||||||
|
|
||||||
def test_market_analyst_handles_empty_tool_calls(
|
def test_market_analyst_handles_empty_tool_calls(
|
||||||
self, mock_llm, mock_toolkit, sample_agent_state,
|
self,
|
||||||
|
mock_llm,
|
||||||
|
mock_toolkit,
|
||||||
|
sample_agent_state,
|
||||||
):
|
):
|
||||||
"""Test handling when no tool calls are made."""
|
"""Test handling when no tool calls are made."""
|
||||||
# Setup
|
# Setup
|
||||||
|
|
@ -132,7 +147,10 @@ class TestMarketAnalyst:
|
||||||
assert result["messages"] == [mock_result]
|
assert result["messages"] == [mock_result]
|
||||||
|
|
||||||
def test_market_analyst_with_tool_calls(
|
def test_market_analyst_with_tool_calls(
|
||||||
self, mock_llm, mock_toolkit, sample_agent_state,
|
self,
|
||||||
|
mock_llm,
|
||||||
|
mock_toolkit,
|
||||||
|
sample_agent_state,
|
||||||
):
|
):
|
||||||
"""Test handling when tool calls are present."""
|
"""Test handling when tool calls are present."""
|
||||||
# Setup
|
# Setup
|
||||||
|
|
@ -153,7 +171,11 @@ class TestMarketAnalyst:
|
||||||
|
|
||||||
@pytest.mark.parametrize("online_tools", [True, False])
|
@pytest.mark.parametrize("online_tools", [True, False])
|
||||||
def test_market_analyst_tool_configuration(
|
def test_market_analyst_tool_configuration(
|
||||||
self, mock_llm, mock_toolkit, sample_agent_state, online_tools,
|
self,
|
||||||
|
mock_llm,
|
||||||
|
mock_toolkit,
|
||||||
|
sample_agent_state,
|
||||||
|
online_tools,
|
||||||
):
|
):
|
||||||
"""Test tool configuration for both online and offline modes."""
|
"""Test tool configuration for both online and offline modes."""
|
||||||
# Setup
|
# Setup
|
||||||
|
|
|
||||||
|
|
@ -190,7 +190,10 @@ class TestFinnhubUtils:
|
||||||
|
|
||||||
# Test without period
|
# Test without period
|
||||||
expected_path_no_period = os.path.join(
|
expected_path_no_period = os.path.join(
|
||||||
temp_data_dir, "finnhub_data", data_type, f"{ticker}_data_formatted.json",
|
temp_data_dir,
|
||||||
|
"finnhub_data",
|
||||||
|
data_type,
|
||||||
|
f"{ticker}_data_formatted.json",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Test with period
|
# Test with period
|
||||||
|
|
@ -248,7 +251,10 @@ class TestFinnhubUtils:
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_get_data_in_range_various_data_types(
|
def test_get_data_in_range_various_data_types(
|
||||||
self, temp_data_dir, data_type, period,
|
self,
|
||||||
|
temp_data_dir,
|
||||||
|
data_type,
|
||||||
|
period,
|
||||||
):
|
):
|
||||||
"""Test get_data_in_range with various data types."""
|
"""Test get_data_in_range with various data types."""
|
||||||
ticker = "TEST"
|
ticker = "TEST"
|
||||||
|
|
|
||||||
|
|
@ -45,7 +45,11 @@ class TestTradingAgentsGraph:
|
||||||
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
||||||
@patch("tradingagents.graph.trading_graph.Toolkit")
|
@patch("tradingagents.graph.trading_graph.Toolkit")
|
||||||
def test_init_with_debug(
|
def test_init_with_debug(
|
||||||
self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir,
|
self,
|
||||||
|
mock_toolkit,
|
||||||
|
mock_chat_openai,
|
||||||
|
sample_config,
|
||||||
|
temp_data_dir,
|
||||||
):
|
):
|
||||||
"""Test initialization with debug mode enabled."""
|
"""Test initialization with debug mode enabled."""
|
||||||
sample_config["project_dir"] = temp_data_dir
|
sample_config["project_dir"] = temp_data_dir
|
||||||
|
|
@ -63,7 +67,11 @@ class TestTradingAgentsGraph:
|
||||||
@patch("tradingagents.graph.trading_graph.ChatAnthropic")
|
@patch("tradingagents.graph.trading_graph.ChatAnthropic")
|
||||||
@patch("tradingagents.graph.trading_graph.Toolkit")
|
@patch("tradingagents.graph.trading_graph.Toolkit")
|
||||||
def test_init_with_anthropic(
|
def test_init_with_anthropic(
|
||||||
self, mock_toolkit, mock_chat_anthropic, sample_config, temp_data_dir,
|
self,
|
||||||
|
mock_toolkit,
|
||||||
|
mock_chat_anthropic,
|
||||||
|
sample_config,
|
||||||
|
temp_data_dir,
|
||||||
):
|
):
|
||||||
"""Test initialization with Anthropic LLM provider."""
|
"""Test initialization with Anthropic LLM provider."""
|
||||||
sample_config["project_dir"] = temp_data_dir
|
sample_config["project_dir"] = temp_data_dir
|
||||||
|
|
@ -82,7 +90,11 @@ class TestTradingAgentsGraph:
|
||||||
@patch("tradingagents.graph.trading_graph.ChatGoogleGenerativeAI")
|
@patch("tradingagents.graph.trading_graph.ChatGoogleGenerativeAI")
|
||||||
@patch("tradingagents.graph.trading_graph.Toolkit")
|
@patch("tradingagents.graph.trading_graph.Toolkit")
|
||||||
def test_init_with_google(
|
def test_init_with_google(
|
||||||
self, mock_toolkit, mock_chat_google, sample_config, temp_data_dir,
|
self,
|
||||||
|
mock_toolkit,
|
||||||
|
mock_chat_google,
|
||||||
|
sample_config,
|
||||||
|
temp_data_dir,
|
||||||
):
|
):
|
||||||
"""Test initialization with Google LLM provider."""
|
"""Test initialization with Google LLM provider."""
|
||||||
sample_config["project_dir"] = temp_data_dir
|
sample_config["project_dir"] = temp_data_dir
|
||||||
|
|
@ -100,7 +112,10 @@ class TestTradingAgentsGraph:
|
||||||
|
|
||||||
@patch("tradingagents.graph.trading_graph.Toolkit")
|
@patch("tradingagents.graph.trading_graph.Toolkit")
|
||||||
def test_init_unsupported_llm_provider(
|
def test_init_unsupported_llm_provider(
|
||||||
self, mock_toolkit, sample_config, temp_data_dir,
|
self,
|
||||||
|
mock_toolkit,
|
||||||
|
sample_config,
|
||||||
|
temp_data_dir,
|
||||||
):
|
):
|
||||||
"""Test initialization with unsupported LLM provider raises error."""
|
"""Test initialization with unsupported LLM provider raises error."""
|
||||||
sample_config["project_dir"] = temp_data_dir
|
sample_config["project_dir"] = temp_data_dir
|
||||||
|
|
@ -115,7 +130,11 @@ class TestTradingAgentsGraph:
|
||||||
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
||||||
@patch("tradingagents.graph.trading_graph.Toolkit")
|
@patch("tradingagents.graph.trading_graph.Toolkit")
|
||||||
def test_create_tool_nodes(
|
def test_create_tool_nodes(
|
||||||
self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir,
|
self,
|
||||||
|
mock_toolkit,
|
||||||
|
mock_chat_openai,
|
||||||
|
sample_config,
|
||||||
|
temp_data_dir,
|
||||||
):
|
):
|
||||||
"""Test creation of tool nodes."""
|
"""Test creation of tool nodes."""
|
||||||
sample_config["project_dir"] = temp_data_dir
|
sample_config["project_dir"] = temp_data_dir
|
||||||
|
|
@ -143,7 +162,11 @@ class TestTradingAgentsGraph:
|
||||||
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
||||||
@patch("tradingagents.graph.trading_graph.Toolkit")
|
@patch("tradingagents.graph.trading_graph.Toolkit")
|
||||||
def test_propagate_basic(
|
def test_propagate_basic(
|
||||||
self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir,
|
self,
|
||||||
|
mock_toolkit,
|
||||||
|
mock_chat_openai,
|
||||||
|
sample_config,
|
||||||
|
temp_data_dir,
|
||||||
):
|
):
|
||||||
"""Test basic propagate functionality."""
|
"""Test basic propagate functionality."""
|
||||||
sample_config["project_dir"] = temp_data_dir
|
sample_config["project_dir"] = temp_data_dir
|
||||||
|
|
@ -206,7 +229,11 @@ class TestTradingAgentsGraph:
|
||||||
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
||||||
@patch("tradingagents.graph.trading_graph.Toolkit")
|
@patch("tradingagents.graph.trading_graph.Toolkit")
|
||||||
def test_propagate_debug_mode(
|
def test_propagate_debug_mode(
|
||||||
self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir,
|
self,
|
||||||
|
mock_toolkit,
|
||||||
|
mock_chat_openai,
|
||||||
|
sample_config,
|
||||||
|
temp_data_dir,
|
||||||
):
|
):
|
||||||
"""Test propagate in debug mode."""
|
"""Test propagate in debug mode."""
|
||||||
sample_config["project_dir"] = temp_data_dir
|
sample_config["project_dir"] = temp_data_dir
|
||||||
|
|
@ -245,7 +272,11 @@ class TestTradingAgentsGraph:
|
||||||
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
||||||
@patch("tradingagents.graph.trading_graph.Toolkit")
|
@patch("tradingagents.graph.trading_graph.Toolkit")
|
||||||
def test_log_state(
|
def test_log_state(
|
||||||
self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir,
|
self,
|
||||||
|
mock_toolkit,
|
||||||
|
mock_chat_openai,
|
||||||
|
sample_config,
|
||||||
|
temp_data_dir,
|
||||||
):
|
):
|
||||||
"""Test state logging functionality."""
|
"""Test state logging functionality."""
|
||||||
sample_config["project_dir"] = temp_data_dir
|
sample_config["project_dir"] = temp_data_dir
|
||||||
|
|
@ -300,7 +331,11 @@ class TestTradingAgentsGraph:
|
||||||
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
||||||
@patch("tradingagents.graph.trading_graph.Toolkit")
|
@patch("tradingagents.graph.trading_graph.Toolkit")
|
||||||
def test_reflect_and_remember(
|
def test_reflect_and_remember(
|
||||||
self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir,
|
self,
|
||||||
|
mock_toolkit,
|
||||||
|
mock_chat_openai,
|
||||||
|
sample_config,
|
||||||
|
temp_data_dir,
|
||||||
):
|
):
|
||||||
"""Test reflection and memory update functionality."""
|
"""Test reflection and memory update functionality."""
|
||||||
sample_config["project_dir"] = temp_data_dir
|
sample_config["project_dir"] = temp_data_dir
|
||||||
|
|
@ -309,9 +344,12 @@ class TestTradingAgentsGraph:
|
||||||
mock_toolkit_instance = Mock()
|
mock_toolkit_instance = Mock()
|
||||||
mock_toolkit.return_value = mock_toolkit_instance
|
mock_toolkit.return_value = mock_toolkit_instance
|
||||||
|
|
||||||
with patch(
|
with (
|
||||||
"tradingagents.graph.trading_graph.FinancialSituationMemory",
|
patch(
|
||||||
), patch("tradingagents.graph.trading_graph.set_config"):
|
"tradingagents.graph.trading_graph.FinancialSituationMemory",
|
||||||
|
),
|
||||||
|
patch("tradingagents.graph.trading_graph.set_config"),
|
||||||
|
):
|
||||||
graph = TradingAgentsGraph(config=sample_config)
|
graph = TradingAgentsGraph(config=sample_config)
|
||||||
|
|
||||||
# Set up current state
|
# Set up current state
|
||||||
|
|
@ -339,7 +377,11 @@ class TestTradingAgentsGraph:
|
||||||
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
||||||
@patch("tradingagents.graph.trading_graph.Toolkit")
|
@patch("tradingagents.graph.trading_graph.Toolkit")
|
||||||
def test_process_signal(
|
def test_process_signal(
|
||||||
self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir,
|
self,
|
||||||
|
mock_toolkit,
|
||||||
|
mock_chat_openai,
|
||||||
|
sample_config,
|
||||||
|
temp_data_dir,
|
||||||
):
|
):
|
||||||
"""Test signal processing functionality."""
|
"""Test signal processing functionality."""
|
||||||
sample_config["project_dir"] = temp_data_dir
|
sample_config["project_dir"] = temp_data_dir
|
||||||
|
|
@ -388,7 +430,8 @@ class TestTradingAgentsGraph:
|
||||||
with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"):
|
with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"):
|
||||||
with patch("tradingagents.graph.trading_graph.set_config"):
|
with patch("tradingagents.graph.trading_graph.set_config"):
|
||||||
TradingAgentsGraph(
|
TradingAgentsGraph(
|
||||||
selected_analysts=selected_analysts, config=sample_config,
|
selected_analysts=selected_analysts,
|
||||||
|
config=sample_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify graph was set up with selected analysts
|
# Verify graph was set up with selected analysts
|
||||||
|
|
@ -416,7 +459,10 @@ class TestTradingAgentsGraphErrorHandling:
|
||||||
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
||||||
@patch("tradingagents.graph.trading_graph.Toolkit")
|
@patch("tradingagents.graph.trading_graph.Toolkit")
|
||||||
def test_directory_creation_failure(
|
def test_directory_creation_failure(
|
||||||
self, mock_toolkit, mock_chat_openai, sample_config,
|
self,
|
||||||
|
mock_toolkit,
|
||||||
|
mock_chat_openai,
|
||||||
|
sample_config,
|
||||||
):
|
):
|
||||||
"""Test handling when directory creation fails."""
|
"""Test handling when directory creation fails."""
|
||||||
sample_config["project_dir"] = "/invalid/path/that/cannot/be/created"
|
sample_config["project_dir"] = "/invalid/path/that/cannot/be/created"
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,7 @@ def create_fundamentals_analyst(llm, toolkit):
|
||||||
|
|
||||||
system_message = (
|
system_message = (
|
||||||
"You are a researcher tasked with analyzing fundamental information over the past week about a company. Please write a comprehensive report of the company's fundamental information such as financial documents, company profile, basic company financials, company financial history, insider sentiment and insider transactions to gain a full view of the company's fundamental information to inform traders. Make sure to include as much detail as possible. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."
|
"You are a researcher tasked with analyzing fundamental information over the past week about a company. Please write a comprehensive report of the company's fundamental information such as financial documents, company profile, basic company financials, company financial history, insider sentiment and insider transactions to gain a full view of the company's fundamental information to inform traders. Make sure to include as much detail as possible. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."
|
||||||
" Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read.",
|
" Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read.",
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt = ChatPromptTemplate.from_messages(
|
prompt = ChatPromptTemplate.from_messages(
|
||||||
|
|
|
||||||
|
|
@ -45,7 +45,7 @@ Volume-Based Indicators:
|
||||||
- vwma: VWMA: A moving average weighted by volume. Usage: Confirm trends by integrating price action with volume data. Tips: Watch for skewed results from volume spikes; use in combination with other volume analyses.
|
- vwma: VWMA: A moving average weighted by volume. Usage: Confirm trends by integrating price action with volume data. Tips: Watch for skewed results from volume spikes; use in combination with other volume analyses.
|
||||||
|
|
||||||
- Select indicators that provide diverse and complementary information. Avoid redundancy (e.g., do not select both rsi and stochrsi). Also briefly explain why they are suitable for the given market context. When you tool call, please use the exact name of the indicators provided above as they are defined parameters, otherwise your call will fail. Please make sure to call get_YFin_data first to retrieve the CSV that is needed to generate indicators. Write a very detailed and nuanced report of the trends you observe. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."""
|
- Select indicators that provide diverse and complementary information. Avoid redundancy (e.g., do not select both rsi and stochrsi). Also briefly explain why they are suitable for the given market context. When you tool call, please use the exact name of the indicators provided above as they are defined parameters, otherwise your call will fail. Please make sure to call get_YFin_data first to retrieve the CSV that is needed to generate indicators. Write a very detailed and nuanced report of the trends you observe. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."""
|
||||||
""" Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read."""
|
""" Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read."""
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt = ChatPromptTemplate.from_messages(
|
prompt = ChatPromptTemplate.from_messages(
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ def create_news_analyst(llm, toolkit):
|
||||||
|
|
||||||
system_message = (
|
system_message = (
|
||||||
"You are a news researcher tasked with analyzing recent news and trends over the past week. Please write a comprehensive report of the current state of the world that is relevant for trading and macroeconomics. Look at news from EODHD, and finnhub to be comprehensive. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."
|
"You are a news researcher tasked with analyzing recent news and trends over the past week. Please write a comprehensive report of the current state of the world that is relevant for trading and macroeconomics. Look at news from EODHD, and finnhub to be comprehensive. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."
|
||||||
""" Make sure to append a Makrdown table at the end of the report to organize key points in the report, organized and easy to read."""
|
""" Make sure to append a Makrdown table at the end of the report to organize key points in the report, organized and easy to read."""
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt = ChatPromptTemplate.from_messages(
|
prompt = ChatPromptTemplate.from_messages(
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@ def create_social_media_analyst(llm, toolkit):
|
||||||
|
|
||||||
system_message = (
|
system_message = (
|
||||||
"You are a social media and company specific news researcher/analyst tasked with analyzing social media posts, recent company news, and public sentiment for a specific company over the past week. You will be given a company's name your objective is to write a comprehensive long report detailing your analysis, insights, and implications for traders and investors on this company's current state after looking at social media and what people are saying about that company, analyzing sentiment data of what people feel each day about the company, and looking at recent company news. Try to look at all sources possible from social media to sentiment to news. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."
|
"You are a social media and company specific news researcher/analyst tasked with analyzing social media posts, recent company news, and public sentiment for a specific company over the past week. You will be given a company's name your objective is to write a comprehensive long report detailing your analysis, insights, and implications for traders and investors on this company's current state after looking at social media and what people are saying about that company, analyzing sentiment data of what people feel each day about the company, and looking at recent company news. Try to look at all sources possible from social media to sentiment to news. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."
|
||||||
""" Make sure to append a Makrdown table at the end of the report to organize key points in the report, organized and easy to read.""",
|
""" Make sure to append a Makrdown table at the end of the report to organize key points in the report, organized and easy to read.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt = ChatPromptTemplate.from_messages(
|
prompt = ChatPromptTemplate.from_messages(
|
||||||
|
|
|
||||||
|
|
@ -41,7 +41,8 @@ Engage actively by addressing any specific concerns raised, refuting the weaknes
|
||||||
"current_risky_response": argument,
|
"current_risky_response": argument,
|
||||||
"current_safe_response": risk_debate_state.get("current_safe_response", ""),
|
"current_safe_response": risk_debate_state.get("current_safe_response", ""),
|
||||||
"current_neutral_response": risk_debate_state.get(
|
"current_neutral_response": risk_debate_state.get(
|
||||||
"current_neutral_response", "",
|
"current_neutral_response",
|
||||||
|
"",
|
||||||
),
|
),
|
||||||
"count": risk_debate_state["count"] + 1,
|
"count": risk_debate_state["count"] + 1,
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -39,11 +39,13 @@ Engage by questioning their optimism and emphasizing the potential downsides the
|
||||||
"neutral_history": risk_debate_state.get("neutral_history", ""),
|
"neutral_history": risk_debate_state.get("neutral_history", ""),
|
||||||
"latest_speaker": "Safe",
|
"latest_speaker": "Safe",
|
||||||
"current_risky_response": risk_debate_state.get(
|
"current_risky_response": risk_debate_state.get(
|
||||||
"current_risky_response", "",
|
"current_risky_response",
|
||||||
|
"",
|
||||||
),
|
),
|
||||||
"current_safe_response": argument,
|
"current_safe_response": argument,
|
||||||
"current_neutral_response": risk_debate_state.get(
|
"current_neutral_response": risk_debate_state.get(
|
||||||
"current_neutral_response", "",
|
"current_neutral_response",
|
||||||
|
"",
|
||||||
),
|
),
|
||||||
"count": risk_debate_state["count"] + 1,
|
"count": risk_debate_state["count"] + 1,
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -39,7 +39,8 @@ Engage actively by analyzing both sides critically, addressing weaknesses in the
|
||||||
"neutral_history": neutral_history + "\n" + argument,
|
"neutral_history": neutral_history + "\n" + argument,
|
||||||
"latest_speaker": "Neutral",
|
"latest_speaker": "Neutral",
|
||||||
"current_risky_response": risk_debate_state.get(
|
"current_risky_response": risk_debate_state.get(
|
||||||
"current_risky_response", "",
|
"current_risky_response",
|
||||||
|
"",
|
||||||
),
|
),
|
||||||
"current_safe_response": risk_debate_state.get("current_safe_response", ""),
|
"current_safe_response": risk_debate_state.get("current_safe_response", ""),
|
||||||
"current_neutral_response": argument,
|
"current_neutral_response": argument,
|
||||||
|
|
|
||||||
|
|
@ -9,10 +9,12 @@ from typing_extensions import TypedDict
|
||||||
# Researcher team state
|
# Researcher team state
|
||||||
class InvestDebateState(TypedDict):
|
class InvestDebateState(TypedDict):
|
||||||
bull_history: Annotated[
|
bull_history: Annotated[
|
||||||
str, "Bullish Conversation history",
|
str,
|
||||||
|
"Bullish Conversation history",
|
||||||
] # Bullish Conversation history
|
] # Bullish Conversation history
|
||||||
bear_history: Annotated[
|
bear_history: Annotated[
|
||||||
str, "Bearish Conversation history",
|
str,
|
||||||
|
"Bearish Conversation history",
|
||||||
] # Bullish Conversation history
|
] # Bullish Conversation history
|
||||||
history: Annotated[str, "Conversation history"] # Conversation history
|
history: Annotated[str, "Conversation history"] # Conversation history
|
||||||
current_response: Annotated[str, "Latest response"] # Last response
|
current_response: Annotated[str, "Latest response"] # Last response
|
||||||
|
|
@ -23,24 +25,30 @@ class InvestDebateState(TypedDict):
|
||||||
# Risk management team state
|
# Risk management team state
|
||||||
class RiskDebateState(TypedDict):
|
class RiskDebateState(TypedDict):
|
||||||
risky_history: Annotated[
|
risky_history: Annotated[
|
||||||
str, "Risky Agent's Conversation history",
|
str,
|
||||||
|
"Risky Agent's Conversation history",
|
||||||
] # Conversation history
|
] # Conversation history
|
||||||
safe_history: Annotated[
|
safe_history: Annotated[
|
||||||
str, "Safe Agent's Conversation history",
|
str,
|
||||||
|
"Safe Agent's Conversation history",
|
||||||
] # Conversation history
|
] # Conversation history
|
||||||
neutral_history: Annotated[
|
neutral_history: Annotated[
|
||||||
str, "Neutral Agent's Conversation history",
|
str,
|
||||||
|
"Neutral Agent's Conversation history",
|
||||||
] # Conversation history
|
] # Conversation history
|
||||||
history: Annotated[str, "Conversation history"] # Conversation history
|
history: Annotated[str, "Conversation history"] # Conversation history
|
||||||
latest_speaker: Annotated[str, "Analyst that spoke last"]
|
latest_speaker: Annotated[str, "Analyst that spoke last"]
|
||||||
current_risky_response: Annotated[
|
current_risky_response: Annotated[
|
||||||
str, "Latest response by the risky analyst",
|
str,
|
||||||
|
"Latest response by the risky analyst",
|
||||||
] # Last response
|
] # Last response
|
||||||
current_safe_response: Annotated[
|
current_safe_response: Annotated[
|
||||||
str, "Latest response by the safe analyst",
|
str,
|
||||||
|
"Latest response by the safe analyst",
|
||||||
] # Last response
|
] # Last response
|
||||||
current_neutral_response: Annotated[
|
current_neutral_response: Annotated[
|
||||||
str, "Latest response by the neutral analyst",
|
str,
|
||||||
|
"Latest response by the neutral analyst",
|
||||||
] # Last response
|
] # Last response
|
||||||
judge_decision: Annotated[str, "Judge's decision"]
|
judge_decision: Annotated[str, "Judge's decision"]
|
||||||
count: Annotated[int, "Length of the current conversation"] # Conversation length
|
count: Annotated[int, "Length of the current conversation"] # Conversation length
|
||||||
|
|
@ -56,13 +64,15 @@ class AgentState(MessagesState):
|
||||||
market_report: Annotated[str, "Report from the Market Analyst"]
|
market_report: Annotated[str, "Report from the Market Analyst"]
|
||||||
sentiment_report: Annotated[str, "Report from the Social Media Analyst"]
|
sentiment_report: Annotated[str, "Report from the Social Media Analyst"]
|
||||||
news_report: Annotated[
|
news_report: Annotated[
|
||||||
str, "Report from the News Researcher of current world affairs",
|
str,
|
||||||
|
"Report from the News Researcher of current world affairs",
|
||||||
]
|
]
|
||||||
fundamentals_report: Annotated[str, "Report from the Fundamentals Researcher"]
|
fundamentals_report: Annotated[str, "Report from the Fundamentals Researcher"]
|
||||||
|
|
||||||
# researcher team discussion step
|
# researcher team discussion step
|
||||||
investment_debate_state: Annotated[
|
investment_debate_state: Annotated[
|
||||||
InvestDebateState, "Current state of the debate on if to invest or not",
|
InvestDebateState,
|
||||||
|
"Current state of the debate on if to invest or not",
|
||||||
]
|
]
|
||||||
investment_plan: Annotated[str, "Plan generated by the Analyst"]
|
investment_plan: Annotated[str, "Plan generated by the Analyst"]
|
||||||
|
|
||||||
|
|
@ -70,6 +80,7 @@ class AgentState(MessagesState):
|
||||||
|
|
||||||
# risk management team discussion step
|
# risk management team discussion step
|
||||||
risk_debate_state: Annotated[
|
risk_debate_state: Annotated[
|
||||||
RiskDebateState, "Current state of the debate on evaluating risk",
|
RiskDebateState,
|
||||||
|
"Current state of the debate on evaluating risk",
|
||||||
]
|
]
|
||||||
final_trade_decision: Annotated[str, "Final decision made by the Risk Analysts"]
|
final_trade_decision: Annotated[str, "Final decision made by the Risk Analysts"]
|
||||||
|
|
|
||||||
|
|
@ -56,7 +56,6 @@ class Toolkit:
|
||||||
|
|
||||||
return interface.get_reddit_global_news(curr_date, 7, 5)
|
return interface.get_reddit_global_news(curr_date, 7, 5)
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@tool
|
@tool
|
||||||
def get_finnhub_news(
|
def get_finnhub_news(
|
||||||
|
|
@ -84,10 +83,11 @@ class Toolkit:
|
||||||
look_back_days = (end_date - start_date).days
|
look_back_days = (end_date - start_date).days
|
||||||
|
|
||||||
return interface.get_finnhub_news(
|
return interface.get_finnhub_news(
|
||||||
ticker, end_date_str, look_back_days,
|
ticker,
|
||||||
|
end_date_str,
|
||||||
|
look_back_days,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@tool
|
@tool
|
||||||
def get_reddit_stock_info(
|
def get_reddit_stock_info(
|
||||||
|
|
@ -108,7 +108,6 @@ class Toolkit:
|
||||||
|
|
||||||
return interface.get_reddit_company_news(ticker, curr_date, 7, 5)
|
return interface.get_reddit_company_news(ticker, curr_date, 7, 5)
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@tool
|
@tool
|
||||||
def get_YFin_data(
|
def get_YFin_data(
|
||||||
|
|
@ -128,7 +127,6 @@ class Toolkit:
|
||||||
|
|
||||||
return interface.get_YFin_data(symbol, start_date, end_date)
|
return interface.get_YFin_data(symbol, start_date, end_date)
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@tool
|
@tool
|
||||||
def get_YFin_data_online(
|
def get_YFin_data_online(
|
||||||
|
|
@ -148,16 +146,17 @@ class Toolkit:
|
||||||
|
|
||||||
return interface.get_YFin_data_online(symbol, start_date, end_date)
|
return interface.get_YFin_data_online(symbol, start_date, end_date)
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@tool
|
@tool
|
||||||
def get_stockstats_indicators_report(
|
def get_stockstats_indicators_report(
|
||||||
symbol: Annotated[str, "ticker symbol of the company"],
|
symbol: Annotated[str, "ticker symbol of the company"],
|
||||||
indicator: Annotated[
|
indicator: Annotated[
|
||||||
str, "technical indicator to get the analysis and report of",
|
str,
|
||||||
|
"technical indicator to get the analysis and report of",
|
||||||
],
|
],
|
||||||
curr_date: Annotated[
|
curr_date: Annotated[
|
||||||
str, "The current trading date you are trading on, YYYY-mm-dd",
|
str,
|
||||||
|
"The current trading date you are trading on, YYYY-mm-dd",
|
||||||
],
|
],
|
||||||
look_back_days: Annotated[int, "how many days to look back"] = 30,
|
look_back_days: Annotated[int, "how many days to look back"] = 30,
|
||||||
) -> str:
|
) -> str:
|
||||||
|
|
@ -173,19 +172,24 @@ class Toolkit:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return interface.get_stock_stats_indicators_window(
|
return interface.get_stock_stats_indicators_window(
|
||||||
symbol, indicator, curr_date, look_back_days, False,
|
symbol,
|
||||||
|
indicator,
|
||||||
|
curr_date,
|
||||||
|
look_back_days,
|
||||||
|
False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@tool
|
@tool
|
||||||
def get_stockstats_indicators_report_online(
|
def get_stockstats_indicators_report_online(
|
||||||
symbol: Annotated[str, "ticker symbol of the company"],
|
symbol: Annotated[str, "ticker symbol of the company"],
|
||||||
indicator: Annotated[
|
indicator: Annotated[
|
||||||
str, "technical indicator to get the analysis and report of",
|
str,
|
||||||
|
"technical indicator to get the analysis and report of",
|
||||||
],
|
],
|
||||||
curr_date: Annotated[
|
curr_date: Annotated[
|
||||||
str, "The current trading date you are trading on, YYYY-mm-dd",
|
str,
|
||||||
|
"The current trading date you are trading on, YYYY-mm-dd",
|
||||||
],
|
],
|
||||||
look_back_days: Annotated[int, "how many days to look back"] = 30,
|
look_back_days: Annotated[int, "how many days to look back"] = 30,
|
||||||
) -> str:
|
) -> str:
|
||||||
|
|
@ -201,10 +205,13 @@ class Toolkit:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return interface.get_stock_stats_indicators_window(
|
return interface.get_stock_stats_indicators_window(
|
||||||
symbol, indicator, curr_date, look_back_days, True,
|
symbol,
|
||||||
|
indicator,
|
||||||
|
curr_date,
|
||||||
|
look_back_days,
|
||||||
|
True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@tool
|
@tool
|
||||||
def get_finnhub_company_insider_sentiment(
|
def get_finnhub_company_insider_sentiment(
|
||||||
|
|
@ -224,10 +231,11 @@ class Toolkit:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return interface.get_finnhub_company_insider_sentiment(
|
return interface.get_finnhub_company_insider_sentiment(
|
||||||
ticker, curr_date, 30,
|
ticker,
|
||||||
|
curr_date,
|
||||||
|
30,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@tool
|
@tool
|
||||||
def get_finnhub_company_insider_transactions(
|
def get_finnhub_company_insider_transactions(
|
||||||
|
|
@ -247,10 +255,11 @@ class Toolkit:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return interface.get_finnhub_company_insider_transactions(
|
return interface.get_finnhub_company_insider_transactions(
|
||||||
ticker, curr_date, 30,
|
ticker,
|
||||||
|
curr_date,
|
||||||
|
30,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@tool
|
@tool
|
||||||
def get_simfin_balance_sheet(
|
def get_simfin_balance_sheet(
|
||||||
|
|
@ -273,7 +282,6 @@ class Toolkit:
|
||||||
|
|
||||||
return interface.get_simfin_balance_sheet(ticker, freq, curr_date)
|
return interface.get_simfin_balance_sheet(ticker, freq, curr_date)
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@tool
|
@tool
|
||||||
def get_simfin_cashflow(
|
def get_simfin_cashflow(
|
||||||
|
|
@ -296,7 +304,6 @@ class Toolkit:
|
||||||
|
|
||||||
return interface.get_simfin_cashflow(ticker, freq, curr_date)
|
return interface.get_simfin_cashflow(ticker, freq, curr_date)
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@tool
|
@tool
|
||||||
def get_simfin_income_stmt(
|
def get_simfin_income_stmt(
|
||||||
|
|
@ -318,10 +325,11 @@ class Toolkit:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return interface.get_simfin_income_statements(
|
return interface.get_simfin_income_statements(
|
||||||
ticker, freq, curr_date,
|
ticker,
|
||||||
|
freq,
|
||||||
|
curr_date,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@tool
|
@tool
|
||||||
def get_google_news(
|
def get_google_news(
|
||||||
|
|
@ -340,7 +348,6 @@ class Toolkit:
|
||||||
|
|
||||||
return interface.get_google_news(query, curr_date, 7)
|
return interface.get_google_news(query, curr_date, 7)
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@tool
|
@tool
|
||||||
def get_stock_news_openai(
|
def get_stock_news_openai(
|
||||||
|
|
@ -358,7 +365,6 @@ class Toolkit:
|
||||||
|
|
||||||
return interface.get_stock_news_openai(ticker, curr_date)
|
return interface.get_stock_news_openai(ticker, curr_date)
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@tool
|
@tool
|
||||||
def get_global_news_openai(
|
def get_global_news_openai(
|
||||||
|
|
@ -374,7 +380,6 @@ class Toolkit:
|
||||||
|
|
||||||
return interface.get_global_news_openai(curr_date)
|
return interface.get_global_news_openai(curr_date)
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@tool
|
@tool
|
||||||
def get_fundamentals_openai(
|
def get_fundamentals_openai(
|
||||||
|
|
@ -391,6 +396,6 @@ class Toolkit:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return interface.get_fundamentals_openai(
|
return interface.get_fundamentals_openai(
|
||||||
ticker, curr_date,
|
ticker,
|
||||||
|
curr_date,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -21,7 +21,8 @@ def get_config():
|
||||||
"project_dir": str(project_root / "tradingagents"),
|
"project_dir": str(project_root / "tradingagents"),
|
||||||
"results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", "./results"),
|
"results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", "./results"),
|
||||||
"data_dir": os.getenv(
|
"data_dir": os.getenv(
|
||||||
"TRADINGAGENTS_DATA_DIR", "/Users/yluo/Documents/Code/ScAI/FR1-data",
|
"TRADINGAGENTS_DATA_DIR",
|
||||||
|
"/Users/yluo/Documents/Code/ScAI/FR1-data",
|
||||||
),
|
),
|
||||||
"data_cache_dir": str(
|
"data_cache_dir": str(
|
||||||
project_root / "tradingagents" / "dataflows" / "data_cache",
|
project_root / "tradingagents" / "dataflows" / "data_cache",
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
|
|
||||||
from tradingagents import default_config
|
from tradingagents import default_config
|
||||||
|
|
||||||
# Use default config but allow it to be overridden
|
# Use default config but allow it to be overridden
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,10 @@ def get_data_in_range(ticker, start_date, end_date, data_type, data_dir, period=
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
data_path = os.path.join(
|
data_path = os.path.join(
|
||||||
data_dir, "finnhub_data", data_type, f"{ticker}_data_formatted.json",
|
data_dir,
|
||||||
|
"finnhub_data",
|
||||||
|
data_type,
|
||||||
|
f"{ticker}_data_formatted.json",
|
||||||
)
|
)
|
||||||
|
|
||||||
data = open(data_path)
|
data = open(data_path)
|
||||||
|
|
|
||||||
|
|
@ -419,7 +419,8 @@ def get_stock_stats_indicators_window(
|
||||||
symbol: Annotated[str, "ticker symbol of the company"],
|
symbol: Annotated[str, "ticker symbol of the company"],
|
||||||
indicator: Annotated[str, "technical indicator to get the analysis and report of"],
|
indicator: Annotated[str, "technical indicator to get the analysis and report of"],
|
||||||
curr_date: Annotated[
|
curr_date: Annotated[
|
||||||
str, "The current trading date you are trading on, YYYY-mm-dd",
|
str,
|
||||||
|
"The current trading date you are trading on, YYYY-mm-dd",
|
||||||
],
|
],
|
||||||
look_back_days: Annotated[int, "how many days to look back"],
|
look_back_days: Annotated[int, "how many days to look back"],
|
||||||
online: Annotated[bool, "to fetch data online or offline"],
|
online: Annotated[bool, "to fetch data online or offline"],
|
||||||
|
|
@ -524,7 +525,10 @@ def get_stock_stats_indicators_window(
|
||||||
# only do the trading dates
|
# only do the trading dates
|
||||||
if curr_date.strftime("%Y-%m-%d") in dates_in_df.values:
|
if curr_date.strftime("%Y-%m-%d") in dates_in_df.values:
|
||||||
indicator_value = get_stockstats_indicator(
|
indicator_value = get_stockstats_indicator(
|
||||||
symbol, indicator, curr_date.strftime("%Y-%m-%d"), online,
|
symbol,
|
||||||
|
indicator,
|
||||||
|
curr_date.strftime("%Y-%m-%d"),
|
||||||
|
online,
|
||||||
)
|
)
|
||||||
|
|
||||||
ind_string += f"{curr_date.strftime('%Y-%m-%d')}: {indicator_value}\n"
|
ind_string += f"{curr_date.strftime('%Y-%m-%d')}: {indicator_value}\n"
|
||||||
|
|
@ -535,7 +539,10 @@ def get_stock_stats_indicators_window(
|
||||||
ind_string = ""
|
ind_string = ""
|
||||||
while curr_date >= before:
|
while curr_date >= before:
|
||||||
indicator_value = get_stockstats_indicator(
|
indicator_value = get_stockstats_indicator(
|
||||||
symbol, indicator, curr_date.strftime("%Y-%m-%d"), online,
|
symbol,
|
||||||
|
indicator,
|
||||||
|
curr_date.strftime("%Y-%m-%d"),
|
||||||
|
online,
|
||||||
)
|
)
|
||||||
|
|
||||||
ind_string += f"{curr_date.strftime('%Y-%m-%d')}: {indicator_value}\n"
|
ind_string += f"{curr_date.strftime('%Y-%m-%d')}: {indicator_value}\n"
|
||||||
|
|
@ -550,12 +557,12 @@ def get_stock_stats_indicators_window(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_stockstats_indicator(
|
def get_stockstats_indicator(
|
||||||
symbol: Annotated[str, "ticker symbol of the company"],
|
symbol: Annotated[str, "ticker symbol of the company"],
|
||||||
indicator: Annotated[str, "technical indicator to get the analysis and report of"],
|
indicator: Annotated[str, "technical indicator to get the analysis and report of"],
|
||||||
curr_date: Annotated[
|
curr_date: Annotated[
|
||||||
str, "The current trading date you are trading on, YYYY-mm-dd",
|
str,
|
||||||
|
"The current trading date you are trading on, YYYY-mm-dd",
|
||||||
],
|
],
|
||||||
online: Annotated[bool, "to fetch data online or offline"],
|
online: Annotated[bool, "to fetch data online or offline"],
|
||||||
) -> str:
|
) -> str:
|
||||||
|
|
@ -608,7 +615,12 @@ def get_YFin_data_window(
|
||||||
|
|
||||||
# Set pandas display options to show the full DataFrame
|
# Set pandas display options to show the full DataFrame
|
||||||
with pd.option_context(
|
with pd.option_context(
|
||||||
"display.max_rows", None, "display.max_columns", None, "display.width", None,
|
"display.max_rows",
|
||||||
|
None,
|
||||||
|
"display.max_columns",
|
||||||
|
None,
|
||||||
|
"display.width",
|
||||||
|
None,
|
||||||
):
|
):
|
||||||
df_string = filtered_data.to_string()
|
df_string = filtered_data.to_string()
|
||||||
|
|
||||||
|
|
@ -694,7 +706,6 @@ def get_YFin_data(
|
||||||
return filtered_data.reset_index(drop=True)
|
return filtered_data.reset_index(drop=True)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_stock_news_openai(ticker, curr_date):
|
def get_stock_news_openai(ticker, curr_date):
|
||||||
config = get_config()
|
config = get_config()
|
||||||
client = OpenAI(base_url=config["backend_url"])
|
client = OpenAI(base_url=config["backend_url"])
|
||||||
|
|
|
||||||
|
|
@ -48,11 +48,14 @@ ticker_to_company = {
|
||||||
|
|
||||||
def fetch_top_from_category(
|
def fetch_top_from_category(
|
||||||
category: Annotated[
|
category: Annotated[
|
||||||
str, "Category to fetch top post from. Collection of subreddits.",
|
str,
|
||||||
|
"Category to fetch top post from. Collection of subreddits.",
|
||||||
],
|
],
|
||||||
date: Annotated[str, "Date to fetch top posts from."],
|
date: Annotated[str, "Date to fetch top posts from."],
|
||||||
max_limit: Annotated[int, "Maximum number of posts to fetch."],
|
max_limit: Annotated[int, "Maximum number of posts to fetch."],
|
||||||
query: Annotated[str | None, "Optional query to search for in the subreddit."] = None,
|
query: Annotated[
|
||||||
|
str | None, "Optional query to search for in the subreddit."
|
||||||
|
] = None,
|
||||||
data_path: Annotated[
|
data_path: Annotated[
|
||||||
str,
|
str,
|
||||||
"Path to the data folder. Default is 'reddit_data'.",
|
"Path to the data folder. Default is 'reddit_data'.",
|
||||||
|
|
@ -107,7 +110,9 @@ def fetch_top_from_category(
|
||||||
found = False
|
found = False
|
||||||
for term in search_terms:
|
for term in search_terms:
|
||||||
if re.search(
|
if re.search(
|
||||||
term, parsed_line["title"], re.IGNORECASE,
|
term,
|
||||||
|
parsed_line["title"],
|
||||||
|
re.IGNORECASE,
|
||||||
) or re.search(term, parsed_line["selftext"], re.IGNORECASE):
|
) or re.search(term, parsed_line["selftext"], re.IGNORECASE):
|
||||||
found = True
|
found = True
|
||||||
break
|
break
|
||||||
|
|
|
||||||
|
|
@ -13,10 +13,12 @@ class StockstatsUtils:
|
||||||
def get_stock_stats(
|
def get_stock_stats(
|
||||||
symbol: Annotated[str, "ticker symbol for the company"],
|
symbol: Annotated[str, "ticker symbol for the company"],
|
||||||
indicator: Annotated[
|
indicator: Annotated[
|
||||||
str, "quantitative indicators based off of the stock data for the company",
|
str,
|
||||||
|
"quantitative indicators based off of the stock data for the company",
|
||||||
],
|
],
|
||||||
curr_date: Annotated[
|
curr_date: Annotated[
|
||||||
str, "curr date for retrieving stock price data, YYYY-mm-dd",
|
str,
|
||||||
|
"curr date for retrieving stock price data, YYYY-mm-dd",
|
||||||
],
|
],
|
||||||
data_dir: Annotated[
|
data_dir: Annotated[
|
||||||
str,
|
str,
|
||||||
|
|
|
||||||
|
|
@ -28,10 +28,12 @@ class YFinanceUtils:
|
||||||
def get_stock_data(
|
def get_stock_data(
|
||||||
self: Annotated[str, "ticker symbol"],
|
self: Annotated[str, "ticker symbol"],
|
||||||
start_date: Annotated[
|
start_date: Annotated[
|
||||||
str, "start date for retrieving stock price data, YYYY-mm-dd",
|
str,
|
||||||
|
"start date for retrieving stock price data, YYYY-mm-dd",
|
||||||
],
|
],
|
||||||
end_date: Annotated[
|
end_date: Annotated[
|
||||||
str, "end date for retrieving stock price data, YYYY-mm-dd",
|
str,
|
||||||
|
"end date for retrieving stock price data, YYYY-mm-dd",
|
||||||
],
|
],
|
||||||
save_path: SavePathType = None,
|
save_path: SavePathType = None,
|
||||||
) -> DataFrame:
|
) -> DataFrame:
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,9 @@ class Propagator:
|
||||||
self.max_recur_limit = max_recur_limit
|
self.max_recur_limit = max_recur_limit
|
||||||
|
|
||||||
def create_initial_state(
|
def create_initial_state(
|
||||||
self, company_name: str, trade_date: str,
|
self,
|
||||||
|
company_name: str,
|
||||||
|
trade_date: str,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Create the initial state for the agent graph."""
|
"""Create the initial state for the agent graph."""
|
||||||
return {
|
return {
|
||||||
|
|
|
||||||
|
|
@ -57,7 +57,11 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur
|
||||||
return f"{curr_market_report}\n\n{curr_sentiment_report}\n\n{curr_news_report}\n\n{curr_fundamentals_report}"
|
return f"{curr_market_report}\n\n{curr_sentiment_report}\n\n{curr_news_report}\n\n{curr_fundamentals_report}"
|
||||||
|
|
||||||
def _reflect_on_component(
|
def _reflect_on_component(
|
||||||
self, component_type: str, report: str, situation: str, returns_losses,
|
self,
|
||||||
|
component_type: str,
|
||||||
|
report: str,
|
||||||
|
situation: str,
|
||||||
|
returns_losses,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Generate reflection for a component."""
|
"""Generate reflection for a component."""
|
||||||
messages = [
|
messages = [
|
||||||
|
|
@ -76,7 +80,10 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur
|
||||||
bull_debate_history = current_state["investment_debate_state"]["bull_history"]
|
bull_debate_history = current_state["investment_debate_state"]["bull_history"]
|
||||||
|
|
||||||
result = self._reflect_on_component(
|
result = self._reflect_on_component(
|
||||||
"BULL", bull_debate_history, situation, returns_losses,
|
"BULL",
|
||||||
|
bull_debate_history,
|
||||||
|
situation,
|
||||||
|
returns_losses,
|
||||||
)
|
)
|
||||||
bull_memory.add_situations([(situation, result)])
|
bull_memory.add_situations([(situation, result)])
|
||||||
|
|
||||||
|
|
@ -86,7 +93,10 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur
|
||||||
bear_debate_history = current_state["investment_debate_state"]["bear_history"]
|
bear_debate_history = current_state["investment_debate_state"]["bear_history"]
|
||||||
|
|
||||||
result = self._reflect_on_component(
|
result = self._reflect_on_component(
|
||||||
"BEAR", bear_debate_history, situation, returns_losses,
|
"BEAR",
|
||||||
|
bear_debate_history,
|
||||||
|
situation,
|
||||||
|
returns_losses,
|
||||||
)
|
)
|
||||||
bear_memory.add_situations([(situation, result)])
|
bear_memory.add_situations([(situation, result)])
|
||||||
|
|
||||||
|
|
@ -96,7 +106,10 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur
|
||||||
trader_decision = current_state["trader_investment_plan"]
|
trader_decision = current_state["trader_investment_plan"]
|
||||||
|
|
||||||
result = self._reflect_on_component(
|
result = self._reflect_on_component(
|
||||||
"TRADER", trader_decision, situation, returns_losses,
|
"TRADER",
|
||||||
|
trader_decision,
|
||||||
|
situation,
|
||||||
|
returns_losses,
|
||||||
)
|
)
|
||||||
trader_memory.add_situations([(situation, result)])
|
trader_memory.add_situations([(situation, result)])
|
||||||
|
|
||||||
|
|
@ -106,7 +119,10 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur
|
||||||
judge_decision = current_state["investment_debate_state"]["judge_decision"]
|
judge_decision = current_state["investment_debate_state"]["judge_decision"]
|
||||||
|
|
||||||
result = self._reflect_on_component(
|
result = self._reflect_on_component(
|
||||||
"INVEST JUDGE", judge_decision, situation, returns_losses,
|
"INVEST JUDGE",
|
||||||
|
judge_decision,
|
||||||
|
situation,
|
||||||
|
returns_losses,
|
||||||
)
|
)
|
||||||
invest_judge_memory.add_situations([(situation, result)])
|
invest_judge_memory.add_situations([(situation, result)])
|
||||||
|
|
||||||
|
|
@ -116,6 +132,9 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur
|
||||||
judge_decision = current_state["risk_debate_state"]["judge_decision"]
|
judge_decision = current_state["risk_debate_state"]["judge_decision"]
|
||||||
|
|
||||||
result = self._reflect_on_component(
|
result = self._reflect_on_component(
|
||||||
"RISK JUDGE", judge_decision, situation, returns_losses,
|
"RISK JUDGE",
|
||||||
|
judge_decision,
|
||||||
|
situation,
|
||||||
|
returns_losses,
|
||||||
)
|
)
|
||||||
risk_manager_memory.add_situations([(situation, result)])
|
risk_manager_memory.add_situations([(situation, result)])
|
||||||
|
|
|
||||||
|
|
@ -55,7 +55,8 @@ class GraphSetup:
|
||||||
self.conditional_logic = conditional_logic
|
self.conditional_logic = conditional_logic
|
||||||
|
|
||||||
def setup_graph(
|
def setup_graph(
|
||||||
self, selected_analysts=None,
|
self,
|
||||||
|
selected_analysts=None,
|
||||||
):
|
):
|
||||||
"""Set up and compile the agent workflow graph.
|
"""Set up and compile the agent workflow graph.
|
||||||
|
|
||||||
|
|
@ -79,41 +80,48 @@ class GraphSetup:
|
||||||
|
|
||||||
if "market" in selected_analysts:
|
if "market" in selected_analysts:
|
||||||
analyst_nodes["market"] = create_market_analyst(
|
analyst_nodes["market"] = create_market_analyst(
|
||||||
self.quick_thinking_llm, self.toolkit,
|
self.quick_thinking_llm,
|
||||||
|
self.toolkit,
|
||||||
)
|
)
|
||||||
delete_nodes["market"] = create_msg_delete()
|
delete_nodes["market"] = create_msg_delete()
|
||||||
tool_nodes["market"] = self.tool_nodes["market"]
|
tool_nodes["market"] = self.tool_nodes["market"]
|
||||||
|
|
||||||
if "social" in selected_analysts:
|
if "social" in selected_analysts:
|
||||||
analyst_nodes["social"] = create_social_media_analyst(
|
analyst_nodes["social"] = create_social_media_analyst(
|
||||||
self.quick_thinking_llm, self.toolkit,
|
self.quick_thinking_llm,
|
||||||
|
self.toolkit,
|
||||||
)
|
)
|
||||||
delete_nodes["social"] = create_msg_delete()
|
delete_nodes["social"] = create_msg_delete()
|
||||||
tool_nodes["social"] = self.tool_nodes["social"]
|
tool_nodes["social"] = self.tool_nodes["social"]
|
||||||
|
|
||||||
if "news" in selected_analysts:
|
if "news" in selected_analysts:
|
||||||
analyst_nodes["news"] = create_news_analyst(
|
analyst_nodes["news"] = create_news_analyst(
|
||||||
self.quick_thinking_llm, self.toolkit,
|
self.quick_thinking_llm,
|
||||||
|
self.toolkit,
|
||||||
)
|
)
|
||||||
delete_nodes["news"] = create_msg_delete()
|
delete_nodes["news"] = create_msg_delete()
|
||||||
tool_nodes["news"] = self.tool_nodes["news"]
|
tool_nodes["news"] = self.tool_nodes["news"]
|
||||||
|
|
||||||
if "fundamentals" in selected_analysts:
|
if "fundamentals" in selected_analysts:
|
||||||
analyst_nodes["fundamentals"] = create_fundamentals_analyst(
|
analyst_nodes["fundamentals"] = create_fundamentals_analyst(
|
||||||
self.quick_thinking_llm, self.toolkit,
|
self.quick_thinking_llm,
|
||||||
|
self.toolkit,
|
||||||
)
|
)
|
||||||
delete_nodes["fundamentals"] = create_msg_delete()
|
delete_nodes["fundamentals"] = create_msg_delete()
|
||||||
tool_nodes["fundamentals"] = self.tool_nodes["fundamentals"]
|
tool_nodes["fundamentals"] = self.tool_nodes["fundamentals"]
|
||||||
|
|
||||||
# Create researcher and manager nodes
|
# Create researcher and manager nodes
|
||||||
bull_researcher_node = create_bull_researcher(
|
bull_researcher_node = create_bull_researcher(
|
||||||
self.quick_thinking_llm, self.bull_memory,
|
self.quick_thinking_llm,
|
||||||
|
self.bull_memory,
|
||||||
)
|
)
|
||||||
bear_researcher_node = create_bear_researcher(
|
bear_researcher_node = create_bear_researcher(
|
||||||
self.quick_thinking_llm, self.bear_memory,
|
self.quick_thinking_llm,
|
||||||
|
self.bear_memory,
|
||||||
)
|
)
|
||||||
research_manager_node = create_research_manager(
|
research_manager_node = create_research_manager(
|
||||||
self.deep_thinking_llm, self.invest_judge_memory,
|
self.deep_thinking_llm,
|
||||||
|
self.invest_judge_memory,
|
||||||
)
|
)
|
||||||
trader_node = create_trader(self.quick_thinking_llm, self.trader_memory)
|
trader_node = create_trader(self.quick_thinking_llm, self.trader_memory)
|
||||||
|
|
||||||
|
|
@ -122,7 +130,8 @@ class GraphSetup:
|
||||||
neutral_analyst = create_neutral_debator(self.quick_thinking_llm)
|
neutral_analyst = create_neutral_debator(self.quick_thinking_llm)
|
||||||
safe_analyst = create_safe_debator(self.quick_thinking_llm)
|
safe_analyst = create_safe_debator(self.quick_thinking_llm)
|
||||||
risk_manager_node = create_risk_manager(
|
risk_manager_node = create_risk_manager(
|
||||||
self.deep_thinking_llm, self.risk_manager_memory,
|
self.deep_thinking_llm,
|
||||||
|
self.risk_manager_memory,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create workflow
|
# Create workflow
|
||||||
|
|
@ -132,7 +141,8 @@ class GraphSetup:
|
||||||
for analyst_type, node in analyst_nodes.items():
|
for analyst_type, node in analyst_nodes.items():
|
||||||
workflow.add_node(f"{analyst_type.capitalize()} Analyst", node)
|
workflow.add_node(f"{analyst_type.capitalize()} Analyst", node)
|
||||||
workflow.add_node(
|
workflow.add_node(
|
||||||
f"Msg Clear {analyst_type.capitalize()}", delete_nodes[analyst_type],
|
f"Msg Clear {analyst_type.capitalize()}",
|
||||||
|
delete_nodes[analyst_type],
|
||||||
)
|
)
|
||||||
workflow.add_node(f"tools_{analyst_type}", tool_nodes[analyst_type])
|
workflow.add_node(f"tools_{analyst_type}", tool_nodes[analyst_type])
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -59,7 +59,8 @@ class TradingAgentsGraph:
|
||||||
or self.config["llm_provider"] == "openrouter"
|
or self.config["llm_provider"] == "openrouter"
|
||||||
):
|
):
|
||||||
self.deep_thinking_llm = ChatOpenAI(
|
self.deep_thinking_llm = ChatOpenAI(
|
||||||
model=self.config["deep_think_llm"], base_url=self.config["backend_url"],
|
model=self.config["deep_think_llm"],
|
||||||
|
base_url=self.config["backend_url"],
|
||||||
)
|
)
|
||||||
self.quick_thinking_llm = ChatOpenAI(
|
self.quick_thinking_llm = ChatOpenAI(
|
||||||
model=self.config["quick_think_llm"],
|
model=self.config["quick_think_llm"],
|
||||||
|
|
@ -67,7 +68,8 @@ class TradingAgentsGraph:
|
||||||
)
|
)
|
||||||
elif self.config["llm_provider"].lower() == "anthropic":
|
elif self.config["llm_provider"].lower() == "anthropic":
|
||||||
self.deep_thinking_llm = ChatAnthropic(
|
self.deep_thinking_llm = ChatAnthropic(
|
||||||
model=self.config["deep_think_llm"], base_url=self.config["backend_url"],
|
model=self.config["deep_think_llm"],
|
||||||
|
base_url=self.config["backend_url"],
|
||||||
)
|
)
|
||||||
self.quick_thinking_llm = ChatAnthropic(
|
self.quick_thinking_llm = ChatAnthropic(
|
||||||
model=self.config["quick_think_llm"],
|
model=self.config["quick_think_llm"],
|
||||||
|
|
@ -91,10 +93,12 @@ class TradingAgentsGraph:
|
||||||
self.bear_memory = FinancialSituationMemory("bear_memory", self.config)
|
self.bear_memory = FinancialSituationMemory("bear_memory", self.config)
|
||||||
self.trader_memory = FinancialSituationMemory("trader_memory", self.config)
|
self.trader_memory = FinancialSituationMemory("trader_memory", self.config)
|
||||||
self.invest_judge_memory = FinancialSituationMemory(
|
self.invest_judge_memory = FinancialSituationMemory(
|
||||||
"invest_judge_memory", self.config,
|
"invest_judge_memory",
|
||||||
|
self.config,
|
||||||
)
|
)
|
||||||
self.risk_manager_memory = FinancialSituationMemory(
|
self.risk_manager_memory = FinancialSituationMemory(
|
||||||
"risk_manager_memory", self.config,
|
"risk_manager_memory",
|
||||||
|
self.config,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create tool nodes
|
# Create tool nodes
|
||||||
|
|
@ -179,7 +183,8 @@ class TradingAgentsGraph:
|
||||||
|
|
||||||
# Initialize state
|
# Initialize state
|
||||||
init_agent_state = self.propagator.create_initial_state(
|
init_agent_state = self.propagator.create_initial_state(
|
||||||
company_name, trade_date,
|
company_name,
|
||||||
|
trade_date,
|
||||||
)
|
)
|
||||||
args = self.propagator.get_graph_args()
|
args = self.propagator.get_graph_args()
|
||||||
|
|
||||||
|
|
@ -252,19 +257,29 @@ class TradingAgentsGraph:
|
||||||
def reflect_and_remember(self, returns_losses):
|
def reflect_and_remember(self, returns_losses):
|
||||||
"""Reflect on decisions and update memory based on returns."""
|
"""Reflect on decisions and update memory based on returns."""
|
||||||
self.reflector.reflect_bull_researcher(
|
self.reflector.reflect_bull_researcher(
|
||||||
self.curr_state, returns_losses, self.bull_memory,
|
self.curr_state,
|
||||||
|
returns_losses,
|
||||||
|
self.bull_memory,
|
||||||
)
|
)
|
||||||
self.reflector.reflect_bear_researcher(
|
self.reflector.reflect_bear_researcher(
|
||||||
self.curr_state, returns_losses, self.bear_memory,
|
self.curr_state,
|
||||||
|
returns_losses,
|
||||||
|
self.bear_memory,
|
||||||
)
|
)
|
||||||
self.reflector.reflect_trader(
|
self.reflector.reflect_trader(
|
||||||
self.curr_state, returns_losses, self.trader_memory,
|
self.curr_state,
|
||||||
|
returns_losses,
|
||||||
|
self.trader_memory,
|
||||||
)
|
)
|
||||||
self.reflector.reflect_invest_judge(
|
self.reflector.reflect_invest_judge(
|
||||||
self.curr_state, returns_losses, self.invest_judge_memory,
|
self.curr_state,
|
||||||
|
returns_losses,
|
||||||
|
self.invest_judge_memory,
|
||||||
)
|
)
|
||||||
self.reflector.reflect_risk_manager(
|
self.reflector.reflect_risk_manager(
|
||||||
self.curr_state, returns_losses, self.risk_manager_memory,
|
self.curr_state,
|
||||||
|
returns_losses,
|
||||||
|
self.risk_manager_memory,
|
||||||
)
|
)
|
||||||
|
|
||||||
def process_signal(self, full_signal):
|
def process_signal(self, full_signal):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue