commodity support and cleanup cli code

This commit is contained in:
Marvin Gabler 2025-10-21 14:46:10 +02:00
parent 13b826a31d
commit 0c917f01d1
21 changed files with 992 additions and 717 deletions

34
cli/asset_detection.py Normal file
View File

@ -0,0 +1,34 @@
"""Asset class detection utilities for the CLI."""
# Known commodities from Alpha Vantage
KNOWN_COMMODITIES = {
"WTI",
"BRENT",
"NATURAL_GAS",
"COPPER",
"ALUMINUM",
"WHEAT",
"CORN",
"SUGAR",
"COTTON",
"COFFEE",
}
def detect_asset_class(symbol: str) -> str:
"""
Automatically detect if a symbol is a commodity or equity.
Args:
symbol: The ticker symbol (e.g., "BRENT", "AAPL")
Returns:
"commodity" if the symbol matches a known commodity, "equity" otherwise
"""
return "commodity" if symbol.upper() in KNOWN_COMMODITIES else "equity"
def get_asset_class_display_name(asset_class: str) -> str:
"""Get a human-friendly display name for the asset class."""
return asset_class.capitalize()

28
cli/helper_functions.py Normal file
View File

@ -0,0 +1,28 @@
"""Helper functions for the TradingAgents CLI."""
def update_research_team_status(message_buffer, status):
"""Update status for all research team members and trader."""
research_team = ["Bull Researcher", "Bear Researcher", "Research Manager", "Trader"]
for agent in research_team:
message_buffer.update_agent_status(agent, status)
def extract_content_string(content):
"""Extract string content from various message formats."""
if isinstance(content, str):
return content
elif isinstance(content, list):
# Handle Anthropic's list format
text_parts = []
for item in content:
if isinstance(item, dict):
if item.get('type') == 'text':
text_parts.append(item.get('text', ''))
elif item.get('type') == 'tool_use':
text_parts.append(f"[Tool: {item.get('name', 'unknown')}]")
else:
text_parts.append(str(item))
return ' '.join(text_parts)
else:
return str(content)

View File

