fix: harden valuation analyst parsing
This commit is contained in:
parent
85377d27e2
commit
be2694367a
|
|
@ -2,7 +2,10 @@ import json
|
|||
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.runnables import RunnableLambda
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
from langgraph.prebuilt import ToolNode
|
||||
|
||||
from tradingagents.agents.utils.agent_states import AgentState
|
||||
from tradingagents.graph.setup import GraphSetup
|
||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||
|
||||
|
|
@ -203,3 +206,114 @@ def test_valuation_analyst_returns_structured_valuation_data():
|
|||
|
||||
assert result["valuation_data"] == response
|
||||
assert list(result) == ["messages", "valuation_data"]
|
||||
|
||||
|
||||
def test_valuation_analyst_marks_parse_failure_without_changing_shape():
|
||||
from tradingagents.agents.analysts.valuation_analyst import create_valuation_analyst
|
||||
|
||||
class FakeLLM:
|
||||
def bind_tools(self, _tools):
|
||||
return RunnableLambda(
|
||||
lambda _inputs: AIMessage(content="not valid json", tool_calls=[])
|
||||
)
|
||||
|
||||
node = create_valuation_analyst(FakeLLM())
|
||||
result = node(
|
||||
{
|
||||
"trade_date": "2026-03-24",
|
||||
"company_of_interest": "NVDA",
|
||||
"messages": [("human", "Value NVDA")],
|
||||
}
|
||||
)
|
||||
|
||||
assert set(result["valuation_data"]) == {
|
||||
"fair_value_range",
|
||||
"expected_return_pct",
|
||||
"primary_method",
|
||||
"thesis",
|
||||
}
|
||||
assert result["valuation_data"]["fair_value_range"] == {"low": None, "high": None}
|
||||
assert result["valuation_data"]["expected_return_pct"] is None
|
||||
assert result["valuation_data"]["primary_method"] == "parse_error"
|
||||
assert result["valuation_data"]["thesis"] == "not valid json"
|
||||
|
||||
|
||||
def test_valuation_analyst_populates_structured_data_after_tool_loop(monkeypatch):
|
||||
import tradingagents.dataflows.interface as interface
|
||||
from tradingagents.agents.analysts.valuation_analyst import create_valuation_analyst
|
||||
from tradingagents.agents.utils.valuation_tools import get_valuation_inputs
|
||||
|
||||
llm_responses = iter(
|
||||
[
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "get_valuation_inputs",
|
||||
"args": {"ticker": "NVDA", "curr_date": "2026-03-24"},
|
||||
"id": "call_1",
|
||||
"type": "tool_call",
|
||||
}
|
||||
],
|
||||
),
|
||||
AIMessage(
|
||||
content=json.dumps(
|
||||
{
|
||||
"fair_value_range": {"low": 120.5, "high": 145.0},
|
||||
"expected_return_pct": 18.2,
|
||||
"primary_method": "discounted cash flow",
|
||||
"thesis": "Free cash flow implies upside versus the current price.",
|
||||
}
|
||||
),
|
||||
tool_calls=[],
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
calls = []
|
||||
|
||||
def fake_route_to_vendor(method, *args, **kwargs):
|
||||
calls.append((method, args, kwargs))
|
||||
return "valuation inputs"
|
||||
|
||||
monkeypatch.setattr(interface, "route_to_vendor", fake_route_to_vendor)
|
||||
|
||||
class FakeLLM:
|
||||
def bind_tools(self, _tools):
|
||||
return RunnableLambda(lambda _inputs: next(llm_responses))
|
||||
|
||||
node = create_valuation_analyst(FakeLLM())
|
||||
workflow = StateGraph(AgentState)
|
||||
workflow.add_node("Valuation Analyst", node)
|
||||
workflow.add_node("tools_valuation", ToolNode([get_valuation_inputs]))
|
||||
workflow.add_node("Msg Clear Valuation", lambda _state: {})
|
||||
workflow.add_edge(START, "Valuation Analyst")
|
||||
workflow.add_conditional_edges(
|
||||
"Valuation Analyst",
|
||||
lambda state: (
|
||||
"tools_valuation"
|
||||
if getattr(state["messages"][-1], "tool_calls", None)
|
||||
else "Msg Clear Valuation"
|
||||
),
|
||||
["tools_valuation", "Msg Clear Valuation"],
|
||||
)
|
||||
workflow.add_edge("tools_valuation", "Valuation Analyst")
|
||||
workflow.add_edge("Msg Clear Valuation", END)
|
||||
|
||||
final_state = workflow.compile().invoke(
|
||||
{
|
||||
"trade_date": "2026-03-24",
|
||||
"company_of_interest": "NVDA",
|
||||
"messages": [("human", "Value NVDA")],
|
||||
}
|
||||
)
|
||||
|
||||
assert final_state["valuation_data"] == {
|
||||
"fair_value_range": {"low": 120.5, "high": 145.0},
|
||||
"expected_return_pct": 18.2,
|
||||
"primary_method": "discounted cash flow",
|
||||
"thesis": "Free cash flow implies upside versus the current price.",
|
||||
}
|
||||
assert calls == [
|
||||
("get_fundamentals", (), {"ticker": "NVDA", "curr_date": "2026-03-24"})
|
||||
]
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ def _coerce_optional_float(value):
|
|||
def _parse_json_payload(raw_text: str):
|
||||
text = raw_text.strip()
|
||||
if not text:
|
||||
return {}
|
||||
return {}, False
|
||||
|
||||
candidates = [text]
|
||||
fenced_blocks = re.findall(r"```(?:json)?\s*(.*?)```", text, flags=re.DOTALL)
|
||||
|
|
@ -45,12 +45,13 @@ def _parse_json_payload(raw_text: str):
|
|||
except json.JSONDecodeError:
|
||||
continue
|
||||
if isinstance(parsed, dict):
|
||||
return parsed
|
||||
return {}
|
||||
return parsed, True
|
||||
return {}, False
|
||||
|
||||
|
||||
def _parse_valuation_data(content):
|
||||
payload = _parse_json_payload(_content_to_text(content))
|
||||
raw_text = _content_to_text(content).strip()
|
||||
payload, parsed = _parse_json_payload(raw_text)
|
||||
valuation_data = make_default_valuation_data()
|
||||
|
||||
fair_value_range = payload.get("fair_value_range")
|
||||
|
|
@ -66,6 +67,10 @@ def _parse_valuation_data(content):
|
|||
valuation_data["primary_method"] = str(payload.get("primary_method") or "")
|
||||
valuation_data["thesis"] = str(payload.get("thesis") or "")
|
||||
|
||||
if not parsed:
|
||||
valuation_data["primary_method"] = "parse_error"
|
||||
valuation_data["thesis"] = raw_text or "[empty model response]"
|
||||
|
||||
return valuation_data
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue