From be2694367af5caff5b2509c6b89392ef67a8a6bf Mon Sep 17 00:00:00 2001 From: Garrick Date: Tue, 24 Mar 2026 16:58:09 -0700 Subject: [PATCH] fix: harden valuation analyst parsing --- tests/test_valuation_analyst.py | 114 ++++++++++++++++++ .../agents/analysts/valuation_analyst.py | 13 +- 2 files changed, 123 insertions(+), 4 deletions(-) diff --git a/tests/test_valuation_analyst.py b/tests/test_valuation_analyst.py index fa1cc53d..d3a7750c 100644 --- a/tests/test_valuation_analyst.py +++ b/tests/test_valuation_analyst.py @@ -2,7 +2,10 @@ import json from langchain_core.messages import AIMessage from langchain_core.runnables import RunnableLambda +from langgraph.graph import END, START, StateGraph +from langgraph.prebuilt import ToolNode +from tradingagents.agents.utils.agent_states import AgentState from tradingagents.graph.setup import GraphSetup from tradingagents.graph.trading_graph import TradingAgentsGraph @@ -203,3 +206,114 @@ def test_valuation_analyst_returns_structured_valuation_data(): assert result["valuation_data"] == response assert list(result) == ["messages", "valuation_data"] + + +def test_valuation_analyst_marks_parse_failure_without_changing_shape(): + from tradingagents.agents.analysts.valuation_analyst import create_valuation_analyst + + class FakeLLM: + def bind_tools(self, _tools): + return RunnableLambda( + lambda _inputs: AIMessage(content="not valid json", tool_calls=[]) + ) + + node = create_valuation_analyst(FakeLLM()) + result = node( + { + "trade_date": "2026-03-24", + "company_of_interest": "NVDA", + "messages": [("human", "Value NVDA")], + } + ) + + assert set(result["valuation_data"]) == { + "fair_value_range", + "expected_return_pct", + "primary_method", + "thesis", + } + assert result["valuation_data"]["fair_value_range"] == {"low": None, "high": None} + assert result["valuation_data"]["expected_return_pct"] is None + assert result["valuation_data"]["primary_method"] == "parse_error" + assert result["valuation_data"]["thesis"] == "not valid json" + + +def test_valuation_analyst_populates_structured_data_after_tool_loop(monkeypatch): + import tradingagents.dataflows.interface as interface + from tradingagents.agents.analysts.valuation_analyst import create_valuation_analyst + from tradingagents.agents.utils.valuation_tools import get_valuation_inputs + + llm_responses = iter( + [ + AIMessage( + content="", + tool_calls=[ + { + "name": "get_valuation_inputs", + "args": {"ticker": "NVDA", "curr_date": "2026-03-24"}, + "id": "call_1", + "type": "tool_call", + } + ], + ), + AIMessage( + content=json.dumps( + { + "fair_value_range": {"low": 120.5, "high": 145.0}, + "expected_return_pct": 18.2, + "primary_method": "discounted cash flow", + "thesis": "Free cash flow implies upside versus the current price.", + } + ), + tool_calls=[], + ), + ] + ) + + calls = [] + + def fake_route_to_vendor(method, *args, **kwargs): + calls.append((method, args, kwargs)) + return "valuation inputs" + + monkeypatch.setattr(interface, "route_to_vendor", fake_route_to_vendor) + + class FakeLLM: + def bind_tools(self, _tools): + return RunnableLambda(lambda _inputs: next(llm_responses)) + + node = create_valuation_analyst(FakeLLM()) + workflow = StateGraph(AgentState) + workflow.add_node("Valuation Analyst", node) + workflow.add_node("tools_valuation", ToolNode([get_valuation_inputs])) + workflow.add_node("Msg Clear Valuation", lambda _state: {}) + workflow.add_edge(START, "Valuation Analyst") + workflow.add_conditional_edges( + "Valuation Analyst", + lambda state: ( + "tools_valuation" + if getattr(state["messages"][-1], "tool_calls", None) + else "Msg Clear Valuation" + ), + ["tools_valuation", "Msg Clear Valuation"], + ) + workflow.add_edge("tools_valuation", "Valuation Analyst") + workflow.add_edge("Msg Clear Valuation", END) + + final_state = workflow.compile().invoke( + { + "trade_date": "2026-03-24", + "company_of_interest": "NVDA", + "messages": [("human", "Value NVDA")], + } + ) + + assert final_state["valuation_data"] == { + "fair_value_range": {"low": 120.5, "high": 145.0}, + "expected_return_pct": 18.2, + "primary_method": "discounted cash flow", + "thesis": "Free cash flow implies upside versus the current price.", + } + assert calls == [ + ("get_fundamentals", (), {"ticker": "NVDA", "curr_date": "2026-03-24"}) + ] diff --git a/tradingagents/agents/analysts/valuation_analyst.py b/tradingagents/agents/analysts/valuation_analyst.py index f0e859aa..d02ac8a3 100644 --- a/tradingagents/agents/analysts/valuation_analyst.py +++ b/tradingagents/agents/analysts/valuation_analyst.py @@ -33,7 +33,7 @@ def _coerce_optional_float(value): def _parse_json_payload(raw_text: str): text = raw_text.strip() if not text: - return {} + return {}, False candidates = [text] fenced_blocks = re.findall(r"```(?:json)?\s*(.*?)```", text, flags=re.DOTALL) @@ -45,12 +45,13 @@ def _parse_json_payload(raw_text: str): except json.JSONDecodeError: continue if isinstance(parsed, dict): - return parsed - return {} + return parsed, True + return {}, False def _parse_valuation_data(content): - payload = _parse_json_payload(_content_to_text(content)) + raw_text = _content_to_text(content).strip() + payload, parsed = _parse_json_payload(raw_text) valuation_data = make_default_valuation_data() fair_value_range = payload.get("fair_value_range") @@ -66,6 +67,10 @@ def _parse_valuation_data(content): valuation_data["primary_method"] = str(payload.get("primary_method") or "") valuation_data["thesis"] = str(payload.get("thesis") or "") + if not parsed: + valuation_data["primary_method"] = "parse_error" + valuation_data["thesis"] = raw_text or "[empty model response]" + return valuation_data