@ -28,6 +28,11 @@ from tradingagents.graph.trading_graph import TradingAgentsGraph
from tradingagents.default_config import DEFAULT_CONFIG
from cli.models import AnalystType
from cli.utils import *
from cli.message_buffer import MessageBuffer
from cli.ui_display import create_layout, update_display
from cli.report_display import display_complete_report
from cli.helper_functions import update_research_team_status, extract_content_string
from cli.asset_detection import detect_asset_class, get_asset_class_display_name
console = Console()
@ -38,362 +43,10 @@ app = typer.Typer(
)
# Create a deque to store recent messages with a maximum length
class MessageBuffer:
def __init__(self, max_length=100):
self.messages = deque(maxlen=max_length)
self.tool_calls = deque(maxlen=max_length)
self.current_report = None
self.final_report = None # Store the complete final report
self.agent_status = {
# Analyst Team
"Market Analyst": "pending",
"Social Analyst": "pending",
"News Analyst": "pending",
"Fundamentals Analyst": "pending",
# Research Team
"Bull Researcher": "pending",
"Bear Researcher": "pending",
"Research Manager": "pending",
# Trading Team
"Trader": "pending",
# Risk Management Team
"Risky Analyst": "pending",
"Neutral Analyst": "pending",
"Safe Analyst": "pending",
# Portfolio Management Team
"Portfolio Manager": "pending",
}
self.current_agent = None
self.report_sections = {
"market_report": None,
"sentiment_report": None,
"news_report": None,
"fundamentals_report": None,
"investment_plan": None,
"trader_investment_plan": None,
"final_trade_decision": None,
}
def add_message(self, message_type, content):
timestamp = datetime.datetime.now().strftime("%H:%M:%S")
self.messages.append((timestamp, message_type, content))
def add_tool_call(self, tool_name, args):
timestamp = datetime.datetime.now().strftime("%H:%M:%S")
self.tool_calls.append((timestamp, tool_name, args))
def update_agent_status(self, agent, status):
if agent in self.agent_status:
self.agent_status[agent] = status
self.current_agent = agent
def update_report_section(self, section_name, content):
if section_name in self.report_sections:
self.report_sections[section_name] = content
self._update_current_report()
def _update_current_report(self):
# For the panel display, only show the most recently updated section
latest_section = None
latest_content = None
# Find the most recently updated section
for section, content in self.report_sections.items():
if content is not None:
latest_section = section
latest_content = content
if latest_section and latest_content:
# Format the current section for display
section_titles = {
"market_report": "Market Analysis",
"sentiment_report": "Social Sentiment",
"news_report": "News Analysis",
"fundamentals_report": "Fundamentals Analysis",
"investment_plan": "Research Team Decision",
"trader_investment_plan": "Trading Team Plan",
"final_trade_decision": "Portfolio Management Decision",
}
self.current_report = (
f"### {section_titles[latest_section]}\n{latest_content}"
)
# Update the final complete report
self._update_final_report()
def _update_final_report(self):
report_parts = []
# Analyst Team Reports
if any(
self.report_sections[section]
for section in [
"market_report",
"sentiment_report",
"news_report",
"fundamentals_report",
]
):
report_parts.append("## Analyst Team Reports")
if self.report_sections["market_report"]:
report_parts.append(
f"### Market Analysis\n{self.report_sections['market_report']}"
)
if self.report_sections["sentiment_report"]:
report_parts.append(
f"### Social Sentiment\n{self.report_sections['sentiment_report']}"
)
if self.report_sections["news_report"]:
report_parts.append(
f"### News Analysis\n{self.report_sections['news_report']}"
)
if self.report_sections["fundamentals_report"]:
report_parts.append(
f"### Fundamentals Analysis\n{self.report_sections['fundamentals_report']}"
)
# Research Team Reports
if self.report_sections["investment_plan"]:
report_parts.append("## Research Team Decision")
report_parts.append(f"{self.report_sections['investment_plan']}")
# Trading Team Reports
if self.report_sections["trader_investment_plan"]:
report_parts.append("## Trading Team Plan")
report_parts.append(f"{self.report_sections['trader_investment_plan']}")
# Portfolio Management Decision
if self.report_sections["final_trade_decision"]:
report_parts.append("## Portfolio Management Decision")
report_parts.append(f"{self.report_sections['final_trade_decision']}")
self.final_report = "\n\n".join(report_parts) if report_parts else None
# Create a global message buffer instance
message_buffer = MessageBuffer()
def create_layout():
layout = Layout()
layout.split_column(
Layout(name="header", size=3),
Layout(name="main"),
Layout(name="footer", size=3),
)
layout["main"].split_column(
Layout(name="upper", ratio=3), Layout(name="analysis", ratio=5)
)
layout["upper"].split_row(
Layout(name="progress", ratio=2), Layout(name="messages", ratio=3)
)
return layout
def update_display(layout, spinner_text=None):
# Header with welcome message
layout["header"].update(
Panel(
"[bold green]Welcome to TradingAgents CLI[/bold green]\n"
"[dim]© [Tauric Research](https://github.com/TauricResearch)[/dim]",
title="Welcome to TradingAgents",
border_style="green",
padding=(1, 2),
expand=True,
)
)
# Progress panel showing agent status
progress_table = Table(
show_header=True,
header_style="bold magenta",
show_footer=False,
box=box.SIMPLE_HEAD, # Use simple header with horizontal lines
title=None, # Remove the redundant Progress title
padding=(0, 2), # Add horizontal padding
expand=True, # Make table expand to fill available space
)
progress_table.add_column("Team", style="cyan", justify="center", width=20)
progress_table.add_column("Agent", style="green", justify="center", width=20)
progress_table.add_column("Status", style="yellow", justify="center", width=20)
# Group agents by team
teams = {
"Analyst Team": [
"Market Analyst",
"Social Analyst",
"News Analyst",
"Fundamentals Analyst",
],
"Research Team": ["Bull Researcher", "Bear Researcher", "Research Manager"],
"Trading Team": ["Trader"],
"Risk Management": ["Risky Analyst", "Neutral Analyst", "Safe Analyst"],
"Portfolio Management": ["Portfolio Manager"],
}
for team, agents in teams.items():
# Add first agent with team name
first_agent = agents[0]
status = message_buffer.agent_status[first_agent]
if status == "in_progress":
spinner = Spinner(
"dots", text="[blue]in_progress[/blue]", style="bold cyan"
)
status_cell = spinner
else:
status_color = {
"pending": "yellow",
"completed": "green",
"error": "red",
}.get(status, "white")
status_cell = f"[{status_color}]{status}[/{status_color}]"
progress_table.add_row(team, first_agent, status_cell)
# Add remaining agents in team
for agent in agents[1:]:
status = message_buffer.agent_status[agent]
if status == "in_progress":
spinner = Spinner(
"dots", text="[blue]in_progress[/blue]", style="bold cyan"
)
status_cell = spinner
else:
status_color = {
"pending": "yellow",
"completed": "green",
"error": "red",
}.get(status, "white")
status_cell = f"[{status_color}]{status}[/{status_color}]"
progress_table.add_row("", agent, status_cell)
# Add horizontal line after each team
progress_table.add_row("" * 20, "" * 20, "" * 20, style="dim")
layout["progress"].update(
Panel(progress_table, title="Progress", border_style="cyan", padding=(1, 2))
)
# Messages panel showing recent messages and tool calls
messages_table = Table(
show_header=True,
header_style="bold magenta",
show_footer=False,
expand=True, # Make table expand to fill available space
box=box.MINIMAL, # Use minimal box style for a lighter look
show_lines=True, # Keep horizontal lines
padding=(0, 1), # Add some padding between columns
)
messages_table.add_column("Time", style="cyan", width=8, justify="center")
messages_table.add_column("Type", style="green", width=10, justify="center")
messages_table.add_column(
"Content", style="white", no_wrap=False, ratio=1
) # Make content column expand
# Combine tool calls and messages
all_messages = []
# Add tool calls
for timestamp, tool_name, args in message_buffer.tool_calls:
# Truncate tool call args if too long
if isinstance(args, str) and len(args) > 100:
args = args[:97] + "..."
all_messages.append((timestamp, "Tool", f"{tool_name}: {args}"))
# Add regular messages
for timestamp, msg_type, content in message_buffer.messages:
# Convert content to string if it's not already
content_str = content
if isinstance(content, list):
# Handle list of content blocks (Anthropic format)
text_parts = []
for item in content:
if isinstance(item, dict):
if item.get('type') == 'text':
text_parts.append(item.get('text', ''))
elif item.get('type') == 'tool_use':
text_parts.append(f"[Tool: {item.get('name', 'unknown')}]")
else:
text_parts.append(str(item))
content_str = ' '.join(text_parts)
elif not isinstance(content_str, str):
content_str = str(content)
# Truncate message content if too long
if len(content_str) > 200:
content_str = content_str[:197] + "..."
all_messages.append((timestamp, msg_type, content_str))
# Sort by timestamp
all_messages.sort(key=lambda x: x[0])
# Calculate how many messages we can show based on available space
# Start with a reasonable number and adjust based on content length
max_messages = 12 # Increased from 8 to better fill the space
# Get the last N messages that will fit in the panel
recent_messages = all_messages[-max_messages:]
# Add messages to table
for timestamp, msg_type, content in recent_messages:
# Format content with word wrapping
wrapped_content = Text(content, overflow="fold")
messages_table.add_row(timestamp, msg_type, wrapped_content)
if spinner_text:
messages_table.add_row("", "Spinner", spinner_text)
# Add a footer to indicate if messages were truncated
if len(all_messages) > max_messages:
messages_table.footer = (
f"[dim]Showing last {max_messages} of {len(all_messages)} messages[/dim]"
)
layout["messages"].update(
Panel(
messages_table,
title="Messages & Tools",
border_style="blue",
padding=(1, 2),
)
)
# Analysis panel showing current report
if message_buffer.current_report:
layout["analysis"].update(
Panel(
Markdown(message_buffer.current_report),
title="Current Report",
border_style="green",
padding=(1, 2),
)
)
else:
layout["analysis"].update(
Panel(
"[italic]Waiting for analysis report...[/italic]",
title="Current Report",
border_style="green",
padding=(1, 2),
)
)
# Footer with statistics
tool_calls_count = len(message_buffer.tool_calls)
llm_calls_count = sum(
1 for _, msg_type, _ in message_buffer.messages if msg_type == "Reasoning"
)
reports_count = sum(
1 for content in message_buffer.report_sections.values() if content is not None
)
stats_table = Table(show_header=False, box=None, padding=(0, 2), expand=True)
stats_table.add_column("Stats", justify="center")
stats_table.add_row(
f"Tool Calls: {tool_calls_count} | LLM Calls: {llm_calls_count} | Generated Reports: {reports_count}"
)
layout["footer"].update(Panel(stats_table, border_style="grey50"))
def get_user_selections():
"""Get all user selections before starting the analysis display."""
@ -436,6 +89,12 @@ def get_user_selections():
)
)
selected_ticker = get_ticker()
# Auto-detect asset class from ticker
asset_class = detect_asset_class(selected_ticker)
console.print(
f"[dim]→ Detected asset class: [bold]{get_asset_class_display_name(asset_class)}[/bold][/dim]\n"
)
# Step 2: Analysis date
default_date = datetime.datetime.now().strftime("%Y-%m-%d")
@ -454,7 +113,7 @@ def get_user_selections():
"Step 3: Analysts Team", "Select your LLM analyst agents for the analysis"
)
)
selected_analysts = select_analysts()
selected_analysts = select_analysts(asset_class)
console.print(
f"[green]Selected analysts:[/green] {', '.join(analyst.value for analyst in selected_analysts)}"
)
@ -467,10 +126,10 @@ def get_user_selections():
)
selected_research_depth = select_research_depth()
# Step 5: OpenAI backend
# Step 5: LLM backend
console.print(
create_question_box(
"Step 5: OpenAI backend", "Select which service to talk to"
"Step 5: LLM Backend", "Select which service to talk to"
)
)
selected_llm_provider, backend_url = select_llm_provider()
@ -487,6 +146,7 @@ def get_user_selections():
return {
"ticker": selected_ticker,
"analysis_date": analysis_date,
"asset_class": asset_class,
"analysts": selected_analysts,
"research_depth": selected_research_depth,
"llm_provider": selected_llm_provider.lower(),
@ -520,220 +180,6 @@ def get_analysis_date():
)
def display_complete_report(final_state):
"""Display the complete analysis report with team-based panels."""
console.print("\n[bold green]Complete Analysis Report[/bold green]\n")
# I. Analyst Team Reports
analyst_reports = []
# Market Analyst Report
if final_state.get("market_report"):
analyst_reports.append(
Panel(
Markdown(final_state["market_report"]),
title="Market Analyst",
border_style="blue",
padding=(1, 2),
)
)
# Social Analyst Report
if final_state.get("sentiment_report"):
analyst_reports.append(
Panel(
Markdown(final_state["sentiment_report"]),
title="Social Analyst",
border_style="blue",
padding=(1, 2),
)
)
# News Analyst Report
if final_state.get("news_report"):
analyst_reports.append(
Panel(
Markdown(final_state["news_report"]),
title="News Analyst",
border_style="blue",
padding=(1, 2),
)
)
# Fundamentals Analyst Report
if final_state.get("fundamentals_report"):
analyst_reports.append(
Panel(
Markdown(final_state["fundamentals_report"]),
title="Fundamentals Analyst",
border_style="blue",
padding=(1, 2),
)
)
if analyst_reports:
console.print(
Panel(
Columns(analyst_reports, equal=True, expand=True),
title="I. Analyst Team Reports",
border_style="cyan",
padding=(1, 2),
)
)
# II. Research Team Reports
if final_state.get("investment_debate_state"):
research_reports = []
debate_state = final_state["investment_debate_state"]
# Bull Researcher Analysis
if debate_state.get("bull_history"):
research_reports.append(
Panel(
Markdown(debate_state["bull_history"]),
title="Bull Researcher",
border_style="blue",
padding=(1, 2),
)
)
# Bear Researcher Analysis
if debate_state.get("bear_history"):
research_reports.append(
Panel(
Markdown(debate_state["bear_history"]),
title="Bear Researcher",
border_style="blue",
padding=(1, 2),
)
)
# Research Manager Decision
if debate_state.get("judge_decision"):
research_reports.append(
Panel(
Markdown(debate_state["judge_decision"]),
title="Research Manager",
border_style="blue",
padding=(1, 2),
)
)
if research_reports:
console.print(
Panel(
Columns(research_reports, equal=True, expand=True),
title="II. Research Team Decision",
border_style="magenta",
padding=(1, 2),
)
)
# III. Trading Team Reports
if final_state.get("trader_investment_plan"):
console.print(
Panel(
Panel(
Markdown(final_state["trader_investment_plan"]),
title="Trader",
border_style="blue",
padding=(1, 2),
),
title="III. Trading Team Plan",
border_style="yellow",
padding=(1, 2),
)
)
# IV. Risk Management Team Reports
if final_state.get("risk_debate_state"):
risk_reports = []
risk_state = final_state["risk_debate_state"]
# Aggressive (Risky) Analyst Analysis
if risk_state.get("risky_history"):
risk_reports.append(
Panel(
Markdown(risk_state["risky_history"]),
title="Aggressive Analyst",
border_style="blue",
padding=(1, 2),
)
)
# Conservative (Safe) Analyst Analysis
if risk_state.get("safe_history"):
risk_reports.append(
Panel(
Markdown(risk_state["safe_history"]),
title="Conservative Analyst",
border_style="blue",
padding=(1, 2),
)
)
# Neutral Analyst Analysis
if risk_state.get("neutral_history"):
risk_reports.append(
Panel(
Markdown(risk_state["neutral_history"]),
title="Neutral Analyst",
border_style="blue",
padding=(1, 2),
)
)
if risk_reports:
console.print(
Panel(
Columns(risk_reports, equal=True, expand=True),
title="IV. Risk Management Team Decision",
border_style="red",
padding=(1, 2),
)
)
# V. Portfolio Manager Decision
if risk_state.get("judge_decision"):
console.print(
Panel(
Panel(
Markdown(risk_state["judge_decision"]),
title="Portfolio Manager",
border_style="blue",
padding=(1, 2),
),
title="V. Portfolio Manager Decision",
border_style="green",
padding=(1, 2),
)
)
def update_research_team_status(status):
"""Update status for all research team members and trader."""
research_team = ["Bull Researcher", "Bear Researcher", "Research Manager", "Trader"]
for agent in research_team:
message_buffer.update_agent_status(agent, status)
def extract_content_string(content):
"""Extract string content from various message formats."""
if isinstance(content, str):
return content
elif isinstance(content, list):
# Handle Anthropic's list format
text_parts = []
for item in content:
if isinstance(item, dict):
if item.get('type') == 'text':
text_parts.append(item.get('text', ''))
elif item.get('type') == 'tool_use':
text_parts.append(f"[Tool: {item.get('name', 'unknown')}]")
else:
text_parts.append(str(item))
return ' '.join(text_parts)
else:
return str(content)
def run_analysis():
# First get all user selections
@ -747,6 +193,7 @@ def run_analysis():
config["deep_think_llm"] = selections["deep_thinker"]
config["backend_url"] = selections["backend_url"]
config["llm_provider"] = selections["llm_provider"].lower()
config["asset_class"] = selections["asset_class"]
# Initialize the graph
graph = TradingAgentsGraph(
@ -805,7 +252,7 @@ def run_analysis():
with Live(layout, refresh_per_second=4) as live:
# Initial display
update_display(layout)
update_display(layout, message_buffer)
# Add initial messages
message_buffer.add_message("System", f"Selected ticker: {selections['ticker']}")
@ -816,7 +263,7 @@ def run_analysis():
"System",
f"Selected analysts: {', '.join(analyst.value for analyst in selections['analysts'])}",
)
update_display(layout)
update_display(layout, message_buffer)
# Reset agent statuses
for agent in message_buffer.agent_status:
@ -831,18 +278,20 @@ def run_analysis():
# Update agent status to in_progress for the first analyst
first_analyst = f"{selections['analysts'][0].value.capitalize()} Analyst"
message_buffer.update_agent_status(first_analyst, "in_progress")
update_display(layout)
update_display(layout, message_buffer)
# Create spinner text
spinner_text = (
f"Analyzing {selections['ticker']} on {selections['analysis_date']}..."
)
update_display(layout, spinner_text)
update_display(layout, message_buffer, spinner_text)
# Initialize state and get graph args
init_agent_state = graph.propagator.create_initial_state(
selections["ticker"], selections["analysis_date"]
)
# CRITICAL: Add asset_class to state so market analyst can branch correctly
init_agent_state["asset_class"] = selections["asset_class"]
args = graph.propagator.get_graph_args()
# Stream the analysis
@ -875,49 +324,24 @@ def run_analysis():
message_buffer.add_tool_call(tool_call.name, tool_call.args)
# Update reports and agent status based on chunk content
# Analyst Team Reports
if "market_report" in chunk and chunk["market_report"]:
message_buffer.update_report_section(
"market_report", chunk["market_report"]
)
message_buffer.update_agent_status("Market Analyst", "completed")
# Set next analyst to in_progress
if "social" in selections["analysts"]:
message_buffer.update_agent_status(
"Social Analyst", "in_progress"
)
if "sentiment_report" in chunk and chunk["sentiment_report"]:
message_buffer.update_report_section(
"sentiment_report", chunk["sentiment_report"]
)
message_buffer.update_agent_status("Social Analyst", "completed")
# Set next analyst to in_progress
if "news" in selections["analysts"]:
message_buffer.update_agent_status(
"News Analyst", "in_progress"
)
if "news_report" in chunk and chunk["news_report"]:
message_buffer.update_report_section(
"news_report", chunk["news_report"]
)
message_buffer.update_agent_status("News Analyst", "completed")
# Set next analyst to in_progress
if "fundamentals" in selections["analysts"]:
message_buffer.update_agent_status(
"Fundamentals Analyst", "in_progress"
)
if "fundamentals_report" in chunk and chunk["fundamentals_report"]:
message_buffer.update_report_section(
"fundamentals_report", chunk["fundamentals_report"]
)
message_buffer.update_agent_status(
"Fundamentals Analyst", "completed"
)
# Set all research team members to in_progress
update_research_team_status("in_progress")
# Analyst Team Reports - use a mapping to reduce repetition
analyst_mappings = [
("market_report", "Market Analyst", "social", "Social Analyst"),
("sentiment_report", "Social Analyst", "news", "News Analyst"),
("news_report", "News Analyst", "fundamentals", "Fundamentals Analyst"),
("fundamentals_report", "Fundamentals Analyst", None, None),
]
for report_key, analyst_name, next_type, next_analyst in analyst_mappings:
if report_key in chunk and chunk[report_key]:
message_buffer.update_report_section(report_key, chunk[report_key])
message_buffer.update_agent_status(analyst_name, "completed")
if report_key == "fundamentals_report":
# Special case: set all research team to in_progress
update_research_team_status(message_buffer, "in_progress")
elif next_type and next_type in [a.value for a in selections["analysts"]]:
message_buffer.update_agent_status(next_analyst, "in_progress")
# Research Team - Handle Investment Debate State
if (
@ -929,7 +353,7 @@ def run_analysis():
# Update Bull Researcher status and report
if "bull_history" in debate_state and debate_state["bull_history"]:
# Keep all research team members in progress
update_research_team_status("in_progress")
update_research_team_status(message_buffer, "in_progress")
# Extract latest bull response
bull_responses = debate_state["bull_history"].split("\n")
latest_bull = bull_responses[-1] if bull_responses else ""
@ -944,7 +368,7 @@ def run_analysis():
# Update Bear Researcher status and report
if "bear_history" in debate_state and debate_state["bear_history"]:
# Keep all research team members in progress
update_research_team_status("in_progress")
update_research_team_status(message_buffer, "in_progress")
# Extract latest bear response
bear_responses = debate_state["bear_history"].split("\n")
latest_bear = bear_responses[-1] if bear_responses else ""
@ -962,7 +386,7 @@ def run_analysis():
and debate_state["judge_decision"]
):
# Keep all research team members in progress until final decision
update_research_team_status("in_progress")
update_research_team_status(message_buffer, "in_progress")
message_buffer.add_message(
"Reasoning",
f"Research Manager: {debate_state['judge_decision']}",
@ -973,7 +397,7 @@ def run_analysis():
f"{message_buffer.report_sections['investment_plan']}\n\n### Research Manager Decision\n{debate_state['judge_decision']}",
)
# Mark all research team members as completed
update_research_team_status("completed")
update_research_team_status(message_buffer, "completed")
# Set first risk analyst to in_progress
message_buffer.update_agent_status(
"Risky Analyst", "in_progress"
@ -993,60 +417,25 @@ def run_analysis():
# Risk Management Team - Handle Risk Debate State
if "risk_debate_state" in chunk and chunk["risk_debate_state"]:
risk_state = chunk["risk_debate_state"]
# Update Risky Analyst status and report
if (
"current_risky_response" in risk_state
and risk_state["current_risky_response"]
):
message_buffer.update_agent_status(
"Risky Analyst", "in_progress"
)
message_buffer.add_message(
"Reasoning",
f"Risky Analyst: {risk_state['current_risky_response']}",
)
# Update risk report with risky analyst's latest analysis only
message_buffer.update_report_section(
"final_trade_decision",
f"### Risky Analyst Analysis\n{risk_state['current_risky_response']}",
)
# Update Safe Analyst status and report
if (
"current_safe_response" in risk_state
and risk_state["current_safe_response"]
):
message_buffer.update_agent_status(
"Safe Analyst", "in_progress"
)
message_buffer.add_message(
"Reasoning",
f"Safe Analyst: {risk_state['current_safe_response']}",
)
# Update risk report with safe analyst's latest analysis only
message_buffer.update_report_section(
"final_trade_decision",
f"### Safe Analyst Analysis\n{risk_state['current_safe_response']}",
)
# Update Neutral Analyst status and report
if (
"current_neutral_response" in risk_state
and risk_state["current_neutral_response"]
):
message_buffer.update_agent_status(
"Neutral Analyst", "in_progress"
)
message_buffer.add_message(
"Reasoning",
f"Neutral Analyst: {risk_state['current_neutral_response']}",
)
# Update risk report with neutral analyst's latest analysis only
message_buffer.update_report_section(
"final_trade_decision",
f"### Neutral Analyst Analysis\n{risk_state['current_neutral_response']}",
)
# Handle all risk analysts with a mapping
risk_analysts = [
("current_risky_response", "Risky Analyst"),
("current_safe_response", "Safe Analyst"),
("current_neutral_response", "Neutral Analyst"),
]
for response_key, analyst_name in risk_analysts:
if response_key in risk_state and risk_state[response_key]:
message_buffer.update_agent_status(analyst_name, "in_progress")
message_buffer.add_message(
"Reasoning",
f"{analyst_name}: {risk_state[response_key]}",
)
message_buffer.update_report_section(
"final_trade_decision",
f"### {analyst_name} Analysis\n{risk_state[response_key]}",
)
# Update Portfolio Manager status and final decision
if "judge_decision" in risk_state and risk_state["judge_decision"]:
@ -1073,7 +462,7 @@ def run_analysis():
)
# Update the display
update_display(layout)
update_display(layout, message_buffer)
trace.append(chunk)
@ -1097,7 +486,7 @@ def run_analysis():
# Display the complete final report
display_complete_report(final_state)
update_display(layout)
update_display(layout, message_buffer)
@app.command()

