feat: generic AI trading agent interface (#264)

Standardized input/output contract for all agents, enabling pluggable
composition and benchmarking.

- AgentInput/AgentOutput Pydantic schemas (5-tier rating, confidence, targets)
- BaseAgent abstract class with analyze(input) -> output contract
- AgentRegistry for pluggable agent discovery
- Existing analysts refactored to implement BaseAgent
- Agent benchmarking: compare outputs across different LLM backends

Closes #264
This commit is contained in:
Clayton Brown 2026-04-21 08:25:27 +10:00
parent fa4d01c23a
commit 18ac47f5d6
6 changed files with 391 additions and 0 deletions

View File

@ -1,11 +1,21 @@
from .base_agent import BaseAgent
from .benchmark import BenchmarkReport, BenchmarkResult, LLMBackend, benchmark_agent, benchmark_agents
from .registry import AgentRegistry
from .utils.agent_utils import create_msg_delete
from .utils.agent_states import AgentState, InvestDebateState, RiskDebateState
from .utils.memory import FinancialSituationMemory
from .utils.schemas import AgentInput, AgentOutput, PriceTargets
from .analysts.fundamentals_analyst import create_fundamentals_analyst
from .analysts.market_analyst import create_market_analyst
from .analysts.news_analyst import create_news_analyst
from .analysts.social_media_analyst import create_social_media_analyst
from .analysts.base_analysts import (
FundamentalsAgent,
SentimentAgent,
NewsAgent,
TechnicalAgent,
)
from .researchers.bear_researcher import create_bear_researcher
from .researchers.bull_researcher import create_bull_researcher
@ -20,8 +30,22 @@ from .managers.portfolio_manager import create_portfolio_manager
from .trader.trader import create_trader
__all__ = [
"AgentRegistry",
"BaseAgent",
"BenchmarkReport",
"BenchmarkResult",
"LLMBackend",
"benchmark_agent",
"benchmark_agents",
"FundamentalsAgent",
"SentimentAgent",
"NewsAgent",
"TechnicalAgent",
"FinancialSituationMemory",
"AgentState",
"AgentInput",
"AgentOutput",
"PriceTargets",
"create_msg_delete",
"InvestDebateState",
"RiskDebateState",

View File

@ -0,0 +1,109 @@
"""BaseAgent implementations for the four analyst types.
Each class wraps the existing analyst logic behind the standardized
``BaseAgent.analyze(AgentInput) -> AgentOutput`` contract while the
original ``create_*`` factory functions remain unchanged for LangGraph
node compatibility.
"""
from __future__ import annotations
from langchain_core.messages import HumanMessage
from tradingagents.agents.base_agent import BaseAgent
from tradingagents.agents.utils.schemas import AgentInput, AgentOutput
# Shared prompt that asks the LLM to return a JSON matching AgentOutput.
_STRUCTURED_SUFFIX = (
"\n\nAfter your analysis, provide a final JSON object with these exact keys:\n"
'- "rating": one of "BUY", "OVERWEIGHT", "HOLD", "UNDERWEIGHT", "SELL"\n'
'- "confidence": float 0.0-1.0\n'
'- "thesis": one-paragraph summary\n'
'- "risk_factors": list of strings\n'
"Return ONLY the JSON object, no other text."
)
def _invoke_structured(llm, role_prompt: str, agent_input: AgentInput) -> AgentOutput:
"""Ask *llm* to produce an ``AgentOutput`` via structured output."""
full_prompt = (
f"{role_prompt}\n\n"
f"Ticker: {agent_input.ticker}\n"
f"Date: {agent_input.date}\n"
)
if agent_input.context:
for k, v in agent_input.context.items():
full_prompt += f"\n--- {k} ---\n{v}\n"
full_prompt += _STRUCTURED_SUFFIX
structured_llm = llm.with_structured_output(AgentOutput)
return structured_llm.invoke([HumanMessage(content=full_prompt)])
class FundamentalsAgent(BaseAgent):
"""Standardized fundamentals analyst."""
name: str = "fundamentals_analyst"
def __init__(self, llm) -> None:
self.llm = llm
def analyze(self, agent_input: AgentInput) -> AgentOutput:
return _invoke_structured(
self.llm,
"You are a fundamentals analyst. Evaluate the company's financial health "
"using balance sheets, cash flow, income statements, and key ratios.",
agent_input,
)
class SentimentAgent(BaseAgent):
"""Standardized sentiment / social-media analyst."""
name: str = "sentiment_analyst"
def __init__(self, llm) -> None:
self.llm = llm
def analyze(self, agent_input: AgentInput) -> AgentOutput:
return _invoke_structured(
self.llm,
"You are a sentiment analyst. Evaluate public sentiment from social media, "
"news headlines, and community discussions about the company.",
agent_input,
)
class NewsAgent(BaseAgent):
"""Standardized news analyst."""
name: str = "news_analyst"
def __init__(self, llm) -> None:
self.llm = llm
def analyze(self, agent_input: AgentInput) -> AgentOutput:
return _invoke_structured(
self.llm,
"You are a news analyst. Evaluate recent news, macroeconomic events, "
"and geopolitical developments relevant to the company.",
agent_input,
)
class TechnicalAgent(BaseAgent):
"""Standardized technical / market analyst."""
name: str = "technical_analyst"
def __init__(self, llm) -> None:
self.llm = llm
def analyze(self, agent_input: AgentInput) -> AgentOutput:
return _invoke_structured(
self.llm,
"You are a technical analyst. Evaluate price action, volume, moving averages, "
"MACD, RSI, Bollinger Bands, and other technical indicators.",
agent_input,
)

View File

@ -0,0 +1,23 @@
"""Abstract base class for trading agents with a standardized analyze contract."""
from __future__ import annotations
from abc import ABC, abstractmethod
from .utils.schemas import AgentInput, AgentOutput
class BaseAgent(ABC):
"""Base class all trading agents must implement.
Subclasses provide ``analyze`` which accepts an :class:`AgentInput` and
returns an :class:`AgentOutput`, ensuring a uniform contract across every
agent in the system.
"""
name: str = "unnamed_agent"
@abstractmethod
def analyze(self, agent_input: AgentInput) -> AgentOutput:
"""Run analysis and return a standardized output."""
...

View File

@ -0,0 +1,136 @@
"""Agent benchmarking: compare outputs across different LLM backends."""
from __future__ import annotations
import time
from dataclasses import dataclass, field
from typing import Any
from tradingagents.agents.base_agent import BaseAgent
from tradingagents.agents.utils.schemas import AgentInput, AgentOutput
from tradingagents.llm_clients import create_llm_client
@dataclass
class BenchmarkResult:
"""Result of a single agent run against one LLM backend."""
agent_name: str
provider: str
model: str
output: AgentOutput | None
elapsed_seconds: float
error: str | None = None
@dataclass
class BenchmarkReport:
"""Aggregated results from benchmarking one or more agents across backends."""
results: list[BenchmarkResult] = field(default_factory=list)
def summary(self) -> list[dict[str, Any]]:
"""Return a list of dicts summarising each result for easy comparison."""
rows: list[dict[str, Any]] = []
for r in self.results:
row: dict[str, Any] = {
"agent": r.agent_name,
"provider": r.provider,
"model": r.model,
"elapsed_s": round(r.elapsed_seconds, 2),
}
if r.error:
row["error"] = r.error
elif r.output:
row["rating"] = r.output.rating
row["confidence"] = r.output.confidence
row["thesis_len"] = len(r.output.thesis)
row["risk_factors"] = len(r.output.risk_factors)
rows.append(row)
return rows
@dataclass
class LLMBackend:
"""Describes an LLM backend to benchmark against."""
provider: str
model: str
base_url: str | None = None
kwargs: dict[str, Any] = field(default_factory=dict)
def _make_llm(backend: LLMBackend) -> Any:
"""Create a LangChain LLM from a backend spec."""
client = create_llm_client(
provider=backend.provider,
model=backend.model,
base_url=backend.base_url,
**backend.kwargs,
)
return client.get_llm()
def benchmark_agent(
agent_cls: type[BaseAgent],
agent_input: AgentInput,
backends: list[LLMBackend],
) -> BenchmarkReport:
"""Run *agent_cls* with *agent_input* across each backend and collect results.
Args:
agent_cls: A ``BaseAgent`` subclass whose ``__init__`` accepts a single
``llm`` positional argument.
agent_input: The standardized input to feed every agent instance.
backends: LLM backends to compare.
Returns:
A :class:`BenchmarkReport` with one :class:`BenchmarkResult` per backend.
"""
report = BenchmarkReport()
for backend in backends:
t0 = time.monotonic()
try:
llm = _make_llm(backend)
agent = agent_cls(llm)
output = agent.analyze(agent_input)
elapsed = time.monotonic() - t0
report.results.append(
BenchmarkResult(
agent_name=agent.name,
provider=backend.provider,
model=backend.model,
output=output,
elapsed_seconds=elapsed,
)
)
except Exception as exc: # noqa: BLE001
elapsed = time.monotonic() - t0
report.results.append(
BenchmarkResult(
agent_name=agent_cls.name if hasattr(agent_cls, "name") else agent_cls.__name__,
provider=backend.provider,
model=backend.model,
output=None,
elapsed_seconds=elapsed,
error=str(exc),
)
)
return report
def benchmark_agents(
agent_classes: list[type[BaseAgent]],
agent_input: AgentInput,
backends: list[LLMBackend],
) -> BenchmarkReport:
"""Run multiple agent types across multiple backends.
Convenience wrapper that calls :func:`benchmark_agent` for each class and
merges the results into a single report.
"""
merged = BenchmarkReport()
for cls in agent_classes:
report = benchmark_agent(cls, agent_input, backends)
merged.results.extend(report.results)
return merged

View File

@ -0,0 +1,63 @@
"""AgentRegistry for pluggable agent discovery."""
from __future__ import annotations
from typing import Callable
from .base_agent import BaseAgent
class AgentRegistry:
"""Registry that maps agent names to their factory callables.
Usage::
registry = AgentRegistry()
registry.register("fundamentals", FundamentalsAgent, llm=my_llm)
agent = registry.get("fundamentals")
output = agent.analyze(agent_input)
Agents can be registered either as a pre-built instance or as a class
(with optional ``**kwargs`` forwarded to the constructor on first access).
"""
def __init__(self) -> None:
self._factories: dict[str, Callable[[], BaseAgent]] = {}
self._instances: dict[str, BaseAgent] = {}
def register(
self,
name: str,
agent: type[BaseAgent] | BaseAgent,
**kwargs,
) -> None:
"""Register an agent class or instance under *name*.
If *agent* is a class, ``kwargs`` are forwarded to its constructor
when :meth:`get` is called. If it is already an instance, it is
stored directly.
"""
if isinstance(agent, BaseAgent):
self._instances[name] = agent
elif isinstance(agent, type) and issubclass(agent, BaseAgent):
self._factories[name] = lambda: agent(**kwargs)
else:
raise TypeError(f"Expected BaseAgent subclass or instance, got {type(agent)}")
def get(self, name: str) -> BaseAgent:
"""Return the agent registered under *name*, instantiating lazily if needed."""
if name not in self._instances:
if name not in self._factories:
raise KeyError(f"No agent registered under '{name}'")
self._instances[name] = self._factories.pop(name)()
return self._instances[name]
def list(self) -> list[str]:
"""Return sorted list of all registered agent names."""
return sorted({*self._factories, *self._instances})
def __contains__(self, name: str) -> bool:
return name in self._factories or name in self._instances
def __len__(self) -> int:
return len(self.list())

View File

@ -0,0 +1,36 @@
"""Standardized input/output schemas for the generic agent interface."""
from __future__ import annotations
from typing import Literal
from pydantic import BaseModel, Field
class AgentInput(BaseModel):
"""Standardized input contract for any trading agent."""
ticker: str
date: str
context: dict[str, str] = Field(
default_factory=dict,
description="Optional context keyed by: market_data, news, fundamentals, sentiment, technical_indicators",
)
class PriceTargets(BaseModel):
"""Entry, target, and stop-loss price levels."""
entry: float
target: float
stop_loss: float
class AgentOutput(BaseModel):
"""Standardized output contract for any trading agent."""
rating: Literal["BUY", "OVERWEIGHT", "HOLD", "UNDERWEIGHT", "SELL"]
confidence: float = Field(ge=0.0, le=1.0)
price_targets: PriceTargets | None = None
thesis: str
risk_factors: list[str] = Field(default_factory=list)