From 85377d27e2f8bbab0cd49a6373034143544ecc18 Mon Sep 17 00:00:00 2001 From: Garrick Date: Tue, 24 Mar 2026 16:51:14 -0700 Subject: [PATCH] feat: add valuation analyst --- tests/test_valuation_analyst.py | 205 ++++++++++++++++++ tradingagents/agents/__init__.py | 4 + .../agents/analysts/valuation_analyst.py | 120 ++++++++++ tradingagents/agents/utils/agent_utils.py | 4 + tradingagents/agents/utils/valuation_tools.py | 18 ++ tradingagents/graph/setup.py | 11 + tradingagents/graph/trading_graph.py | 8 + 7 files changed, 370 insertions(+) create mode 100644 tests/test_valuation_analyst.py create mode 100644 tradingagents/agents/analysts/valuation_analyst.py create mode 100644 tradingagents/agents/utils/valuation_tools.py diff --git a/tests/test_valuation_analyst.py b/tests/test_valuation_analyst.py new file mode 100644 index 00000000..fa1cc53d --- /dev/null +++ b/tests/test_valuation_analyst.py @@ -0,0 +1,205 @@ +import json + +from langchain_core.messages import AIMessage +from langchain_core.runnables import RunnableLambda + +from tradingagents.graph.setup import GraphSetup +from tradingagents.graph.trading_graph import TradingAgentsGraph + + +class DummyStateGraph: + def __init__(self, _state_type): + self.nodes = {} + self.conditional_edges = {} + + def add_node(self, name, node): + self.nodes[name] = node + + def add_edge(self, *_args, **_kwargs): + return None + + def add_conditional_edges(self, source, condition, destinations): + self.conditional_edges[source] = { + "condition": condition, + "destinations": destinations, + } + + def compile(self): + return { + "nodes": self.nodes, + "conditional_edges": self.conditional_edges, + } + + +class DummyToolNode: + def __init__(self, tools): + self.tools = tools + + +def test_valuation_tools_route_to_vendor(monkeypatch): + import tradingagents.dataflows.interface as interface + from tradingagents.agents.utils.valuation_tools import get_valuation_inputs + + calls = [] + + def fake_route_to_vendor(method, *args, **kwargs): + calls.append((method, args, kwargs)) + return f"{method}-result" + + monkeypatch.setattr(interface, "route_to_vendor", fake_route_to_vendor) + + assert ( + get_valuation_inputs.invoke({"ticker": "NVDA", "curr_date": "2026-03-24"}) + == "get_fundamentals-result" + ) + assert calls == [ + ("get_fundamentals", (), {"ticker": "NVDA", "curr_date": "2026-03-24"}) + ] + + +def test_graph_setup_wires_valuation_analyst_and_tools(monkeypatch): + recorded_llms = {} + + monkeypatch.setattr("tradingagents.graph.setup.StateGraph", DummyStateGraph) + monkeypatch.setattr("tradingagents.graph.setup.create_msg_delete", lambda: "delete") + + def make_factory(node_name): + def factory(llm, *_args): + recorded_llms[node_name] = llm + return node_name + + return factory + + monkeypatch.setattr( + "tradingagents.graph.setup.create_market_analyst", + make_factory("Market Analyst"), + ) + monkeypatch.setattr( + "tradingagents.graph.setup.create_valuation_analyst", + make_factory("Valuation Analyst"), + ) + monkeypatch.setattr( + "tradingagents.graph.setup.create_social_media_analyst", + make_factory("Social Analyst"), + ) + monkeypatch.setattr( + "tradingagents.graph.setup.create_news_analyst", + make_factory("News Analyst"), + ) + monkeypatch.setattr( + "tradingagents.graph.setup.create_fundamentals_analyst", + make_factory("Fundamentals Analyst"), + ) + monkeypatch.setattr( + "tradingagents.graph.setup.create_factor_rule_analyst", + make_factory("Factor Rules Analyst"), + ) + monkeypatch.setattr( + "tradingagents.graph.setup.create_macro_analyst", + make_factory("Macro Analyst"), + ) + monkeypatch.setattr( + "tradingagents.graph.setup.create_bull_researcher", + make_factory("Bull Researcher"), + ) + monkeypatch.setattr( + "tradingagents.graph.setup.create_bear_researcher", + make_factory("Bear Researcher"), + ) + monkeypatch.setattr( + "tradingagents.graph.setup.create_research_manager", + make_factory("Research Manager"), + ) + monkeypatch.setattr( + "tradingagents.graph.setup.create_trader", + make_factory("Trader"), + ) + monkeypatch.setattr( + "tradingagents.graph.setup.create_aggressive_debator", + make_factory("Aggressive Analyst"), + ) + monkeypatch.setattr( + "tradingagents.graph.setup.create_neutral_debator", + make_factory("Neutral Analyst"), + ) + monkeypatch.setattr( + "tradingagents.graph.setup.create_conservative_debator", + make_factory("Conservative Analyst"), + ) + monkeypatch.setattr( + "tradingagents.graph.setup.create_portfolio_manager", + make_factory("Portfolio Manager"), + ) + + class PartialConditionalLogic: + def should_continue_market(self, _state): + return "Msg Clear Market" + + def should_continue_debate(self, _state): + return "Research Manager" + + def should_continue_risk_analysis(self, _state): + return "Portfolio Manager" + + setup = GraphSetup( + quick_thinking_llm="quick-llm", + deep_thinking_llm="deep-llm", + tool_nodes={"market": "market-tools", "valuation": "valuation-tools"}, + bull_memory=object(), + bear_memory=object(), + trader_memory=object(), + invest_judge_memory=object(), + portfolio_manager_memory=object(), + conditional_logic=PartialConditionalLogic(), + role_llms={"valuation": "valuation-llm"}, + ) + + graph = setup.setup_graph(selected_analysts=["market", "valuation"]) + + assert recorded_llms["Valuation Analyst"] == "valuation-llm" + assert graph["nodes"]["Valuation Analyst"] == "Valuation Analyst" + assert graph["nodes"]["tools_valuation"] == "valuation-tools" + assert graph["conditional_edges"]["Valuation Analyst"]["destinations"] == [ + "tools_valuation", + "Msg Clear Valuation", + ] + + +def test_trading_graph_creates_valuation_tool_node(monkeypatch): + monkeypatch.setattr("tradingagents.graph.trading_graph.ToolNode", DummyToolNode) + + graph = TradingAgentsGraph.__new__(TradingAgentsGraph) + tool_nodes = TradingAgentsGraph._create_tool_nodes(graph) + + assert [tool.name for tool in tool_nodes["valuation"].tools] == [ + "get_valuation_inputs" + ] + + +def test_valuation_analyst_returns_structured_valuation_data(): + from tradingagents.agents.analysts.valuation_analyst import create_valuation_analyst + + response = { + "fair_value_range": {"low": 120.5, "high": 145.0}, + "expected_return_pct": 18.2, + "primary_method": "discounted cash flow", + "thesis": "Free cash flow implies upside versus the current price.", + } + + class FakeLLM: + def bind_tools(self, _tools): + return RunnableLambda( + lambda _inputs: AIMessage(content=json.dumps(response), tool_calls=[]) + ) + + node = create_valuation_analyst(FakeLLM()) + result = node( + { + "trade_date": "2026-03-24", + "company_of_interest": "NVDA", + "messages": [("human", "Value NVDA")], + } + ) + + assert result["valuation_data"] == response + assert list(result) == ["messages", "valuation_data"] diff --git a/tradingagents/agents/__init__.py b/tradingagents/agents/__init__.py index 3cd95d7c..263d1ce7 100644 --- a/tradingagents/agents/__init__.py +++ b/tradingagents/agents/__init__.py @@ -3,10 +3,12 @@ from .utils.agent_states import AgentState, InvestDebateState, RiskDebateState from .utils.memory import FinancialSituationMemory from .analysts.fundamentals_analyst import create_fundamentals_analyst +from .analysts.factor_rule_analyst import create_factor_rule_analyst from .analysts.macro_analyst import create_macro_analyst from .analysts.market_analyst import create_market_analyst from .analysts.news_analyst import create_news_analyst from .analysts.social_media_analyst import create_social_media_analyst +from .analysts.valuation_analyst import create_valuation_analyst from .researchers.bear_researcher import create_bear_researcher from .researchers.bull_researcher import create_bull_researcher @@ -29,11 +31,13 @@ __all__ = [ "create_bear_researcher", "create_bull_researcher", "create_research_manager", + "create_factor_rule_analyst", "create_fundamentals_analyst", "create_macro_analyst", "create_market_analyst", "create_neutral_debator", "create_news_analyst", + "create_valuation_analyst", "create_aggressive_debator", "create_portfolio_manager", "create_conservative_debator", diff --git a/tradingagents/agents/analysts/valuation_analyst.py b/tradingagents/agents/analysts/valuation_analyst.py new file mode 100644 index 00000000..f0e859aa --- /dev/null +++ b/tradingagents/agents/analysts/valuation_analyst.py @@ -0,0 +1,120 @@ +import json +import re + +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder + +from tradingagents.agents.utils.agent_states import make_default_valuation_data +from tradingagents.agents.utils.agent_utils import ( + build_instrument_context, + get_valuation_inputs, +) + + +def _content_to_text(content) -> str: + if isinstance(content, str): + return content + if isinstance(content, list): + return "".join( + part.get("text", "") if isinstance(part, dict) else str(part) + for part in content + ) + return str(content) + + +def _coerce_optional_float(value): + if value in (None, ""): + return None + try: + return float(value) + except (TypeError, ValueError): + return None + + +def _parse_json_payload(raw_text: str): + text = raw_text.strip() + if not text: + return {} + + candidates = [text] + fenced_blocks = re.findall(r"```(?:json)?\s*(.*?)```", text, flags=re.DOTALL) + candidates.extend(block.strip() for block in fenced_blocks if block.strip()) + + for candidate in candidates: + try: + parsed = json.loads(candidate) + except json.JSONDecodeError: + continue + if isinstance(parsed, dict): + return parsed + return {} + + +def _parse_valuation_data(content): + payload = _parse_json_payload(_content_to_text(content)) + valuation_data = make_default_valuation_data() + + fair_value_range = payload.get("fair_value_range") + if isinstance(fair_value_range, dict): + valuation_data["fair_value_range"] = { + "low": _coerce_optional_float(fair_value_range.get("low")), + "high": _coerce_optional_float(fair_value_range.get("high")), + } + + valuation_data["expected_return_pct"] = _coerce_optional_float( + payload.get("expected_return_pct") + ) + valuation_data["primary_method"] = str(payload.get("primary_method") or "") + valuation_data["thesis"] = str(payload.get("thesis") or "") + + return valuation_data + + +def create_valuation_analyst(llm): + def valuation_analyst_node(state): + current_date = state["trade_date"] + instrument_context = build_instrument_context(state["company_of_interest"]) + tools = [get_valuation_inputs] + + system_message = ( + "You are a valuation analyst responsible for translating company " + "fundamentals into a concise underwriting view. Use `get_valuation_inputs` " + "to gather valuation context, estimate a fair value range, choose the " + "primary valuation method, and explain the core thesis. Respond with valid " + "JSON only using this exact schema: " + '{"fair_value_range":{"low":null,"high":null},"expected_return_pct":null,' + '"primary_method":"","thesis":""}. ' + "Use null for unknown numeric values and do not add any extra keys." + ) + + prompt = ChatPromptTemplate.from_messages( + [ + ( + "system", + "You are a helpful AI assistant, collaborating with other assistants." + " Use the provided tools to progress towards answering the question." + " If you are unable to fully answer, that's OK; another assistant with different tools" + " will help where you left off. Execute what you can to make progress." + " If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable," + " prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop." + " You have access to the following tools: {tool_names}.\n{system_message}" + "For your reference, the current date is {current_date}. {instrument_context}", + ), + MessagesPlaceholder(variable_name="messages"), + ] + ) + + prompt = prompt.partial(system_message=system_message) + prompt = prompt.partial(tool_names=", ".join(tool.name for tool in tools)) + prompt = prompt.partial(current_date=current_date) + prompt = prompt.partial(instrument_context=instrument_context) + + chain = prompt | llm.bind_tools(tools) + result = chain.invoke(state["messages"]) + + payload = {"messages": [result]} + if len(result.tool_calls) == 0: + payload["valuation_data"] = _parse_valuation_data(result.content) + + return payload + + return valuation_analyst_node diff --git a/tradingagents/agents/utils/agent_utils.py b/tradingagents/agents/utils/agent_utils.py index 39028138..b1df53ed 100644 --- a/tradingagents/agents/utils/agent_utils.py +++ b/tradingagents/agents/utils/agent_utils.py @@ -26,6 +26,9 @@ from tradingagents.agents.utils.macro_data_tools import ( get_fed_calendar, get_yield_curve, ) +from tradingagents.agents.utils.valuation_tools import ( + get_valuation_inputs, +) __all__ = [ @@ -43,6 +46,7 @@ __all__ = [ "get_insider_transactions", "get_news", "get_stock_data", + "get_valuation_inputs", "get_yield_curve", ] diff --git a/tradingagents/agents/utils/valuation_tools.py b/tradingagents/agents/utils/valuation_tools.py new file mode 100644 index 00000000..f9ae6aed --- /dev/null +++ b/tradingagents/agents/utils/valuation_tools.py @@ -0,0 +1,18 @@ +from typing import Annotated + +from langchain_core.tools import tool + + +@tool +def get_valuation_inputs( + ticker: Annotated[str, "ticker symbol"], + curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"], +) -> str: + """Retrieve valuation-oriented fundamental inputs for a company.""" + from tradingagents.dataflows.interface import route_to_vendor + + return route_to_vendor( + "get_fundamentals", + ticker=ticker, + curr_date=curr_date, + ) diff --git a/tradingagents/graph/setup.py b/tradingagents/graph/setup.py index 334a5b5d..32447ad9 100644 --- a/tradingagents/graph/setup.py +++ b/tradingagents/graph/setup.py @@ -47,6 +47,9 @@ class GraphSetup: self.factor_rules_analyst_llm = self._get_role_llm( "factor_rules", self.quick_thinking_llm ) + self.valuation_analyst_llm = self._get_role_llm( + "valuation", self.quick_thinking_llm + ) self.macro_analyst_llm = self._get_role_llm("macro", self.quick_thinking_llm) self.bull_researcher_llm = self._get_role_llm( "bull_researcher", self.quick_thinking_llm @@ -104,6 +107,7 @@ class GraphSetup: - "news": News analyst - "fundamentals": Fundamentals analyst - "factor_rules": Factor rule analyst + - "valuation": Valuation analyst - "macro": Macro analyst """ if len(selected_analysts) == 0: @@ -148,6 +152,13 @@ class GraphSetup: ) delete_nodes["factor_rules"] = create_msg_delete() + if "valuation" in selected_analysts: + analyst_nodes["valuation"] = create_valuation_analyst( + self.valuation_analyst_llm + ) + delete_nodes["valuation"] = create_msg_delete() + tool_nodes["valuation"] = self.tool_nodes["valuation"] + if "macro" in selected_analysts: analyst_nodes["macro"] = create_macro_analyst(self.macro_analyst_llm) delete_nodes["macro"] = create_msg_delete() diff --git a/tradingagents/graph/trading_graph.py b/tradingagents/graph/trading_graph.py index 294b25a5..05cbee23 100644 --- a/tradingagents/graph/trading_graph.py +++ b/tradingagents/graph/trading_graph.py @@ -33,6 +33,7 @@ from tradingagents.agents.utils.agent_utils import ( get_news, get_insider_transactions, get_global_news, + get_valuation_inputs, get_yield_curve, ) @@ -62,6 +63,7 @@ class TradingAgentsGraph: "news", "fundamentals", "factor_rules", + "valuation", "macro", "bull_researcher", "bear_researcher", @@ -304,6 +306,12 @@ class TradingAgentsGraph: get_income_statement, ] ), + "valuation": ToolNode( + [ + # Valuation analysis tools + get_valuation_inputs, + ] + ), "macro": ToolNode( [ # Macroeconomic analysis tools