128
cli/message_buffer.py Normal file
View File

@ -0,0 +1,128 @@
"""Message buffer for tracking agent messages and reports in the CLI."""
from collections import deque
import datetime
class MessageBuffer:
"""Stores and manages messages, tool calls, and reports for the trading agents UI."""
def __init__(self, max_length=100):
self.messages = deque(maxlen=max_length)
self.tool_calls = deque(maxlen=max_length)
self.current_report = None
self.final_report = None # Store the complete final report
# Initialize all agents as pending
all_agents = [
"Market Analyst", "Social Analyst", "News Analyst", "Fundamentals Analyst",
"Bull Researcher", "Bear Researcher", "Research Manager",
"Trader",
"Risky Analyst", "Neutral Analyst", "Safe Analyst",
"Portfolio Manager"
]
self.agent_status = {agent: "pending" for agent in all_agents}
self.current_agent = None
self.report_sections = {
"market_report": None,
"sentiment_report": None,
"news_report": None,
"fundamentals_report": None,
"investment_plan": None,
"trader_investment_plan": None,
"final_trade_decision": None,
}
def add_message(self, message_type, content):
timestamp = datetime.datetime.now().strftime("%H:%M:%S")
self.messages.append((timestamp, message_type, content))
def add_tool_call(self, tool_name, args):
timestamp = datetime.datetime.now().strftime("%H:%M:%S")
self.tool_calls.append((timestamp, tool_name, args))
def update_agent_status(self, agent, status):
if agent in self.agent_status:
self.agent_status[agent] = status
self.current_agent = agent
def update_report_section(self, section_name, content):
if section_name in self.report_sections:
self.report_sections[section_name] = content
self._update_current_report()
def _update_current_report(self):
# For the panel display, only show the most recently updated section
latest_section = None
latest_content = None
# Find the most recently updated section
for section, content in self.report_sections.items():
if content is not None:
latest_section = section
latest_content = content
if latest_section and latest_content:
# Format the current section for display
section_titles = {
"market_report": "Market Analysis",
"sentiment_report": "Social Sentiment",
"news_report": "News Analysis",
"fundamentals_report": "Fundamentals Analysis",
"investment_plan": "Research Team Decision",
"trader_investment_plan": "Trading Team Plan",
"final_trade_decision": "Portfolio Management Decision",
}
self.current_report = (
f"### {section_titles[latest_section]}\n{latest_content}"
)
# Update the final complete report
self._update_final_report()
def _update_final_report(self):
report_parts = []
# Analyst Team Reports
if any(
self.report_sections[section]
for section in [
"market_report",
"sentiment_report",
"news_report",
"fundamentals_report",
]
):
report_parts.append("## Analyst Team Reports")
if self.report_sections["market_report"]:
report_parts.append(
f"### Market Analysis\n{self.report_sections['market_report']}"
)
if self.report_sections["sentiment_report"]:
report_parts.append(
f"### Social Sentiment\n{self.report_sections['sentiment_report']}"
)
if self.report_sections["news_report"]:
report_parts.append(
f"### News Analysis\n{self.report_sections['news_report']}"
)
if self.report_sections["fundamentals_report"]:
report_parts.append(
f"### Fundamentals Analysis\n{self.report_sections['fundamentals_report']}"
)
# Research Team Reports
if self.report_sections["investment_plan"]:
report_parts.append("## Research Team Decision")
report_parts.append(f"{self.report_sections['investment_plan']}")
# Trading Team Reports
if self.report_sections["trader_investment_plan"]:
report_parts.append("## Trading Team Plan")
report_parts.append(f"{self.report_sections['trader_investment_plan']}")
# Portfolio Management Decision
if self.report_sections["final_trade_decision"]:
report_parts.append("## Portfolio Management Decision")
report_parts.append(f"{self.report_sections['final_trade_decision']}")
self.final_report = "\n\n".join(report_parts) if report_parts else None

