fix: harden valuation analyst parsing
This commit is contained in:
parent
85377d27e2
commit
be2694367a
|
|
@ -2,7 +2,10 @@ import json
|
||||||
|
|
||||||
from langchain_core.messages import AIMessage
|
from langchain_core.messages import AIMessage
|
||||||
from langchain_core.runnables import RunnableLambda
|
from langchain_core.runnables import RunnableLambda
|
||||||
|
from langgraph.graph import END, START, StateGraph
|
||||||
|
from langgraph.prebuilt import ToolNode
|
||||||
|
|
||||||
|
from tradingagents.agents.utils.agent_states import AgentState
|
||||||
from tradingagents.graph.setup import GraphSetup
|
from tradingagents.graph.setup import GraphSetup
|
||||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||||
|
|
||||||
|
|
@ -203,3 +206,114 @@ def test_valuation_analyst_returns_structured_valuation_data():
|
||||||
|
|
||||||
assert result["valuation_data"] == response
|
assert result["valuation_data"] == response
|
||||||
assert list(result) == ["messages", "valuation_data"]
|
assert list(result) == ["messages", "valuation_data"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_valuation_analyst_marks_parse_failure_without_changing_shape():
|
||||||
|
from tradingagents.agents.analysts.valuation_analyst import create_valuation_analyst
|
||||||
|
|
||||||
|
class FakeLLM:
|
||||||
|
def bind_tools(self, _tools):
|
||||||
|
return RunnableLambda(
|
||||||
|
lambda _inputs: AIMessage(content="not valid json", tool_calls=[])
|
||||||
|
)
|
||||||
|
|
||||||
|
node = create_valuation_analyst(FakeLLM())
|
||||||
|
result = node(
|
||||||
|
{
|
||||||
|
"trade_date": "2026-03-24",
|
||||||
|
"company_of_interest": "NVDA",
|
||||||
|
"messages": [("human", "Value NVDA")],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert set(result["valuation_data"]) == {
|
||||||
|
"fair_value_range",
|
||||||
|
"expected_return_pct",
|
||||||
|
"primary_method",
|
||||||
|
"thesis",
|
||||||
|
}
|
||||||
|
assert result["valuation_data"]["fair_value_range"] == {"low": None, "high": None}
|
||||||
|
assert result["valuation_data"]["expected_return_pct"] is None
|
||||||
|
assert result["valuation_data"]["primary_method"] == "parse_error"
|
||||||
|
assert result["valuation_data"]["thesis"] == "not valid json"
|
||||||
|
|
||||||
|
|
||||||
|
def test_valuation_analyst_populates_structured_data_after_tool_loop(monkeypatch):
|
||||||
|
import tradingagents.dataflows.interface as interface
|
||||||
|
from tradingagents.agents.analysts.valuation_analyst import create_valuation_analyst
|
||||||
|
from tradingagents.agents.utils.valuation_tools import get_valuation_inputs
|
||||||
|
|
||||||
|
llm_responses = iter(
|
||||||
|
[
|
||||||
|
AIMessage(
|
||||||
|
content="",
|
||||||
|
tool_calls=[
|
||||||
|
{
|
||||||
|
"name": "get_valuation_inputs",
|
||||||
|
"args": {"ticker": "NVDA", "curr_date": "2026-03-24"},
|
||||||
|
"id": "call_1",
|
||||||
|
"type": "tool_call",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
),
|
||||||
|
AIMessage(
|
||||||
|
content=json.dumps(
|
||||||
|
{
|
||||||
|
"fair_value_range": {"low": 120.5, "high": 145.0},
|
||||||
|
"expected_return_pct": 18.2,
|
||||||
|
"primary_method": "discounted cash flow",
|
||||||
|
"thesis": "Free cash flow implies upside versus the current price.",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
tool_calls=[],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
calls = []
|
||||||
|
|
||||||
|
def fake_route_to_vendor(method, *args, **kwargs):
|
||||||
|
calls.append((method, args, kwargs))
|
||||||
|
return "valuation inputs"
|
||||||
|
|
||||||
|
monkeypatch.setattr(interface, "route_to_vendor", fake_route_to_vendor)
|
||||||
|
|
||||||
|
class FakeLLM:
|
||||||
|
def bind_tools(self, _tools):
|
||||||
|
return RunnableLambda(lambda _inputs: next(llm_responses))
|
||||||
|
|
||||||
|
node = create_valuation_analyst(FakeLLM())
|
||||||
|
workflow = StateGraph(AgentState)
|
||||||
|
workflow.add_node("Valuation Analyst", node)
|
||||||
|
workflow.add_node("tools_valuation", ToolNode([get_valuation_inputs]))
|
||||||
|
workflow.add_node("Msg Clear Valuation", lambda _state: {})
|
||||||
|
workflow.add_edge(START, "Valuation Analyst")
|
||||||
|
workflow.add_conditional_edges(
|
||||||
|
"Valuation Analyst",
|
||||||
|
lambda state: (
|
||||||
|
"tools_valuation"
|
||||||
|
if getattr(state["messages"][-1], "tool_calls", None)
|
||||||
|
else "Msg Clear Valuation"
|
||||||
|
),
|
||||||
|
["tools_valuation", "Msg Clear Valuation"],
|
||||||
|
)
|
||||||
|
workflow.add_edge("tools_valuation", "Valuation Analyst")
|
||||||
|
workflow.add_edge("Msg Clear Valuation", END)
|
||||||
|
|
||||||
|
final_state = workflow.compile().invoke(
|
||||||
|
{
|
||||||
|
"trade_date": "2026-03-24",
|
||||||
|
"company_of_interest": "NVDA",
|
||||||
|
"messages": [("human", "Value NVDA")],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert final_state["valuation_data"] == {
|
||||||
|
"fair_value_range": {"low": 120.5, "high": 145.0},
|
||||||
|
"expected_return_pct": 18.2,
|
||||||
|
"primary_method": "discounted cash flow",
|
||||||
|
"thesis": "Free cash flow implies upside versus the current price.",
|
||||||
|
}
|
||||||
|
assert calls == [
|
||||||
|
("get_fundamentals", (), {"ticker": "NVDA", "curr_date": "2026-03-24"})
|
||||||
|
]
|
||||||
|
|
|
||||||
|
|
@ -33,7 +33,7 @@ def _coerce_optional_float(value):
|
||||||
def _parse_json_payload(raw_text: str):
|
def _parse_json_payload(raw_text: str):
|
||||||
text = raw_text.strip()
|
text = raw_text.strip()
|
||||||
if not text:
|
if not text:
|
||||||
return {}
|
return {}, False
|
||||||
|
|
||||||
candidates = [text]
|
candidates = [text]
|
||||||
fenced_blocks = re.findall(r"```(?:json)?\s*(.*?)```", text, flags=re.DOTALL)
|
fenced_blocks = re.findall(r"```(?:json)?\s*(.*?)```", text, flags=re.DOTALL)
|
||||||
|
|
@ -45,12 +45,13 @@ def _parse_json_payload(raw_text: str):
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
continue
|
continue
|
||||||
if isinstance(parsed, dict):
|
if isinstance(parsed, dict):
|
||||||
return parsed
|
return parsed, True
|
||||||
return {}
|
return {}, False
|
||||||
|
|
||||||
|
|
||||||
def _parse_valuation_data(content):
|
def _parse_valuation_data(content):
|
||||||
payload = _parse_json_payload(_content_to_text(content))
|
raw_text = _content_to_text(content).strip()
|
||||||
|
payload, parsed = _parse_json_payload(raw_text)
|
||||||
valuation_data = make_default_valuation_data()
|
valuation_data = make_default_valuation_data()
|
||||||
|
|
||||||
fair_value_range = payload.get("fair_value_range")
|
fair_value_range = payload.get("fair_value_range")
|
||||||
|
|
@ -66,6 +67,10 @@ def _parse_valuation_data(content):
|
||||||
valuation_data["primary_method"] = str(payload.get("primary_method") or "")
|
valuation_data["primary_method"] = str(payload.get("primary_method") or "")
|
||||||
valuation_data["thesis"] = str(payload.get("thesis") or "")
|
valuation_data["thesis"] = str(payload.get("thesis") or "")
|
||||||
|
|
||||||
|
if not parsed:
|
||||||
|
valuation_data["primary_method"] = "parse_error"
|
||||||
|
valuation_data["thesis"] = raw_text or "[empty model response]"
|
||||||
|
|
||||||
return valuation_data
|
return valuation_data
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue