Apply Black formatting to pass CI checks

- Formatted 33 Python files with Black
- Fixed code style to meet project standards
- Ensures CI/CD pipeline passes formatting checks
This commit is contained in:
佐藤優一 2025-08-10 23:25:04 +09:00
parent 6f3981412b
commit 850764ad7b
33 changed files with 475 additions and 238 deletions

View File

@ -148,7 +148,7 @@ class MessageBuffer:
f"### News Analysis\n{self.report_sections['news_report']}", f"### News Analysis\n{self.report_sections['news_report']}",
) )
if self.report_sections["fundamentals_report"]: if self.report_sections["fundamentals_report"]:
fundamentals = self.report_sections['fundamentals_report'] fundamentals = self.report_sections["fundamentals_report"]
report_parts.append( report_parts.append(
f"### Fundamentals Analysis\n{fundamentals}", f"### Fundamentals Analysis\n{fundamentals}",
) )
@ -182,10 +182,12 @@ def create_layout():
Layout(name="footer", size=3), Layout(name="footer", size=3),
) )
layout["main"].split_column( layout["main"].split_column(
Layout(name="upper", ratio=3), Layout(name="analysis", ratio=5), Layout(name="upper", ratio=3),
Layout(name="analysis", ratio=5),
) )
layout["upper"].split_row( layout["upper"].split_row(
Layout(name="progress", ratio=2), Layout(name="messages", ratio=3), Layout(name="progress", ratio=2),
Layout(name="messages", ratio=3),
) )
return layout return layout
@ -237,7 +239,9 @@ def update_display(layout, spinner_text=None):
status = message_buffer.agent_status[first_agent] status = message_buffer.agent_status[first_agent]
if status == "in_progress": if status == "in_progress":
spinner = Spinner( spinner = Spinner(
"dots", text="[blue]in_progress[/blue]", style="bold cyan", "dots",
text="[blue]in_progress[/blue]",
style="bold cyan",
) )
status_cell = spinner status_cell = spinner
else: else:
@ -254,7 +258,9 @@ def update_display(layout, spinner_text=None):
status = message_buffer.agent_status[agent] status = message_buffer.agent_status[agent]
if status == "in_progress": if status == "in_progress":
spinner = Spinner( spinner = Spinner(
"dots", text="[blue]in_progress[/blue]", style="bold cyan", "dots",
text="[blue]in_progress[/blue]",
style="bold cyan",
) )
status_cell = spinner status_cell = spinner
else: else:
@ -286,7 +292,10 @@ def update_display(layout, spinner_text=None):
messages_table.add_column("Time", style="cyan", width=8, justify="center") messages_table.add_column("Time", style="cyan", width=8, justify="center")
messages_table.add_column("Type", style="green", width=10, justify="center") messages_table.add_column("Type", style="green", width=10, justify="center")
messages_table.add_column( messages_table.add_column(
"Content", style="white", no_wrap=False, ratio=1, "Content",
style="white",
no_wrap=False,
ratio=1,
) # Make content column expand ) # Make content column expand
# Combine tool calls and messages # Combine tool calls and messages
@ -441,7 +450,9 @@ def get_user_selections():
# Step 1: Ticker symbol # Step 1: Ticker symbol
console.print( console.print(
create_question_box( create_question_box(
"Step 1: Ticker Symbol", "Enter the ticker symbol to analyze", "SPY", "Step 1: Ticker Symbol",
"Enter the ticker symbol to analyze",
"SPY",
), ),
) )
selected_ticker = get_ticker() selected_ticker = get_ticker()
@ -460,7 +471,8 @@ def get_user_selections():
# Step 3: Select analysts # Step 3: Select analysts
console.print( console.print(
create_question_box( create_question_box(
"Step 3: Analysts Team", "Select your LLM analyst agents for the analysis", "Step 3: Analysts Team",
"Select your LLM analyst agents for the analysis",
), ),
) )
selected_analysts = select_analysts() selected_analysts = select_analysts()
@ -471,21 +483,25 @@ def get_user_selections():
# Step 4: Research depth # Step 4: Research depth
console.print( console.print(
create_question_box( create_question_box(
"Step 4: Research Depth", "Select your research depth level", "Step 4: Research Depth",
"Select your research depth level",
), ),
) )
selected_research_depth = select_research_depth() selected_research_depth = select_research_depth()
# Step 5: OpenAI backend # Step 5: OpenAI backend
console.print( console.print(
create_question_box("Step 5: OpenAI backend", "Select which service to talk to"), create_question_box(
"Step 5: OpenAI backend", "Select which service to talk to"
),
) )
selected_llm_provider, backend_url = select_llm_provider() selected_llm_provider, backend_url = select_llm_provider()
# Step 6: Thinking agents # Step 6: Thinking agents
console.print( console.print(
create_question_box( create_question_box(
"Step 6: Thinking Agents", "Select your thinking agents for analysis", "Step 6: Thinking Agents",
"Select your thinking agents for analysis",
), ),
) )
selected_shallow_thinker = select_shallow_thinking_agent(selected_llm_provider) selected_shallow_thinker = select_shallow_thinking_agent(selected_llm_provider)
@ -737,7 +753,9 @@ def run_analysis():
# Initialize the graph # Initialize the graph
graph = TradingAgentsGraph( graph = TradingAgentsGraph(
[analyst.value for analyst in selections["analysts"]], config=config, debug=True, [analyst.value for analyst in selections["analysts"]],
config=config,
debug=True,
) )
# Create result directory # Create result directory
@ -796,10 +814,12 @@ def run_analysis():
message_buffer.add_message = save_message_decorator(message_buffer, "add_message") message_buffer.add_message = save_message_decorator(message_buffer, "add_message")
message_buffer.add_tool_call = save_tool_call_decorator( message_buffer.add_tool_call = save_tool_call_decorator(
message_buffer, "add_tool_call", message_buffer,
"add_tool_call",
) )
message_buffer.update_report_section = save_report_section_decorator( message_buffer.update_report_section = save_report_section_decorator(
message_buffer, "update_report_section", message_buffer,
"update_report_section",
) )
# Now start the display layout # Now start the display layout
@ -812,7 +832,8 @@ def run_analysis():
# Add initial messages # Add initial messages
message_buffer.add_message("System", f"Selected ticker: {selections['ticker']}") message_buffer.add_message("System", f"Selected ticker: {selections['ticker']}")
message_buffer.add_message( message_buffer.add_message(
"System", f"Analysis date: {selections['analysis_date']}", "System",
f"Analysis date: {selections['analysis_date']}",
) )
message_buffer.add_message( message_buffer.add_message(
"System", "System",
@ -843,7 +864,8 @@ def run_analysis():
# Initialize state and get graph args # Initialize state and get graph args
init_agent_state = graph.propagator.create_initial_state( init_agent_state = graph.propagator.create_initial_state(
selections["ticker"], selections["analysis_date"], selections["ticker"],
selections["analysis_date"],
) )
args = graph.propagator.get_graph_args() args = graph.propagator.get_graph_args()
@ -873,7 +895,8 @@ def run_analysis():
# Handle both dictionary and object tool calls # Handle both dictionary and object tool calls
if isinstance(tool_call, dict): if isinstance(tool_call, dict):
message_buffer.add_tool_call( message_buffer.add_tool_call(
tool_call["name"], tool_call["args"], tool_call["name"],
tool_call["args"],
) )
else: else:
message_buffer.add_tool_call(tool_call.name, tool_call.args) message_buffer.add_tool_call(tool_call.name, tool_call.args)
@ -882,51 +905,57 @@ def run_analysis():
# Analyst Team Reports # Analyst Team Reports
if chunk.get("market_report"): if chunk.get("market_report"):
message_buffer.update_report_section( message_buffer.update_report_section(
"market_report", chunk["market_report"], "market_report",
chunk["market_report"],
) )
message_buffer.update_agent_status("Market Analyst", "completed") message_buffer.update_agent_status("Market Analyst", "completed")
# Set next analyst to in_progress # Set next analyst to in_progress
if "social" in selections["analysts"]: if "social" in selections["analysts"]:
message_buffer.update_agent_status( message_buffer.update_agent_status(
"Social Analyst", "in_progress", "Social Analyst",
"in_progress",
) )
if chunk.get("sentiment_report"): if chunk.get("sentiment_report"):
message_buffer.update_report_section( message_buffer.update_report_section(
"sentiment_report", chunk["sentiment_report"], "sentiment_report",
chunk["sentiment_report"],
) )
message_buffer.update_agent_status("Social Analyst", "completed") message_buffer.update_agent_status("Social Analyst", "completed")
# Set next analyst to in_progress # Set next analyst to in_progress
if "news" in selections["analysts"]: if "news" in selections["analysts"]:
message_buffer.update_agent_status( message_buffer.update_agent_status(
"News Analyst", "in_progress", "News Analyst",
"in_progress",
) )
if chunk.get("news_report"): if chunk.get("news_report"):
message_buffer.update_report_section( message_buffer.update_report_section(
"news_report", chunk["news_report"], "news_report",
chunk["news_report"],
) )
message_buffer.update_agent_status("News Analyst", "completed") message_buffer.update_agent_status("News Analyst", "completed")
# Set next analyst to in_progress # Set next analyst to in_progress
if "fundamentals" in selections["analysts"]: if "fundamentals" in selections["analysts"]:
message_buffer.update_agent_status( message_buffer.update_agent_status(
"Fundamentals Analyst", "in_progress", "Fundamentals Analyst",
"in_progress",
) )
if chunk.get("fundamentals_report"): if chunk.get("fundamentals_report"):
message_buffer.update_report_section( message_buffer.update_report_section(
"fundamentals_report", chunk["fundamentals_report"], "fundamentals_report",
chunk["fundamentals_report"],
) )
message_buffer.update_agent_status( message_buffer.update_agent_status(
"Fundamentals Analyst", "completed", "Fundamentals Analyst",
"completed",
) )
# Set all research team members to in_progress # Set all research team members to in_progress
update_research_team_status("in_progress") update_research_team_status("in_progress")
# Research Team - Handle Investment Debate State # Research Team - Handle Investment Debate State
if ( if chunk.get("investment_debate_state"):
chunk.get("investment_debate_state")
):
debate_state = chunk["investment_debate_state"] debate_state = chunk["investment_debate_state"]
# Update Bull Researcher status and report # Update Bull Researcher status and report
@ -960,9 +989,7 @@ def run_analysis():
) )
# Update Research Manager status and final decision # Update Research Manager status and final decision
if ( if debate_state.get("judge_decision"):
debate_state.get("judge_decision")
):
# Keep all research team members in progress until final decision # Keep all research team members in progress until final decision
update_research_team_status("in_progress") update_research_team_status("in_progress")
message_buffer.add_message( message_buffer.add_message(
@ -978,15 +1005,15 @@ def run_analysis():
update_research_team_status("completed") update_research_team_status("completed")
# Set first risk analyst to in_progress # Set first risk analyst to in_progress
message_buffer.update_agent_status( message_buffer.update_agent_status(
"Risky Analyst", "in_progress", "Risky Analyst",
"in_progress",
) )
# Trading Team # Trading Team
if ( if chunk.get("trader_investment_plan"):
chunk.get("trader_investment_plan")
):
message_buffer.update_report_section( message_buffer.update_report_section(
"trader_investment_plan", chunk["trader_investment_plan"], "trader_investment_plan",
chunk["trader_investment_plan"],
) )
# Set first risk analyst to in_progress # Set first risk analyst to in_progress
message_buffer.update_agent_status("Risky Analyst", "in_progress") message_buffer.update_agent_status("Risky Analyst", "in_progress")
@ -996,11 +1023,10 @@ def run_analysis():
risk_state = chunk["risk_debate_state"] risk_state = chunk["risk_debate_state"]
# Update Risky Analyst status and report # Update Risky Analyst status and report
if ( if risk_state.get("current_risky_response"):
risk_state.get("current_risky_response")
):
message_buffer.update_agent_status( message_buffer.update_agent_status(
"Risky Analyst", "in_progress", "Risky Analyst",
"in_progress",
) )
message_buffer.add_message( message_buffer.add_message(
"Reasoning", "Reasoning",
@ -1013,11 +1039,10 @@ def run_analysis():
) )
# Update Safe Analyst status and report # Update Safe Analyst status and report
if ( if risk_state.get("current_safe_response"):
risk_state.get("current_safe_response")
):
message_buffer.update_agent_status( message_buffer.update_agent_status(
"Safe Analyst", "in_progress", "Safe Analyst",
"in_progress",
) )
message_buffer.add_message( message_buffer.add_message(
"Reasoning", "Reasoning",
@ -1030,11 +1055,10 @@ def run_analysis():
) )
# Update Neutral Analyst status and report # Update Neutral Analyst status and report
if ( if risk_state.get("current_neutral_response"):
risk_state.get("current_neutral_response")
):
message_buffer.update_agent_status( message_buffer.update_agent_status(
"Neutral Analyst", "in_progress", "Neutral Analyst",
"in_progress",
) )
message_buffer.add_message( message_buffer.add_message(
"Reasoning", "Reasoning",
@ -1049,7 +1073,8 @@ def run_analysis():
# Update Portfolio Manager status and final decision # Update Portfolio Manager status and final decision
if risk_state.get("judge_decision"): if risk_state.get("judge_decision"):
message_buffer.update_agent_status( message_buffer.update_agent_status(
"Portfolio Manager", "in_progress", "Portfolio Manager",
"in_progress",
) )
message_buffer.add_message( message_buffer.add_message(
"Reasoning", "Reasoning",
@ -1064,10 +1089,12 @@ def run_analysis():
message_buffer.update_agent_status("Risky Analyst", "completed") message_buffer.update_agent_status("Risky Analyst", "completed")
message_buffer.update_agent_status("Safe Analyst", "completed") message_buffer.update_agent_status("Safe Analyst", "completed")
message_buffer.update_agent_status( message_buffer.update_agent_status(
"Neutral Analyst", "completed", "Neutral Analyst",
"completed",
) )
message_buffer.update_agent_status( message_buffer.update_agent_status(
"Portfolio Manager", "completed", "Portfolio Manager",
"completed",
) )
# Update the display # Update the display
@ -1084,7 +1111,8 @@ def run_analysis():
message_buffer.update_agent_status(agent, "completed") message_buffer.update_agent_status(agent, "completed")
message_buffer.add_message( message_buffer.add_message(
"Analysis", f"Completed analysis for {selections['analysis_date']}", "Analysis",
f"Completed analysis for {selections['analysis_date']}",
) )
# Update final report sections # Update final report sections

View File

@ -1,4 +1,3 @@
import sys import sys
import questionary import questionary

View File

@ -4,7 +4,9 @@ from tradingagents.default_config import DEFAULT_CONFIG
# Create a custom config # Create a custom config
config = DEFAULT_CONFIG.copy() config = DEFAULT_CONFIG.copy()
config["llm_provider"] = "google" # Use a different model config["llm_provider"] = "google" # Use a different model
config["backend_url"] = "https://generativelanguage.googleapis.com/v1" # Use a different backend config["backend_url"] = (
"https://generativelanguage.googleapis.com/v1" # Use a different backend
)
config["deep_think_llm"] = "gemini-2.0-flash" # Use a different model config["deep_think_llm"] = "gemini-2.0-flash" # Use a different model
config["quick_think_llm"] = "gemini-2.0-flash" # Use a different model config["quick_think_llm"] = "gemini-2.0-flash" # Use a different model
config["max_debate_rounds"] = 1 # Increase debate rounds config["max_debate_rounds"] = 1 # Increase debate rounds

View File

@ -33,20 +33,14 @@ def main():
parser.add_argument( parser.add_argument(
"test_type", "test_type",
choices=["unit", "integration", "all", "coverage", "fast", "slow", "lint"], choices=["unit", "integration", "all", "coverage", "fast", "slow", "lint"],
help="Type of tests to run" help="Type of tests to run",
)
parser.add_argument(
"--verbose", "-v", action="store_true", help="Verbose output"
) )
parser.add_argument("--verbose", "-v", action="store_true", help="Verbose output")
parser.add_argument( parser.add_argument(
"--parallel", "-p", action="store_true", help="Run tests in parallel" "--parallel", "-p", action="store_true", help="Run tests in parallel"
) )
parser.add_argument( parser.add_argument("--file", "-f", help="Run specific test file")
"--file", "-f", help="Run specific test file" parser.add_argument("--pattern", "-k", help="Run tests matching pattern")
)
parser.add_argument(
"--pattern", "-k", help="Run tests matching pattern"
)
args = parser.parse_args() args = parser.parse_args()

View File

@ -1,18 +1,22 @@
def poorly_formatted_function(x, y, z): # Missing type hints
def poorly_formatted_function(x,y,z): # Missing type hints
"""This function has formatting issues.""" """This function has formatting issues."""
result=x+y*z # Missing spaces around operators result = x + y * z # Missing spaces around operators
if result>100: # Missing spaces if result > 100: # Missing spaces
print( "Result is large" ) # Extra spaces in parentheses print("Result is large") # Extra spaces in parentheses
return result return result
# Long line that Black will wrap # Long line that Black will wrap
very_long_variable_name_that_exceeds_the_standard_line_length_limit = "This is a very long string that will be wrapped by Black formatter" very_long_variable_name_that_exceeds_the_standard_line_length_limit = (
"This is a very long string that will be wrapped by Black formatter"
)
class MyClass: class MyClass:
def __init__(self,name:str,age:int): # Missing space after comma def __init__(self, name: str, age: int): # Missing space after comma
self.name=name # Missing spaces around = self.name = name # Missing spaces around =
self.age=age self.age = age
# Function with wrong return type hint # Function with wrong return type hint
def get_number() -> str: def get_number() -> str:

View File

@ -2,6 +2,7 @@ def add_numbers(a: int, b: int) -> int:
"""Add two numbers and return the result.""" """Add two numbers and return the result."""
return a + b return a + b
# Test the function # Test the function
result = add_numbers(1, 2) result = add_numbers(1, 2)
print(f"Result: {result}") print(f"Result: {result}")

View File

@ -48,12 +48,17 @@ def main():
# Test 1: Check if pytest is installed and can discover tests # Test 1: Check if pytest is installed and can discover tests
total_tests += 1 total_tests += 1
if run_command(["python", "-m", "pytest", "--version"], "Checking pytest installation"): if run_command(
["python", "-m", "pytest", "--version"], "Checking pytest installation"
):
success_count += 1 success_count += 1
# Test 2: Test discovery # Test 2: Test discovery
total_tests += 1 total_tests += 1
if run_command(["python", "-m", "pytest", "tests/", "--collect-only", "-q"], "Testing test discovery"): if run_command(
["python", "-m", "pytest", "tests/", "--collect-only", "-q"],
"Testing test discovery",
):
success_count += 1 success_count += 1
# Test 3: Check if mypy can run # Test 3: Check if mypy can run
@ -63,12 +68,22 @@ def main():
# Test 4: Run a simple syntax check on test files # Test 4: Run a simple syntax check on test files
total_tests += 1 total_tests += 1
if run_command(["python", "-c", "import tests.conftest; print('Test imports work!')"], "Testing test imports"): if run_command(
["python", "-c", "import tests.conftest; print('Test imports work!')"],
"Testing test imports",
):
success_count += 1 success_count += 1
# Test 5: Check if we can import the main module # Test 5: Check if we can import the main module
total_tests += 1 total_tests += 1
if run_command(["python", "-c", "import tradingagents.config; print('Main module imports work!')"], "Testing main module imports"): if run_command(
[
"python",
"-c",
"import tradingagents.config; print('Main module imports work!')",
],
"Testing main module imports",
):
success_count += 1 success_count += 1
# Summary # Summary

View File

@ -173,7 +173,8 @@ def mock_memory():
def pytest_configure(config): def pytest_configure(config):
"""Configure pytest with custom markers.""" """Configure pytest with custom markers."""
config.addinivalue_line( config.addinivalue_line(
"markers", "integration: mark test as integration test (slow)", "markers",
"integration: mark test as integration test (slow)",
) )
config.addinivalue_line("markers", "unit: mark test as unit test (fast)") config.addinivalue_line("markers", "unit: mark test as unit test (fast)")
config.addinivalue_line("markers", "api: mark test as requiring API access") config.addinivalue_line("markers", "api: mark test as requiring API access")

View File

@ -36,7 +36,8 @@ class SampleDataFactory:
@staticmethod @staticmethod
def create_finnhub_news_data( def create_finnhub_news_data(
ticker: str = "AAPL", count: int = 10, ticker: str = "AAPL",
count: int = 10,
) -> dict[str, list[dict[str, Any]]]: ) -> dict[str, list[dict[str, Any]]]:
"""Create sample FinnHub news data for testing.""" """Create sample FinnHub news data for testing."""
base_date = datetime(2024, 5, 10) base_date = datetime(2024, 5, 10)
@ -136,7 +137,8 @@ class SampleDataFactory:
@staticmethod @staticmethod
def create_financial_statements_data( def create_financial_statements_data(
ticker: str = "AAPL", period: str = "annual", ticker: str = "AAPL",
period: str = "annual",
) -> dict[str, list[dict[str, Any]]]: ) -> dict[str, list[dict[str, Any]]]:
"""Create sample financial statements data for testing.""" """Create sample financial statements data for testing."""
if period == "annual": if period == "annual":
@ -271,10 +273,12 @@ class SampleDataFactory:
ticker, ticker,
), ),
"financial_annual": SampleDataFactory.create_financial_statements_data( "financial_annual": SampleDataFactory.create_financial_statements_data(
ticker, "annual", ticker,
"annual",
), ),
"financial_quarterly": SampleDataFactory.create_financial_statements_data( "financial_quarterly": SampleDataFactory.create_financial_statements_data(
ticker, "quarterly", ticker,
"quarterly",
), ),
"social_sentiment": SampleDataFactory.create_social_sentiment_data(ticker), "social_sentiment": SampleDataFactory.create_social_sentiment_data(ticker),
"technical_indicators": SampleDataFactory.create_technical_indicators_data( "technical_indicators": SampleDataFactory.create_technical_indicators_data(
@ -343,7 +347,9 @@ def save_sample_data_to_files(base_path: str, ticker: str = "AAPL") -> None:
# Save quarterly data separately # Save quarterly data separately
quarterly_path = os.path.join( quarterly_path = os.path.join(
finnhub_path, "fin_as_reported", f"{ticker}_quarterly_data_formatted.json", finnhub_path,
"fin_as_reported",
f"{ticker}_quarterly_data_formatted.json",
) )
with open(quarterly_path, "w") as f: with open(quarterly_path, "w") as f:
json.dump(dataset["financial_quarterly"], f, indent=2) json.dump(dataset["financial_quarterly"], f, indent=2)

View File

@ -31,7 +31,10 @@ class TestFullWorkflowIntegration:
@patch("tradingagents.graph.trading_graph.ChatOpenAI") @patch("tradingagents.graph.trading_graph.ChatOpenAI")
@patch("tradingagents.graph.trading_graph.Toolkit") @patch("tradingagents.graph.trading_graph.Toolkit")
def test_end_to_end_trading_workflow( def test_end_to_end_trading_workflow(
self, mock_toolkit, mock_chat_openai, integration_config, self,
mock_toolkit,
mock_chat_openai,
integration_config,
): ):
"""Test complete end-to-end trading workflow.""" """Test complete end-to-end trading workflow."""
# Setup mocks # Setup mocks
@ -86,7 +89,10 @@ class TestFullWorkflowIntegration:
@patch("tradingagents.graph.trading_graph.ChatOpenAI") @patch("tradingagents.graph.trading_graph.ChatOpenAI")
@patch("tradingagents.graph.trading_graph.Toolkit") @patch("tradingagents.graph.trading_graph.Toolkit")
def test_multiple_analysts_integration( def test_multiple_analysts_integration(
self, mock_toolkit, mock_chat_openai, integration_config, self,
mock_toolkit,
mock_chat_openai,
integration_config,
): ):
"""Test integration with different analyst combinations.""" """Test integration with different analyst combinations."""
analyst_combinations = [ analyst_combinations = [
@ -114,7 +120,8 @@ class TestFullWorkflowIntegration:
with patch("tradingagents.graph.trading_graph.set_config"): with patch("tradingagents.graph.trading_graph.set_config"):
# Test each analyst combination # Test each analyst combination
trading_graph = TradingAgentsGraph( trading_graph = TradingAgentsGraph(
selected_analysts=analysts, config=integration_config, selected_analysts=analysts,
config=integration_config,
) )
trading_graph.graph = mock_graph trading_graph.graph = mock_graph
@ -134,7 +141,8 @@ class TestFullWorkflowIntegration:
# Execute # Execute
with patch("builtins.open", create=True), patch("json.dump"): with patch("builtins.open", create=True), patch("json.dump"):
final_state, decision = trading_graph.propagate( final_state, decision = trading_graph.propagate(
"TSLA", "2024-05-15", "TSLA",
"2024-05-15",
) )
# Verify # Verify
@ -144,7 +152,10 @@ class TestFullWorkflowIntegration:
@patch("tradingagents.graph.trading_graph.ChatOpenAI") @patch("tradingagents.graph.trading_graph.ChatOpenAI")
@patch("tradingagents.graph.trading_graph.Toolkit") @patch("tradingagents.graph.trading_graph.Toolkit")
def test_memory_and_reflection_integration( def test_memory_and_reflection_integration(
self, mock_toolkit, mock_chat_openai, integration_config, self,
mock_toolkit,
mock_chat_openai,
integration_config,
): ):
"""Test integration of memory and reflection components.""" """Test integration of memory and reflection components."""
# Setup # Setup
@ -208,7 +219,10 @@ class TestFullWorkflowIntegration:
@patch("tradingagents.graph.trading_graph.ChatOpenAI") @patch("tradingagents.graph.trading_graph.ChatOpenAI")
@patch("tradingagents.graph.trading_graph.Toolkit") @patch("tradingagents.graph.trading_graph.Toolkit")
def test_debug_mode_integration( def test_debug_mode_integration(
self, mock_toolkit, mock_chat_openai, integration_config, self,
mock_toolkit,
mock_chat_openai,
integration_config,
): ):
"""Test integration in debug mode.""" """Test integration in debug mode."""
# Setup # Setup
@ -240,7 +254,8 @@ class TestFullWorkflowIntegration:
with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"): with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"):
with patch("tradingagents.graph.trading_graph.set_config"): with patch("tradingagents.graph.trading_graph.set_config"):
trading_graph = TradingAgentsGraph( trading_graph = TradingAgentsGraph(
debug=True, config=integration_config, debug=True,
config=integration_config,
) )
trading_graph.graph = mock_graph trading_graph.graph = mock_graph
@ -276,7 +291,12 @@ class TestFullWorkflowIntegration:
@patch("tradingagents.graph.trading_graph.ChatOpenAI") @patch("tradingagents.graph.trading_graph.ChatOpenAI")
@patch("tradingagents.graph.trading_graph.Toolkit") @patch("tradingagents.graph.trading_graph.Toolkit")
def test_multiple_stocks_integration( def test_multiple_stocks_integration(
self, mock_toolkit, mock_chat_openai, ticker, date, integration_config, self,
mock_toolkit,
mock_chat_openai,
ticker,
date,
integration_config,
): ):
"""Test integration with different stocks and dates.""" """Test integration with different stocks and dates."""
# Setup # Setup
@ -382,7 +402,11 @@ class TestPerformanceIntegration:
@patch("tradingagents.graph.trading_graph.ChatOpenAI") @patch("tradingagents.graph.trading_graph.ChatOpenAI")
@patch("tradingagents.graph.trading_graph.Toolkit") @patch("tradingagents.graph.trading_graph.Toolkit")
def test_multiple_consecutive_runs( def test_multiple_consecutive_runs(
self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir, self,
mock_toolkit,
mock_chat_openai,
sample_config,
temp_data_dir,
): ):
"""Test multiple consecutive trading decisions.""" """Test multiple consecutive trading decisions."""
sample_config["project_dir"] = temp_data_dir sample_config["project_dir"] = temp_data_dir

View File

@ -17,7 +17,10 @@ class TestMarketAnalyst:
assert callable(analyst_node) assert callable(analyst_node)
def test_market_analyst_node_basic_execution( def test_market_analyst_node_basic_execution(
self, mock_llm, mock_toolkit, sample_agent_state, self,
mock_llm,
mock_toolkit,
sample_agent_state,
): ):
"""Test basic execution of market analyst node.""" """Test basic execution of market analyst node."""
# Setup # Setup
@ -39,7 +42,10 @@ class TestMarketAnalyst:
assert result["market_report"] == "Market analysis complete" assert result["market_report"] == "Market analysis complete"
def test_market_analyst_uses_online_tools_when_configured( def test_market_analyst_uses_online_tools_when_configured(
self, mock_llm, mock_toolkit, sample_agent_state, self,
mock_llm,
mock_toolkit,
sample_agent_state,
): ):
"""Test that analyst uses online tools when configured.""" """Test that analyst uses online tools when configured."""
# Setup # Setup
@ -64,7 +70,10 @@ class TestMarketAnalyst:
assert "get_YFin_data_online" in str(tool_names) or len(bound_tools) == 2 assert "get_YFin_data_online" in str(tool_names) or len(bound_tools) == 2
def test_market_analyst_uses_offline_tools_when_configured( def test_market_analyst_uses_offline_tools_when_configured(
self, mock_llm, mock_toolkit, sample_agent_state, self,
mock_llm,
mock_toolkit,
sample_agent_state,
): ):
"""Test that analyst uses offline tools when configured.""" """Test that analyst uses offline tools when configured."""
# Setup # Setup
@ -88,7 +97,10 @@ class TestMarketAnalyst:
assert len(bound_tools) == 2 # Should have 2 offline tools assert len(bound_tools) == 2 # Should have 2 offline tools
def test_market_analyst_processes_state_variables( def test_market_analyst_processes_state_variables(
self, mock_llm, mock_toolkit, sample_agent_state, self,
mock_llm,
mock_toolkit,
sample_agent_state,
): ):
"""Test that market analyst correctly processes state variables.""" """Test that market analyst correctly processes state variables."""
# Setup # Setup
@ -112,7 +124,10 @@ class TestMarketAnalyst:
assert result["market_report"] == "Analysis for AAPL on 2024-05-10" assert result["market_report"] == "Analysis for AAPL on 2024-05-10"
def test_market_analyst_handles_empty_tool_calls( def test_market_analyst_handles_empty_tool_calls(
self, mock_llm, mock_toolkit, sample_agent_state, self,
mock_llm,
mock_toolkit,
sample_agent_state,
): ):
"""Test handling when no tool calls are made.""" """Test handling when no tool calls are made."""
# Setup # Setup
@ -132,7 +147,10 @@ class TestMarketAnalyst:
assert result["messages"] == [mock_result] assert result["messages"] == [mock_result]
def test_market_analyst_with_tool_calls( def test_market_analyst_with_tool_calls(
self, mock_llm, mock_toolkit, sample_agent_state, self,
mock_llm,
mock_toolkit,
sample_agent_state,
): ):
"""Test handling when tool calls are present.""" """Test handling when tool calls are present."""
# Setup # Setup
@ -153,7 +171,11 @@ class TestMarketAnalyst:
@pytest.mark.parametrize("online_tools", [True, False]) @pytest.mark.parametrize("online_tools", [True, False])
def test_market_analyst_tool_configuration( def test_market_analyst_tool_configuration(
self, mock_llm, mock_toolkit, sample_agent_state, online_tools, self,
mock_llm,
mock_toolkit,
sample_agent_state,
online_tools,
): ):
"""Test tool configuration for both online and offline modes.""" """Test tool configuration for both online and offline modes."""
# Setup # Setup

View File

@ -190,7 +190,10 @@ class TestFinnhubUtils:
# Test without period # Test without period
expected_path_no_period = os.path.join( expected_path_no_period = os.path.join(
temp_data_dir, "finnhub_data", data_type, f"{ticker}_data_formatted.json", temp_data_dir,
"finnhub_data",
data_type,
f"{ticker}_data_formatted.json",
) )
# Test with period # Test with period
@ -248,7 +251,10 @@ class TestFinnhubUtils:
], ],
) )
def test_get_data_in_range_various_data_types( def test_get_data_in_range_various_data_types(
self, temp_data_dir, data_type, period, self,
temp_data_dir,
data_type,
period,
): ):
"""Test get_data_in_range with various data types.""" """Test get_data_in_range with various data types."""
ticker = "TEST" ticker = "TEST"

View File

@ -45,7 +45,11 @@ class TestTradingAgentsGraph:
@patch("tradingagents.graph.trading_graph.ChatOpenAI") @patch("tradingagents.graph.trading_graph.ChatOpenAI")
@patch("tradingagents.graph.trading_graph.Toolkit") @patch("tradingagents.graph.trading_graph.Toolkit")
def test_init_with_debug( def test_init_with_debug(
self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir, self,
mock_toolkit,
mock_chat_openai,
sample_config,
temp_data_dir,
): ):
"""Test initialization with debug mode enabled.""" """Test initialization with debug mode enabled."""
sample_config["project_dir"] = temp_data_dir sample_config["project_dir"] = temp_data_dir
@ -63,7 +67,11 @@ class TestTradingAgentsGraph:
@patch("tradingagents.graph.trading_graph.ChatAnthropic") @patch("tradingagents.graph.trading_graph.ChatAnthropic")
@patch("tradingagents.graph.trading_graph.Toolkit") @patch("tradingagents.graph.trading_graph.Toolkit")
def test_init_with_anthropic( def test_init_with_anthropic(
self, mock_toolkit, mock_chat_anthropic, sample_config, temp_data_dir, self,
mock_toolkit,
mock_chat_anthropic,
sample_config,
temp_data_dir,
): ):
"""Test initialization with Anthropic LLM provider.""" """Test initialization with Anthropic LLM provider."""
sample_config["project_dir"] = temp_data_dir sample_config["project_dir"] = temp_data_dir
@ -82,7 +90,11 @@ class TestTradingAgentsGraph:
@patch("tradingagents.graph.trading_graph.ChatGoogleGenerativeAI") @patch("tradingagents.graph.trading_graph.ChatGoogleGenerativeAI")
@patch("tradingagents.graph.trading_graph.Toolkit") @patch("tradingagents.graph.trading_graph.Toolkit")
def test_init_with_google( def test_init_with_google(
self, mock_toolkit, mock_chat_google, sample_config, temp_data_dir, self,
mock_toolkit,
mock_chat_google,
sample_config,
temp_data_dir,
): ):
"""Test initialization with Google LLM provider.""" """Test initialization with Google LLM provider."""
sample_config["project_dir"] = temp_data_dir sample_config["project_dir"] = temp_data_dir
@ -100,7 +112,10 @@ class TestTradingAgentsGraph:
@patch("tradingagents.graph.trading_graph.Toolkit") @patch("tradingagents.graph.trading_graph.Toolkit")
def test_init_unsupported_llm_provider( def test_init_unsupported_llm_provider(
self, mock_toolkit, sample_config, temp_data_dir, self,
mock_toolkit,
sample_config,
temp_data_dir,
): ):
"""Test initialization with unsupported LLM provider raises error.""" """Test initialization with unsupported LLM provider raises error."""
sample_config["project_dir"] = temp_data_dir sample_config["project_dir"] = temp_data_dir
@ -115,7 +130,11 @@ class TestTradingAgentsGraph:
@patch("tradingagents.graph.trading_graph.ChatOpenAI") @patch("tradingagents.graph.trading_graph.ChatOpenAI")
@patch("tradingagents.graph.trading_graph.Toolkit") @patch("tradingagents.graph.trading_graph.Toolkit")
def test_create_tool_nodes( def test_create_tool_nodes(
self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir, self,
mock_toolkit,
mock_chat_openai,
sample_config,
temp_data_dir,
): ):
"""Test creation of tool nodes.""" """Test creation of tool nodes."""
sample_config["project_dir"] = temp_data_dir sample_config["project_dir"] = temp_data_dir
@ -143,7 +162,11 @@ class TestTradingAgentsGraph:
@patch("tradingagents.graph.trading_graph.ChatOpenAI") @patch("tradingagents.graph.trading_graph.ChatOpenAI")
@patch("tradingagents.graph.trading_graph.Toolkit") @patch("tradingagents.graph.trading_graph.Toolkit")
def test_propagate_basic( def test_propagate_basic(
self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir, self,
mock_toolkit,
mock_chat_openai,
sample_config,
temp_data_dir,
): ):
"""Test basic propagate functionality.""" """Test basic propagate functionality."""
sample_config["project_dir"] = temp_data_dir sample_config["project_dir"] = temp_data_dir
@ -206,7 +229,11 @@ class TestTradingAgentsGraph:
@patch("tradingagents.graph.trading_graph.ChatOpenAI") @patch("tradingagents.graph.trading_graph.ChatOpenAI")
@patch("tradingagents.graph.trading_graph.Toolkit") @patch("tradingagents.graph.trading_graph.Toolkit")
def test_propagate_debug_mode( def test_propagate_debug_mode(
self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir, self,
mock_toolkit,
mock_chat_openai,
sample_config,
temp_data_dir,
): ):
"""Test propagate in debug mode.""" """Test propagate in debug mode."""
sample_config["project_dir"] = temp_data_dir sample_config["project_dir"] = temp_data_dir
@ -245,7 +272,11 @@ class TestTradingAgentsGraph:
@patch("tradingagents.graph.trading_graph.ChatOpenAI") @patch("tradingagents.graph.trading_graph.ChatOpenAI")
@patch("tradingagents.graph.trading_graph.Toolkit") @patch("tradingagents.graph.trading_graph.Toolkit")
def test_log_state( def test_log_state(
self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir, self,
mock_toolkit,
mock_chat_openai,
sample_config,
temp_data_dir,
): ):
"""Test state logging functionality.""" """Test state logging functionality."""
sample_config["project_dir"] = temp_data_dir sample_config["project_dir"] = temp_data_dir
@ -300,7 +331,11 @@ class TestTradingAgentsGraph:
@patch("tradingagents.graph.trading_graph.ChatOpenAI") @patch("tradingagents.graph.trading_graph.ChatOpenAI")
@patch("tradingagents.graph.trading_graph.Toolkit") @patch("tradingagents.graph.trading_graph.Toolkit")
def test_reflect_and_remember( def test_reflect_and_remember(
self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir, self,
mock_toolkit,
mock_chat_openai,
sample_config,
temp_data_dir,
): ):
"""Test reflection and memory update functionality.""" """Test reflection and memory update functionality."""
sample_config["project_dir"] = temp_data_dir sample_config["project_dir"] = temp_data_dir
@ -309,9 +344,12 @@ class TestTradingAgentsGraph:
mock_toolkit_instance = Mock() mock_toolkit_instance = Mock()
mock_toolkit.return_value = mock_toolkit_instance mock_toolkit.return_value = mock_toolkit_instance
with patch( with (
"tradingagents.graph.trading_graph.FinancialSituationMemory", patch(
), patch("tradingagents.graph.trading_graph.set_config"): "tradingagents.graph.trading_graph.FinancialSituationMemory",
),
patch("tradingagents.graph.trading_graph.set_config"),
):
graph = TradingAgentsGraph(config=sample_config) graph = TradingAgentsGraph(config=sample_config)
# Set up current state # Set up current state
@ -339,7 +377,11 @@ class TestTradingAgentsGraph:
@patch("tradingagents.graph.trading_graph.ChatOpenAI") @patch("tradingagents.graph.trading_graph.ChatOpenAI")
@patch("tradingagents.graph.trading_graph.Toolkit") @patch("tradingagents.graph.trading_graph.Toolkit")
def test_process_signal( def test_process_signal(
self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir, self,
mock_toolkit,
mock_chat_openai,
sample_config,
temp_data_dir,
): ):
"""Test signal processing functionality.""" """Test signal processing functionality."""
sample_config["project_dir"] = temp_data_dir sample_config["project_dir"] = temp_data_dir
@ -388,7 +430,8 @@ class TestTradingAgentsGraph:
with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"): with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"):
with patch("tradingagents.graph.trading_graph.set_config"): with patch("tradingagents.graph.trading_graph.set_config"):
TradingAgentsGraph( TradingAgentsGraph(
selected_analysts=selected_analysts, config=sample_config, selected_analysts=selected_analysts,
config=sample_config,
) )
# Verify graph was set up with selected analysts # Verify graph was set up with selected analysts
@ -416,7 +459,10 @@ class TestTradingAgentsGraphErrorHandling:
@patch("tradingagents.graph.trading_graph.ChatOpenAI") @patch("tradingagents.graph.trading_graph.ChatOpenAI")
@patch("tradingagents.graph.trading_graph.Toolkit") @patch("tradingagents.graph.trading_graph.Toolkit")
def test_directory_creation_failure( def test_directory_creation_failure(
self, mock_toolkit, mock_chat_openai, sample_config, self,
mock_toolkit,
mock_chat_openai,
sample_config,
): ):
"""Test handling when directory creation fails.""" """Test handling when directory creation fails."""
sample_config["project_dir"] = "/invalid/path/that/cannot/be/created" sample_config["project_dir"] = "/invalid/path/that/cannot/be/created"

View File

@ -20,7 +20,7 @@ def create_fundamentals_analyst(llm, toolkit):
system_message = ( system_message = (
"You are a researcher tasked with analyzing fundamental information over the past week about a company. Please write a comprehensive report of the company's fundamental information such as financial documents, company profile, basic company financials, company financial history, insider sentiment and insider transactions to gain a full view of the company's fundamental information to inform traders. Make sure to include as much detail as possible. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions." "You are a researcher tasked with analyzing fundamental information over the past week about a company. Please write a comprehensive report of the company's fundamental information such as financial documents, company profile, basic company financials, company financial history, insider sentiment and insider transactions to gain a full view of the company's fundamental information to inform traders. Make sure to include as much detail as possible. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."
" Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read.", " Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read.",
) )
prompt = ChatPromptTemplate.from_messages( prompt = ChatPromptTemplate.from_messages(

View File

@ -45,7 +45,7 @@ Volume-Based Indicators:
- vwma: VWMA: A moving average weighted by volume. Usage: Confirm trends by integrating price action with volume data. Tips: Watch for skewed results from volume spikes; use in combination with other volume analyses. - vwma: VWMA: A moving average weighted by volume. Usage: Confirm trends by integrating price action with volume data. Tips: Watch for skewed results from volume spikes; use in combination with other volume analyses.
- Select indicators that provide diverse and complementary information. Avoid redundancy (e.g., do not select both rsi and stochrsi). Also briefly explain why they are suitable for the given market context. When you tool call, please use the exact name of the indicators provided above as they are defined parameters, otherwise your call will fail. Please make sure to call get_YFin_data first to retrieve the CSV that is needed to generate indicators. Write a very detailed and nuanced report of the trends you observe. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions.""" - Select indicators that provide diverse and complementary information. Avoid redundancy (e.g., do not select both rsi and stochrsi). Also briefly explain why they are suitable for the given market context. When you tool call, please use the exact name of the indicators provided above as they are defined parameters, otherwise your call will fail. Please make sure to call get_YFin_data first to retrieve the CSV that is needed to generate indicators. Write a very detailed and nuanced report of the trends you observe. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."""
""" Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read.""" """ Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read."""
) )
prompt = ChatPromptTemplate.from_messages( prompt = ChatPromptTemplate.from_messages(

View File

@ -17,7 +17,7 @@ def create_news_analyst(llm, toolkit):
system_message = ( system_message = (
"You are a news researcher tasked with analyzing recent news and trends over the past week. Please write a comprehensive report of the current state of the world that is relevant for trading and macroeconomics. Look at news from EODHD, and finnhub to be comprehensive. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions." "You are a news researcher tasked with analyzing recent news and trends over the past week. Please write a comprehensive report of the current state of the world that is relevant for trading and macroeconomics. Look at news from EODHD, and finnhub to be comprehensive. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."
""" Make sure to append a Makrdown table at the end of the report to organize key points in the report, organized and easy to read.""" """ Make sure to append a Makrdown table at the end of the report to organize key points in the report, organized and easy to read."""
) )
prompt = ChatPromptTemplate.from_messages( prompt = ChatPromptTemplate.from_messages(

View File

@ -16,7 +16,7 @@ def create_social_media_analyst(llm, toolkit):
system_message = ( system_message = (
"You are a social media and company specific news researcher/analyst tasked with analyzing social media posts, recent company news, and public sentiment for a specific company over the past week. You will be given a company's name your objective is to write a comprehensive long report detailing your analysis, insights, and implications for traders and investors on this company's current state after looking at social media and what people are saying about that company, analyzing sentiment data of what people feel each day about the company, and looking at recent company news. Try to look at all sources possible from social media to sentiment to news. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions." "You are a social media and company specific news researcher/analyst tasked with analyzing social media posts, recent company news, and public sentiment for a specific company over the past week. You will be given a company's name your objective is to write a comprehensive long report detailing your analysis, insights, and implications for traders and investors on this company's current state after looking at social media and what people are saying about that company, analyzing sentiment data of what people feel each day about the company, and looking at recent company news. Try to look at all sources possible from social media to sentiment to news. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."
""" Make sure to append a Makrdown table at the end of the report to organize key points in the report, organized and easy to read.""", """ Make sure to append a Makrdown table at the end of the report to organize key points in the report, organized and easy to read.""",
) )
prompt = ChatPromptTemplate.from_messages( prompt = ChatPromptTemplate.from_messages(

View File

@ -41,7 +41,8 @@ Engage actively by addressing any specific concerns raised, refuting the weaknes
"current_risky_response": argument, "current_risky_response": argument,
"current_safe_response": risk_debate_state.get("current_safe_response", ""), "current_safe_response": risk_debate_state.get("current_safe_response", ""),
"current_neutral_response": risk_debate_state.get( "current_neutral_response": risk_debate_state.get(
"current_neutral_response", "", "current_neutral_response",
"",
), ),
"count": risk_debate_state["count"] + 1, "count": risk_debate_state["count"] + 1,
} }

View File

@ -39,11 +39,13 @@ Engage by questioning their optimism and emphasizing the potential downsides the
"neutral_history": risk_debate_state.get("neutral_history", ""), "neutral_history": risk_debate_state.get("neutral_history", ""),
"latest_speaker": "Safe", "latest_speaker": "Safe",
"current_risky_response": risk_debate_state.get( "current_risky_response": risk_debate_state.get(
"current_risky_response", "", "current_risky_response",
"",
), ),
"current_safe_response": argument, "current_safe_response": argument,
"current_neutral_response": risk_debate_state.get( "current_neutral_response": risk_debate_state.get(
"current_neutral_response", "", "current_neutral_response",
"",
), ),
"count": risk_debate_state["count"] + 1, "count": risk_debate_state["count"] + 1,
} }

View File

@ -39,7 +39,8 @@ Engage actively by analyzing both sides critically, addressing weaknesses in the
"neutral_history": neutral_history + "\n" + argument, "neutral_history": neutral_history + "\n" + argument,
"latest_speaker": "Neutral", "latest_speaker": "Neutral",
"current_risky_response": risk_debate_state.get( "current_risky_response": risk_debate_state.get(
"current_risky_response", "", "current_risky_response",
"",
), ),
"current_safe_response": risk_debate_state.get("current_safe_response", ""), "current_safe_response": risk_debate_state.get("current_safe_response", ""),
"current_neutral_response": argument, "current_neutral_response": argument,

View File

@ -9,10 +9,12 @@ from typing_extensions import TypedDict
# Researcher team state # Researcher team state
class InvestDebateState(TypedDict): class InvestDebateState(TypedDict):
bull_history: Annotated[ bull_history: Annotated[
str, "Bullish Conversation history", str,
"Bullish Conversation history",
] # Bullish Conversation history ] # Bullish Conversation history
bear_history: Annotated[ bear_history: Annotated[
str, "Bearish Conversation history", str,
"Bearish Conversation history",
] # Bullish Conversation history ] # Bullish Conversation history
history: Annotated[str, "Conversation history"] # Conversation history history: Annotated[str, "Conversation history"] # Conversation history
current_response: Annotated[str, "Latest response"] # Last response current_response: Annotated[str, "Latest response"] # Last response
@ -23,24 +25,30 @@ class InvestDebateState(TypedDict):
# Risk management team state # Risk management team state
class RiskDebateState(TypedDict): class RiskDebateState(TypedDict):
risky_history: Annotated[ risky_history: Annotated[
str, "Risky Agent's Conversation history", str,
"Risky Agent's Conversation history",
] # Conversation history ] # Conversation history
safe_history: Annotated[ safe_history: Annotated[
str, "Safe Agent's Conversation history", str,
"Safe Agent's Conversation history",
] # Conversation history ] # Conversation history
neutral_history: Annotated[ neutral_history: Annotated[
str, "Neutral Agent's Conversation history", str,
"Neutral Agent's Conversation history",
] # Conversation history ] # Conversation history
history: Annotated[str, "Conversation history"] # Conversation history history: Annotated[str, "Conversation history"] # Conversation history
latest_speaker: Annotated[str, "Analyst that spoke last"] latest_speaker: Annotated[str, "Analyst that spoke last"]
current_risky_response: Annotated[ current_risky_response: Annotated[
str, "Latest response by the risky analyst", str,
"Latest response by the risky analyst",
] # Last response ] # Last response
current_safe_response: Annotated[ current_safe_response: Annotated[
str, "Latest response by the safe analyst", str,
"Latest response by the safe analyst",
] # Last response ] # Last response
current_neutral_response: Annotated[ current_neutral_response: Annotated[
str, "Latest response by the neutral analyst", str,
"Latest response by the neutral analyst",
] # Last response ] # Last response
judge_decision: Annotated[str, "Judge's decision"] judge_decision: Annotated[str, "Judge's decision"]
count: Annotated[int, "Length of the current conversation"] # Conversation length count: Annotated[int, "Length of the current conversation"] # Conversation length
@ -56,13 +64,15 @@ class AgentState(MessagesState):
market_report: Annotated[str, "Report from the Market Analyst"] market_report: Annotated[str, "Report from the Market Analyst"]
sentiment_report: Annotated[str, "Report from the Social Media Analyst"] sentiment_report: Annotated[str, "Report from the Social Media Analyst"]
news_report: Annotated[ news_report: Annotated[
str, "Report from the News Researcher of current world affairs", str,
"Report from the News Researcher of current world affairs",
] ]
fundamentals_report: Annotated[str, "Report from the Fundamentals Researcher"] fundamentals_report: Annotated[str, "Report from the Fundamentals Researcher"]
# researcher team discussion step # researcher team discussion step
investment_debate_state: Annotated[ investment_debate_state: Annotated[
InvestDebateState, "Current state of the debate on if to invest or not", InvestDebateState,
"Current state of the debate on if to invest or not",
] ]
investment_plan: Annotated[str, "Plan generated by the Analyst"] investment_plan: Annotated[str, "Plan generated by the Analyst"]
@ -70,6 +80,7 @@ class AgentState(MessagesState):
# risk management team discussion step # risk management team discussion step
risk_debate_state: Annotated[ risk_debate_state: Annotated[
RiskDebateState, "Current state of the debate on evaluating risk", RiskDebateState,
"Current state of the debate on evaluating risk",
] ]
final_trade_decision: Annotated[str, "Final decision made by the Risk Analysts"] final_trade_decision: Annotated[str, "Final decision made by the Risk Analysts"]

View File

@ -56,7 +56,6 @@ class Toolkit:
return interface.get_reddit_global_news(curr_date, 7, 5) return interface.get_reddit_global_news(curr_date, 7, 5)
@staticmethod @staticmethod
@tool @tool
def get_finnhub_news( def get_finnhub_news(
@ -84,10 +83,11 @@ class Toolkit:
look_back_days = (end_date - start_date).days look_back_days = (end_date - start_date).days
return interface.get_finnhub_news( return interface.get_finnhub_news(
ticker, end_date_str, look_back_days, ticker,
end_date_str,
look_back_days,
) )
@staticmethod @staticmethod
@tool @tool
def get_reddit_stock_info( def get_reddit_stock_info(
@ -108,7 +108,6 @@ class Toolkit:
return interface.get_reddit_company_news(ticker, curr_date, 7, 5) return interface.get_reddit_company_news(ticker, curr_date, 7, 5)
@staticmethod @staticmethod
@tool @tool
def get_YFin_data( def get_YFin_data(
@ -128,7 +127,6 @@ class Toolkit:
return interface.get_YFin_data(symbol, start_date, end_date) return interface.get_YFin_data(symbol, start_date, end_date)
@staticmethod @staticmethod
@tool @tool
def get_YFin_data_online( def get_YFin_data_online(
@ -148,16 +146,17 @@ class Toolkit:
return interface.get_YFin_data_online(symbol, start_date, end_date) return interface.get_YFin_data_online(symbol, start_date, end_date)
@staticmethod @staticmethod
@tool @tool
def get_stockstats_indicators_report( def get_stockstats_indicators_report(
symbol: Annotated[str, "ticker symbol of the company"], symbol: Annotated[str, "ticker symbol of the company"],
indicator: Annotated[ indicator: Annotated[
str, "technical indicator to get the analysis and report of", str,
"technical indicator to get the analysis and report of",
], ],
curr_date: Annotated[ curr_date: Annotated[
str, "The current trading date you are trading on, YYYY-mm-dd", str,
"The current trading date you are trading on, YYYY-mm-dd",
], ],
look_back_days: Annotated[int, "how many days to look back"] = 30, look_back_days: Annotated[int, "how many days to look back"] = 30,
) -> str: ) -> str:
@ -173,19 +172,24 @@ class Toolkit:
""" """
return interface.get_stock_stats_indicators_window( return interface.get_stock_stats_indicators_window(
symbol, indicator, curr_date, look_back_days, False, symbol,
indicator,
curr_date,
look_back_days,
False,
) )
@staticmethod @staticmethod
@tool @tool
def get_stockstats_indicators_report_online( def get_stockstats_indicators_report_online(
symbol: Annotated[str, "ticker symbol of the company"], symbol: Annotated[str, "ticker symbol of the company"],
indicator: Annotated[ indicator: Annotated[
str, "technical indicator to get the analysis and report of", str,
"technical indicator to get the analysis and report of",
], ],
curr_date: Annotated[ curr_date: Annotated[
str, "The current trading date you are trading on, YYYY-mm-dd", str,
"The current trading date you are trading on, YYYY-mm-dd",
], ],
look_back_days: Annotated[int, "how many days to look back"] = 30, look_back_days: Annotated[int, "how many days to look back"] = 30,
) -> str: ) -> str:
@ -201,10 +205,13 @@ class Toolkit:
""" """
return interface.get_stock_stats_indicators_window( return interface.get_stock_stats_indicators_window(
symbol, indicator, curr_date, look_back_days, True, symbol,
indicator,
curr_date,
look_back_days,
True,
) )
@staticmethod @staticmethod
@tool @tool
def get_finnhub_company_insider_sentiment( def get_finnhub_company_insider_sentiment(
@ -224,10 +231,11 @@ class Toolkit:
""" """
return interface.get_finnhub_company_insider_sentiment( return interface.get_finnhub_company_insider_sentiment(
ticker, curr_date, 30, ticker,
curr_date,
30,
) )
@staticmethod @staticmethod
@tool @tool
def get_finnhub_company_insider_transactions( def get_finnhub_company_insider_transactions(
@ -247,10 +255,11 @@ class Toolkit:
""" """
return interface.get_finnhub_company_insider_transactions( return interface.get_finnhub_company_insider_transactions(
ticker, curr_date, 30, ticker,
curr_date,
30,
) )
@staticmethod @staticmethod
@tool @tool
def get_simfin_balance_sheet( def get_simfin_balance_sheet(
@ -273,7 +282,6 @@ class Toolkit:
return interface.get_simfin_balance_sheet(ticker, freq, curr_date) return interface.get_simfin_balance_sheet(ticker, freq, curr_date)
@staticmethod @staticmethod
@tool @tool
def get_simfin_cashflow( def get_simfin_cashflow(
@ -296,7 +304,6 @@ class Toolkit:
return interface.get_simfin_cashflow(ticker, freq, curr_date) return interface.get_simfin_cashflow(ticker, freq, curr_date)
@staticmethod @staticmethod
@tool @tool
def get_simfin_income_stmt( def get_simfin_income_stmt(
@ -318,10 +325,11 @@ class Toolkit:
""" """
return interface.get_simfin_income_statements( return interface.get_simfin_income_statements(
ticker, freq, curr_date, ticker,
freq,
curr_date,
) )
@staticmethod @staticmethod
@tool @tool
def get_google_news( def get_google_news(
@ -340,7 +348,6 @@ class Toolkit:
return interface.get_google_news(query, curr_date, 7) return interface.get_google_news(query, curr_date, 7)
@staticmethod @staticmethod
@tool @tool
def get_stock_news_openai( def get_stock_news_openai(
@ -358,7 +365,6 @@ class Toolkit:
return interface.get_stock_news_openai(ticker, curr_date) return interface.get_stock_news_openai(ticker, curr_date)
@staticmethod @staticmethod
@tool @tool
def get_global_news_openai( def get_global_news_openai(
@ -374,7 +380,6 @@ class Toolkit:
return interface.get_global_news_openai(curr_date) return interface.get_global_news_openai(curr_date)
@staticmethod @staticmethod
@tool @tool
def get_fundamentals_openai( def get_fundamentals_openai(
@ -391,6 +396,6 @@ class Toolkit:
""" """
return interface.get_fundamentals_openai( return interface.get_fundamentals_openai(
ticker, curr_date, ticker,
curr_date,
) )

View File

@ -21,7 +21,8 @@ def get_config():
"project_dir": str(project_root / "tradingagents"), "project_dir": str(project_root / "tradingagents"),
"results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", "./results"), "results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", "./results"),
"data_dir": os.getenv( "data_dir": os.getenv(
"TRADINGAGENTS_DATA_DIR", "/Users/yluo/Documents/Code/ScAI/FR1-data", "TRADINGAGENTS_DATA_DIR",
"/Users/yluo/Documents/Code/ScAI/FR1-data",
), ),
"data_cache_dir": str( "data_cache_dir": str(
project_root / "tradingagents" / "dataflows" / "data_cache", project_root / "tradingagents" / "dataflows" / "data_cache",

View File

@ -1,4 +1,3 @@
from tradingagents import default_config from tradingagents import default_config
# Use default config but allow it to be overridden # Use default config but allow it to be overridden

View File

@ -22,7 +22,10 @@ def get_data_in_range(ticker, start_date, end_date, data_type, data_dir, period=
) )
else: else:
data_path = os.path.join( data_path = os.path.join(
data_dir, "finnhub_data", data_type, f"{ticker}_data_formatted.json", data_dir,
"finnhub_data",
data_type,
f"{ticker}_data_formatted.json",
) )
data = open(data_path) data = open(data_path)

View File

@ -419,7 +419,8 @@ def get_stock_stats_indicators_window(
symbol: Annotated[str, "ticker symbol of the company"], symbol: Annotated[str, "ticker symbol of the company"],
indicator: Annotated[str, "technical indicator to get the analysis and report of"], indicator: Annotated[str, "technical indicator to get the analysis and report of"],
curr_date: Annotated[ curr_date: Annotated[
str, "The current trading date you are trading on, YYYY-mm-dd", str,
"The current trading date you are trading on, YYYY-mm-dd",
], ],
look_back_days: Annotated[int, "how many days to look back"], look_back_days: Annotated[int, "how many days to look back"],
online: Annotated[bool, "to fetch data online or offline"], online: Annotated[bool, "to fetch data online or offline"],
@ -524,7 +525,10 @@ def get_stock_stats_indicators_window(
# only do the trading dates # only do the trading dates
if curr_date.strftime("%Y-%m-%d") in dates_in_df.values: if curr_date.strftime("%Y-%m-%d") in dates_in_df.values:
indicator_value = get_stockstats_indicator( indicator_value = get_stockstats_indicator(
symbol, indicator, curr_date.strftime("%Y-%m-%d"), online, symbol,
indicator,
curr_date.strftime("%Y-%m-%d"),
online,
) )
ind_string += f"{curr_date.strftime('%Y-%m-%d')}: {indicator_value}\n" ind_string += f"{curr_date.strftime('%Y-%m-%d')}: {indicator_value}\n"
@ -535,7 +539,10 @@ def get_stock_stats_indicators_window(
ind_string = "" ind_string = ""
while curr_date >= before: while curr_date >= before:
indicator_value = get_stockstats_indicator( indicator_value = get_stockstats_indicator(
symbol, indicator, curr_date.strftime("%Y-%m-%d"), online, symbol,
indicator,
curr_date.strftime("%Y-%m-%d"),
online,
) )
ind_string += f"{curr_date.strftime('%Y-%m-%d')}: {indicator_value}\n" ind_string += f"{curr_date.strftime('%Y-%m-%d')}: {indicator_value}\n"
@ -550,12 +557,12 @@ def get_stock_stats_indicators_window(
) )
def get_stockstats_indicator( def get_stockstats_indicator(
symbol: Annotated[str, "ticker symbol of the company"], symbol: Annotated[str, "ticker symbol of the company"],
indicator: Annotated[str, "technical indicator to get the analysis and report of"], indicator: Annotated[str, "technical indicator to get the analysis and report of"],
curr_date: Annotated[ curr_date: Annotated[
str, "The current trading date you are trading on, YYYY-mm-dd", str,
"The current trading date you are trading on, YYYY-mm-dd",
], ],
online: Annotated[bool, "to fetch data online or offline"], online: Annotated[bool, "to fetch data online or offline"],
) -> str: ) -> str:
@ -608,7 +615,12 @@ def get_YFin_data_window(
# Set pandas display options to show the full DataFrame # Set pandas display options to show the full DataFrame
with pd.option_context( with pd.option_context(
"display.max_rows", None, "display.max_columns", None, "display.width", None, "display.max_rows",
None,
"display.max_columns",
None,
"display.width",
None,
): ):
df_string = filtered_data.to_string() df_string = filtered_data.to_string()
@ -694,7 +706,6 @@ def get_YFin_data(
return filtered_data.reset_index(drop=True) return filtered_data.reset_index(drop=True)
def get_stock_news_openai(ticker, curr_date): def get_stock_news_openai(ticker, curr_date):
config = get_config() config = get_config()
client = OpenAI(base_url=config["backend_url"]) client = OpenAI(base_url=config["backend_url"])

View File

@ -48,11 +48,14 @@ ticker_to_company = {
def fetch_top_from_category( def fetch_top_from_category(
category: Annotated[ category: Annotated[
str, "Category to fetch top post from. Collection of subreddits.", str,
"Category to fetch top post from. Collection of subreddits.",
], ],
date: Annotated[str, "Date to fetch top posts from."], date: Annotated[str, "Date to fetch top posts from."],
max_limit: Annotated[int, "Maximum number of posts to fetch."], max_limit: Annotated[int, "Maximum number of posts to fetch."],
query: Annotated[str | None, "Optional query to search for in the subreddit."] = None, query: Annotated[
str | None, "Optional query to search for in the subreddit."
] = None,
data_path: Annotated[ data_path: Annotated[
str, str,
"Path to the data folder. Default is 'reddit_data'.", "Path to the data folder. Default is 'reddit_data'.",
@ -107,7 +110,9 @@ def fetch_top_from_category(
found = False found = False
for term in search_terms: for term in search_terms:
if re.search( if re.search(
term, parsed_line["title"], re.IGNORECASE, term,
parsed_line["title"],
re.IGNORECASE,
) or re.search(term, parsed_line["selftext"], re.IGNORECASE): ) or re.search(term, parsed_line["selftext"], re.IGNORECASE):
found = True found = True
break break

View File

@ -13,10 +13,12 @@ class StockstatsUtils:
def get_stock_stats( def get_stock_stats(
symbol: Annotated[str, "ticker symbol for the company"], symbol: Annotated[str, "ticker symbol for the company"],
indicator: Annotated[ indicator: Annotated[
str, "quantitative indicators based off of the stock data for the company", str,
"quantitative indicators based off of the stock data for the company",
], ],
curr_date: Annotated[ curr_date: Annotated[
str, "curr date for retrieving stock price data, YYYY-mm-dd", str,
"curr date for retrieving stock price data, YYYY-mm-dd",
], ],
data_dir: Annotated[ data_dir: Annotated[
str, str,

View File

@ -28,10 +28,12 @@ class YFinanceUtils:
def get_stock_data( def get_stock_data(
self: Annotated[str, "ticker symbol"], self: Annotated[str, "ticker symbol"],
start_date: Annotated[ start_date: Annotated[
str, "start date for retrieving stock price data, YYYY-mm-dd", str,
"start date for retrieving stock price data, YYYY-mm-dd",
], ],
end_date: Annotated[ end_date: Annotated[
str, "end date for retrieving stock price data, YYYY-mm-dd", str,
"end date for retrieving stock price data, YYYY-mm-dd",
], ],
save_path: SavePathType = None, save_path: SavePathType = None,
) -> DataFrame: ) -> DataFrame:

View File

@ -16,7 +16,9 @@ class Propagator:
self.max_recur_limit = max_recur_limit self.max_recur_limit = max_recur_limit
def create_initial_state( def create_initial_state(
self, company_name: str, trade_date: str, self,
company_name: str,
trade_date: str,
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Create the initial state for the agent graph.""" """Create the initial state for the agent graph."""
return { return {

View File

@ -57,7 +57,11 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur
return f"{curr_market_report}\n\n{curr_sentiment_report}\n\n{curr_news_report}\n\n{curr_fundamentals_report}" return f"{curr_market_report}\n\n{curr_sentiment_report}\n\n{curr_news_report}\n\n{curr_fundamentals_report}"
def _reflect_on_component( def _reflect_on_component(
self, component_type: str, report: str, situation: str, returns_losses, self,
component_type: str,
report: str,
situation: str,
returns_losses,
) -> str: ) -> str:
"""Generate reflection for a component.""" """Generate reflection for a component."""
messages = [ messages = [
@ -76,7 +80,10 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur
bull_debate_history = current_state["investment_debate_state"]["bull_history"] bull_debate_history = current_state["investment_debate_state"]["bull_history"]
result = self._reflect_on_component( result = self._reflect_on_component(
"BULL", bull_debate_history, situation, returns_losses, "BULL",
bull_debate_history,
situation,
returns_losses,
) )
bull_memory.add_situations([(situation, result)]) bull_memory.add_situations([(situation, result)])
@ -86,7 +93,10 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur
bear_debate_history = current_state["investment_debate_state"]["bear_history"] bear_debate_history = current_state["investment_debate_state"]["bear_history"]
result = self._reflect_on_component( result = self._reflect_on_component(
"BEAR", bear_debate_history, situation, returns_losses, "BEAR",
bear_debate_history,
situation,
returns_losses,
) )
bear_memory.add_situations([(situation, result)]) bear_memory.add_situations([(situation, result)])
@ -96,7 +106,10 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur
trader_decision = current_state["trader_investment_plan"] trader_decision = current_state["trader_investment_plan"]
result = self._reflect_on_component( result = self._reflect_on_component(
"TRADER", trader_decision, situation, returns_losses, "TRADER",
trader_decision,
situation,
returns_losses,
) )
trader_memory.add_situations([(situation, result)]) trader_memory.add_situations([(situation, result)])
@ -106,7 +119,10 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur
judge_decision = current_state["investment_debate_state"]["judge_decision"] judge_decision = current_state["investment_debate_state"]["judge_decision"]
result = self._reflect_on_component( result = self._reflect_on_component(
"INVEST JUDGE", judge_decision, situation, returns_losses, "INVEST JUDGE",
judge_decision,
situation,
returns_losses,
) )
invest_judge_memory.add_situations([(situation, result)]) invest_judge_memory.add_situations([(situation, result)])
@ -116,6 +132,9 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur
judge_decision = current_state["risk_debate_state"]["judge_decision"] judge_decision = current_state["risk_debate_state"]["judge_decision"]
result = self._reflect_on_component( result = self._reflect_on_component(
"RISK JUDGE", judge_decision, situation, returns_losses, "RISK JUDGE",
judge_decision,
situation,
returns_losses,
) )
risk_manager_memory.add_situations([(situation, result)]) risk_manager_memory.add_situations([(situation, result)])

View File

@ -55,7 +55,8 @@ class GraphSetup:
self.conditional_logic = conditional_logic self.conditional_logic = conditional_logic
def setup_graph( def setup_graph(
self, selected_analysts=None, self,
selected_analysts=None,
): ):
"""Set up and compile the agent workflow graph. """Set up and compile the agent workflow graph.
@ -79,41 +80,48 @@ class GraphSetup:
if "market" in selected_analysts: if "market" in selected_analysts:
analyst_nodes["market"] = create_market_analyst( analyst_nodes["market"] = create_market_analyst(
self.quick_thinking_llm, self.toolkit, self.quick_thinking_llm,
self.toolkit,
) )
delete_nodes["market"] = create_msg_delete() delete_nodes["market"] = create_msg_delete()
tool_nodes["market"] = self.tool_nodes["market"] tool_nodes["market"] = self.tool_nodes["market"]
if "social" in selected_analysts: if "social" in selected_analysts:
analyst_nodes["social"] = create_social_media_analyst( analyst_nodes["social"] = create_social_media_analyst(
self.quick_thinking_llm, self.toolkit, self.quick_thinking_llm,
self.toolkit,
) )
delete_nodes["social"] = create_msg_delete() delete_nodes["social"] = create_msg_delete()
tool_nodes["social"] = self.tool_nodes["social"] tool_nodes["social"] = self.tool_nodes["social"]
if "news" in selected_analysts: if "news" in selected_analysts:
analyst_nodes["news"] = create_news_analyst( analyst_nodes["news"] = create_news_analyst(
self.quick_thinking_llm, self.toolkit, self.quick_thinking_llm,
self.toolkit,
) )
delete_nodes["news"] = create_msg_delete() delete_nodes["news"] = create_msg_delete()
tool_nodes["news"] = self.tool_nodes["news"] tool_nodes["news"] = self.tool_nodes["news"]
if "fundamentals" in selected_analysts: if "fundamentals" in selected_analysts:
analyst_nodes["fundamentals"] = create_fundamentals_analyst( analyst_nodes["fundamentals"] = create_fundamentals_analyst(
self.quick_thinking_llm, self.toolkit, self.quick_thinking_llm,
self.toolkit,
) )
delete_nodes["fundamentals"] = create_msg_delete() delete_nodes["fundamentals"] = create_msg_delete()
tool_nodes["fundamentals"] = self.tool_nodes["fundamentals"] tool_nodes["fundamentals"] = self.tool_nodes["fundamentals"]
# Create researcher and manager nodes # Create researcher and manager nodes
bull_researcher_node = create_bull_researcher( bull_researcher_node = create_bull_researcher(
self.quick_thinking_llm, self.bull_memory, self.quick_thinking_llm,
self.bull_memory,
) )
bear_researcher_node = create_bear_researcher( bear_researcher_node = create_bear_researcher(
self.quick_thinking_llm, self.bear_memory, self.quick_thinking_llm,
self.bear_memory,
) )
research_manager_node = create_research_manager( research_manager_node = create_research_manager(
self.deep_thinking_llm, self.invest_judge_memory, self.deep_thinking_llm,
self.invest_judge_memory,
) )
trader_node = create_trader(self.quick_thinking_llm, self.trader_memory) trader_node = create_trader(self.quick_thinking_llm, self.trader_memory)
@ -122,7 +130,8 @@ class GraphSetup:
neutral_analyst = create_neutral_debator(self.quick_thinking_llm) neutral_analyst = create_neutral_debator(self.quick_thinking_llm)
safe_analyst = create_safe_debator(self.quick_thinking_llm) safe_analyst = create_safe_debator(self.quick_thinking_llm)
risk_manager_node = create_risk_manager( risk_manager_node = create_risk_manager(
self.deep_thinking_llm, self.risk_manager_memory, self.deep_thinking_llm,
self.risk_manager_memory,
) )
# Create workflow # Create workflow
@ -132,7 +141,8 @@ class GraphSetup:
for analyst_type, node in analyst_nodes.items(): for analyst_type, node in analyst_nodes.items():
workflow.add_node(f"{analyst_type.capitalize()} Analyst", node) workflow.add_node(f"{analyst_type.capitalize()} Analyst", node)
workflow.add_node( workflow.add_node(
f"Msg Clear {analyst_type.capitalize()}", delete_nodes[analyst_type], f"Msg Clear {analyst_type.capitalize()}",
delete_nodes[analyst_type],
) )
workflow.add_node(f"tools_{analyst_type}", tool_nodes[analyst_type]) workflow.add_node(f"tools_{analyst_type}", tool_nodes[analyst_type])

View File

@ -59,7 +59,8 @@ class TradingAgentsGraph:
or self.config["llm_provider"] == "openrouter" or self.config["llm_provider"] == "openrouter"
): ):
self.deep_thinking_llm = ChatOpenAI( self.deep_thinking_llm = ChatOpenAI(
model=self.config["deep_think_llm"], base_url=self.config["backend_url"], model=self.config["deep_think_llm"],
base_url=self.config["backend_url"],
) )
self.quick_thinking_llm = ChatOpenAI( self.quick_thinking_llm = ChatOpenAI(
model=self.config["quick_think_llm"], model=self.config["quick_think_llm"],
@ -67,7 +68,8 @@ class TradingAgentsGraph:
) )
elif self.config["llm_provider"].lower() == "anthropic": elif self.config["llm_provider"].lower() == "anthropic":
self.deep_thinking_llm = ChatAnthropic( self.deep_thinking_llm = ChatAnthropic(
model=self.config["deep_think_llm"], base_url=self.config["backend_url"], model=self.config["deep_think_llm"],
base_url=self.config["backend_url"],
) )
self.quick_thinking_llm = ChatAnthropic( self.quick_thinking_llm = ChatAnthropic(
model=self.config["quick_think_llm"], model=self.config["quick_think_llm"],
@ -91,10 +93,12 @@ class TradingAgentsGraph:
self.bear_memory = FinancialSituationMemory("bear_memory", self.config) self.bear_memory = FinancialSituationMemory("bear_memory", self.config)
self.trader_memory = FinancialSituationMemory("trader_memory", self.config) self.trader_memory = FinancialSituationMemory("trader_memory", self.config)
self.invest_judge_memory = FinancialSituationMemory( self.invest_judge_memory = FinancialSituationMemory(
"invest_judge_memory", self.config, "invest_judge_memory",
self.config,
) )
self.risk_manager_memory = FinancialSituationMemory( self.risk_manager_memory = FinancialSituationMemory(
"risk_manager_memory", self.config, "risk_manager_memory",
self.config,
) )
# Create tool nodes # Create tool nodes
@ -179,7 +183,8 @@ class TradingAgentsGraph:
# Initialize state # Initialize state
init_agent_state = self.propagator.create_initial_state( init_agent_state = self.propagator.create_initial_state(
company_name, trade_date, company_name,
trade_date,
) )
args = self.propagator.get_graph_args() args = self.propagator.get_graph_args()
@ -252,19 +257,29 @@ class TradingAgentsGraph:
def reflect_and_remember(self, returns_losses): def reflect_and_remember(self, returns_losses):
"""Reflect on decisions and update memory based on returns.""" """Reflect on decisions and update memory based on returns."""
self.reflector.reflect_bull_researcher( self.reflector.reflect_bull_researcher(
self.curr_state, returns_losses, self.bull_memory, self.curr_state,
returns_losses,
self.bull_memory,
) )
self.reflector.reflect_bear_researcher( self.reflector.reflect_bear_researcher(
self.curr_state, returns_losses, self.bear_memory, self.curr_state,
returns_losses,
self.bear_memory,
) )
self.reflector.reflect_trader( self.reflector.reflect_trader(
self.curr_state, returns_losses, self.trader_memory, self.curr_state,
returns_losses,
self.trader_memory,
) )
self.reflector.reflect_invest_judge( self.reflector.reflect_invest_judge(
self.curr_state, returns_losses, self.invest_judge_memory, self.curr_state,
returns_losses,
self.invest_judge_memory,
) )
self.reflector.reflect_risk_manager( self.reflector.reflect_risk_manager(
self.curr_state, returns_losses, self.risk_manager_memory, self.curr_state,
returns_losses,
self.risk_manager_memory,
) )
def process_signal(self, full_signal): def process_signal(self, full_signal):