159
cli/report_display.py Normal file
View File

@ -0,0 +1,159 @@
"""Report display functions for the TradingAgents CLI."""
from rich.console import Console
from rich.panel import Panel
from rich.markdown import Markdown
from rich.columns import Columns
console = Console()
def display_complete_report(final_state):
"""Display the complete analysis report with team-based panels."""
console.print("\n[bold green]Complete Analysis Report[/bold green]\n")
# I. Analyst Team Reports
analyst_reports = []
# Map report keys to analyst names
analyst_report_map = [
("market_report", "Market Analyst"),
("sentiment_report", "Social Analyst"),
("news_report", "News Analyst"),
("fundamentals_report", "Fundamentals Analyst"),
]
for report_key, analyst_name in analyst_report_map:
if final_state.get(report_key):
analyst_reports.append(
Panel(
Markdown(final_state[report_key]),
title=analyst_name,
border_style="blue",
padding=(1, 2),
)
)
if analyst_reports:
console.print(
Panel(
Columns(analyst_reports, equal=True, expand=True),
title="I. Analyst Team Reports",
border_style="cyan",
padding=(1, 2),
)
)
# II. Research Team Reports
if final_state.get("investment_debate_state"):
research_reports = []
debate_state = final_state["investment_debate_state"]
# Bull Researcher Analysis
if debate_state.get("bull_history"):
research_reports.append(
Panel(
Markdown(debate_state["bull_history"]),
title="Bull Researcher",
border_style="blue",
padding=(1, 2),
)
)
# Bear Researcher Analysis
if debate_state.get("bear_history"):
research_reports.append(
Panel(
Markdown(debate_state["bear_history"]),
title="Bear Researcher",
border_style="blue",
padding=(1, 2),
)
)
# Research Manager Decision
if debate_state.get("judge_decision"):
research_reports.append(
Panel(
Markdown(debate_state["judge_decision"]),
title="Research Manager",
border_style="blue",
padding=(1, 2),
)
)
if research_reports:
console.print(
Panel(
Columns(research_reports, equal=True, expand=True),
title="II. Research Team Decision",
border_style="magenta",
padding=(1, 2),
)
)
# III. Trading Team Reports
if final_state.get("trader_investment_plan"):
console.print(
Panel(
Panel(
Markdown(final_state["trader_investment_plan"]),
title="Trader",
border_style="blue",
padding=(1, 2),
),
title="III. Trading Team Plan",
border_style="yellow",
padding=(1, 2),
)
)
# IV. Risk Management Team Reports
if final_state.get("risk_debate_state"):
risk_reports = []
risk_state = final_state["risk_debate_state"]
# Map risk history keys to analyst names
risk_analyst_map = [
("risky_history", "Aggressive Analyst"),
("safe_history", "Conservative Analyst"),
("neutral_history", "Neutral Analyst"),
]
for history_key, analyst_name in risk_analyst_map:
if risk_state.get(history_key):
risk_reports.append(
Panel(
Markdown(risk_state[history_key]),
title=analyst_name,
border_style="blue",
padding=(1, 2),
)
)
if risk_reports:
console.print(
Panel(
Columns(risk_reports, equal=True, expand=True),
title="IV. Risk Management Team Decision",
border_style="red",
padding=(1, 2),
)
)
# V. Portfolio Manager Decision
if risk_state.get("judge_decision"):
console.print(
Panel(
Panel(
Markdown(risk_state["judge_decision"]),
title="Portfolio Manager",
border_style="blue",
padding=(1, 2),
),
title="V. Portfolio Manager Decision",
border_style="green",
padding=(1, 2),
)
)

232
cli/ui_display.py Normal file
View File

@ -0,0 +1,232 @@
"""UI display functions for the TradingAgents CLI using Rich library."""
from rich.panel import Panel
from rich.spinner import Spinner
from rich.layout import Layout
from rich.text import Text
from rich.table import Table
from rich import box
def create_layout():
"""Create the main layout structure for the CLI display."""
layout = Layout()
layout.split_column(
Layout(name="header", size=3),
Layout(name="main"),
Layout(name="footer", size=3),
)
layout["main"].split_column(
Layout(name="upper", ratio=3), Layout(name="analysis", ratio=5)
)
layout["upper"].split_row(
Layout(name="progress", ratio=2), Layout(name="messages", ratio=3)
)
return layout
def update_display(layout, message_buffer, spinner_text=None):
"""Update all panels in the display layout with current data."""
# Header with welcome message
layout["header"].update(
Panel(
"[bold green]Welcome to TradingAgents CLI[/bold green]\n"
"[dim]© [Tauric Research](https://github.com/TauricResearch)[/dim]",
title="Welcome to TradingAgents",
border_style="green",
padding=(1, 2),
expand=True,
)
)
# Progress panel showing agent status
progress_table = Table(
show_header=True,
header_style="bold magenta",
show_footer=False,
box=box.SIMPLE_HEAD, # Use simple header with horizontal lines
title=None, # Remove the redundant Progress title
padding=(0, 2), # Add horizontal padding
expand=True, # Make table expand to fill available space
)
progress_table.add_column("Team", style="cyan", justify="center", width=20)
progress_table.add_column("Agent", style="green", justify="center", width=20)
progress_table.add_column("Status", style="yellow", justify="center", width=20)
# Group agents by team
teams = {
"Analyst Team": [
"Market Analyst",
"Social Analyst",
"News Analyst",
"Fundamentals Analyst",
],
"Research Team": ["Bull Researcher", "Bear Researcher", "Research Manager"],
"Trading Team": ["Trader"],
"Risk Management": ["Risky Analyst", "Neutral Analyst", "Safe Analyst"],
"Portfolio Management": ["Portfolio Manager"],
}
for team, agents in teams.items():
# Add first agent with team name
first_agent = agents[0]
status = message_buffer.agent_status[first_agent]
if status == "in_progress":
spinner = Spinner(
"dots", text="[blue]in_progress[/blue]", style="bold cyan"
)
status_cell = spinner
else:
status_color = {
"pending": "yellow",
"completed": "green",
"error": "red",
}.get(status, "white")
status_cell = f"[{status_color}]{status}[/{status_color}]"
progress_table.add_row(team, first_agent, status_cell)
# Add remaining agents in team
for agent in agents[1:]:
status = message_buffer.agent_status[agent]
if status == "in_progress":
spinner = Spinner(
"dots", text="[blue]in_progress[/blue]", style="bold cyan"
)
status_cell = spinner
else:
status_color = {
"pending": "yellow",
"completed": "green",
"error": "red",
}.get(status, "white")
status_cell = f"[{status_color}]{status}[/{status_color}]"
progress_table.add_row("", agent, status_cell)
# Add horizontal line after each team
progress_table.add_row("" * 20, "" * 20, "" * 20, style="dim")
layout["progress"].update(
Panel(progress_table, title="Progress", border_style="cyan", padding=(1, 2))
)
# Messages panel showing recent messages and tool calls
messages_table = Table(
show_header=True,
header_style="bold magenta",
show_footer=False,
expand=True, # Make table expand to fill available space
box=box.MINIMAL, # Use minimal box style for a lighter look
show_lines=True, # Keep horizontal lines
padding=(0, 1), # Add some padding between columns
)
messages_table.add_column("Time", style="cyan", width=8, justify="center")
messages_table.add_column("Type", style="green", width=10, justify="center")
messages_table.add_column(
"Content", style="white", no_wrap=False, ratio=1
) # Make content column expand
# Combine tool calls and messages
all_messages = []
# Add tool calls
for timestamp, tool_name, args in message_buffer.tool_calls:
# Truncate tool call args if too long
if isinstance(args, str) and len(args) > 100:
args = args[:97] + "..."
all_messages.append((timestamp, "Tool", f"{tool_name}: {args}"))
# Add regular messages
for timestamp, msg_type, content in message_buffer.messages:
# Convert content to string if it's not already
content_str = content
if isinstance(content, list):
# Handle list of content blocks (Anthropic format)
text_parts = []
for item in content:
if isinstance(item, dict):
if item.get('type') == 'text':
text_parts.append(item.get('text', ''))
elif item.get('type') == 'tool_use':
text_parts.append(f"[Tool: {item.get('name', 'unknown')}]")
else:
text_parts.append(str(item))
content_str = ' '.join(text_parts)
elif not isinstance(content_str, str):
content_str = str(content)
# Truncate message content if too long
if len(content_str) > 200:
content_str = content_str[:197] + "..."
all_messages.append((timestamp, msg_type, content_str))
# Sort by timestamp
all_messages.sort(key=lambda x: x[0])
# Calculate how many messages we can show based on available space
# Start with a reasonable number and adjust based on content length
max_messages = 12 # Increased from 8 to better fill the space
# Get the last N messages that will fit in the panel
recent_messages = all_messages[-max_messages:]
# Add messages to table
for timestamp, msg_type, content in recent_messages:
# Format content with word wrapping
wrapped_content = Text(content, overflow="fold")
messages_table.add_row(timestamp, msg_type, wrapped_content)
if spinner_text:
messages_table.add_row("", "Spinner", spinner_text)
# Add a footer to indicate if messages were truncated
if len(all_messages) > max_messages:
messages_table.footer = (
f"[dim]Showing last {max_messages} of {len(all_messages)} messages[/dim]"
)
layout["messages"].update(
Panel(
messages_table,
title="Messages & Tools",
border_style="blue",
padding=(1, 2),
)
)
# Analysis panel showing current report
if message_buffer.current_report:
from rich.markdown import Markdown
layout["analysis"].update(
Panel(
Markdown(message_buffer.current_report),
title="Current Report",
border_style="green",
padding=(1, 2),
)
)
else:
layout["analysis"].update(
Panel(
"[italic]Waiting for analysis report...[/italic]",
title="Current Report",
border_style="green",
padding=(1, 2),
)
)
# Footer with statistics
tool_calls_count = len(message_buffer.tool_calls)
llm_calls_count = sum(
1 for _, msg_type, _ in message_buffer.messages if msg_type == "Reasoning"
)
reports_count = sum(
1 for content in message_buffer.report_sections.values() if content is not None
)
stats_table = Table(show_header=False, box=None, padding=(0, 2), expand=True)
stats_table.add_column("Stats", justify="center")
stats_table.add_row(
f"Tool Calls: {tool_calls_count} | LLM Calls: {llm_calls_count} | Generated Reports: {reports_count}"
)
layout["footer"].update(Panel(stats_table, border_style="grey50"))

