TradingAgents/tradingagents/agents/utils/agent_utils.py

136 lines
5.1 KiB
Python

from typing import Any, Callable, Dict
from langchain_core.messages import HumanMessage, RemoveMessage
from tradingagents.agents.utils.llm_utils import (
create_and_invoke_chain,
parse_llm_response,
)
from tradingagents.agents.utils.prompt_templates import format_analyst_prompt
from tradingagents.tools.generator import ALL_TOOLS, get_agent_tools
# Re-export tools for backward compatibility
get_stock_data = ALL_TOOLS["get_stock_data"]
validate_ticker = ALL_TOOLS["validate_ticker"] # Fixed: was validate_ticker_tool
get_indicators = ALL_TOOLS["get_indicators"]
get_fundamentals = ALL_TOOLS["get_fundamentals"]
get_balance_sheet = ALL_TOOLS["get_balance_sheet"]
get_cashflow = ALL_TOOLS["get_cashflow"]
get_income_statement = ALL_TOOLS["get_income_statement"]
get_recommendation_trends = ALL_TOOLS["get_recommendation_trends"]
get_news = ALL_TOOLS["get_news"]
get_global_news = ALL_TOOLS["get_global_news"]
get_insider_sentiment = ALL_TOOLS["get_insider_sentiment"]
get_insider_transactions = ALL_TOOLS["get_insider_transactions"]
# Legacy alias for backward compatibility
validate_ticker_tool = validate_ticker
def create_msg_delete():
def delete_messages(state):
"""Clear messages and add placeholder for Anthropic compatibility"""
messages = state["messages"]
# Remove all messages
removal_operations = [RemoveMessage(id=m.id) for m in messages]
# Add a minimal placeholder message
placeholder = HumanMessage(content="Continue")
return {"messages": removal_operations + [placeholder]}
return delete_messages
def format_memory_context(memory: Any, state: Dict[str, Any], n_matches: int = 2) -> str:
"""Fetch and format past memories into a prompt section.
Returns the formatted memory string, or "" if no memories available.
Identical logic previously duplicated across 5 agent files.
"""
reports = (
state["market_report"],
state["sentiment_report"],
state["news_report"],
state["fundamentals_report"],
)
curr_situation = "\n\n".join(reports)
if not memory:
return ""
past_memories = memory.get_memories(curr_situation, n_matches=n_matches)
if not past_memories:
return ""
past_memory_str = "### Past Lessons Applied\\n**Reflections from Similar Situations:**\\n"
for i, rec in enumerate(past_memories, 1):
past_memory_str += rec["recommendation"] + "\\n\\n"
past_memory_str += "\\n\\n**How I'm Using These Lessons:**\\n"
past_memory_str += "- [Specific adjustment based on past mistake/success]\\n"
past_memory_str += "- [Impact on current conviction level]\\n"
return past_memory_str
def update_risk_debate_state(
debate_state: Dict[str, Any], argument: str, role: str
) -> Dict[str, Any]:
"""Build updated risk debate state after a debator speaks.
Args:
debate_state: Current risk_debate_state dict.
argument: The formatted argument string (e.g. "Safe Analyst: ...").
role: One of "Safe", "Risky", "Neutral".
"""
role_key = role.lower() # "safe", "risky", "neutral"
new_state = {
"history": debate_state.get("history", "") + "\n" + argument,
"risky_history": debate_state.get("risky_history", ""),
"safe_history": debate_state.get("safe_history", ""),
"neutral_history": debate_state.get("neutral_history", ""),
"latest_speaker": role,
"current_risky_response": debate_state.get("current_risky_response", ""),
"current_safe_response": debate_state.get("current_safe_response", ""),
"current_neutral_response": debate_state.get("current_neutral_response", ""),
"count": debate_state["count"] + 1,
}
# Append to the speaker's own history and set their current response
new_state[f"{role_key}_history"] = debate_state.get(f"{role_key}_history", "") + "\n" + argument
new_state[f"current_{role_key}_response"] = argument
return new_state
def create_analyst_node(
llm: Any,
tool_group: str,
output_key: str,
prompt_builder: Callable[[str, str], str],
) -> Callable:
"""Factory for analyst graph nodes.
Args:
llm: The LLM to use.
tool_group: Tool group name for ``get_agent_tools`` (e.g. "fundamentals").
output_key: State key for the report (e.g. "fundamentals_report").
prompt_builder: ``(ticker, current_date) -> system_message`` callable.
"""
def analyst_node(state: Dict[str, Any]) -> Dict[str, Any]:
ticker = state["company_of_interest"]
current_date = state["trade_date"]
tools = get_agent_tools(tool_group)
system_message = prompt_builder(ticker, current_date)
tool_names_str = ", ".join(tool.name for tool in tools)
full_message = format_analyst_prompt(system_message, current_date, ticker, tool_names_str)
result = create_and_invoke_chain(llm, tools, full_message, state["messages"])
report = ""
if len(result.tool_calls) == 0:
report = parse_llm_response(result.content)
return {"messages": [result], output_key: report}
return analyst_node