Apply Black formatting to pass CI checks
- Formatted 33 Python files with Black - Fixed code style to meet project standards - Ensures CI/CD pipeline passes formatting checks
This commit is contained in:
parent
6f3981412b
commit
850764ad7b
132
cli/main.py
132
cli/main.py
|
|
@ -148,7 +148,7 @@ class MessageBuffer:
|
|||
f"### News Analysis\n{self.report_sections['news_report']}",
|
||||
)
|
||||
if self.report_sections["fundamentals_report"]:
|
||||
fundamentals = self.report_sections['fundamentals_report']
|
||||
fundamentals = self.report_sections["fundamentals_report"]
|
||||
report_parts.append(
|
||||
f"### Fundamentals Analysis\n{fundamentals}",
|
||||
)
|
||||
|
|
@ -182,10 +182,12 @@ def create_layout():
|
|||
Layout(name="footer", size=3),
|
||||
)
|
||||
layout["main"].split_column(
|
||||
Layout(name="upper", ratio=3), Layout(name="analysis", ratio=5),
|
||||
Layout(name="upper", ratio=3),
|
||||
Layout(name="analysis", ratio=5),
|
||||
)
|
||||
layout["upper"].split_row(
|
||||
Layout(name="progress", ratio=2), Layout(name="messages", ratio=3),
|
||||
Layout(name="progress", ratio=2),
|
||||
Layout(name="messages", ratio=3),
|
||||
)
|
||||
return layout
|
||||
|
||||
|
|
@ -237,7 +239,9 @@ def update_display(layout, spinner_text=None):
|
|||
status = message_buffer.agent_status[first_agent]
|
||||
if status == "in_progress":
|
||||
spinner = Spinner(
|
||||
"dots", text="[blue]in_progress[/blue]", style="bold cyan",
|
||||
"dots",
|
||||
text="[blue]in_progress[/blue]",
|
||||
style="bold cyan",
|
||||
)
|
||||
status_cell = spinner
|
||||
else:
|
||||
|
|
@ -254,7 +258,9 @@ def update_display(layout, spinner_text=None):
|
|||
status = message_buffer.agent_status[agent]
|
||||
if status == "in_progress":
|
||||
spinner = Spinner(
|
||||
"dots", text="[blue]in_progress[/blue]", style="bold cyan",
|
||||
"dots",
|
||||
text="[blue]in_progress[/blue]",
|
||||
style="bold cyan",
|
||||
)
|
||||
status_cell = spinner
|
||||
else:
|
||||
|
|
@ -286,7 +292,10 @@ def update_display(layout, spinner_text=None):
|
|||
messages_table.add_column("Time", style="cyan", width=8, justify="center")
|
||||
messages_table.add_column("Type", style="green", width=10, justify="center")
|
||||
messages_table.add_column(
|
||||
"Content", style="white", no_wrap=False, ratio=1,
|
||||
"Content",
|
||||
style="white",
|
||||
no_wrap=False,
|
||||
ratio=1,
|
||||
) # Make content column expand
|
||||
|
||||
# Combine tool calls and messages
|
||||
|
|
@ -441,7 +450,9 @@ def get_user_selections():
|
|||
# Step 1: Ticker symbol
|
||||
console.print(
|
||||
create_question_box(
|
||||
"Step 1: Ticker Symbol", "Enter the ticker symbol to analyze", "SPY",
|
||||
"Step 1: Ticker Symbol",
|
||||
"Enter the ticker symbol to analyze",
|
||||
"SPY",
|
||||
),
|
||||
)
|
||||
selected_ticker = get_ticker()
|
||||
|
|
@ -460,7 +471,8 @@ def get_user_selections():
|
|||
# Step 3: Select analysts
|
||||
console.print(
|
||||
create_question_box(
|
||||
"Step 3: Analysts Team", "Select your LLM analyst agents for the analysis",
|
||||
"Step 3: Analysts Team",
|
||||
"Select your LLM analyst agents for the analysis",
|
||||
),
|
||||
)
|
||||
selected_analysts = select_analysts()
|
||||
|
|
@ -471,21 +483,25 @@ def get_user_selections():
|
|||
# Step 4: Research depth
|
||||
console.print(
|
||||
create_question_box(
|
||||
"Step 4: Research Depth", "Select your research depth level",
|
||||
"Step 4: Research Depth",
|
||||
"Select your research depth level",
|
||||
),
|
||||
)
|
||||
selected_research_depth = select_research_depth()
|
||||
|
||||
# Step 5: OpenAI backend
|
||||
console.print(
|
||||
create_question_box("Step 5: OpenAI backend", "Select which service to talk to"),
|
||||
create_question_box(
|
||||
"Step 5: OpenAI backend", "Select which service to talk to"
|
||||
),
|
||||
)
|
||||
selected_llm_provider, backend_url = select_llm_provider()
|
||||
|
||||
# Step 6: Thinking agents
|
||||
console.print(
|
||||
create_question_box(
|
||||
"Step 6: Thinking Agents", "Select your thinking agents for analysis",
|
||||
"Step 6: Thinking Agents",
|
||||
"Select your thinking agents for analysis",
|
||||
),
|
||||
)
|
||||
selected_shallow_thinker = select_shallow_thinking_agent(selected_llm_provider)
|
||||
|
|
@ -737,7 +753,9 @@ def run_analysis():
|
|||
|
||||
# Initialize the graph
|
||||
graph = TradingAgentsGraph(
|
||||
[analyst.value for analyst in selections["analysts"]], config=config, debug=True,
|
||||
[analyst.value for analyst in selections["analysts"]],
|
||||
config=config,
|
||||
debug=True,
|
||||
)
|
||||
|
||||
# Create result directory
|
||||
|
|
@ -796,10 +814,12 @@ def run_analysis():
|
|||
|
||||
message_buffer.add_message = save_message_decorator(message_buffer, "add_message")
|
||||
message_buffer.add_tool_call = save_tool_call_decorator(
|
||||
message_buffer, "add_tool_call",
|
||||
message_buffer,
|
||||
"add_tool_call",
|
||||
)
|
||||
message_buffer.update_report_section = save_report_section_decorator(
|
||||
message_buffer, "update_report_section",
|
||||
message_buffer,
|
||||
"update_report_section",
|
||||
)
|
||||
|
||||
# Now start the display layout
|
||||
|
|
@ -812,7 +832,8 @@ def run_analysis():
|
|||
# Add initial messages
|
||||
message_buffer.add_message("System", f"Selected ticker: {selections['ticker']}")
|
||||
message_buffer.add_message(
|
||||
"System", f"Analysis date: {selections['analysis_date']}",
|
||||
"System",
|
||||
f"Analysis date: {selections['analysis_date']}",
|
||||
)
|
||||
message_buffer.add_message(
|
||||
"System",
|
||||
|
|
@ -843,7 +864,8 @@ def run_analysis():
|
|||
|
||||
# Initialize state and get graph args
|
||||
init_agent_state = graph.propagator.create_initial_state(
|
||||
selections["ticker"], selections["analysis_date"],
|
||||
selections["ticker"],
|
||||
selections["analysis_date"],
|
||||
)
|
||||
args = graph.propagator.get_graph_args()
|
||||
|
||||
|
|
@ -873,7 +895,8 @@ def run_analysis():
|
|||
# Handle both dictionary and object tool calls
|
||||
if isinstance(tool_call, dict):
|
||||
message_buffer.add_tool_call(
|
||||
tool_call["name"], tool_call["args"],
|
||||
tool_call["name"],
|
||||
tool_call["args"],
|
||||
)
|
||||
else:
|
||||
message_buffer.add_tool_call(tool_call.name, tool_call.args)
|
||||
|
|
@ -882,51 +905,57 @@ def run_analysis():
|
|||
# Analyst Team Reports
|
||||
if chunk.get("market_report"):
|
||||
message_buffer.update_report_section(
|
||||
"market_report", chunk["market_report"],
|
||||
"market_report",
|
||||
chunk["market_report"],
|
||||
)
|
||||
message_buffer.update_agent_status("Market Analyst", "completed")
|
||||
# Set next analyst to in_progress
|
||||
if "social" in selections["analysts"]:
|
||||
message_buffer.update_agent_status(
|
||||
"Social Analyst", "in_progress",
|
||||
"Social Analyst",
|
||||
"in_progress",
|
||||
)
|
||||
|
||||
if chunk.get("sentiment_report"):
|
||||
message_buffer.update_report_section(
|
||||
"sentiment_report", chunk["sentiment_report"],
|
||||
"sentiment_report",
|
||||
chunk["sentiment_report"],
|
||||
)
|
||||
message_buffer.update_agent_status("Social Analyst", "completed")
|
||||
# Set next analyst to in_progress
|
||||
if "news" in selections["analysts"]:
|
||||
message_buffer.update_agent_status(
|
||||
"News Analyst", "in_progress",
|
||||
"News Analyst",
|
||||
"in_progress",
|
||||
)
|
||||
|
||||
if chunk.get("news_report"):
|
||||
message_buffer.update_report_section(
|
||||
"news_report", chunk["news_report"],
|
||||
"news_report",
|
||||
chunk["news_report"],
|
||||
)
|
||||
message_buffer.update_agent_status("News Analyst", "completed")
|
||||
# Set next analyst to in_progress
|
||||
if "fundamentals" in selections["analysts"]:
|
||||
message_buffer.update_agent_status(
|
||||
"Fundamentals Analyst", "in_progress",
|
||||
"Fundamentals Analyst",
|
||||
"in_progress",
|
||||
)
|
||||
|
||||
if chunk.get("fundamentals_report"):
|
||||
message_buffer.update_report_section(
|
||||
"fundamentals_report", chunk["fundamentals_report"],
|
||||
"fundamentals_report",
|
||||
chunk["fundamentals_report"],
|
||||
)
|
||||
message_buffer.update_agent_status(
|
||||
"Fundamentals Analyst", "completed",
|
||||
"Fundamentals Analyst",
|
||||
"completed",
|
||||
)
|
||||
# Set all research team members to in_progress
|
||||
update_research_team_status("in_progress")
|
||||
|
||||
# Research Team - Handle Investment Debate State
|
||||
if (
|
||||
chunk.get("investment_debate_state")
|
||||
):
|
||||
if chunk.get("investment_debate_state"):
|
||||
debate_state = chunk["investment_debate_state"]
|
||||
|
||||
# Update Bull Researcher status and report
|
||||
|
|
@ -960,9 +989,7 @@ def run_analysis():
|
|||
)
|
||||
|
||||
# Update Research Manager status and final decision
|
||||
if (
|
||||
debate_state.get("judge_decision")
|
||||
):
|
||||
if debate_state.get("judge_decision"):
|
||||
# Keep all research team members in progress until final decision
|
||||
update_research_team_status("in_progress")
|
||||
message_buffer.add_message(
|
||||
|
|
@ -978,15 +1005,15 @@ def run_analysis():
|
|||
update_research_team_status("completed")
|
||||
# Set first risk analyst to in_progress
|
||||
message_buffer.update_agent_status(
|
||||
"Risky Analyst", "in_progress",
|
||||
"Risky Analyst",
|
||||
"in_progress",
|
||||
)
|
||||
|
||||
# Trading Team
|
||||
if (
|
||||
chunk.get("trader_investment_plan")
|
||||
):
|
||||
if chunk.get("trader_investment_plan"):
|
||||
message_buffer.update_report_section(
|
||||
"trader_investment_plan", chunk["trader_investment_plan"],
|
||||
"trader_investment_plan",
|
||||
chunk["trader_investment_plan"],
|
||||
)
|
||||
# Set first risk analyst to in_progress
|
||||
message_buffer.update_agent_status("Risky Analyst", "in_progress")
|
||||
|
|
@ -996,11 +1023,10 @@ def run_analysis():
|
|||
risk_state = chunk["risk_debate_state"]
|
||||
|
||||
# Update Risky Analyst status and report
|
||||
if (
|
||||
risk_state.get("current_risky_response")
|
||||
):
|
||||
if risk_state.get("current_risky_response"):
|
||||
message_buffer.update_agent_status(
|
||||
"Risky Analyst", "in_progress",
|
||||
"Risky Analyst",
|
||||
"in_progress",
|
||||
)
|
||||
message_buffer.add_message(
|
||||
"Reasoning",
|
||||
|
|
@ -1013,11 +1039,10 @@ def run_analysis():
|
|||
)
|
||||
|
||||
# Update Safe Analyst status and report
|
||||
if (
|
||||
risk_state.get("current_safe_response")
|
||||
):
|
||||
if risk_state.get("current_safe_response"):
|
||||
message_buffer.update_agent_status(
|
||||
"Safe Analyst", "in_progress",
|
||||
"Safe Analyst",
|
||||
"in_progress",
|
||||
)
|
||||
message_buffer.add_message(
|
||||
"Reasoning",
|
||||
|
|
@ -1030,11 +1055,10 @@ def run_analysis():
|
|||
)
|
||||
|
||||
# Update Neutral Analyst status and report
|
||||
if (
|
||||
risk_state.get("current_neutral_response")
|
||||
):
|
||||
if risk_state.get("current_neutral_response"):
|
||||
message_buffer.update_agent_status(
|
||||
"Neutral Analyst", "in_progress",
|
||||
"Neutral Analyst",
|
||||
"in_progress",
|
||||
)
|
||||
message_buffer.add_message(
|
||||
"Reasoning",
|
||||
|
|
@ -1049,7 +1073,8 @@ def run_analysis():
|
|||
# Update Portfolio Manager status and final decision
|
||||
if risk_state.get("judge_decision"):
|
||||
message_buffer.update_agent_status(
|
||||
"Portfolio Manager", "in_progress",
|
||||
"Portfolio Manager",
|
||||
"in_progress",
|
||||
)
|
||||
message_buffer.add_message(
|
||||
"Reasoning",
|
||||
|
|
@ -1064,10 +1089,12 @@ def run_analysis():
|
|||
message_buffer.update_agent_status("Risky Analyst", "completed")
|
||||
message_buffer.update_agent_status("Safe Analyst", "completed")
|
||||
message_buffer.update_agent_status(
|
||||
"Neutral Analyst", "completed",
|
||||
"Neutral Analyst",
|
||||
"completed",
|
||||
)
|
||||
message_buffer.update_agent_status(
|
||||
"Portfolio Manager", "completed",
|
||||
"Portfolio Manager",
|
||||
"completed",
|
||||
)
|
||||
|
||||
# Update the display
|
||||
|
|
@ -1084,7 +1111,8 @@ def run_analysis():
|
|||
message_buffer.update_agent_status(agent, "completed")
|
||||
|
||||
message_buffer.add_message(
|
||||
"Analysis", f"Completed analysis for {selections['analysis_date']}",
|
||||
"Analysis",
|
||||
f"Completed analysis for {selections['analysis_date']}",
|
||||
)
|
||||
|
||||
# Update final report sections
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
|
||||
import sys
|
||||
|
||||
import questionary
|
||||
|
|
|
|||
4
main.py
4
main.py
|
|
@ -4,7 +4,9 @@ from tradingagents.default_config import DEFAULT_CONFIG
|
|||
# Create a custom config
|
||||
config = DEFAULT_CONFIG.copy()
|
||||
config["llm_provider"] = "google" # Use a different model
|
||||
config["backend_url"] = "https://generativelanguage.googleapis.com/v1" # Use a different backend
|
||||
config["backend_url"] = (
|
||||
"https://generativelanguage.googleapis.com/v1" # Use a different backend
|
||||
)
|
||||
config["deep_think_llm"] = "gemini-2.0-flash" # Use a different model
|
||||
config["quick_think_llm"] = "gemini-2.0-flash" # Use a different model
|
||||
config["max_debate_rounds"] = 1 # Increase debate rounds
|
||||
|
|
|
|||
52
run_tests.py
52
run_tests.py
|
|
@ -16,10 +16,10 @@ def run_command(cmd, description=""):
|
|||
"""Run a command and handle errors."""
|
||||
if description:
|
||||
print(f"\n🔄 {description}")
|
||||
|
||||
|
||||
print(f"Running: {' '.join(cmd)}")
|
||||
result = subprocess.run(cmd, capture_output=False)
|
||||
|
||||
|
||||
if result.returncode != 0:
|
||||
print(f"❌ Command failed with return code {result.returncode}")
|
||||
sys.exit(result.returncode)
|
||||
|
|
@ -33,48 +33,42 @@ def main():
|
|||
parser.add_argument(
|
||||
"test_type",
|
||||
choices=["unit", "integration", "all", "coverage", "fast", "slow", "lint"],
|
||||
help="Type of tests to run"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose", "-v", action="store_true", help="Verbose output"
|
||||
help="Type of tests to run",
|
||||
)
|
||||
parser.add_argument("--verbose", "-v", action="store_true", help="Verbose output")
|
||||
parser.add_argument(
|
||||
"--parallel", "-p", action="store_true", help="Run tests in parallel"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--file", "-f", help="Run specific test file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pattern", "-k", help="Run tests matching pattern"
|
||||
)
|
||||
|
||||
parser.add_argument("--file", "-f", help="Run specific test file")
|
||||
parser.add_argument("--pattern", "-k", help="Run tests matching pattern")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
# Base pytest command
|
||||
base_cmd = ["python", "-m", "pytest"]
|
||||
|
||||
|
||||
if args.verbose:
|
||||
base_cmd.append("-v")
|
||||
|
||||
|
||||
if args.parallel:
|
||||
base_cmd.extend(["-n", "auto"])
|
||||
|
||||
|
||||
if args.pattern:
|
||||
base_cmd.extend(["-k", args.pattern])
|
||||
|
||||
|
||||
# Configure based on test type
|
||||
if args.test_type == "unit":
|
||||
cmd = base_cmd + ["tests/unit/", "-m", "unit"]
|
||||
run_command(cmd, "Running unit tests")
|
||||
|
||||
|
||||
elif args.test_type == "integration":
|
||||
cmd = base_cmd + ["tests/integration/", "-m", "integration"]
|
||||
run_command(cmd, "Running integration tests")
|
||||
|
||||
|
||||
elif args.test_type == "all":
|
||||
cmd = base_cmd + ["tests/"]
|
||||
run_command(cmd, "Running all tests")
|
||||
|
||||
|
||||
elif args.test_type == "coverage":
|
||||
cmd = base_cmd + [
|
||||
"tests/",
|
||||
|
|
@ -88,28 +82,28 @@ def main():
|
|||
print("\n📊 Coverage report generated:")
|
||||
print(" - HTML: htmlcov/index.html")
|
||||
print(" - XML: coverage.xml")
|
||||
|
||||
|
||||
elif args.test_type == "fast":
|
||||
cmd = base_cmd + ["tests/unit/", "-m", "unit", "--durations=10"]
|
||||
run_command(cmd, "Running fast unit tests")
|
||||
|
||||
|
||||
elif args.test_type == "slow":
|
||||
cmd = base_cmd + ["tests/", "-m", "slow", "--timeout=600"]
|
||||
run_command(cmd, "Running slow tests")
|
||||
|
||||
|
||||
elif args.test_type == "lint":
|
||||
# Run mypy
|
||||
cmd = ["python", "-m", "mypy", "tradingagents/", "cli/", "tests/"]
|
||||
run_command(cmd, "Running mypy type checking")
|
||||
|
||||
|
||||
# Run pytest on tests only for syntax
|
||||
cmd = base_cmd + ["tests/", "--collect-only"]
|
||||
run_command(cmd, "Validating test syntax")
|
||||
|
||||
|
||||
elif args.file:
|
||||
cmd = base_cmd + [args.file]
|
||||
run_command(cmd, f"Running tests in {args.file}")
|
||||
|
||||
|
||||
print("\n🎉 All tests completed successfully!")
|
||||
|
||||
|
||||
|
|
@ -117,5 +111,5 @@ if __name__ == "__main__":
|
|||
# Ensure we're in the project directory
|
||||
script_dir = Path(__file__).parent
|
||||
os.chdir(script_dir)
|
||||
|
||||
main()
|
||||
|
||||
main()
|
||||
|
|
|
|||
|
|
@ -1,19 +1,23 @@
|
|||
|
||||
def poorly_formatted_function(x,y,z): # Missing type hints
|
||||
def poorly_formatted_function(x, y, z): # Missing type hints
|
||||
"""This function has formatting issues."""
|
||||
result=x+y*z # Missing spaces around operators
|
||||
if result>100: # Missing spaces
|
||||
print( "Result is large" ) # Extra spaces in parentheses
|
||||
result = x + y * z # Missing spaces around operators
|
||||
if result > 100: # Missing spaces
|
||||
print("Result is large") # Extra spaces in parentheses
|
||||
return result
|
||||
|
||||
|
||||
# Long line that Black will wrap
|
||||
very_long_variable_name_that_exceeds_the_standard_line_length_limit = "This is a very long string that will be wrapped by Black formatter"
|
||||
very_long_variable_name_that_exceeds_the_standard_line_length_limit = (
|
||||
"This is a very long string that will be wrapped by Black formatter"
|
||||
)
|
||||
|
||||
|
||||
class MyClass:
|
||||
def __init__(self,name:str,age:int): # Missing space after comma
|
||||
self.name=name # Missing spaces around =
|
||||
self.age=age
|
||||
def __init__(self, name: str, age: int): # Missing space after comma
|
||||
self.name = name # Missing spaces around =
|
||||
self.age = age
|
||||
|
||||
|
||||
# Function with wrong return type hint
|
||||
def get_number() -> str:
|
||||
return 123 # Returns int but type hint says str
|
||||
return 123 # Returns int but type hint says str
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ def add_numbers(a: int, b: int) -> int:
|
|||
"""Add two numbers and return the result."""
|
||||
return a + b
|
||||
|
||||
|
||||
# Test the function
|
||||
result = add_numbers(1, 2)
|
||||
print(f"Result: {result}")
|
||||
print(f"Result: {result}")
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ def run_command(cmd, description=""):
|
|||
"""Run a command and return success status."""
|
||||
if description:
|
||||
print(f"\n🔄 {description}")
|
||||
|
||||
|
||||
print(f"Running: {' '.join(cmd)}")
|
||||
try:
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=30)
|
||||
|
|
@ -38,45 +38,60 @@ def main():
|
|||
"""Run setup verification tests."""
|
||||
print("🧪 TradingAgents Test Setup Verification")
|
||||
print("=" * 50)
|
||||
|
||||
|
||||
# Change to project directory
|
||||
project_dir = Path(__file__).parent
|
||||
os.chdir(project_dir)
|
||||
|
||||
|
||||
success_count = 0
|
||||
total_tests = 0
|
||||
|
||||
|
||||
# Test 1: Check if pytest is installed and can discover tests
|
||||
total_tests += 1
|
||||
if run_command(["python", "-m", "pytest", "--version"], "Checking pytest installation"):
|
||||
if run_command(
|
||||
["python", "-m", "pytest", "--version"], "Checking pytest installation"
|
||||
):
|
||||
success_count += 1
|
||||
|
||||
|
||||
# Test 2: Test discovery
|
||||
total_tests += 1
|
||||
if run_command(["python", "-m", "pytest", "tests/", "--collect-only", "-q"], "Testing test discovery"):
|
||||
if run_command(
|
||||
["python", "-m", "pytest", "tests/", "--collect-only", "-q"],
|
||||
"Testing test discovery",
|
||||
):
|
||||
success_count += 1
|
||||
|
||||
|
||||
# Test 3: Check if mypy can run
|
||||
total_tests += 1
|
||||
if run_command(["python", "-m", "mypy", "--version"], "Checking mypy installation"):
|
||||
success_count += 1
|
||||
|
||||
|
||||
# Test 4: Run a simple syntax check on test files
|
||||
total_tests += 1
|
||||
if run_command(["python", "-c", "import tests.conftest; print('Test imports work!')"], "Testing test imports"):
|
||||
if run_command(
|
||||
["python", "-c", "import tests.conftest; print('Test imports work!')"],
|
||||
"Testing test imports",
|
||||
):
|
||||
success_count += 1
|
||||
|
||||
|
||||
# Test 5: Check if we can import the main module
|
||||
total_tests += 1
|
||||
if run_command(["python", "-c", "import tradingagents.config; print('Main module imports work!')"], "Testing main module imports"):
|
||||
if run_command(
|
||||
[
|
||||
"python",
|
||||
"-c",
|
||||
"import tradingagents.config; print('Main module imports work!')",
|
||||
],
|
||||
"Testing main module imports",
|
||||
):
|
||||
success_count += 1
|
||||
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 50)
|
||||
print("📊 Test Setup Verification Results:")
|
||||
print(f"✅ Successful: {success_count}/{total_tests}")
|
||||
print(f"❌ Failed: {total_tests - success_count}/{total_tests}")
|
||||
|
||||
|
||||
if success_count == total_tests:
|
||||
print("\n🎉 All verification tests passed! Your test setup is ready.")
|
||||
print("\n📚 Next steps:")
|
||||
|
|
@ -95,4 +110,4 @@ def main():
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
sys.exit(main())
|
||||
|
|
|
|||
|
|
@ -173,7 +173,8 @@ def mock_memory():
|
|||
def pytest_configure(config):
|
||||
"""Configure pytest with custom markers."""
|
||||
config.addinivalue_line(
|
||||
"markers", "integration: mark test as integration test (slow)",
|
||||
"markers",
|
||||
"integration: mark test as integration test (slow)",
|
||||
)
|
||||
config.addinivalue_line("markers", "unit: mark test as unit test (fast)")
|
||||
config.addinivalue_line("markers", "api: mark test as requiring API access")
|
||||
|
|
|
|||
|
|
@ -36,7 +36,8 @@ class SampleDataFactory:
|
|||
|
||||
@staticmethod
|
||||
def create_finnhub_news_data(
|
||||
ticker: str = "AAPL", count: int = 10,
|
||||
ticker: str = "AAPL",
|
||||
count: int = 10,
|
||||
) -> dict[str, list[dict[str, Any]]]:
|
||||
"""Create sample FinnHub news data for testing."""
|
||||
base_date = datetime(2024, 5, 10)
|
||||
|
|
@ -136,7 +137,8 @@ class SampleDataFactory:
|
|||
|
||||
@staticmethod
|
||||
def create_financial_statements_data(
|
||||
ticker: str = "AAPL", period: str = "annual",
|
||||
ticker: str = "AAPL",
|
||||
period: str = "annual",
|
||||
) -> dict[str, list[dict[str, Any]]]:
|
||||
"""Create sample financial statements data for testing."""
|
||||
if period == "annual":
|
||||
|
|
@ -271,10 +273,12 @@ class SampleDataFactory:
|
|||
ticker,
|
||||
),
|
||||
"financial_annual": SampleDataFactory.create_financial_statements_data(
|
||||
ticker, "annual",
|
||||
ticker,
|
||||
"annual",
|
||||
),
|
||||
"financial_quarterly": SampleDataFactory.create_financial_statements_data(
|
||||
ticker, "quarterly",
|
||||
ticker,
|
||||
"quarterly",
|
||||
),
|
||||
"social_sentiment": SampleDataFactory.create_social_sentiment_data(ticker),
|
||||
"technical_indicators": SampleDataFactory.create_technical_indicators_data(
|
||||
|
|
@ -343,7 +347,9 @@ def save_sample_data_to_files(base_path: str, ticker: str = "AAPL") -> None:
|
|||
|
||||
# Save quarterly data separately
|
||||
quarterly_path = os.path.join(
|
||||
finnhub_path, "fin_as_reported", f"{ticker}_quarterly_data_formatted.json",
|
||||
finnhub_path,
|
||||
"fin_as_reported",
|
||||
f"{ticker}_quarterly_data_formatted.json",
|
||||
)
|
||||
with open(quarterly_path, "w") as f:
|
||||
json.dump(dataset["financial_quarterly"], f, indent=2)
|
||||
|
|
|
|||
|
|
@ -31,7 +31,10 @@ class TestFullWorkflowIntegration:
|
|||
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
||||
@patch("tradingagents.graph.trading_graph.Toolkit")
|
||||
def test_end_to_end_trading_workflow(
|
||||
self, mock_toolkit, mock_chat_openai, integration_config,
|
||||
self,
|
||||
mock_toolkit,
|
||||
mock_chat_openai,
|
||||
integration_config,
|
||||
):
|
||||
"""Test complete end-to-end trading workflow."""
|
||||
# Setup mocks
|
||||
|
|
@ -86,7 +89,10 @@ class TestFullWorkflowIntegration:
|
|||
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
||||
@patch("tradingagents.graph.trading_graph.Toolkit")
|
||||
def test_multiple_analysts_integration(
|
||||
self, mock_toolkit, mock_chat_openai, integration_config,
|
||||
self,
|
||||
mock_toolkit,
|
||||
mock_chat_openai,
|
||||
integration_config,
|
||||
):
|
||||
"""Test integration with different analyst combinations."""
|
||||
analyst_combinations = [
|
||||
|
|
@ -114,7 +120,8 @@ class TestFullWorkflowIntegration:
|
|||
with patch("tradingagents.graph.trading_graph.set_config"):
|
||||
# Test each analyst combination
|
||||
trading_graph = TradingAgentsGraph(
|
||||
selected_analysts=analysts, config=integration_config,
|
||||
selected_analysts=analysts,
|
||||
config=integration_config,
|
||||
)
|
||||
trading_graph.graph = mock_graph
|
||||
|
||||
|
|
@ -134,7 +141,8 @@ class TestFullWorkflowIntegration:
|
|||
# Execute
|
||||
with patch("builtins.open", create=True), patch("json.dump"):
|
||||
final_state, decision = trading_graph.propagate(
|
||||
"TSLA", "2024-05-15",
|
||||
"TSLA",
|
||||
"2024-05-15",
|
||||
)
|
||||
|
||||
# Verify
|
||||
|
|
@ -144,7 +152,10 @@ class TestFullWorkflowIntegration:
|
|||
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
||||
@patch("tradingagents.graph.trading_graph.Toolkit")
|
||||
def test_memory_and_reflection_integration(
|
||||
self, mock_toolkit, mock_chat_openai, integration_config,
|
||||
self,
|
||||
mock_toolkit,
|
||||
mock_chat_openai,
|
||||
integration_config,
|
||||
):
|
||||
"""Test integration of memory and reflection components."""
|
||||
# Setup
|
||||
|
|
@ -208,7 +219,10 @@ class TestFullWorkflowIntegration:
|
|||
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
||||
@patch("tradingagents.graph.trading_graph.Toolkit")
|
||||
def test_debug_mode_integration(
|
||||
self, mock_toolkit, mock_chat_openai, integration_config,
|
||||
self,
|
||||
mock_toolkit,
|
||||
mock_chat_openai,
|
||||
integration_config,
|
||||
):
|
||||
"""Test integration in debug mode."""
|
||||
# Setup
|
||||
|
|
@ -240,7 +254,8 @@ class TestFullWorkflowIntegration:
|
|||
with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"):
|
||||
with patch("tradingagents.graph.trading_graph.set_config"):
|
||||
trading_graph = TradingAgentsGraph(
|
||||
debug=True, config=integration_config,
|
||||
debug=True,
|
||||
config=integration_config,
|
||||
)
|
||||
trading_graph.graph = mock_graph
|
||||
|
||||
|
|
@ -276,7 +291,12 @@ class TestFullWorkflowIntegration:
|
|||
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
||||
@patch("tradingagents.graph.trading_graph.Toolkit")
|
||||
def test_multiple_stocks_integration(
|
||||
self, mock_toolkit, mock_chat_openai, ticker, date, integration_config,
|
||||
self,
|
||||
mock_toolkit,
|
||||
mock_chat_openai,
|
||||
ticker,
|
||||
date,
|
||||
integration_config,
|
||||
):
|
||||
"""Test integration with different stocks and dates."""
|
||||
# Setup
|
||||
|
|
@ -382,7 +402,11 @@ class TestPerformanceIntegration:
|
|||
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
||||
@patch("tradingagents.graph.trading_graph.Toolkit")
|
||||
def test_multiple_consecutive_runs(
|
||||
self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir,
|
||||
self,
|
||||
mock_toolkit,
|
||||
mock_chat_openai,
|
||||
sample_config,
|
||||
temp_data_dir,
|
||||
):
|
||||
"""Test multiple consecutive trading decisions."""
|
||||
sample_config["project_dir"] = temp_data_dir
|
||||
|
|
|
|||
|
|
@ -17,7 +17,10 @@ class TestMarketAnalyst:
|
|||
assert callable(analyst_node)
|
||||
|
||||
def test_market_analyst_node_basic_execution(
|
||||
self, mock_llm, mock_toolkit, sample_agent_state,
|
||||
self,
|
||||
mock_llm,
|
||||
mock_toolkit,
|
||||
sample_agent_state,
|
||||
):
|
||||
"""Test basic execution of market analyst node."""
|
||||
# Setup
|
||||
|
|
@ -39,7 +42,10 @@ class TestMarketAnalyst:
|
|||
assert result["market_report"] == "Market analysis complete"
|
||||
|
||||
def test_market_analyst_uses_online_tools_when_configured(
|
||||
self, mock_llm, mock_toolkit, sample_agent_state,
|
||||
self,
|
||||
mock_llm,
|
||||
mock_toolkit,
|
||||
sample_agent_state,
|
||||
):
|
||||
"""Test that analyst uses online tools when configured."""
|
||||
# Setup
|
||||
|
|
@ -64,7 +70,10 @@ class TestMarketAnalyst:
|
|||
assert "get_YFin_data_online" in str(tool_names) or len(bound_tools) == 2
|
||||
|
||||
def test_market_analyst_uses_offline_tools_when_configured(
|
||||
self, mock_llm, mock_toolkit, sample_agent_state,
|
||||
self,
|
||||
mock_llm,
|
||||
mock_toolkit,
|
||||
sample_agent_state,
|
||||
):
|
||||
"""Test that analyst uses offline tools when configured."""
|
||||
# Setup
|
||||
|
|
@ -88,7 +97,10 @@ class TestMarketAnalyst:
|
|||
assert len(bound_tools) == 2 # Should have 2 offline tools
|
||||
|
||||
def test_market_analyst_processes_state_variables(
|
||||
self, mock_llm, mock_toolkit, sample_agent_state,
|
||||
self,
|
||||
mock_llm,
|
||||
mock_toolkit,
|
||||
sample_agent_state,
|
||||
):
|
||||
"""Test that market analyst correctly processes state variables."""
|
||||
# Setup
|
||||
|
|
@ -112,7 +124,10 @@ class TestMarketAnalyst:
|
|||
assert result["market_report"] == "Analysis for AAPL on 2024-05-10"
|
||||
|
||||
def test_market_analyst_handles_empty_tool_calls(
|
||||
self, mock_llm, mock_toolkit, sample_agent_state,
|
||||
self,
|
||||
mock_llm,
|
||||
mock_toolkit,
|
||||
sample_agent_state,
|
||||
):
|
||||
"""Test handling when no tool calls are made."""
|
||||
# Setup
|
||||
|
|
@ -132,7 +147,10 @@ class TestMarketAnalyst:
|
|||
assert result["messages"] == [mock_result]
|
||||
|
||||
def test_market_analyst_with_tool_calls(
|
||||
self, mock_llm, mock_toolkit, sample_agent_state,
|
||||
self,
|
||||
mock_llm,
|
||||
mock_toolkit,
|
||||
sample_agent_state,
|
||||
):
|
||||
"""Test handling when tool calls are present."""
|
||||
# Setup
|
||||
|
|
@ -153,7 +171,11 @@ class TestMarketAnalyst:
|
|||
|
||||
@pytest.mark.parametrize("online_tools", [True, False])
|
||||
def test_market_analyst_tool_configuration(
|
||||
self, mock_llm, mock_toolkit, sample_agent_state, online_tools,
|
||||
self,
|
||||
mock_llm,
|
||||
mock_toolkit,
|
||||
sample_agent_state,
|
||||
online_tools,
|
||||
):
|
||||
"""Test tool configuration for both online and offline modes."""
|
||||
# Setup
|
||||
|
|
|
|||
|
|
@ -190,7 +190,10 @@ class TestFinnhubUtils:
|
|||
|
||||
# Test without period
|
||||
expected_path_no_period = os.path.join(
|
||||
temp_data_dir, "finnhub_data", data_type, f"{ticker}_data_formatted.json",
|
||||
temp_data_dir,
|
||||
"finnhub_data",
|
||||
data_type,
|
||||
f"{ticker}_data_formatted.json",
|
||||
)
|
||||
|
||||
# Test with period
|
||||
|
|
@ -248,7 +251,10 @@ class TestFinnhubUtils:
|
|||
],
|
||||
)
|
||||
def test_get_data_in_range_various_data_types(
|
||||
self, temp_data_dir, data_type, period,
|
||||
self,
|
||||
temp_data_dir,
|
||||
data_type,
|
||||
period,
|
||||
):
|
||||
"""Test get_data_in_range with various data types."""
|
||||
ticker = "TEST"
|
||||
|
|
|
|||
|
|
@ -45,7 +45,11 @@ class TestTradingAgentsGraph:
|
|||
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
||||
@patch("tradingagents.graph.trading_graph.Toolkit")
|
||||
def test_init_with_debug(
|
||||
self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir,
|
||||
self,
|
||||
mock_toolkit,
|
||||
mock_chat_openai,
|
||||
sample_config,
|
||||
temp_data_dir,
|
||||
):
|
||||
"""Test initialization with debug mode enabled."""
|
||||
sample_config["project_dir"] = temp_data_dir
|
||||
|
|
@ -63,7 +67,11 @@ class TestTradingAgentsGraph:
|
|||
@patch("tradingagents.graph.trading_graph.ChatAnthropic")
|
||||
@patch("tradingagents.graph.trading_graph.Toolkit")
|
||||
def test_init_with_anthropic(
|
||||
self, mock_toolkit, mock_chat_anthropic, sample_config, temp_data_dir,
|
||||
self,
|
||||
mock_toolkit,
|
||||
mock_chat_anthropic,
|
||||
sample_config,
|
||||
temp_data_dir,
|
||||
):
|
||||
"""Test initialization with Anthropic LLM provider."""
|
||||
sample_config["project_dir"] = temp_data_dir
|
||||
|
|
@ -82,7 +90,11 @@ class TestTradingAgentsGraph:
|
|||
@patch("tradingagents.graph.trading_graph.ChatGoogleGenerativeAI")
|
||||
@patch("tradingagents.graph.trading_graph.Toolkit")
|
||||
def test_init_with_google(
|
||||
self, mock_toolkit, mock_chat_google, sample_config, temp_data_dir,
|
||||
self,
|
||||
mock_toolkit,
|
||||
mock_chat_google,
|
||||
sample_config,
|
||||
temp_data_dir,
|
||||
):
|
||||
"""Test initialization with Google LLM provider."""
|
||||
sample_config["project_dir"] = temp_data_dir
|
||||
|
|
@ -100,7 +112,10 @@ class TestTradingAgentsGraph:
|
|||
|
||||
@patch("tradingagents.graph.trading_graph.Toolkit")
|
||||
def test_init_unsupported_llm_provider(
|
||||
self, mock_toolkit, sample_config, temp_data_dir,
|
||||
self,
|
||||
mock_toolkit,
|
||||
sample_config,
|
||||
temp_data_dir,
|
||||
):
|
||||
"""Test initialization with unsupported LLM provider raises error."""
|
||||
sample_config["project_dir"] = temp_data_dir
|
||||
|
|
@ -115,7 +130,11 @@ class TestTradingAgentsGraph:
|
|||
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
||||
@patch("tradingagents.graph.trading_graph.Toolkit")
|
||||
def test_create_tool_nodes(
|
||||
self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir,
|
||||
self,
|
||||
mock_toolkit,
|
||||
mock_chat_openai,
|
||||
sample_config,
|
||||
temp_data_dir,
|
||||
):
|
||||
"""Test creation of tool nodes."""
|
||||
sample_config["project_dir"] = temp_data_dir
|
||||
|
|
@ -143,7 +162,11 @@ class TestTradingAgentsGraph:
|
|||
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
||||
@patch("tradingagents.graph.trading_graph.Toolkit")
|
||||
def test_propagate_basic(
|
||||
self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir,
|
||||
self,
|
||||
mock_toolkit,
|
||||
mock_chat_openai,
|
||||
sample_config,
|
||||
temp_data_dir,
|
||||
):
|
||||
"""Test basic propagate functionality."""
|
||||
sample_config["project_dir"] = temp_data_dir
|
||||
|
|
@ -206,7 +229,11 @@ class TestTradingAgentsGraph:
|
|||
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
||||
@patch("tradingagents.graph.trading_graph.Toolkit")
|
||||
def test_propagate_debug_mode(
|
||||
self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir,
|
||||
self,
|
||||
mock_toolkit,
|
||||
mock_chat_openai,
|
||||
sample_config,
|
||||
temp_data_dir,
|
||||
):
|
||||
"""Test propagate in debug mode."""
|
||||
sample_config["project_dir"] = temp_data_dir
|
||||
|
|
@ -245,7 +272,11 @@ class TestTradingAgentsGraph:
|
|||
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
||||
@patch("tradingagents.graph.trading_graph.Toolkit")
|
||||
def test_log_state(
|
||||
self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir,
|
||||
self,
|
||||
mock_toolkit,
|
||||
mock_chat_openai,
|
||||
sample_config,
|
||||
temp_data_dir,
|
||||
):
|
||||
"""Test state logging functionality."""
|
||||
sample_config["project_dir"] = temp_data_dir
|
||||
|
|
@ -300,7 +331,11 @@ class TestTradingAgentsGraph:
|
|||
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
||||
@patch("tradingagents.graph.trading_graph.Toolkit")
|
||||
def test_reflect_and_remember(
|
||||
self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir,
|
||||
self,
|
||||
mock_toolkit,
|
||||
mock_chat_openai,
|
||||
sample_config,
|
||||
temp_data_dir,
|
||||
):
|
||||
"""Test reflection and memory update functionality."""
|
||||
sample_config["project_dir"] = temp_data_dir
|
||||
|
|
@ -309,9 +344,12 @@ class TestTradingAgentsGraph:
|
|||
mock_toolkit_instance = Mock()
|
||||
mock_toolkit.return_value = mock_toolkit_instance
|
||||
|
||||
with patch(
|
||||
"tradingagents.graph.trading_graph.FinancialSituationMemory",
|
||||
), patch("tradingagents.graph.trading_graph.set_config"):
|
||||
with (
|
||||
patch(
|
||||
"tradingagents.graph.trading_graph.FinancialSituationMemory",
|
||||
),
|
||||
patch("tradingagents.graph.trading_graph.set_config"),
|
||||
):
|
||||
graph = TradingAgentsGraph(config=sample_config)
|
||||
|
||||
# Set up current state
|
||||
|
|
@ -339,7 +377,11 @@ class TestTradingAgentsGraph:
|
|||
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
||||
@patch("tradingagents.graph.trading_graph.Toolkit")
|
||||
def test_process_signal(
|
||||
self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir,
|
||||
self,
|
||||
mock_toolkit,
|
||||
mock_chat_openai,
|
||||
sample_config,
|
||||
temp_data_dir,
|
||||
):
|
||||
"""Test signal processing functionality."""
|
||||
sample_config["project_dir"] = temp_data_dir
|
||||
|
|
@ -388,7 +430,8 @@ class TestTradingAgentsGraph:
|
|||
with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"):
|
||||
with patch("tradingagents.graph.trading_graph.set_config"):
|
||||
TradingAgentsGraph(
|
||||
selected_analysts=selected_analysts, config=sample_config,
|
||||
selected_analysts=selected_analysts,
|
||||
config=sample_config,
|
||||
)
|
||||
|
||||
# Verify graph was set up with selected analysts
|
||||
|
|
@ -416,7 +459,10 @@ class TestTradingAgentsGraphErrorHandling:
|
|||
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
||||
@patch("tradingagents.graph.trading_graph.Toolkit")
|
||||
def test_directory_creation_failure(
|
||||
self, mock_toolkit, mock_chat_openai, sample_config,
|
||||
self,
|
||||
mock_toolkit,
|
||||
mock_chat_openai,
|
||||
sample_config,
|
||||
):
|
||||
"""Test handling when directory creation fails."""
|
||||
sample_config["project_dir"] = "/invalid/path/that/cannot/be/created"
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ def create_fundamentals_analyst(llm, toolkit):
|
|||
|
||||
system_message = (
|
||||
"You are a researcher tasked with analyzing fundamental information over the past week about a company. Please write a comprehensive report of the company's fundamental information such as financial documents, company profile, basic company financials, company financial history, insider sentiment and insider transactions to gain a full view of the company's fundamental information to inform traders. Make sure to include as much detail as possible. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."
|
||||
" Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read.",
|
||||
" Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read.",
|
||||
)
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ Volume-Based Indicators:
|
|||
- vwma: VWMA: A moving average weighted by volume. Usage: Confirm trends by integrating price action with volume data. Tips: Watch for skewed results from volume spikes; use in combination with other volume analyses.
|
||||
|
||||
- Select indicators that provide diverse and complementary information. Avoid redundancy (e.g., do not select both rsi and stochrsi). Also briefly explain why they are suitable for the given market context. When you tool call, please use the exact name of the indicators provided above as they are defined parameters, otherwise your call will fail. Please make sure to call get_YFin_data first to retrieve the CSV that is needed to generate indicators. Write a very detailed and nuanced report of the trends you observe. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."""
|
||||
""" Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read."""
|
||||
""" Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read."""
|
||||
)
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ def create_news_analyst(llm, toolkit):
|
|||
|
||||
system_message = (
|
||||
"You are a news researcher tasked with analyzing recent news and trends over the past week. Please write a comprehensive report of the current state of the world that is relevant for trading and macroeconomics. Look at news from EODHD, and finnhub to be comprehensive. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."
|
||||
""" Make sure to append a Makrdown table at the end of the report to organize key points in the report, organized and easy to read."""
|
||||
""" Make sure to append a Makrdown table at the end of the report to organize key points in the report, organized and easy to read."""
|
||||
)
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ def create_social_media_analyst(llm, toolkit):
|
|||
|
||||
system_message = (
|
||||
"You are a social media and company specific news researcher/analyst tasked with analyzing social media posts, recent company news, and public sentiment for a specific company over the past week. You will be given a company's name your objective is to write a comprehensive long report detailing your analysis, insights, and implications for traders and investors on this company's current state after looking at social media and what people are saying about that company, analyzing sentiment data of what people feel each day about the company, and looking at recent company news. Try to look at all sources possible from social media to sentiment to news. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."
|
||||
""" Make sure to append a Makrdown table at the end of the report to organize key points in the report, organized and easy to read.""",
|
||||
""" Make sure to append a Makrdown table at the end of the report to organize key points in the report, organized and easy to read.""",
|
||||
)
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
|
|
|
|||
|
|
@ -41,7 +41,8 @@ Engage actively by addressing any specific concerns raised, refuting the weaknes
|
|||
"current_risky_response": argument,
|
||||
"current_safe_response": risk_debate_state.get("current_safe_response", ""),
|
||||
"current_neutral_response": risk_debate_state.get(
|
||||
"current_neutral_response", "",
|
||||
"current_neutral_response",
|
||||
"",
|
||||
),
|
||||
"count": risk_debate_state["count"] + 1,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -39,11 +39,13 @@ Engage by questioning their optimism and emphasizing the potential downsides the
|
|||
"neutral_history": risk_debate_state.get("neutral_history", ""),
|
||||
"latest_speaker": "Safe",
|
||||
"current_risky_response": risk_debate_state.get(
|
||||
"current_risky_response", "",
|
||||
"current_risky_response",
|
||||
"",
|
||||
),
|
||||
"current_safe_response": argument,
|
||||
"current_neutral_response": risk_debate_state.get(
|
||||
"current_neutral_response", "",
|
||||
"current_neutral_response",
|
||||
"",
|
||||
),
|
||||
"count": risk_debate_state["count"] + 1,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -39,7 +39,8 @@ Engage actively by analyzing both sides critically, addressing weaknesses in the
|
|||
"neutral_history": neutral_history + "\n" + argument,
|
||||
"latest_speaker": "Neutral",
|
||||
"current_risky_response": risk_debate_state.get(
|
||||
"current_risky_response", "",
|
||||
"current_risky_response",
|
||||
"",
|
||||
),
|
||||
"current_safe_response": risk_debate_state.get("current_safe_response", ""),
|
||||
"current_neutral_response": argument,
|
||||
|
|
|
|||
|
|
@ -9,10 +9,12 @@ from typing_extensions import TypedDict
|
|||
# Researcher team state
|
||||
class InvestDebateState(TypedDict):
|
||||
bull_history: Annotated[
|
||||
str, "Bullish Conversation history",
|
||||
str,
|
||||
"Bullish Conversation history",
|
||||
] # Bullish Conversation history
|
||||
bear_history: Annotated[
|
||||
str, "Bearish Conversation history",
|
||||
str,
|
||||
"Bearish Conversation history",
|
||||
] # Bullish Conversation history
|
||||
history: Annotated[str, "Conversation history"] # Conversation history
|
||||
current_response: Annotated[str, "Latest response"] # Last response
|
||||
|
|
@ -23,24 +25,30 @@ class InvestDebateState(TypedDict):
|
|||
# Risk management team state
|
||||
class RiskDebateState(TypedDict):
|
||||
risky_history: Annotated[
|
||||
str, "Risky Agent's Conversation history",
|
||||
str,
|
||||
"Risky Agent's Conversation history",
|
||||
] # Conversation history
|
||||
safe_history: Annotated[
|
||||
str, "Safe Agent's Conversation history",
|
||||
str,
|
||||
"Safe Agent's Conversation history",
|
||||
] # Conversation history
|
||||
neutral_history: Annotated[
|
||||
str, "Neutral Agent's Conversation history",
|
||||
str,
|
||||
"Neutral Agent's Conversation history",
|
||||
] # Conversation history
|
||||
history: Annotated[str, "Conversation history"] # Conversation history
|
||||
latest_speaker: Annotated[str, "Analyst that spoke last"]
|
||||
current_risky_response: Annotated[
|
||||
str, "Latest response by the risky analyst",
|
||||
str,
|
||||
"Latest response by the risky analyst",
|
||||
] # Last response
|
||||
current_safe_response: Annotated[
|
||||
str, "Latest response by the safe analyst",
|
||||
str,
|
||||
"Latest response by the safe analyst",
|
||||
] # Last response
|
||||
current_neutral_response: Annotated[
|
||||
str, "Latest response by the neutral analyst",
|
||||
str,
|
||||
"Latest response by the neutral analyst",
|
||||
] # Last response
|
||||
judge_decision: Annotated[str, "Judge's decision"]
|
||||
count: Annotated[int, "Length of the current conversation"] # Conversation length
|
||||
|
|
@ -56,13 +64,15 @@ class AgentState(MessagesState):
|
|||
market_report: Annotated[str, "Report from the Market Analyst"]
|
||||
sentiment_report: Annotated[str, "Report from the Social Media Analyst"]
|
||||
news_report: Annotated[
|
||||
str, "Report from the News Researcher of current world affairs",
|
||||
str,
|
||||
"Report from the News Researcher of current world affairs",
|
||||
]
|
||||
fundamentals_report: Annotated[str, "Report from the Fundamentals Researcher"]
|
||||
|
||||
# researcher team discussion step
|
||||
investment_debate_state: Annotated[
|
||||
InvestDebateState, "Current state of the debate on if to invest or not",
|
||||
InvestDebateState,
|
||||
"Current state of the debate on if to invest or not",
|
||||
]
|
||||
investment_plan: Annotated[str, "Plan generated by the Analyst"]
|
||||
|
||||
|
|
@ -70,6 +80,7 @@ class AgentState(MessagesState):
|
|||
|
||||
# risk management team discussion step
|
||||
risk_debate_state: Annotated[
|
||||
RiskDebateState, "Current state of the debate on evaluating risk",
|
||||
RiskDebateState,
|
||||
"Current state of the debate on evaluating risk",
|
||||
]
|
||||
final_trade_decision: Annotated[str, "Final decision made by the Risk Analysts"]
|
||||
|
|
|
|||
|
|
@ -56,7 +56,6 @@ class Toolkit:
|
|||
|
||||
return interface.get_reddit_global_news(curr_date, 7, 5)
|
||||
|
||||
|
||||
@staticmethod
|
||||
@tool
|
||||
def get_finnhub_news(
|
||||
|
|
@ -84,10 +83,11 @@ class Toolkit:
|
|||
look_back_days = (end_date - start_date).days
|
||||
|
||||
return interface.get_finnhub_news(
|
||||
ticker, end_date_str, look_back_days,
|
||||
ticker,
|
||||
end_date_str,
|
||||
look_back_days,
|
||||
)
|
||||
|
||||
|
||||
@staticmethod
|
||||
@tool
|
||||
def get_reddit_stock_info(
|
||||
|
|
@ -108,7 +108,6 @@ class Toolkit:
|
|||
|
||||
return interface.get_reddit_company_news(ticker, curr_date, 7, 5)
|
||||
|
||||
|
||||
@staticmethod
|
||||
@tool
|
||||
def get_YFin_data(
|
||||
|
|
@ -128,7 +127,6 @@ class Toolkit:
|
|||
|
||||
return interface.get_YFin_data(symbol, start_date, end_date)
|
||||
|
||||
|
||||
@staticmethod
|
||||
@tool
|
||||
def get_YFin_data_online(
|
||||
|
|
@ -148,16 +146,17 @@ class Toolkit:
|
|||
|
||||
return interface.get_YFin_data_online(symbol, start_date, end_date)
|
||||
|
||||
|
||||
@staticmethod
|
||||
@tool
|
||||
def get_stockstats_indicators_report(
|
||||
symbol: Annotated[str, "ticker symbol of the company"],
|
||||
indicator: Annotated[
|
||||
str, "technical indicator to get the analysis and report of",
|
||||
str,
|
||||
"technical indicator to get the analysis and report of",
|
||||
],
|
||||
curr_date: Annotated[
|
||||
str, "The current trading date you are trading on, YYYY-mm-dd",
|
||||
str,
|
||||
"The current trading date you are trading on, YYYY-mm-dd",
|
||||
],
|
||||
look_back_days: Annotated[int, "how many days to look back"] = 30,
|
||||
) -> str:
|
||||
|
|
@ -173,19 +172,24 @@ class Toolkit:
|
|||
"""
|
||||
|
||||
return interface.get_stock_stats_indicators_window(
|
||||
symbol, indicator, curr_date, look_back_days, False,
|
||||
symbol,
|
||||
indicator,
|
||||
curr_date,
|
||||
look_back_days,
|
||||
False,
|
||||
)
|
||||
|
||||
|
||||
@staticmethod
|
||||
@tool
|
||||
def get_stockstats_indicators_report_online(
|
||||
symbol: Annotated[str, "ticker symbol of the company"],
|
||||
indicator: Annotated[
|
||||
str, "technical indicator to get the analysis and report of",
|
||||
str,
|
||||
"technical indicator to get the analysis and report of",
|
||||
],
|
||||
curr_date: Annotated[
|
||||
str, "The current trading date you are trading on, YYYY-mm-dd",
|
||||
str,
|
||||
"The current trading date you are trading on, YYYY-mm-dd",
|
||||
],
|
||||
look_back_days: Annotated[int, "how many days to look back"] = 30,
|
||||
) -> str:
|
||||
|
|
@ -201,10 +205,13 @@ class Toolkit:
|
|||
"""
|
||||
|
||||
return interface.get_stock_stats_indicators_window(
|
||||
symbol, indicator, curr_date, look_back_days, True,
|
||||
symbol,
|
||||
indicator,
|
||||
curr_date,
|
||||
look_back_days,
|
||||
True,
|
||||
)
|
||||
|
||||
|
||||
@staticmethod
|
||||
@tool
|
||||
def get_finnhub_company_insider_sentiment(
|
||||
|
|
@ -224,10 +231,11 @@ class Toolkit:
|
|||
"""
|
||||
|
||||
return interface.get_finnhub_company_insider_sentiment(
|
||||
ticker, curr_date, 30,
|
||||
ticker,
|
||||
curr_date,
|
||||
30,
|
||||
)
|
||||
|
||||
|
||||
@staticmethod
|
||||
@tool
|
||||
def get_finnhub_company_insider_transactions(
|
||||
|
|
@ -247,10 +255,11 @@ class Toolkit:
|
|||
"""
|
||||
|
||||
return interface.get_finnhub_company_insider_transactions(
|
||||
ticker, curr_date, 30,
|
||||
ticker,
|
||||
curr_date,
|
||||
30,
|
||||
)
|
||||
|
||||
|
||||
@staticmethod
|
||||
@tool
|
||||
def get_simfin_balance_sheet(
|
||||
|
|
@ -273,7 +282,6 @@ class Toolkit:
|
|||
|
||||
return interface.get_simfin_balance_sheet(ticker, freq, curr_date)
|
||||
|
||||
|
||||
@staticmethod
|
||||
@tool
|
||||
def get_simfin_cashflow(
|
||||
|
|
@ -296,7 +304,6 @@ class Toolkit:
|
|||
|
||||
return interface.get_simfin_cashflow(ticker, freq, curr_date)
|
||||
|
||||
|
||||
@staticmethod
|
||||
@tool
|
||||
def get_simfin_income_stmt(
|
||||
|
|
@ -318,10 +325,11 @@ class Toolkit:
|
|||
"""
|
||||
|
||||
return interface.get_simfin_income_statements(
|
||||
ticker, freq, curr_date,
|
||||
ticker,
|
||||
freq,
|
||||
curr_date,
|
||||
)
|
||||
|
||||
|
||||
@staticmethod
|
||||
@tool
|
||||
def get_google_news(
|
||||
|
|
@ -340,7 +348,6 @@ class Toolkit:
|
|||
|
||||
return interface.get_google_news(query, curr_date, 7)
|
||||
|
||||
|
||||
@staticmethod
|
||||
@tool
|
||||
def get_stock_news_openai(
|
||||
|
|
@ -358,7 +365,6 @@ class Toolkit:
|
|||
|
||||
return interface.get_stock_news_openai(ticker, curr_date)
|
||||
|
||||
|
||||
@staticmethod
|
||||
@tool
|
||||
def get_global_news_openai(
|
||||
|
|
@ -374,7 +380,6 @@ class Toolkit:
|
|||
|
||||
return interface.get_global_news_openai(curr_date)
|
||||
|
||||
|
||||
@staticmethod
|
||||
@tool
|
||||
def get_fundamentals_openai(
|
||||
|
|
@ -391,6 +396,6 @@ class Toolkit:
|
|||
"""
|
||||
|
||||
return interface.get_fundamentals_openai(
|
||||
ticker, curr_date,
|
||||
ticker,
|
||||
curr_date,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -21,7 +21,8 @@ def get_config():
|
|||
"project_dir": str(project_root / "tradingagents"),
|
||||
"results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", "./results"),
|
||||
"data_dir": os.getenv(
|
||||
"TRADINGAGENTS_DATA_DIR", "/Users/yluo/Documents/Code/ScAI/FR1-data",
|
||||
"TRADINGAGENTS_DATA_DIR",
|
||||
"/Users/yluo/Documents/Code/ScAI/FR1-data",
|
||||
),
|
||||
"data_cache_dir": str(
|
||||
project_root / "tradingagents" / "dataflows" / "data_cache",
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
|
||||
from tradingagents import default_config
|
||||
|
||||
# Use default config but allow it to be overridden
|
||||
|
|
|
|||
|
|
@ -22,7 +22,10 @@ def get_data_in_range(ticker, start_date, end_date, data_type, data_dir, period=
|
|||
)
|
||||
else:
|
||||
data_path = os.path.join(
|
||||
data_dir, "finnhub_data", data_type, f"{ticker}_data_formatted.json",
|
||||
data_dir,
|
||||
"finnhub_data",
|
||||
data_type,
|
||||
f"{ticker}_data_formatted.json",
|
||||
)
|
||||
|
||||
data = open(data_path)
|
||||
|
|
|
|||
|
|
@ -419,7 +419,8 @@ def get_stock_stats_indicators_window(
|
|||
symbol: Annotated[str, "ticker symbol of the company"],
|
||||
indicator: Annotated[str, "technical indicator to get the analysis and report of"],
|
||||
curr_date: Annotated[
|
||||
str, "The current trading date you are trading on, YYYY-mm-dd",
|
||||
str,
|
||||
"The current trading date you are trading on, YYYY-mm-dd",
|
||||
],
|
||||
look_back_days: Annotated[int, "how many days to look back"],
|
||||
online: Annotated[bool, "to fetch data online or offline"],
|
||||
|
|
@ -524,7 +525,10 @@ def get_stock_stats_indicators_window(
|
|||
# only do the trading dates
|
||||
if curr_date.strftime("%Y-%m-%d") in dates_in_df.values:
|
||||
indicator_value = get_stockstats_indicator(
|
||||
symbol, indicator, curr_date.strftime("%Y-%m-%d"), online,
|
||||
symbol,
|
||||
indicator,
|
||||
curr_date.strftime("%Y-%m-%d"),
|
||||
online,
|
||||
)
|
||||
|
||||
ind_string += f"{curr_date.strftime('%Y-%m-%d')}: {indicator_value}\n"
|
||||
|
|
@ -535,7 +539,10 @@ def get_stock_stats_indicators_window(
|
|||
ind_string = ""
|
||||
while curr_date >= before:
|
||||
indicator_value = get_stockstats_indicator(
|
||||
symbol, indicator, curr_date.strftime("%Y-%m-%d"), online,
|
||||
symbol,
|
||||
indicator,
|
||||
curr_date.strftime("%Y-%m-%d"),
|
||||
online,
|
||||
)
|
||||
|
||||
ind_string += f"{curr_date.strftime('%Y-%m-%d')}: {indicator_value}\n"
|
||||
|
|
@ -550,12 +557,12 @@ def get_stock_stats_indicators_window(
|
|||
)
|
||||
|
||||
|
||||
|
||||
def get_stockstats_indicator(
|
||||
symbol: Annotated[str, "ticker symbol of the company"],
|
||||
indicator: Annotated[str, "technical indicator to get the analysis and report of"],
|
||||
curr_date: Annotated[
|
||||
str, "The current trading date you are trading on, YYYY-mm-dd",
|
||||
str,
|
||||
"The current trading date you are trading on, YYYY-mm-dd",
|
||||
],
|
||||
online: Annotated[bool, "to fetch data online or offline"],
|
||||
) -> str:
|
||||
|
|
@ -608,7 +615,12 @@ def get_YFin_data_window(
|
|||
|
||||
# Set pandas display options to show the full DataFrame
|
||||
with pd.option_context(
|
||||
"display.max_rows", None, "display.max_columns", None, "display.width", None,
|
||||
"display.max_rows",
|
||||
None,
|
||||
"display.max_columns",
|
||||
None,
|
||||
"display.width",
|
||||
None,
|
||||
):
|
||||
df_string = filtered_data.to_string()
|
||||
|
||||
|
|
@ -694,7 +706,6 @@ def get_YFin_data(
|
|||
return filtered_data.reset_index(drop=True)
|
||||
|
||||
|
||||
|
||||
def get_stock_news_openai(ticker, curr_date):
|
||||
config = get_config()
|
||||
client = OpenAI(base_url=config["backend_url"])
|
||||
|
|
|
|||
|
|
@ -48,11 +48,14 @@ ticker_to_company = {
|
|||
|
||||
def fetch_top_from_category(
|
||||
category: Annotated[
|
||||
str, "Category to fetch top post from. Collection of subreddits.",
|
||||
str,
|
||||
"Category to fetch top post from. Collection of subreddits.",
|
||||
],
|
||||
date: Annotated[str, "Date to fetch top posts from."],
|
||||
max_limit: Annotated[int, "Maximum number of posts to fetch."],
|
||||
query: Annotated[str | None, "Optional query to search for in the subreddit."] = None,
|
||||
query: Annotated[
|
||||
str | None, "Optional query to search for in the subreddit."
|
||||
] = None,
|
||||
data_path: Annotated[
|
||||
str,
|
||||
"Path to the data folder. Default is 'reddit_data'.",
|
||||
|
|
@ -107,7 +110,9 @@ def fetch_top_from_category(
|
|||
found = False
|
||||
for term in search_terms:
|
||||
if re.search(
|
||||
term, parsed_line["title"], re.IGNORECASE,
|
||||
term,
|
||||
parsed_line["title"],
|
||||
re.IGNORECASE,
|
||||
) or re.search(term, parsed_line["selftext"], re.IGNORECASE):
|
||||
found = True
|
||||
break
|
||||
|
|
|
|||
|
|
@ -13,10 +13,12 @@ class StockstatsUtils:
|
|||
def get_stock_stats(
|
||||
symbol: Annotated[str, "ticker symbol for the company"],
|
||||
indicator: Annotated[
|
||||
str, "quantitative indicators based off of the stock data for the company",
|
||||
str,
|
||||
"quantitative indicators based off of the stock data for the company",
|
||||
],
|
||||
curr_date: Annotated[
|
||||
str, "curr date for retrieving stock price data, YYYY-mm-dd",
|
||||
str,
|
||||
"curr date for retrieving stock price data, YYYY-mm-dd",
|
||||
],
|
||||
data_dir: Annotated[
|
||||
str,
|
||||
|
|
|
|||
|
|
@ -28,10 +28,12 @@ class YFinanceUtils:
|
|||
def get_stock_data(
|
||||
self: Annotated[str, "ticker symbol"],
|
||||
start_date: Annotated[
|
||||
str, "start date for retrieving stock price data, YYYY-mm-dd",
|
||||
str,
|
||||
"start date for retrieving stock price data, YYYY-mm-dd",
|
||||
],
|
||||
end_date: Annotated[
|
||||
str, "end date for retrieving stock price data, YYYY-mm-dd",
|
||||
str,
|
||||
"end date for retrieving stock price data, YYYY-mm-dd",
|
||||
],
|
||||
save_path: SavePathType = None,
|
||||
) -> DataFrame:
|
||||
|
|
|
|||
|
|
@ -16,7 +16,9 @@ class Propagator:
|
|||
self.max_recur_limit = max_recur_limit
|
||||
|
||||
def create_initial_state(
|
||||
self, company_name: str, trade_date: str,
|
||||
self,
|
||||
company_name: str,
|
||||
trade_date: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Create the initial state for the agent graph."""
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -57,7 +57,11 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur
|
|||
return f"{curr_market_report}\n\n{curr_sentiment_report}\n\n{curr_news_report}\n\n{curr_fundamentals_report}"
|
||||
|
||||
def _reflect_on_component(
|
||||
self, component_type: str, report: str, situation: str, returns_losses,
|
||||
self,
|
||||
component_type: str,
|
||||
report: str,
|
||||
situation: str,
|
||||
returns_losses,
|
||||
) -> str:
|
||||
"""Generate reflection for a component."""
|
||||
messages = [
|
||||
|
|
@ -76,7 +80,10 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur
|
|||
bull_debate_history = current_state["investment_debate_state"]["bull_history"]
|
||||
|
||||
result = self._reflect_on_component(
|
||||
"BULL", bull_debate_history, situation, returns_losses,
|
||||
"BULL",
|
||||
bull_debate_history,
|
||||
situation,
|
||||
returns_losses,
|
||||
)
|
||||
bull_memory.add_situations([(situation, result)])
|
||||
|
||||
|
|
@ -86,7 +93,10 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur
|
|||
bear_debate_history = current_state["investment_debate_state"]["bear_history"]
|
||||
|
||||
result = self._reflect_on_component(
|
||||
"BEAR", bear_debate_history, situation, returns_losses,
|
||||
"BEAR",
|
||||
bear_debate_history,
|
||||
situation,
|
||||
returns_losses,
|
||||
)
|
||||
bear_memory.add_situations([(situation, result)])
|
||||
|
||||
|
|
@ -96,7 +106,10 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur
|
|||
trader_decision = current_state["trader_investment_plan"]
|
||||
|
||||
result = self._reflect_on_component(
|
||||
"TRADER", trader_decision, situation, returns_losses,
|
||||
"TRADER",
|
||||
trader_decision,
|
||||
situation,
|
||||
returns_losses,
|
||||
)
|
||||
trader_memory.add_situations([(situation, result)])
|
||||
|
||||
|
|
@ -106,7 +119,10 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur
|
|||
judge_decision = current_state["investment_debate_state"]["judge_decision"]
|
||||
|
||||
result = self._reflect_on_component(
|
||||
"INVEST JUDGE", judge_decision, situation, returns_losses,
|
||||
"INVEST JUDGE",
|
||||
judge_decision,
|
||||
situation,
|
||||
returns_losses,
|
||||
)
|
||||
invest_judge_memory.add_situations([(situation, result)])
|
||||
|
||||
|
|
@ -116,6 +132,9 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur
|
|||
judge_decision = current_state["risk_debate_state"]["judge_decision"]
|
||||
|
||||
result = self._reflect_on_component(
|
||||
"RISK JUDGE", judge_decision, situation, returns_losses,
|
||||
"RISK JUDGE",
|
||||
judge_decision,
|
||||
situation,
|
||||
returns_losses,
|
||||
)
|
||||
risk_manager_memory.add_situations([(situation, result)])
|
||||
|
|
|
|||
|
|
@ -55,7 +55,8 @@ class GraphSetup:
|
|||
self.conditional_logic = conditional_logic
|
||||
|
||||
def setup_graph(
|
||||
self, selected_analysts=None,
|
||||
self,
|
||||
selected_analysts=None,
|
||||
):
|
||||
"""Set up and compile the agent workflow graph.
|
||||
|
||||
|
|
@ -79,41 +80,48 @@ class GraphSetup:
|
|||
|
||||
if "market" in selected_analysts:
|
||||
analyst_nodes["market"] = create_market_analyst(
|
||||
self.quick_thinking_llm, self.toolkit,
|
||||
self.quick_thinking_llm,
|
||||
self.toolkit,
|
||||
)
|
||||
delete_nodes["market"] = create_msg_delete()
|
||||
tool_nodes["market"] = self.tool_nodes["market"]
|
||||
|
||||
if "social" in selected_analysts:
|
||||
analyst_nodes["social"] = create_social_media_analyst(
|
||||
self.quick_thinking_llm, self.toolkit,
|
||||
self.quick_thinking_llm,
|
||||
self.toolkit,
|
||||
)
|
||||
delete_nodes["social"] = create_msg_delete()
|
||||
tool_nodes["social"] = self.tool_nodes["social"]
|
||||
|
||||
if "news" in selected_analysts:
|
||||
analyst_nodes["news"] = create_news_analyst(
|
||||
self.quick_thinking_llm, self.toolkit,
|
||||
self.quick_thinking_llm,
|
||||
self.toolkit,
|
||||
)
|
||||
delete_nodes["news"] = create_msg_delete()
|
||||
tool_nodes["news"] = self.tool_nodes["news"]
|
||||
|
||||
if "fundamentals" in selected_analysts:
|
||||
analyst_nodes["fundamentals"] = create_fundamentals_analyst(
|
||||
self.quick_thinking_llm, self.toolkit,
|
||||
self.quick_thinking_llm,
|
||||
self.toolkit,
|
||||
)
|
||||
delete_nodes["fundamentals"] = create_msg_delete()
|
||||
tool_nodes["fundamentals"] = self.tool_nodes["fundamentals"]
|
||||
|
||||
# Create researcher and manager nodes
|
||||
bull_researcher_node = create_bull_researcher(
|
||||
self.quick_thinking_llm, self.bull_memory,
|
||||
self.quick_thinking_llm,
|
||||
self.bull_memory,
|
||||
)
|
||||
bear_researcher_node = create_bear_researcher(
|
||||
self.quick_thinking_llm, self.bear_memory,
|
||||
self.quick_thinking_llm,
|
||||
self.bear_memory,
|
||||
)
|
||||
research_manager_node = create_research_manager(
|
||||
self.deep_thinking_llm, self.invest_judge_memory,
|
||||
self.deep_thinking_llm,
|
||||
self.invest_judge_memory,
|
||||
)
|
||||
trader_node = create_trader(self.quick_thinking_llm, self.trader_memory)
|
||||
|
||||
|
|
@ -122,7 +130,8 @@ class GraphSetup:
|
|||
neutral_analyst = create_neutral_debator(self.quick_thinking_llm)
|
||||
safe_analyst = create_safe_debator(self.quick_thinking_llm)
|
||||
risk_manager_node = create_risk_manager(
|
||||
self.deep_thinking_llm, self.risk_manager_memory,
|
||||
self.deep_thinking_llm,
|
||||
self.risk_manager_memory,
|
||||
)
|
||||
|
||||
# Create workflow
|
||||
|
|
@ -132,7 +141,8 @@ class GraphSetup:
|
|||
for analyst_type, node in analyst_nodes.items():
|
||||
workflow.add_node(f"{analyst_type.capitalize()} Analyst", node)
|
||||
workflow.add_node(
|
||||
f"Msg Clear {analyst_type.capitalize()}", delete_nodes[analyst_type],
|
||||
f"Msg Clear {analyst_type.capitalize()}",
|
||||
delete_nodes[analyst_type],
|
||||
)
|
||||
workflow.add_node(f"tools_{analyst_type}", tool_nodes[analyst_type])
|
||||
|
||||
|
|
|
|||
|
|
@ -59,7 +59,8 @@ class TradingAgentsGraph:
|
|||
or self.config["llm_provider"] == "openrouter"
|
||||
):
|
||||
self.deep_thinking_llm = ChatOpenAI(
|
||||
model=self.config["deep_think_llm"], base_url=self.config["backend_url"],
|
||||
model=self.config["deep_think_llm"],
|
||||
base_url=self.config["backend_url"],
|
||||
)
|
||||
self.quick_thinking_llm = ChatOpenAI(
|
||||
model=self.config["quick_think_llm"],
|
||||
|
|
@ -67,7 +68,8 @@ class TradingAgentsGraph:
|
|||
)
|
||||
elif self.config["llm_provider"].lower() == "anthropic":
|
||||
self.deep_thinking_llm = ChatAnthropic(
|
||||
model=self.config["deep_think_llm"], base_url=self.config["backend_url"],
|
||||
model=self.config["deep_think_llm"],
|
||||
base_url=self.config["backend_url"],
|
||||
)
|
||||
self.quick_thinking_llm = ChatAnthropic(
|
||||
model=self.config["quick_think_llm"],
|
||||
|
|
@ -91,10 +93,12 @@ class TradingAgentsGraph:
|
|||
self.bear_memory = FinancialSituationMemory("bear_memory", self.config)
|
||||
self.trader_memory = FinancialSituationMemory("trader_memory", self.config)
|
||||
self.invest_judge_memory = FinancialSituationMemory(
|
||||
"invest_judge_memory", self.config,
|
||||
"invest_judge_memory",
|
||||
self.config,
|
||||
)
|
||||
self.risk_manager_memory = FinancialSituationMemory(
|
||||
"risk_manager_memory", self.config,
|
||||
"risk_manager_memory",
|
||||
self.config,
|
||||
)
|
||||
|
||||
# Create tool nodes
|
||||
|
|
@ -179,7 +183,8 @@ class TradingAgentsGraph:
|
|||
|
||||
# Initialize state
|
||||
init_agent_state = self.propagator.create_initial_state(
|
||||
company_name, trade_date,
|
||||
company_name,
|
||||
trade_date,
|
||||
)
|
||||
args = self.propagator.get_graph_args()
|
||||
|
||||
|
|
@ -252,19 +257,29 @@ class TradingAgentsGraph:
|
|||
def reflect_and_remember(self, returns_losses):
|
||||
"""Reflect on decisions and update memory based on returns."""
|
||||
self.reflector.reflect_bull_researcher(
|
||||
self.curr_state, returns_losses, self.bull_memory,
|
||||
self.curr_state,
|
||||
returns_losses,
|
||||
self.bull_memory,
|
||||
)
|
||||
self.reflector.reflect_bear_researcher(
|
||||
self.curr_state, returns_losses, self.bear_memory,
|
||||
self.curr_state,
|
||||
returns_losses,
|
||||
self.bear_memory,
|
||||
)
|
||||
self.reflector.reflect_trader(
|
||||
self.curr_state, returns_losses, self.trader_memory,
|
||||
self.curr_state,
|
||||
returns_losses,
|
||||
self.trader_memory,
|
||||
)
|
||||
self.reflector.reflect_invest_judge(
|
||||
self.curr_state, returns_losses, self.invest_judge_memory,
|
||||
self.curr_state,
|
||||
returns_losses,
|
||||
self.invest_judge_memory,
|
||||
)
|
||||
self.reflector.reflect_risk_manager(
|
||||
self.curr_state, returns_losses, self.risk_manager_memory,
|
||||
self.curr_state,
|
||||
returns_losses,
|
||||
self.risk_manager_memory,
|
||||
)
|
||||
|
||||
def process_signal(self, full_signal):
|
||||
|
|
|
|||
Loading…
Reference in New Issue