View File

@ -64,12 +64,19 @@ def get_analysis_date() -> str:
return date.strip()
def select_analysts() -> List[AnalystType]:
"""Select analysts using an interactive checkbox."""
def select_analysts(asset_class: str | None = None) -> List[AnalystType]:
"""Select analysts using an interactive checkbox.
If asset_class is 'commodity', hide Fundamentals Analyst.
"""
order = ANALYST_ORDER
if asset_class and asset_class.lower() == "commodity":
order = [(d, v) for (d, v) in ANALYST_ORDER if v != AnalystType.FUNDAMENTALS]
choices = questionary.checkbox(
"Select Your [Analysts Team]:",
choices=[
questionary.Choice(display, value=value) for display, value in ANALYST_ORDER
questionary.Choice(display, value=value) for display, value in order
],
instruction="\n- Press Space to select/unselect analysts\n- Press 'a' to select/unselect all\n- Press Enter when done",
validate=lambda x: len(x) > 0 or "You must select at least one analyst.",

View File

@ -2,6 +2,7 @@ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
import time
import json
from tradingagents.agents.utils.agent_utils import get_stock_data, get_indicators
from tradingagents.agents.utils.commodity_data_tools import get_commodity_data
from tradingagents.dataflows.config import get_config
@ -12,10 +13,16 @@ def create_market_analyst(llm):
ticker = state["company_of_interest"]
company_name = state["company_of_interest"]
tools = [
get_stock_data,
get_indicators,
]
asset_class = state.get("asset_class", "equity")
if asset_class == "commodity":
tools = [
get_commodity_data,
]
else:
tools = [
get_stock_data,
get_indicators,
]
system_message = (
"""You are a trading assistant tasked with analyzing financial markets. Your role is to select the **most relevant indicators** for a given market condition or trading strategy from the following list. The goal is to choose up to **8 indicators** that provide complementary insights without redundancy. Categories and each category's indicators are:
@ -40,12 +47,17 @@ Volatility Indicators:
- atr: ATR: Averages true range to measure volatility. Usage: Set stop-loss levels and adjust position sizes based on current market volatility. Tips: It's a reactive measure, so use it as part of a broader risk management strategy.
Volume-Based Indicators:
- vwma: VWMA: A moving average weighted by volume. Usage: Confirm trends by integrating price action with volume data. Tips: Watch for skewed results from volume spikes; use in combination with other volume analyses.
- Select indicators that provide diverse and complementary information. Avoid redundancy (e.g., do not select both rsi and stochrsi). Also briefly explain why they are suitable for the given market context. When you tool call, please use the exact name of the indicators provided above as they are defined parameters, otherwise your call will fail. Please make sure to call get_stock_data first to retrieve the CSV that is needed to generate indicators. Then use get_indicators with the specific indicator names. Write a very detailed and nuanced report of the trends you observe. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."""
- vwma: VWMA: A moving average weighted by volume. Usage: Confirm trends by integrating price action with volume data. Tips: Watch for skewed results from volume spikes; use in combination with other volume analyses.
- Select indicators that provide diverse and complementary information. Avoid redundancy (e.g., do not select both rsi and stochrsi). Also briefly explain why they are suitable for the given market context. When you tool call, please use the exact name of the indicators provided above as they are defined parameters, otherwise your call will fail. Write a very detailed and nuanced report of the trends you observe. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."""
+ """ Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read."""
)
if asset_class == "equity":
system_message += " ""Please make sure to call get_stock_data first to retrieve the CSV that is needed to generate indicators. Then use get_indicators with the specific indicator names."""
else:
system_message += " ""For commodities, call get_commodity_data to retrieve the price series (value column). You may analyze trends directly on the series or proceed without additional indicators."""
prompt = ChatPromptTemplate.from_messages(
[
(

View File

@ -1,7 +1,7 @@
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
import time
import json
from tradingagents.agents.utils.agent_utils import get_news, get_global_news
from tradingagents.agents.utils.agent_utils import get_news, get_commodity_news, get_global_news
from tradingagents.dataflows.config import get_config
@ -9,16 +9,38 @@ def create_news_analyst(llm):
def news_analyst_node(state):
current_date = state["trade_date"]
ticker = state["company_of_interest"]
asset_class = state.get("asset_class", "equity")
is_commodity = asset_class.lower() == "commodity"
tools = [
get_news,
get_global_news,
]
system_message = (
"You are a news researcher tasked with analyzing recent news and trends over the past week. Please write a comprehensive report of the current state of the world that is relevant for trading and macroeconomics. Use the available tools: get_news(query, start_date, end_date) for company-specific or targeted news searches, and get_global_news(curr_date, look_back_days, limit) for broader macroeconomic news. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."
+ """ Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read."""
)
# Branch tools based on asset class
if is_commodity:
tools = [
get_commodity_news,
get_global_news,
]
system_message = (
f"You are a news researcher tasked with analyzing recent news and trends for the commodity {ticker}. "
"Please write a comprehensive report of relevant news over the past week that impacts this commodity's price. "
"Use the available tools: get_commodity_news(commodity, start_date, end_date) for commodity-specific news (searches by topic like 'energy' for oil, 'economy_macro' for agriculture), "
"and get_global_news(curr_date, look_back_days, limit) for broader macroeconomic context. "
"IMPORTANT: If get_commodity_news returns limited results, make sure to use get_global_news to provide additional market context. "
"Focus on supply/demand factors, geopolitical events, weather impacts (for agriculture), and macroeconomic trends. "
"Do not simply state the trends are mixed, provide detailed and fine-grained analysis."
+ """ Make sure to append a Markdown table at the end of the report to organize key points."""
)
else:
tools = [
get_news,
get_global_news,
]
system_message = (
"You are a news researcher tasked with analyzing recent news and trends over the past week. "
"Please write a comprehensive report of the current state of the world that is relevant for trading and macroeconomics. "
"Use the available tools: get_news(ticker, start_date, end_date) for company-specific or targeted news searches, "
"and get_global_news(curr_date, look_back_days, limit) for broader macroeconomic news. "
"Do not simply state the trends are mixed, provide detailed and fine-grained analysis and insights that may help traders make decisions."
+ """ Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read."""
)
prompt = ChatPromptTemplate.from_messages(
[

View File

@ -1,7 +1,7 @@
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
import time
import json
from tradingagents.agents.utils.agent_utils import get_news
from tradingagents.agents.utils.agent_utils import get_news, get_commodity_news, get_global_news
from tradingagents.dataflows.config import get_config
@ -9,16 +9,38 @@ def create_social_media_analyst(llm):
def social_media_analyst_node(state):
current_date = state["trade_date"]
ticker = state["company_of_interest"]
company_name = state["company_of_interest"]
asset_class = state.get("asset_class", "equity")
is_commodity = asset_class.lower() == "commodity"
tools = [
get_news,
]
system_message = (
"You are a social media and company specific news researcher/analyst tasked with analyzing social media posts, recent company news, and public sentiment for a specific company over the past week. You will be given a company's name your objective is to write a comprehensive long report detailing your analysis, insights, and implications for traders and investors on this company's current state after looking at social media and what people are saying about that company, analyzing sentiment data of what people feel each day about the company, and looking at recent company news. Use the get_news(query, start_date, end_date) tool to search for company-specific news and social media discussions. Try to look at all sources possible from social media to sentiment to news. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."
+ """ Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read.""",
)
# Branch tools based on asset class
if is_commodity:
tools = [
get_commodity_news,
get_global_news,
]
system_message = (
f"You are a social media and news researcher/analyst tasked with analyzing recent discussions and sentiment for the commodity {ticker}. "
"Your objective is to write a comprehensive report detailing market sentiment, trader discussions, and public perception over the past week. "
"Use get_commodity_news(commodity, start_date, end_date) to search for commodity-related news and discussions (searches by topic like 'energy' for oil). "
"IMPORTANT: If get_commodity_news returns limited results, supplement with get_global_news(curr_date, look_back_days, limit) for broader market context. "
"Focus on trader sentiment, supply/demand expectations, geopolitical concerns, and market psychology. "
"Do not simply state the trends are mixed, provide detailed and fine-grained analysis."
+ """ Make sure to append a Markdown table at the end of the report to organize key points."""
)
else:
tools = [
get_news,
get_global_news,
]
system_message = (
"You are a social media and company specific news researcher/analyst tasked with analyzing social media posts, recent company news, and public sentiment for a specific company over the past week. "
"Your objective is to write a comprehensive long report detailing your analysis, insights, and implications for traders and investors on this company's current state after looking at social media and what people are saying about that company, "
"analyzing sentiment data of what people feel each day about the company, and looking at recent company news. "
"Use the get_news(ticker, start_date, end_date) tool to search for company-specific news and social media discussions. "
"If needed, use get_global_news(curr_date, look_back_days, limit) for broader market context. "
"Try to look at all sources possible from social media to sentiment to news. Do not simply state the trends are mixed, provide detailed and fine-grained analysis and insights that may help traders make decisions."
+ """ Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read."""
)
prompt = ChatPromptTemplate.from_messages(
[

View File

@ -50,6 +50,7 @@ class RiskDebateState(TypedDict):
class AgentState(MessagesState):
company_of_interest: Annotated[str, "Company that we are interested in trading"]
trade_date: Annotated[str, "What date we are trading at"]
asset_class: Annotated[str, "Asset class: equity or commodity"]
sender: Annotated[str, "Agent that sent this message"]

View File

@ -15,6 +15,7 @@ from tradingagents.agents.utils.fundamental_data_tools import (
)
from tradingagents.agents.utils.news_data_tools import (
get_news,
get_commodity_news,
get_insider_sentiment,
get_insider_transactions,
get_global_news

View File

@ -0,0 +1,19 @@
from langchain_core.tools import tool
from typing import Annotated
from tradingagents.dataflows.interface import route_to_vendor
@tool
def get_commodity_data(
commodity: Annotated[str, "name like WTI, BRENT, NATURAL_GAS, COPPER"],
start_date: Annotated[str, "YYYY-mm-dd"],
end_date: Annotated[str, "YYYY-mm-dd"],
interval: Annotated[str, "daily|weekly|monthly"] = "monthly",
) -> str:
"""
Retrieve commodity price data for a given commodity symbol.
Uses the configured commodity_data vendor.
"""
return route_to_vendor("get_commodity_data", commodity, start_date, end_date, interval)

View File

@ -38,6 +38,26 @@ def get_global_news(
"""
return route_to_vendor("get_global_news", curr_date, look_back_days, limit)
@tool
def get_commodity_news(
commodity: Annotated[str, "Commodity symbol like BRENT, WTI, COPPER"],
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
) -> str:
"""
Retrieve news data for a commodity (oil, metals, agriculture).
Uses topic-based search since commodities don't have stock tickers.
Searches news by relevant topics (energy, economy, etc.) and filters for the commodity.
Args:
commodity (str): Commodity symbol (e.g., "BRENT", "WTI", "COPPER")
start_date (str): Start date in yyyy-mm-dd format
end_date (str): End date in yyyy-mm-dd format
Returns:
str: A formatted string containing commodity-related news data
"""
return route_to_vendor("get_commodity_news", commodity, start_date, end_date)
@tool
def get_insider_sentiment(
ticker: Annotated[str, "ticker symbol for the company"],

View File

@ -2,4 +2,5 @@
from .alpha_vantage_stock import get_stock
from .alpha_vantage_indicator import get_indicator
from .alpha_vantage_fundamentals import get_fundamentals, get_balance_sheet, get_cashflow, get_income_statement
from .alpha_vantage_news import get_news, get_insider_transactions
from .alpha_vantage_news import get_news, get_insider_transactions, get_commodity_news
from .alpha_vantage_commodity import get_commodity

View File

@ -0,0 +1,88 @@
from .alpha_vantage_common import _make_api_request
from datetime import datetime, timedelta
# Map human-friendly names to Alpha Vantage commodity functions
FUNCTIONS = {
"WTI": "WTI",
"BRENT": "BRENT",
"NATURAL_GAS": "NATURAL_GAS",
"COPPER": "COPPER",
"ALUMINUM": "ALUMINUM",
"WHEAT": "WHEAT",
"CORN": "CORN",
"SUGAR": "SUGAR",
"COTTON": "COTTON",
"COFFEE": "COFFEE",
}
def get_commodity(
commodity: str,
start_date: str,
end_date: str,
interval: str = "monthly",
) -> str:
"""
Fetch commodity price series from Alpha Vantage and return as CSV with columns time,value.
Args:
commodity: e.g. WTI, BRENT, NATURAL_GAS, COPPER
start_date: YYYY-mm-dd
end_date: YYYY-mm-dd
interval: daily|weekly|monthly (depends on AV endpoint support)
Returns:
CSV string with headers time,value
"""
func = FUNCTIONS.get(commodity.upper())
if not func:
raise ValueError(f"Unsupported commodity: {commodity}")
params = {
"interval": interval,
"datatype": "json",
}
raw = _make_api_request(func, params)
# Convert AV JSON payload to simple CSV
import json
import io
try:
payload = json.loads(raw)
series = payload.get("data") or []
s_dt = datetime.strptime(start_date, "%Y-%m-%d")
e_dt = datetime.strptime(end_date, "%Y-%m-%d")
# If user passed a very narrow window (e.g., single day) on a monthly/weekly series,
# widen to a reasonable historical window to ensure data presence.
if interval == "monthly" and (e_dt - s_dt).days < 28:
s_dt = e_dt - timedelta(days=365)
elif interval == "weekly" and (e_dt - s_dt).days < 7:
s_dt = e_dt - timedelta(days=180)
rows = []
for item in series:
# items are like {"date": "2025-05-01", "value": "xxxx"}
d = datetime.strptime(item["date"], "%Y-%m-%d")
if s_dt <= d <= e_dt:
rows.append((item["date"], item.get("value")))
out = io.StringIO()
out.write("time,value\n")
for d, v in sorted(rows):
out.write(f"{d},{v}\n")
csv_text = out.getvalue()
# If still empty after widening, return header + latest few rows without filtering
if csv_text.strip() == "time,value":
out = io.StringIO()
out.write("time,value\n")
for item in series[:24]: # last ~2 years monthly
out.write(f"{item['date']},{item.get('value')}\n")
return out.getvalue()
return csv_text
except Exception:
# Fallback: return raw payload for debugging
return raw

View File

@ -1,5 +1,39 @@
from .alpha_vantage_common import _make_api_request, format_datetime_for_api
# Map commodity symbols to Alpha Vantage NEWS_SENTIMENT topics
COMMODITY_TOPIC_MAP = {
# Energy commodities
"WTI": "energy",
"BRENT": "energy",
"NATURAL_GAS": "energy",
# Metals
"COPPER": "technology", # Copper is heavily used in tech/manufacturing
"ALUMINUM": "technology",
# Agriculture
"WHEAT": "economy_macro", # Agriculture affects macro economy
"CORN": "economy_macro",
"SUGAR": "economy_macro",
"COTTON": "economy_macro",
"COFFEE": "economy_macro",
}
# Map commodities to search keywords for better context
COMMODITY_KEYWORDS = {
"WTI": "WTI crude oil",
"BRENT": "Brent crude oil",
"NATURAL_GAS": "natural gas",
"COPPER": "copper commodity",
"ALUMINUM": "aluminum commodity",
"WHEAT": "wheat commodity",
"CORN": "corn commodity",
"SUGAR": "sugar commodity",
"COTTON": "cotton commodity",
"COFFEE": "coffee commodity",
}
def get_news(ticker, start_date, end_date) -> dict[str, str] | str:
"""Returns live and historical market news & sentiment data from premier news outlets worldwide.
@ -24,6 +58,55 @@ def get_news(ticker, start_date, end_date) -> dict[str, str] | str:
return _make_api_request("NEWS_SENTIMENT", params)
def get_commodity_news(commodity: str, start_date: str, end_date: str) -> dict[str, str] | str:
"""Returns news data for commodities using topic-based search.
Since Alpha Vantage NEWS_SENTIMENT doesn't support commodity symbols directly,
this function uses the 'topics' parameter to find relevant news.
Args:
commodity: Commodity symbol (e.g., "BRENT", "WTI", "COPPER")
start_date: Start date for news search.
end_date: End date for news search.
Returns:
Dictionary containing news sentiment data or JSON string.
"""
commodity_upper = commodity.upper()
topic = COMMODITY_TOPIC_MAP.get(commodity_upper, "economy_macro")
keyword = COMMODITY_KEYWORDS.get(commodity_upper, commodity)
# Use topics parameter instead of tickers for commodities
params = {
"topics": topic,
"time_from": format_datetime_for_api(start_date),
"time_to": format_datetime_for_api(end_date),
"sort": "LATEST",
"limit": "50", # Get more results to filter for commodity-specific news
}
result = _make_api_request("NEWS_SENTIMENT", params)
# Add metadata to help the LLM understand this is commodity-filtered news
import json
try:
data = json.loads(result) if isinstance(result, str) else result
if isinstance(data, dict) and "feed" in data:
# Add a note about the commodity and topic used
data["_commodity_context"] = {
"commodity": commodity,
"topic": topic,
"keyword_filter": keyword,
"note": f"News filtered by topic '{topic}'. Look for articles mentioning '{keyword}' for most relevant results."
}
return json.dumps(data)
except (json.JSONDecodeError, TypeError):
pass
return result
def get_insider_transactions(symbol: str) -> dict[str, str] | str:
"""Returns latest and historical insider transactions by key stakeholders.

View File

@ -13,7 +13,9 @@ from .alpha_vantage import (
get_cashflow as get_alpha_vantage_cashflow,
get_income_statement as get_alpha_vantage_income_statement,
get_insider_transactions as get_alpha_vantage_insider_transactions,
get_news as get_alpha_vantage_news
get_news as get_alpha_vantage_news,
get_commodity_news as get_alpha_vantage_commodity_news,
get_commodity as get_alpha_vantage_commodity
)
from .alpha_vantage_common import AlphaVantageRateLimitError
@ -28,6 +30,12 @@ TOOLS_CATEGORIES = {
"get_stock_data"
]
},
"commodity_data": {
"description": "Commodity price data",
"tools": [
"get_commodity_data"
]
},
"technical_indicators": {
"description": "Technical analysis indicators",
"tools": [
@ -69,6 +77,10 @@ VENDOR_METHODS = {
"yfinance": get_YFin_data_online,
"local": get_YFin_data,
},
# commodity_data
"get_commodity_data": {
"alpha_vantage": get_alpha_vantage_commodity,
},
# technical_indicators
"get_indicators": {
"alpha_vantage": get_alpha_vantage_indicator,
@ -102,6 +114,10 @@ VENDOR_METHODS = {
"google": get_google_news,
"local": [get_finnhub_news, get_reddit_company_news, get_google_news],
},
"get_commodity_news": {
"alpha_vantage": get_alpha_vantage_commodity_news,
"openai": get_stock_news_openai, # Fallback to OpenAI web search
},
"get_global_news": {
"openai": get_global_news_openai,
"local": get_reddit_global_news

View File

@ -13,6 +13,8 @@ DEFAULT_CONFIG = {
"deep_think_llm": "o4-mini",
"quick_think_llm": "gpt-4o-mini",
"backend_url": "https://api.openai.com/v1",
# Asset class (equity | commodity)
"asset_class": "equity",
# Debate and discussion settings
"max_debate_rounds": 1,
"max_risk_discuss_rounds": 1,
@ -24,6 +26,7 @@ DEFAULT_CONFIG = {
"technical_indicators": "yfinance", # Options: yfinance, alpha_vantage, local
"fundamental_data": "alpha_vantage", # Options: openai, alpha_vantage, local
"news_data": "alpha_vantage", # Options: openai, alpha_vantage, google, local
"commodity_data": "alpha_vantage", # Options: alpha_vantage
},
# Tool-level configuration (takes precedence over category-level)
"tool_vendors": {

View File

@ -35,6 +35,7 @@ from tradingagents.agents.utils.agent_utils import (
get_insider_transactions,
get_global_news
)
from tradingagents.agents.utils.commodity_data_tools import get_commodity_data
from .conditional_logic import ConditionalLogic
from .setup import GraphSetup
@ -122,15 +123,22 @@ class TradingAgentsGraph:
def _create_tool_nodes(self) -> Dict[str, ToolNode]:
"""Create tool nodes for different data sources using abstract methods."""
is_commodity = self.config.get("asset_class", "equity").lower() == "commodity"
market_tools = []
if is_commodity:
# Only expose commodity tool to prevent LLM from selecting stock data
market_tools = [
get_commodity_data,
]
else:
market_tools = [
get_stock_data,
get_indicators,
]
return {
"market": ToolNode(
[
# Core stock data tools
get_stock_data,
# Technical indicators
get_indicators,
]
),
"market": ToolNode(market_tools),
"social": ToolNode(
[
# News tools for social media analysis
@ -166,6 +174,8 @@ class TradingAgentsGraph:
init_agent_state = self.propagator.create_initial_state(
company_name, trade_date
)
# Pass asset class into state for downstream branching
init_agent_state["asset_class"] = self.config.get("asset_class", "equity")
args = self.propagator.get_graph_args()
if self.debug: