Apply Black formatting to pass CI checks

- Formatted 33 Python files with Black
- Fixed code style to meet project standards
- Ensures CI/CD pipeline passes formatting checks
This commit is contained in:
佐藤優一 2025-08-10 23:25:04 +09:00
parent 6f3981412b
commit 850764ad7b
33 changed files with 475 additions and 238 deletions

View File

@ -148,7 +148,7 @@ class MessageBuffer:
f"### News Analysis\n{self.report_sections['news_report']}",
)
if self.report_sections["fundamentals_report"]:
fundamentals = self.report_sections['fundamentals_report']
fundamentals = self.report_sections["fundamentals_report"]
report_parts.append(
f"### Fundamentals Analysis\n{fundamentals}",
)
@ -182,10 +182,12 @@ def create_layout():
Layout(name="footer", size=3),
)
layout["main"].split_column(
Layout(name="upper", ratio=3), Layout(name="analysis", ratio=5),
Layout(name="upper", ratio=3),
Layout(name="analysis", ratio=5),
)
layout["upper"].split_row(
Layout(name="progress", ratio=2), Layout(name="messages", ratio=3),
Layout(name="progress", ratio=2),
Layout(name="messages", ratio=3),
)
return layout
@ -237,7 +239,9 @@ def update_display(layout, spinner_text=None):
status = message_buffer.agent_status[first_agent]
if status == "in_progress":
spinner = Spinner(
"dots", text="[blue]in_progress[/blue]", style="bold cyan",
"dots",
text="[blue]in_progress[/blue]",
style="bold cyan",
)
status_cell = spinner
else:
@ -254,7 +258,9 @@ def update_display(layout, spinner_text=None):
status = message_buffer.agent_status[agent]
if status == "in_progress":
spinner = Spinner(
"dots", text="[blue]in_progress[/blue]", style="bold cyan",
"dots",
text="[blue]in_progress[/blue]",
style="bold cyan",
)
status_cell = spinner
else:
@ -286,7 +292,10 @@ def update_display(layout, spinner_text=None):
messages_table.add_column("Time", style="cyan", width=8, justify="center")
messages_table.add_column("Type", style="green", width=10, justify="center")
messages_table.add_column(
"Content", style="white", no_wrap=False, ratio=1,
"Content",
style="white",
no_wrap=False,
ratio=1,
) # Make content column expand
# Combine tool calls and messages
@ -441,7 +450,9 @@ def get_user_selections():
# Step 1: Ticker symbol
console.print(
create_question_box(
"Step 1: Ticker Symbol", "Enter the ticker symbol to analyze", "SPY",
"Step 1: Ticker Symbol",
"Enter the ticker symbol to analyze",
"SPY",
),
)
selected_ticker = get_ticker()
@ -460,7 +471,8 @@ def get_user_selections():
# Step 3: Select analysts
console.print(
create_question_box(
"Step 3: Analysts Team", "Select your LLM analyst agents for the analysis",
"Step 3: Analysts Team",
"Select your LLM analyst agents for the analysis",
),
)
selected_analysts = select_analysts()
@ -471,21 +483,25 @@ def get_user_selections():
# Step 4: Research depth
console.print(
create_question_box(
"Step 4: Research Depth", "Select your research depth level",
"Step 4: Research Depth",
"Select your research depth level",
),
)
selected_research_depth = select_research_depth()
# Step 5: OpenAI backend
console.print(
create_question_box("Step 5: OpenAI backend", "Select which service to talk to"),
create_question_box(
"Step 5: OpenAI backend", "Select which service to talk to"
),
)
selected_llm_provider, backend_url = select_llm_provider()
# Step 6: Thinking agents
console.print(
create_question_box(
"Step 6: Thinking Agents", "Select your thinking agents for analysis",
"Step 6: Thinking Agents",
"Select your thinking agents for analysis",
),
)
selected_shallow_thinker = select_shallow_thinking_agent(selected_llm_provider)
@ -737,7 +753,9 @@ def run_analysis():
# Initialize the graph
graph = TradingAgentsGraph(
[analyst.value for analyst in selections["analysts"]], config=config, debug=True,
[analyst.value for analyst in selections["analysts"]],
config=config,
debug=True,
)
# Create result directory
@ -796,10 +814,12 @@ def run_analysis():
message_buffer.add_message = save_message_decorator(message_buffer, "add_message")
message_buffer.add_tool_call = save_tool_call_decorator(
message_buffer, "add_tool_call",
message_buffer,
"add_tool_call",
)
message_buffer.update_report_section = save_report_section_decorator(
message_buffer, "update_report_section",
message_buffer,
"update_report_section",
)
# Now start the display layout
@ -812,7 +832,8 @@ def run_analysis():
# Add initial messages
message_buffer.add_message("System", f"Selected ticker: {selections['ticker']}")
message_buffer.add_message(
"System", f"Analysis date: {selections['analysis_date']}",
"System",
f"Analysis date: {selections['analysis_date']}",
)
message_buffer.add_message(
"System",
@ -843,7 +864,8 @@ def run_analysis():
# Initialize state and get graph args
init_agent_state = graph.propagator.create_initial_state(
selections["ticker"], selections["analysis_date"],
selections["ticker"],
selections["analysis_date"],
)
args = graph.propagator.get_graph_args()
@ -873,7 +895,8 @@ def run_analysis():
# Handle both dictionary and object tool calls
if isinstance(tool_call, dict):
message_buffer.add_tool_call(
tool_call["name"], tool_call["args"],
tool_call["name"],
tool_call["args"],
)
else:
message_buffer.add_tool_call(tool_call.name, tool_call.args)
@ -882,51 +905,57 @@ def run_analysis():
# Analyst Team Reports
if chunk.get("market_report"):
message_buffer.update_report_section(
"market_report", chunk["market_report"],
"market_report",
chunk["market_report"],
)
message_buffer.update_agent_status("Market Analyst", "completed")
# Set next analyst to in_progress
if "social" in selections["analysts"]:
message_buffer.update_agent_status(
"Social Analyst", "in_progress",
"Social Analyst",
"in_progress",
)
if chunk.get("sentiment_report"):
message_buffer.update_report_section(
"sentiment_report", chunk["sentiment_report"],
"sentiment_report",
chunk["sentiment_report"],
)
message_buffer.update_agent_status("Social Analyst", "completed")
# Set next analyst to in_progress
if "news" in selections["analysts"]:
message_buffer.update_agent_status(
"News Analyst", "in_progress",
"News Analyst",
"in_progress",
)
if chunk.get("news_report"):
message_buffer.update_report_section(
"news_report", chunk["news_report"],
"news_report",
chunk["news_report"],
)
message_buffer.update_agent_status("News Analyst", "completed")
# Set next analyst to in_progress
if "fundamentals" in selections["analysts"]:
message_buffer.update_agent_status(
"Fundamentals Analyst", "in_progress",
"Fundamentals Analyst",
"in_progress",
)
if chunk.get("fundamentals_report"):
message_buffer.update_report_section(
"fundamentals_report", chunk["fundamentals_report"],
"fundamentals_report",
chunk["fundamentals_report"],
)
message_buffer.update_agent_status(
"Fundamentals Analyst", "completed",
"Fundamentals Analyst",
"completed",
)
# Set all research team members to in_progress
update_research_team_status("in_progress")
# Research Team - Handle Investment Debate State
if (
chunk.get("investment_debate_state")
):
if chunk.get("investment_debate_state"):
debate_state = chunk["investment_debate_state"]
# Update Bull Researcher status and report
@ -960,9 +989,7 @@ def run_analysis():
)
# Update Research Manager status and final decision
if (
debate_state.get("judge_decision")
):
if debate_state.get("judge_decision"):
# Keep all research team members in progress until final decision
update_research_team_status("in_progress")
message_buffer.add_message(
@ -978,15 +1005,15 @@ def run_analysis():
update_research_team_status("completed")
# Set first risk analyst to in_progress
message_buffer.update_agent_status(
"Risky Analyst", "in_progress",
"Risky Analyst",
"in_progress",
)
# Trading Team
if (
chunk.get("trader_investment_plan")
):
if chunk.get("trader_investment_plan"):
message_buffer.update_report_section(
"trader_investment_plan", chunk["trader_investment_plan"],
"trader_investment_plan",
chunk["trader_investment_plan"],
)
# Set first risk analyst to in_progress
message_buffer.update_agent_status("Risky Analyst", "in_progress")
@ -996,11 +1023,10 @@ def run_analysis():
risk_state = chunk["risk_debate_state"]
# Update Risky Analyst status and report
if (
risk_state.get("current_risky_response")
):
if risk_state.get("current_risky_response"):
message_buffer.update_agent_status(
"Risky Analyst", "in_progress",
"Risky Analyst",
"in_progress",
)
message_buffer.add_message(
"Reasoning",
@ -1013,11 +1039,10 @@ def run_analysis():
)
# Update Safe Analyst status and report
if (
risk_state.get("current_safe_response")
):
if risk_state.get("current_safe_response"):
message_buffer.update_agent_status(
"Safe Analyst", "in_progress",
"Safe Analyst",
"in_progress",
)
message_buffer.add_message(
"Reasoning",
@ -1030,11 +1055,10 @@ def run_analysis():
)
# Update Neutral Analyst status and report
if (
risk_state.get("current_neutral_response")
):
if risk_state.get("current_neutral_response"):
message_buffer.update_agent_status(
"Neutral Analyst", "in_progress",
"Neutral Analyst",
"in_progress",
)
message_buffer.add_message(
"Reasoning",
@ -1049,7 +1073,8 @@ def run_analysis():
# Update Portfolio Manager status and final decision
if risk_state.get("judge_decision"):
message_buffer.update_agent_status(
"Portfolio Manager", "in_progress",
"Portfolio Manager",
"in_progress",
)
message_buffer.add_message(
"Reasoning",
@ -1064,10 +1089,12 @@ def run_analysis():
message_buffer.update_agent_status("Risky Analyst", "completed")
message_buffer.update_agent_status("Safe Analyst", "completed")
message_buffer.update_agent_status(
"Neutral Analyst", "completed",
"Neutral Analyst",
"completed",
)
message_buffer.update_agent_status(
"Portfolio Manager", "completed",
"Portfolio Manager",
"completed",
)
# Update the display
@ -1084,7 +1111,8 @@ def run_analysis():
message_buffer.update_agent_status(agent, "completed")
message_buffer.add_message(
"Analysis", f"Completed analysis for {selections['analysis_date']}",
"Analysis",
f"Completed analysis for {selections['analysis_date']}",
)
# Update final report sections

View File

@ -1,4 +1,3 @@
import sys
import questionary

View File

@ -4,7 +4,9 @@ from tradingagents.default_config import DEFAULT_CONFIG
# Create a custom config
config = DEFAULT_CONFIG.copy()
config["llm_provider"] = "google" # Use a different model
config["backend_url"] = "https://generativelanguage.googleapis.com/v1" # Use a different backend
config["backend_url"] = (
"https://generativelanguage.googleapis.com/v1" # Use a different backend
)
config["deep_think_llm"] = "gemini-2.0-flash" # Use a different model
config["quick_think_llm"] = "gemini-2.0-flash" # Use a different model
config["max_debate_rounds"] = 1 # Increase debate rounds

View File

@ -16,10 +16,10 @@ def run_command(cmd, description=""):
"""Run a command and handle errors."""
if description:
print(f"\n🔄 {description}")
print(f"Running: {' '.join(cmd)}")
result = subprocess.run(cmd, capture_output=False)
if result.returncode != 0:
print(f"❌ Command failed with return code {result.returncode}")
sys.exit(result.returncode)
@ -33,48 +33,42 @@ def main():
parser.add_argument(
"test_type",
choices=["unit", "integration", "all", "coverage", "fast", "slow", "lint"],
help="Type of tests to run"
)
parser.add_argument(
"--verbose", "-v", action="store_true", help="Verbose output"
help="Type of tests to run",
)
parser.add_argument("--verbose", "-v", action="store_true", help="Verbose output")
parser.add_argument(
"--parallel", "-p", action="store_true", help="Run tests in parallel"
)
parser.add_argument(
"--file", "-f", help="Run specific test file"
)
parser.add_argument(
"--pattern", "-k", help="Run tests matching pattern"
)
parser.add_argument("--file", "-f", help="Run specific test file")
parser.add_argument("--pattern", "-k", help="Run tests matching pattern")
args = parser.parse_args()
# Base pytest command
base_cmd = ["python", "-m", "pytest"]
if args.verbose:
base_cmd.append("-v")
if args.parallel:
base_cmd.extend(["-n", "auto"])
if args.pattern:
base_cmd.extend(["-k", args.pattern])
# Configure based on test type
if args.test_type == "unit":
cmd = base_cmd + ["tests/unit/", "-m", "unit"]
run_command(cmd, "Running unit tests")
elif args.test_type == "integration":
cmd = base_cmd + ["tests/integration/", "-m", "integration"]
run_command(cmd, "Running integration tests")
elif args.test_type == "all":
cmd = base_cmd + ["tests/"]
run_command(cmd, "Running all tests")
elif args.test_type == "coverage":
cmd = base_cmd + [
"tests/",
@ -88,28 +82,28 @@ def main():
print("\n📊 Coverage report generated:")
print(" - HTML: htmlcov/index.html")
print(" - XML: coverage.xml")
elif args.test_type == "fast":
cmd = base_cmd + ["tests/unit/", "-m", "unit", "--durations=10"]
run_command(cmd, "Running fast unit tests")
elif args.test_type == "slow":
cmd = base_cmd + ["tests/", "-m", "slow", "--timeout=600"]
run_command(cmd, "Running slow tests")
elif args.test_type == "lint":
# Run mypy
cmd = ["python", "-m", "mypy", "tradingagents/", "cli/", "tests/"]
run_command(cmd, "Running mypy type checking")
# Run pytest on tests only for syntax
cmd = base_cmd + ["tests/", "--collect-only"]
run_command(cmd, "Validating test syntax")
elif args.file:
cmd = base_cmd + [args.file]
run_command(cmd, f"Running tests in {args.file}")
print("\n🎉 All tests completed successfully!")
@ -117,5 +111,5 @@ if __name__ == "__main__":
# Ensure we're in the project directory
script_dir = Path(__file__).parent
os.chdir(script_dir)
main()
main()

View File

@ -1,19 +1,23 @@
def poorly_formatted_function(x,y,z): # Missing type hints
def poorly_formatted_function(x, y, z): # Missing type hints
"""This function has formatting issues."""
result=x+y*z # Missing spaces around operators
if result>100: # Missing spaces
print( "Result is large" ) # Extra spaces in parentheses
result = x + y * z # Missing spaces around operators
if result > 100: # Missing spaces
print("Result is large") # Extra spaces in parentheses
return result
# Long line that Black will wrap
very_long_variable_name_that_exceeds_the_standard_line_length_limit = "This is a very long string that will be wrapped by Black formatter"
very_long_variable_name_that_exceeds_the_standard_line_length_limit = (
"This is a very long string that will be wrapped by Black formatter"
)
class MyClass:
def __init__(self,name:str,age:int): # Missing space after comma
self.name=name # Missing spaces around =
self.age=age
def __init__(self, name: str, age: int): # Missing space after comma
self.name = name # Missing spaces around =
self.age = age
# Function with wrong return type hint
def get_number() -> str:
return 123 # Returns int but type hint says str
return 123 # Returns int but type hint says str

View File

@ -2,6 +2,7 @@ def add_numbers(a: int, b: int) -> int:
"""Add two numbers and return the result."""
return a + b
# Test the function
result = add_numbers(1, 2)
print(f"Result: {result}")
print(f"Result: {result}")

View File

@ -15,7 +15,7 @@ def run_command(cmd, description=""):
"""Run a command and return success status."""
if description:
print(f"\n🔄 {description}")
print(f"Running: {' '.join(cmd)}")
try:
result = subprocess.run(cmd, capture_output=True, text=True, timeout=30)
@ -38,45 +38,60 @@ def main():
"""Run setup verification tests."""
print("🧪 TradingAgents Test Setup Verification")
print("=" * 50)
# Change to project directory
project_dir = Path(__file__).parent
os.chdir(project_dir)
success_count = 0
total_tests = 0
# Test 1: Check if pytest is installed and can discover tests
total_tests += 1
if run_command(["python", "-m", "pytest", "--version"], "Checking pytest installation"):
if run_command(
["python", "-m", "pytest", "--version"], "Checking pytest installation"
):
success_count += 1
# Test 2: Test discovery
total_tests += 1
if run_command(["python", "-m", "pytest", "tests/", "--collect-only", "-q"], "Testing test discovery"):
if run_command(
["python", "-m", "pytest", "tests/", "--collect-only", "-q"],
"Testing test discovery",
):
success_count += 1
# Test 3: Check if mypy can run
total_tests += 1
if run_command(["python", "-m", "mypy", "--version"], "Checking mypy installation"):
success_count += 1
# Test 4: Run a simple syntax check on test files
total_tests += 1
if run_command(["python", "-c", "import tests.conftest; print('Test imports work!')"], "Testing test imports"):
if run_command(
["python", "-c", "import tests.conftest; print('Test imports work!')"],
"Testing test imports",
):
success_count += 1
# Test 5: Check if we can import the main module
total_tests += 1
if run_command(["python", "-c", "import tradingagents.config; print('Main module imports work!')"], "Testing main module imports"):
if run_command(
[
"python",
"-c",
"import tradingagents.config; print('Main module imports work!')",
],
"Testing main module imports",
):
success_count += 1
# Summary
print("\n" + "=" * 50)
print("📊 Test Setup Verification Results:")
print(f"✅ Successful: {success_count}/{total_tests}")
print(f"❌ Failed: {total_tests - success_count}/{total_tests}")
if success_count == total_tests:
print("\n🎉 All verification tests passed! Your test setup is ready.")
print("\n📚 Next steps:")
@ -95,4 +110,4 @@ def main():
if __name__ == "__main__":
sys.exit(main())
sys.exit(main())

View File

@ -173,7 +173,8 @@ def mock_memory():
def pytest_configure(config):
"""Configure pytest with custom markers."""
config.addinivalue_line(
"markers", "integration: mark test as integration test (slow)",
"markers",
"integration: mark test as integration test (slow)",
)
config.addinivalue_line("markers", "unit: mark test as unit test (fast)")
config.addinivalue_line("markers", "api: mark test as requiring API access")

View File

@ -36,7 +36,8 @@ class SampleDataFactory:
@staticmethod
def create_finnhub_news_data(
ticker: str = "AAPL", count: int = 10,
ticker: str = "AAPL",
count: int = 10,
) -> dict[str, list[dict[str, Any]]]:
"""Create sample FinnHub news data for testing."""
base_date = datetime(2024, 5, 10)
@ -136,7 +137,8 @@ class SampleDataFactory:
@staticmethod
def create_financial_statements_data(
ticker: str = "AAPL", period: str = "annual",
ticker: str = "AAPL",
period: str = "annual",
) -> dict[str, list[dict[str, Any]]]:
"""Create sample financial statements data for testing."""
if period == "annual":
@ -271,10 +273,12 @@ class SampleDataFactory:
ticker,
),
"financial_annual": SampleDataFactory.create_financial_statements_data(
ticker, "annual",
ticker,
"annual",
),
"financial_quarterly": SampleDataFactory.create_financial_statements_data(
ticker, "quarterly",
ticker,
"quarterly",
),
"social_sentiment": SampleDataFactory.create_social_sentiment_data(ticker),
"technical_indicators": SampleDataFactory.create_technical_indicators_data(
@ -343,7 +347,9 @@ def save_sample_data_to_files(base_path: str, ticker: str = "AAPL") -> None:
# Save quarterly data separately
quarterly_path = os.path.join(
finnhub_path, "fin_as_reported", f"{ticker}_quarterly_data_formatted.json",
finnhub_path,
"fin_as_reported",
f"{ticker}_quarterly_data_formatted.json",
)
with open(quarterly_path, "w") as f:
json.dump(dataset["financial_quarterly"], f, indent=2)

View File

@ -31,7 +31,10 @@ class TestFullWorkflowIntegration:
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
@patch("tradingagents.graph.trading_graph.Toolkit")
def test_end_to_end_trading_workflow(
self, mock_toolkit, mock_chat_openai, integration_config,
self,
mock_toolkit,
mock_chat_openai,
integration_config,
):
"""Test complete end-to-end trading workflow."""
# Setup mocks
@ -86,7 +89,10 @@ class TestFullWorkflowIntegration:
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
@patch("tradingagents.graph.trading_graph.Toolkit")
def test_multiple_analysts_integration(
self, mock_toolkit, mock_chat_openai, integration_config,
self,
mock_toolkit,
mock_chat_openai,
integration_config,
):
"""Test integration with different analyst combinations."""
analyst_combinations = [
@ -114,7 +120,8 @@ class TestFullWorkflowIntegration:
with patch("tradingagents.graph.trading_graph.set_config"):
# Test each analyst combination
trading_graph = TradingAgentsGraph(
selected_analysts=analysts, config=integration_config,
selected_analysts=analysts,
config=integration_config,
)
trading_graph.graph = mock_graph
@ -134,7 +141,8 @@ class TestFullWorkflowIntegration:
# Execute
with patch("builtins.open", create=True), patch("json.dump"):
final_state, decision = trading_graph.propagate(
"TSLA", "2024-05-15",
"TSLA",
"2024-05-15",
)
# Verify
@ -144,7 +152,10 @@ class TestFullWorkflowIntegration:
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
@patch("tradingagents.graph.trading_graph.Toolkit")
def test_memory_and_reflection_integration(
self, mock_toolkit, mock_chat_openai, integration_config,
self,
mock_toolkit,
mock_chat_openai,
integration_config,
):
"""Test integration of memory and reflection components."""
# Setup
@ -208,7 +219,10 @@ class TestFullWorkflowIntegration:
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
@patch("tradingagents.graph.trading_graph.Toolkit")
def test_debug_mode_integration(
self, mock_toolkit, mock_chat_openai, integration_config,
self,
mock_toolkit,
mock_chat_openai,
integration_config,
):
"""Test integration in debug mode."""
# Setup
@ -240,7 +254,8 @@ class TestFullWorkflowIntegration:
with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"):
with patch("tradingagents.graph.trading_graph.set_config"):
trading_graph = TradingAgentsGraph(
debug=True, config=integration_config,
debug=True,
config=integration_config,
)
trading_graph.graph = mock_graph
@ -276,7 +291,12 @@ class TestFullWorkflowIntegration:
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
@patch("tradingagents.graph.trading_graph.Toolkit")
def test_multiple_stocks_integration(
self, mock_toolkit, mock_chat_openai, ticker, date, integration_config,
self,
mock_toolkit,
mock_chat_openai,
ticker,
date,
integration_config,
):
"""Test integration with different stocks and dates."""
# Setup
@ -382,7 +402,11 @@ class TestPerformanceIntegration:
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
@patch("tradingagents.graph.trading_graph.Toolkit")
def test_multiple_consecutive_runs(
self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir,
self,
mock_toolkit,
mock_chat_openai,
sample_config,
temp_data_dir,
):
"""Test multiple consecutive trading decisions."""
sample_config["project_dir"] = temp_data_dir

View File

@ -17,7 +17,10 @@ class TestMarketAnalyst:
assert callable(analyst_node)
def test_market_analyst_node_basic_execution(
self, mock_llm, mock_toolkit, sample_agent_state,
self,
mock_llm,
mock_toolkit,
sample_agent_state,
):
"""Test basic execution of market analyst node."""
# Setup
@ -39,7 +42,10 @@ class TestMarketAnalyst:
assert result["market_report"] == "Market analysis complete"
def test_market_analyst_uses_online_tools_when_configured(
self, mock_llm, mock_toolkit, sample_agent_state,
self,
mock_llm,
mock_toolkit,
sample_agent_state,
):
"""Test that analyst uses online tools when configured."""
# Setup
@ -64,7 +70,10 @@ class TestMarketAnalyst:
assert "get_YFin_data_online" in str(tool_names) or len(bound_tools) == 2
def test_market_analyst_uses_offline_tools_when_configured(
self, mock_llm, mock_toolkit, sample_agent_state,
self,
mock_llm,
mock_toolkit,
sample_agent_state,
):
"""Test that analyst uses offline tools when configured."""
# Setup
@ -88,7 +97,10 @@ class TestMarketAnalyst:
assert len(bound_tools) == 2 # Should have 2 offline tools
def test_market_analyst_processes_state_variables(
self, mock_llm, mock_toolkit, sample_agent_state,
self,
mock_llm,
mock_toolkit,
sample_agent_state,
):
"""Test that market analyst correctly processes state variables."""
# Setup
@ -112,7 +124,10 @@ class TestMarketAnalyst:
assert result["market_report"] == "Analysis for AAPL on 2024-05-10"
def test_market_analyst_handles_empty_tool_calls(
self, mock_llm, mock_toolkit, sample_agent_state,
self,
mock_llm,
mock_toolkit,
sample_agent_state,
):
"""Test handling when no tool calls are made."""
# Setup
@ -132,7 +147,10 @@ class TestMarketAnalyst:
assert result["messages"] == [mock_result]
def test_market_analyst_with_tool_calls(
self, mock_llm, mock_toolkit, sample_agent_state,
self,
mock_llm,
mock_toolkit,
sample_agent_state,
):
"""Test handling when tool calls are present."""
# Setup
@ -153,7 +171,11 @@ class TestMarketAnalyst:
@pytest.mark.parametrize("online_tools", [True, False])
def test_market_analyst_tool_configuration(
self, mock_llm, mock_toolkit, sample_agent_state, online_tools,
self,
mock_llm,
mock_toolkit,
sample_agent_state,
online_tools,
):
"""Test tool configuration for both online and offline modes."""
# Setup

View File

@ -190,7 +190,10 @@ class TestFinnhubUtils:
# Test without period
expected_path_no_period = os.path.join(
temp_data_dir, "finnhub_data", data_type, f"{ticker}_data_formatted.json",
temp_data_dir,
"finnhub_data",
data_type,
f"{ticker}_data_formatted.json",
)
# Test with period
@ -248,7 +251,10 @@ class TestFinnhubUtils:
],
)
def test_get_data_in_range_various_data_types(
self, temp_data_dir, data_type, period,
self,
temp_data_dir,
data_type,
period,
):
"""Test get_data_in_range with various data types."""
ticker = "TEST"

View File

@ -45,7 +45,11 @@ class TestTradingAgentsGraph:
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
@patch("tradingagents.graph.trading_graph.Toolkit")
def test_init_with_debug(
self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir,
self,
mock_toolkit,
mock_chat_openai,
sample_config,
temp_data_dir,
):
"""Test initialization with debug mode enabled."""
sample_config["project_dir"] = temp_data_dir
@ -63,7 +67,11 @@ class TestTradingAgentsGraph:
@patch("tradingagents.graph.trading_graph.ChatAnthropic")
@patch("tradingagents.graph.trading_graph.Toolkit")
def test_init_with_anthropic(
self, mock_toolkit, mock_chat_anthropic, sample_config, temp_data_dir,
self,
mock_toolkit,
mock_chat_anthropic,
sample_config,
temp_data_dir,
):
"""Test initialization with Anthropic LLM provider."""
sample_config["project_dir"] = temp_data_dir
@ -82,7 +90,11 @@ class TestTradingAgentsGraph:
@patch("tradingagents.graph.trading_graph.ChatGoogleGenerativeAI")
@patch("tradingagents.graph.trading_graph.Toolkit")
def test_init_with_google(
self, mock_toolkit, mock_chat_google, sample_config, temp_data_dir,
self,
mock_toolkit,
mock_chat_google,
sample_config,
temp_data_dir,
):
"""Test initialization with Google LLM provider."""
sample_config["project_dir"] = temp_data_dir
@ -100,7 +112,10 @@ class TestTradingAgentsGraph:
@patch("tradingagents.graph.trading_graph.Toolkit")
def test_init_unsupported_llm_provider(
self, mock_toolkit, sample_config, temp_data_dir,
self,
mock_toolkit,
sample_config,
temp_data_dir,
):
"""Test initialization with unsupported LLM provider raises error."""
sample_config["project_dir"] = temp_data_dir
@ -115,7 +130,11 @@ class TestTradingAgentsGraph:
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
@patch("tradingagents.graph.trading_graph.Toolkit")
def test_create_tool_nodes(
self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir,
self,
mock_toolkit,
mock_chat_openai,
sample_config,
temp_data_dir,
):
"""Test creation of tool nodes."""
sample_config["project_dir"] = temp_data_dir
@ -143,7 +162,11 @@ class TestTradingAgentsGraph:
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
@patch("tradingagents.graph.trading_graph.Toolkit")
def test_propagate_basic(
self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir,
self,
mock_toolkit,
mock_chat_openai,
sample_config,
temp_data_dir,
):
"""Test basic propagate functionality."""
sample_config["project_dir"] = temp_data_dir
@ -206,7 +229,11 @@ class TestTradingAgentsGraph:
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
@patch("tradingagents.graph.trading_graph.Toolkit")
def test_propagate_debug_mode(
self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir,
self,
mock_toolkit,
mock_chat_openai,
sample_config,
temp_data_dir,
):
"""Test propagate in debug mode."""
sample_config["project_dir"] = temp_data_dir
@ -245,7 +272,11 @@ class TestTradingAgentsGraph:
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
@patch("tradingagents.graph.trading_graph.Toolkit")
def test_log_state(
self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir,
self,
mock_toolkit,
mock_chat_openai,
sample_config,
temp_data_dir,
):
"""Test state logging functionality."""
sample_config["project_dir"] = temp_data_dir
@ -300,7 +331,11 @@ class TestTradingAgentsGraph:
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
@patch("tradingagents.graph.trading_graph.Toolkit")
def test_reflect_and_remember(
self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir,
self,
mock_toolkit,
mock_chat_openai,
sample_config,
temp_data_dir,
):
"""Test reflection and memory update functionality."""
sample_config["project_dir"] = temp_data_dir
@ -309,9 +344,12 @@ class TestTradingAgentsGraph:
mock_toolkit_instance = Mock()
mock_toolkit.return_value = mock_toolkit_instance
with patch(
"tradingagents.graph.trading_graph.FinancialSituationMemory",
), patch("tradingagents.graph.trading_graph.set_config"):
with (
patch(
"tradingagents.graph.trading_graph.FinancialSituationMemory",
),
patch("tradingagents.graph.trading_graph.set_config"),
):
graph = TradingAgentsGraph(config=sample_config)
# Set up current state
@ -339,7 +377,11 @@ class TestTradingAgentsGraph:
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
@patch("tradingagents.graph.trading_graph.Toolkit")
def test_process_signal(
self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir,
self,
mock_toolkit,
mock_chat_openai,
sample_config,
temp_data_dir,
):
"""Test signal processing functionality."""
sample_config["project_dir"] = temp_data_dir
@ -388,7 +430,8 @@ class TestTradingAgentsGraph:
with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"):
with patch("tradingagents.graph.trading_graph.set_config"):
TradingAgentsGraph(
selected_analysts=selected_analysts, config=sample_config,
selected_analysts=selected_analysts,
config=sample_config,
)
# Verify graph was set up with selected analysts
@ -416,7 +459,10 @@ class TestTradingAgentsGraphErrorHandling:
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
@patch("tradingagents.graph.trading_graph.Toolkit")
def test_directory_creation_failure(
self, mock_toolkit, mock_chat_openai, sample_config,
self,
mock_toolkit,
mock_chat_openai,
sample_config,
):
"""Test handling when directory creation fails."""
sample_config["project_dir"] = "/invalid/path/that/cannot/be/created"

View File

@ -20,7 +20,7 @@ def create_fundamentals_analyst(llm, toolkit):
system_message = (
"You are a researcher tasked with analyzing fundamental information over the past week about a company. Please write a comprehensive report of the company's fundamental information such as financial documents, company profile, basic company financials, company financial history, insider sentiment and insider transactions to gain a full view of the company's fundamental information to inform traders. Make sure to include as much detail as possible. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."
" Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read.",
" Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read.",
)
prompt = ChatPromptTemplate.from_messages(

View File

@ -45,7 +45,7 @@ Volume-Based Indicators:
- vwma: VWMA: A moving average weighted by volume. Usage: Confirm trends by integrating price action with volume data. Tips: Watch for skewed results from volume spikes; use in combination with other volume analyses.
- Select indicators that provide diverse and complementary information. Avoid redundancy (e.g., do not select both rsi and stochrsi). Also briefly explain why they are suitable for the given market context. When you tool call, please use the exact name of the indicators provided above as they are defined parameters, otherwise your call will fail. Please make sure to call get_YFin_data first to retrieve the CSV that is needed to generate indicators. Write a very detailed and nuanced report of the trends you observe. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."""
""" Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read."""
""" Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read."""
)
prompt = ChatPromptTemplate.from_messages(

View File

@ -17,7 +17,7 @@ def create_news_analyst(llm, toolkit):
system_message = (
"You are a news researcher tasked with analyzing recent news and trends over the past week. Please write a comprehensive report of the current state of the world that is relevant for trading and macroeconomics. Look at news from EODHD, and finnhub to be comprehensive. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."
""" Make sure to append a Makrdown table at the end of the report to organize key points in the report, organized and easy to read."""
""" Make sure to append a Makrdown table at the end of the report to organize key points in the report, organized and easy to read."""
)
prompt = ChatPromptTemplate.from_messages(

View File

@ -16,7 +16,7 @@ def create_social_media_analyst(llm, toolkit):
system_message = (
"You are a social media and company specific news researcher/analyst tasked with analyzing social media posts, recent company news, and public sentiment for a specific company over the past week. You will be given a company's name your objective is to write a comprehensive long report detailing your analysis, insights, and implications for traders and investors on this company's current state after looking at social media and what people are saying about that company, analyzing sentiment data of what people feel each day about the company, and looking at recent company news. Try to look at all sources possible from social media to sentiment to news. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."
""" Make sure to append a Makrdown table at the end of the report to organize key points in the report, organized and easy to read.""",
""" Make sure to append a Makrdown table at the end of the report to organize key points in the report, organized and easy to read.""",
)
prompt = ChatPromptTemplate.from_messages(

View File

@ -41,7 +41,8 @@ Engage actively by addressing any specific concerns raised, refuting the weaknes
"current_risky_response": argument,
"current_safe_response": risk_debate_state.get("current_safe_response", ""),
"current_neutral_response": risk_debate_state.get(
"current_neutral_response", "",
"current_neutral_response",
"",
),
"count": risk_debate_state["count"] + 1,
}

View File

@ -39,11 +39,13 @@ Engage by questioning their optimism and emphasizing the potential downsides the
"neutral_history": risk_debate_state.get("neutral_history", ""),
"latest_speaker": "Safe",
"current_risky_response": risk_debate_state.get(
"current_risky_response", "",
"current_risky_response",
"",
),
"current_safe_response": argument,
"current_neutral_response": risk_debate_state.get(
"current_neutral_response", "",
"current_neutral_response",
"",
),
"count": risk_debate_state["count"] + 1,
}

View File

@ -39,7 +39,8 @@ Engage actively by analyzing both sides critically, addressing weaknesses in the
"neutral_history": neutral_history + "\n" + argument,
"latest_speaker": "Neutral",
"current_risky_response": risk_debate_state.get(
"current_risky_response", "",
"current_risky_response",
"",
),
"current_safe_response": risk_debate_state.get("current_safe_response", ""),
"current_neutral_response": argument,

View File

@ -9,10 +9,12 @@ from typing_extensions import TypedDict
# Researcher team state
class InvestDebateState(TypedDict):
bull_history: Annotated[
str, "Bullish Conversation history",
str,
"Bullish Conversation history",
] # Bullish Conversation history
bear_history: Annotated[
str, "Bearish Conversation history",
str,
"Bearish Conversation history",
] # Bullish Conversation history
history: Annotated[str, "Conversation history"] # Conversation history
current_response: Annotated[str, "Latest response"] # Last response
@ -23,24 +25,30 @@ class InvestDebateState(TypedDict):
# Risk management team state
class RiskDebateState(TypedDict):
risky_history: Annotated[
str, "Risky Agent's Conversation history",
str,
"Risky Agent's Conversation history",
] # Conversation history
safe_history: Annotated[
str, "Safe Agent's Conversation history",
str,
"Safe Agent's Conversation history",
] # Conversation history
neutral_history: Annotated[
str, "Neutral Agent's Conversation history",
str,
"Neutral Agent's Conversation history",
] # Conversation history
history: Annotated[str, "Conversation history"] # Conversation history
latest_speaker: Annotated[str, "Analyst that spoke last"]
current_risky_response: Annotated[
str, "Latest response by the risky analyst",
str,
"Latest response by the risky analyst",
] # Last response
current_safe_response: Annotated[
str, "Latest response by the safe analyst",
str,
"Latest response by the safe analyst",
] # Last response
current_neutral_response: Annotated[
str, "Latest response by the neutral analyst",
str,
"Latest response by the neutral analyst",
] # Last response
judge_decision: Annotated[str, "Judge's decision"]
count: Annotated[int, "Length of the current conversation"] # Conversation length
@ -56,13 +64,15 @@ class AgentState(MessagesState):
market_report: Annotated[str, "Report from the Market Analyst"]
sentiment_report: Annotated[str, "Report from the Social Media Analyst"]
news_report: Annotated[
str, "Report from the News Researcher of current world affairs",
str,
"Report from the News Researcher of current world affairs",
]
fundamentals_report: Annotated[str, "Report from the Fundamentals Researcher"]
# researcher team discussion step
investment_debate_state: Annotated[
InvestDebateState, "Current state of the debate on if to invest or not",
InvestDebateState,
"Current state of the debate on if to invest or not",
]
investment_plan: Annotated[str, "Plan generated by the Analyst"]
@ -70,6 +80,7 @@ class AgentState(MessagesState):
# risk management team discussion step
risk_debate_state: Annotated[
RiskDebateState, "Current state of the debate on evaluating risk",
RiskDebateState,
"Current state of the debate on evaluating risk",
]
final_trade_decision: Annotated[str, "Final decision made by the Risk Analysts"]

View File

@ -56,7 +56,6 @@ class Toolkit:
return interface.get_reddit_global_news(curr_date, 7, 5)
@staticmethod
@tool
def get_finnhub_news(
@ -84,10 +83,11 @@ class Toolkit:
look_back_days = (end_date - start_date).days
return interface.get_finnhub_news(
ticker, end_date_str, look_back_days,
ticker,
end_date_str,
look_back_days,
)
@staticmethod
@tool
def get_reddit_stock_info(
@ -108,7 +108,6 @@ class Toolkit:
return interface.get_reddit_company_news(ticker, curr_date, 7, 5)
@staticmethod
@tool
def get_YFin_data(
@ -128,7 +127,6 @@ class Toolkit:
return interface.get_YFin_data(symbol, start_date, end_date)
@staticmethod
@tool
def get_YFin_data_online(
@ -148,16 +146,17 @@ class Toolkit:
return interface.get_YFin_data_online(symbol, start_date, end_date)
@staticmethod
@tool
def get_stockstats_indicators_report(
symbol: Annotated[str, "ticker symbol of the company"],
indicator: Annotated[
str, "technical indicator to get the analysis and report of",
str,
"technical indicator to get the analysis and report of",
],
curr_date: Annotated[
str, "The current trading date you are trading on, YYYY-mm-dd",
str,
"The current trading date you are trading on, YYYY-mm-dd",
],
look_back_days: Annotated[int, "how many days to look back"] = 30,
) -> str:
@ -173,19 +172,24 @@ class Toolkit:
"""
return interface.get_stock_stats_indicators_window(
symbol, indicator, curr_date, look_back_days, False,
symbol,
indicator,
curr_date,
look_back_days,
False,
)
@staticmethod
@tool
def get_stockstats_indicators_report_online(
symbol: Annotated[str, "ticker symbol of the company"],
indicator: Annotated[
str, "technical indicator to get the analysis and report of",
str,
"technical indicator to get the analysis and report of",
],
curr_date: Annotated[
str, "The current trading date you are trading on, YYYY-mm-dd",
str,
"The current trading date you are trading on, YYYY-mm-dd",
],
look_back_days: Annotated[int, "how many days to look back"] = 30,
) -> str:
@ -201,10 +205,13 @@ class Toolkit:
"""
return interface.get_stock_stats_indicators_window(
symbol, indicator, curr_date, look_back_days, True,
symbol,
indicator,
curr_date,
look_back_days,
True,
)
@staticmethod
@tool
def get_finnhub_company_insider_sentiment(
@ -224,10 +231,11 @@ class Toolkit:
"""
return interface.get_finnhub_company_insider_sentiment(
ticker, curr_date, 30,
ticker,
curr_date,
30,
)
@staticmethod
@tool
def get_finnhub_company_insider_transactions(
@ -247,10 +255,11 @@ class Toolkit:
"""
return interface.get_finnhub_company_insider_transactions(
ticker, curr_date, 30,
ticker,
curr_date,
30,
)
@staticmethod
@tool
def get_simfin_balance_sheet(
@ -273,7 +282,6 @@ class Toolkit:
return interface.get_simfin_balance_sheet(ticker, freq, curr_date)
@staticmethod
@tool
def get_simfin_cashflow(
@ -296,7 +304,6 @@ class Toolkit:
return interface.get_simfin_cashflow(ticker, freq, curr_date)
@staticmethod
@tool
def get_simfin_income_stmt(
@ -318,10 +325,11 @@ class Toolkit:
"""
return interface.get_simfin_income_statements(
ticker, freq, curr_date,
ticker,
freq,
curr_date,
)
@staticmethod
@tool
def get_google_news(
@ -340,7 +348,6 @@ class Toolkit:
return interface.get_google_news(query, curr_date, 7)
@staticmethod
@tool
def get_stock_news_openai(
@ -358,7 +365,6 @@ class Toolkit:
return interface.get_stock_news_openai(ticker, curr_date)
@staticmethod
@tool
def get_global_news_openai(
@ -374,7 +380,6 @@ class Toolkit:
return interface.get_global_news_openai(curr_date)
@staticmethod
@tool
def get_fundamentals_openai(
@ -391,6 +396,6 @@ class Toolkit:
"""
return interface.get_fundamentals_openai(
ticker, curr_date,
ticker,
curr_date,
)

View File

@ -21,7 +21,8 @@ def get_config():
"project_dir": str(project_root / "tradingagents"),
"results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", "./results"),
"data_dir": os.getenv(
"TRADINGAGENTS_DATA_DIR", "/Users/yluo/Documents/Code/ScAI/FR1-data",
"TRADINGAGENTS_DATA_DIR",
"/Users/yluo/Documents/Code/ScAI/FR1-data",
),
"data_cache_dir": str(
project_root / "tradingagents" / "dataflows" / "data_cache",

View File

@ -1,4 +1,3 @@
from tradingagents import default_config
# Use default config but allow it to be overridden

View File

@ -22,7 +22,10 @@ def get_data_in_range(ticker, start_date, end_date, data_type, data_dir, period=
)
else:
data_path = os.path.join(
data_dir, "finnhub_data", data_type, f"{ticker}_data_formatted.json",
data_dir,
"finnhub_data",
data_type,
f"{ticker}_data_formatted.json",
)
data = open(data_path)

View File

@ -419,7 +419,8 @@ def get_stock_stats_indicators_window(
symbol: Annotated[str, "ticker symbol of the company"],
indicator: Annotated[str, "technical indicator to get the analysis and report of"],
curr_date: Annotated[
str, "The current trading date you are trading on, YYYY-mm-dd",
str,
"The current trading date you are trading on, YYYY-mm-dd",
],
look_back_days: Annotated[int, "how many days to look back"],
online: Annotated[bool, "to fetch data online or offline"],
@ -524,7 +525,10 @@ def get_stock_stats_indicators_window(
# only do the trading dates
if curr_date.strftime("%Y-%m-%d") in dates_in_df.values:
indicator_value = get_stockstats_indicator(
symbol, indicator, curr_date.strftime("%Y-%m-%d"), online,
symbol,
indicator,
curr_date.strftime("%Y-%m-%d"),
online,
)
ind_string += f"{curr_date.strftime('%Y-%m-%d')}: {indicator_value}\n"
@ -535,7 +539,10 @@ def get_stock_stats_indicators_window(
ind_string = ""
while curr_date >= before:
indicator_value = get_stockstats_indicator(
symbol, indicator, curr_date.strftime("%Y-%m-%d"), online,
symbol,
indicator,
curr_date.strftime("%Y-%m-%d"),
online,
)
ind_string += f"{curr_date.strftime('%Y-%m-%d')}: {indicator_value}\n"
@ -550,12 +557,12 @@ def get_stock_stats_indicators_window(
)
def get_stockstats_indicator(
symbol: Annotated[str, "ticker symbol of the company"],
indicator: Annotated[str, "technical indicator to get the analysis and report of"],
curr_date: Annotated[
str, "The current trading date you are trading on, YYYY-mm-dd",
str,
"The current trading date you are trading on, YYYY-mm-dd",
],
online: Annotated[bool, "to fetch data online or offline"],
) -> str:
@ -608,7 +615,12 @@ def get_YFin_data_window(
# Set pandas display options to show the full DataFrame
with pd.option_context(
"display.max_rows", None, "display.max_columns", None, "display.width", None,
"display.max_rows",
None,
"display.max_columns",
None,
"display.width",
None,
):
df_string = filtered_data.to_string()
@ -694,7 +706,6 @@ def get_YFin_data(
return filtered_data.reset_index(drop=True)
def get_stock_news_openai(ticker, curr_date):
config = get_config()
client = OpenAI(base_url=config["backend_url"])

View File

@ -48,11 +48,14 @@ ticker_to_company = {
def fetch_top_from_category(
category: Annotated[
str, "Category to fetch top post from. Collection of subreddits.",
str,
"Category to fetch top post from. Collection of subreddits.",
],
date: Annotated[str, "Date to fetch top posts from."],
max_limit: Annotated[int, "Maximum number of posts to fetch."],
query: Annotated[str | None, "Optional query to search for in the subreddit."] = None,
query: Annotated[
str | None, "Optional query to search for in the subreddit."
] = None,
data_path: Annotated[
str,
"Path to the data folder. Default is 'reddit_data'.",
@ -107,7 +110,9 @@ def fetch_top_from_category(
found = False
for term in search_terms:
if re.search(
term, parsed_line["title"], re.IGNORECASE,
term,
parsed_line["title"],
re.IGNORECASE,
) or re.search(term, parsed_line["selftext"], re.IGNORECASE):
found = True
break

View File

@ -13,10 +13,12 @@ class StockstatsUtils:
def get_stock_stats(
symbol: Annotated[str, "ticker symbol for the company"],
indicator: Annotated[
str, "quantitative indicators based off of the stock data for the company",
str,
"quantitative indicators based off of the stock data for the company",
],
curr_date: Annotated[
str, "curr date for retrieving stock price data, YYYY-mm-dd",
str,
"curr date for retrieving stock price data, YYYY-mm-dd",
],
data_dir: Annotated[
str,

View File

@ -28,10 +28,12 @@ class YFinanceUtils:
def get_stock_data(
self: Annotated[str, "ticker symbol"],
start_date: Annotated[
str, "start date for retrieving stock price data, YYYY-mm-dd",
str,
"start date for retrieving stock price data, YYYY-mm-dd",
],
end_date: Annotated[
str, "end date for retrieving stock price data, YYYY-mm-dd",
str,
"end date for retrieving stock price data, YYYY-mm-dd",
],
save_path: SavePathType = None,
) -> DataFrame:

View File

@ -16,7 +16,9 @@ class Propagator:
self.max_recur_limit = max_recur_limit
def create_initial_state(
self, company_name: str, trade_date: str,
self,
company_name: str,
trade_date: str,
) -> dict[str, Any]:
"""Create the initial state for the agent graph."""
return {

View File

@ -57,7 +57,11 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur
return f"{curr_market_report}\n\n{curr_sentiment_report}\n\n{curr_news_report}\n\n{curr_fundamentals_report}"
def _reflect_on_component(
self, component_type: str, report: str, situation: str, returns_losses,
self,
component_type: str,
report: str,
situation: str,
returns_losses,
) -> str:
"""Generate reflection for a component."""
messages = [
@ -76,7 +80,10 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur
bull_debate_history = current_state["investment_debate_state"]["bull_history"]
result = self._reflect_on_component(
"BULL", bull_debate_history, situation, returns_losses,
"BULL",
bull_debate_history,
situation,
returns_losses,
)
bull_memory.add_situations([(situation, result)])
@ -86,7 +93,10 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur
bear_debate_history = current_state["investment_debate_state"]["bear_history"]
result = self._reflect_on_component(
"BEAR", bear_debate_history, situation, returns_losses,
"BEAR",
bear_debate_history,
situation,
returns_losses,
)
bear_memory.add_situations([(situation, result)])
@ -96,7 +106,10 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur
trader_decision = current_state["trader_investment_plan"]
result = self._reflect_on_component(
"TRADER", trader_decision, situation, returns_losses,
"TRADER",
trader_decision,
situation,
returns_losses,
)
trader_memory.add_situations([(situation, result)])
@ -106,7 +119,10 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur
judge_decision = current_state["investment_debate_state"]["judge_decision"]
result = self._reflect_on_component(
"INVEST JUDGE", judge_decision, situation, returns_losses,
"INVEST JUDGE",
judge_decision,
situation,
returns_losses,
)
invest_judge_memory.add_situations([(situation, result)])
@ -116,6 +132,9 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur
judge_decision = current_state["risk_debate_state"]["judge_decision"]
result = self._reflect_on_component(
"RISK JUDGE", judge_decision, situation, returns_losses,
"RISK JUDGE",
judge_decision,
situation,
returns_losses,
)
risk_manager_memory.add_situations([(situation, result)])

View File

@ -55,7 +55,8 @@ class GraphSetup:
self.conditional_logic = conditional_logic
def setup_graph(
self, selected_analysts=None,
self,
selected_analysts=None,
):
"""Set up and compile the agent workflow graph.
@ -79,41 +80,48 @@ class GraphSetup:
if "market" in selected_analysts:
analyst_nodes["market"] = create_market_analyst(
self.quick_thinking_llm, self.toolkit,
self.quick_thinking_llm,
self.toolkit,
)
delete_nodes["market"] = create_msg_delete()
tool_nodes["market"] = self.tool_nodes["market"]
if "social" in selected_analysts:
analyst_nodes["social"] = create_social_media_analyst(
self.quick_thinking_llm, self.toolkit,
self.quick_thinking_llm,
self.toolkit,
)
delete_nodes["social"] = create_msg_delete()
tool_nodes["social"] = self.tool_nodes["social"]
if "news" in selected_analysts:
analyst_nodes["news"] = create_news_analyst(
self.quick_thinking_llm, self.toolkit,
self.quick_thinking_llm,
self.toolkit,
)
delete_nodes["news"] = create_msg_delete()
tool_nodes["news"] = self.tool_nodes["news"]
if "fundamentals" in selected_analysts:
analyst_nodes["fundamentals"] = create_fundamentals_analyst(
self.quick_thinking_llm, self.toolkit,
self.quick_thinking_llm,
self.toolkit,
)
delete_nodes["fundamentals"] = create_msg_delete()
tool_nodes["fundamentals"] = self.tool_nodes["fundamentals"]
# Create researcher and manager nodes
bull_researcher_node = create_bull_researcher(
self.quick_thinking_llm, self.bull_memory,
self.quick_thinking_llm,
self.bull_memory,
)
bear_researcher_node = create_bear_researcher(
self.quick_thinking_llm, self.bear_memory,
self.quick_thinking_llm,
self.bear_memory,
)
research_manager_node = create_research_manager(
self.deep_thinking_llm, self.invest_judge_memory,
self.deep_thinking_llm,
self.invest_judge_memory,
)
trader_node = create_trader(self.quick_thinking_llm, self.trader_memory)
@ -122,7 +130,8 @@ class GraphSetup:
neutral_analyst = create_neutral_debator(self.quick_thinking_llm)
safe_analyst = create_safe_debator(self.quick_thinking_llm)
risk_manager_node = create_risk_manager(
self.deep_thinking_llm, self.risk_manager_memory,
self.deep_thinking_llm,
self.risk_manager_memory,
)
# Create workflow
@ -132,7 +141,8 @@ class GraphSetup:
for analyst_type, node in analyst_nodes.items():
workflow.add_node(f"{analyst_type.capitalize()} Analyst", node)
workflow.add_node(
f"Msg Clear {analyst_type.capitalize()}", delete_nodes[analyst_type],
f"Msg Clear {analyst_type.capitalize()}",
delete_nodes[analyst_type],
)
workflow.add_node(f"tools_{analyst_type}", tool_nodes[analyst_type])

View File

@ -59,7 +59,8 @@ class TradingAgentsGraph:
or self.config["llm_provider"] == "openrouter"
):
self.deep_thinking_llm = ChatOpenAI(
model=self.config["deep_think_llm"], base_url=self.config["backend_url"],
model=self.config["deep_think_llm"],
base_url=self.config["backend_url"],
)
self.quick_thinking_llm = ChatOpenAI(
model=self.config["quick_think_llm"],
@ -67,7 +68,8 @@ class TradingAgentsGraph:
)
elif self.config["llm_provider"].lower() == "anthropic":
self.deep_thinking_llm = ChatAnthropic(
model=self.config["deep_think_llm"], base_url=self.config["backend_url"],
model=self.config["deep_think_llm"],
base_url=self.config["backend_url"],
)
self.quick_thinking_llm = ChatAnthropic(
model=self.config["quick_think_llm"],
@ -91,10 +93,12 @@ class TradingAgentsGraph:
self.bear_memory = FinancialSituationMemory("bear_memory", self.config)
self.trader_memory = FinancialSituationMemory("trader_memory", self.config)
self.invest_judge_memory = FinancialSituationMemory(
"invest_judge_memory", self.config,
"invest_judge_memory",
self.config,
)
self.risk_manager_memory = FinancialSituationMemory(
"risk_manager_memory", self.config,
"risk_manager_memory",
self.config,
)
# Create tool nodes
@ -179,7 +183,8 @@ class TradingAgentsGraph:
# Initialize state
init_agent_state = self.propagator.create_initial_state(
company_name, trade_date,
company_name,
trade_date,
)
args = self.propagator.get_graph_args()
@ -252,19 +257,29 @@ class TradingAgentsGraph:
def reflect_and_remember(self, returns_losses):
"""Reflect on decisions and update memory based on returns."""
self.reflector.reflect_bull_researcher(
self.curr_state, returns_losses, self.bull_memory,
self.curr_state,
returns_losses,
self.bull_memory,
)
self.reflector.reflect_bear_researcher(
self.curr_state, returns_losses, self.bear_memory,
self.curr_state,
returns_losses,
self.bear_memory,
)
self.reflector.reflect_trader(
self.curr_state, returns_losses, self.trader_memory,
self.curr_state,
returns_losses,
self.trader_memory,
)
self.reflector.reflect_invest_judge(
self.curr_state, returns_losses, self.invest_judge_memory,
self.curr_state,
returns_losses,
self.invest_judge_memory,
)
self.reflector.reflect_risk_manager(
self.curr_state, returns_losses, self.risk_manager_memory,
self.curr_state,
returns_losses,
self.risk_manager_memory,
)
def process_signal(self, full_signal):