From 18ac47f5d64565ef487f1e146a9a43dca47dd4b3 Mon Sep 17 00:00:00 2001 From: Clayton Brown Date: Tue, 21 Apr 2026 08:25:27 +1000 Subject: [PATCH] feat: generic AI trading agent interface (#264) Standardized input/output contract for all agents, enabling pluggable composition and benchmarking. - AgentInput/AgentOutput Pydantic schemas (5-tier rating, confidence, targets) - BaseAgent abstract class with analyze(input) -> output contract - AgentRegistry for pluggable agent discovery - Existing analysts refactored to implement BaseAgent - Agent benchmarking: compare outputs across different LLM backends Closes #264 --- tradingagents/agents/__init__.py | 24 ++++ .../agents/analysts/base_analysts.py | 109 ++++++++++++++ tradingagents/agents/base_agent.py | 23 +++ tradingagents/agents/benchmark.py | 136 ++++++++++++++++++ tradingagents/agents/registry.py | 63 ++++++++ tradingagents/agents/utils/schemas.py | 36 +++++ 6 files changed, 391 insertions(+) create mode 100644 tradingagents/agents/analysts/base_analysts.py create mode 100644 tradingagents/agents/base_agent.py create mode 100644 tradingagents/agents/benchmark.py create mode 100644 tradingagents/agents/registry.py create mode 100644 tradingagents/agents/utils/schemas.py diff --git a/tradingagents/agents/__init__.py b/tradingagents/agents/__init__.py index 1f03642c..90374c25 100644 --- a/tradingagents/agents/__init__.py +++ b/tradingagents/agents/__init__.py @@ -1,11 +1,21 @@ +from .base_agent import BaseAgent +from .benchmark import BenchmarkReport, BenchmarkResult, LLMBackend, benchmark_agent, benchmark_agents +from .registry import AgentRegistry from .utils.agent_utils import create_msg_delete from .utils.agent_states import AgentState, InvestDebateState, RiskDebateState from .utils.memory import FinancialSituationMemory +from .utils.schemas import AgentInput, AgentOutput, PriceTargets from .analysts.fundamentals_analyst import create_fundamentals_analyst from .analysts.market_analyst import create_market_analyst from .analysts.news_analyst import create_news_analyst from .analysts.social_media_analyst import create_social_media_analyst +from .analysts.base_analysts import ( + FundamentalsAgent, + SentimentAgent, + NewsAgent, + TechnicalAgent, +) from .researchers.bear_researcher import create_bear_researcher from .researchers.bull_researcher import create_bull_researcher @@ -20,8 +30,22 @@ from .managers.portfolio_manager import create_portfolio_manager from .trader.trader import create_trader __all__ = [ + "AgentRegistry", + "BaseAgent", + "BenchmarkReport", + "BenchmarkResult", + "LLMBackend", + "benchmark_agent", + "benchmark_agents", + "FundamentalsAgent", + "SentimentAgent", + "NewsAgent", + "TechnicalAgent", "FinancialSituationMemory", "AgentState", + "AgentInput", + "AgentOutput", + "PriceTargets", "create_msg_delete", "InvestDebateState", "RiskDebateState", diff --git a/tradingagents/agents/analysts/base_analysts.py b/tradingagents/agents/analysts/base_analysts.py new file mode 100644 index 00000000..df755ff0 --- /dev/null +++ b/tradingagents/agents/analysts/base_analysts.py @@ -0,0 +1,109 @@ +"""BaseAgent implementations for the four analyst types. + +Each class wraps the existing analyst logic behind the standardized +``BaseAgent.analyze(AgentInput) -> AgentOutput`` contract while the +original ``create_*`` factory functions remain unchanged for LangGraph +node compatibility. +""" + +from __future__ import annotations + +from langchain_core.messages import HumanMessage + +from tradingagents.agents.base_agent import BaseAgent +from tradingagents.agents.utils.schemas import AgentInput, AgentOutput + +# Shared prompt that asks the LLM to return a JSON matching AgentOutput. +_STRUCTURED_SUFFIX = ( + "\n\nAfter your analysis, provide a final JSON object with these exact keys:\n" + '- "rating": one of "BUY", "OVERWEIGHT", "HOLD", "UNDERWEIGHT", "SELL"\n' + '- "confidence": float 0.0-1.0\n' + '- "thesis": one-paragraph summary\n' + '- "risk_factors": list of strings\n' + "Return ONLY the JSON object, no other text." +) + + +def _invoke_structured(llm, role_prompt: str, agent_input: AgentInput) -> AgentOutput: + """Ask *llm* to produce an ``AgentOutput`` via structured output.""" + full_prompt = ( + f"{role_prompt}\n\n" + f"Ticker: {agent_input.ticker}\n" + f"Date: {agent_input.date}\n" + ) + if agent_input.context: + for k, v in agent_input.context.items(): + full_prompt += f"\n--- {k} ---\n{v}\n" + + full_prompt += _STRUCTURED_SUFFIX + + structured_llm = llm.with_structured_output(AgentOutput) + return structured_llm.invoke([HumanMessage(content=full_prompt)]) + + +class FundamentalsAgent(BaseAgent): + """Standardized fundamentals analyst.""" + + name: str = "fundamentals_analyst" + + def __init__(self, llm) -> None: + self.llm = llm + + def analyze(self, agent_input: AgentInput) -> AgentOutput: + return _invoke_structured( + self.llm, + "You are a fundamentals analyst. Evaluate the company's financial health " + "using balance sheets, cash flow, income statements, and key ratios.", + agent_input, + ) + + +class SentimentAgent(BaseAgent): + """Standardized sentiment / social-media analyst.""" + + name: str = "sentiment_analyst" + + def __init__(self, llm) -> None: + self.llm = llm + + def analyze(self, agent_input: AgentInput) -> AgentOutput: + return _invoke_structured( + self.llm, + "You are a sentiment analyst. Evaluate public sentiment from social media, " + "news headlines, and community discussions about the company.", + agent_input, + ) + + +class NewsAgent(BaseAgent): + """Standardized news analyst.""" + + name: str = "news_analyst" + + def __init__(self, llm) -> None: + self.llm = llm + + def analyze(self, agent_input: AgentInput) -> AgentOutput: + return _invoke_structured( + self.llm, + "You are a news analyst. Evaluate recent news, macroeconomic events, " + "and geopolitical developments relevant to the company.", + agent_input, + ) + + +class TechnicalAgent(BaseAgent): + """Standardized technical / market analyst.""" + + name: str = "technical_analyst" + + def __init__(self, llm) -> None: + self.llm = llm + + def analyze(self, agent_input: AgentInput) -> AgentOutput: + return _invoke_structured( + self.llm, + "You are a technical analyst. Evaluate price action, volume, moving averages, " + "MACD, RSI, Bollinger Bands, and other technical indicators.", + agent_input, + ) diff --git a/tradingagents/agents/base_agent.py b/tradingagents/agents/base_agent.py new file mode 100644 index 00000000..81630ab2 --- /dev/null +++ b/tradingagents/agents/base_agent.py @@ -0,0 +1,23 @@ +"""Abstract base class for trading agents with a standardized analyze contract.""" + +from __future__ import annotations + +from abc import ABC, abstractmethod + +from .utils.schemas import AgentInput, AgentOutput + + +class BaseAgent(ABC): + """Base class all trading agents must implement. + + Subclasses provide ``analyze`` which accepts an :class:`AgentInput` and + returns an :class:`AgentOutput`, ensuring a uniform contract across every + agent in the system. + """ + + name: str = "unnamed_agent" + + @abstractmethod + def analyze(self, agent_input: AgentInput) -> AgentOutput: + """Run analysis and return a standardized output.""" + ... diff --git a/tradingagents/agents/benchmark.py b/tradingagents/agents/benchmark.py new file mode 100644 index 00000000..67df24b9 --- /dev/null +++ b/tradingagents/agents/benchmark.py @@ -0,0 +1,136 @@ +"""Agent benchmarking: compare outputs across different LLM backends.""" + +from __future__ import annotations + +import time +from dataclasses import dataclass, field +from typing import Any + +from tradingagents.agents.base_agent import BaseAgent +from tradingagents.agents.utils.schemas import AgentInput, AgentOutput +from tradingagents.llm_clients import create_llm_client + + +@dataclass +class BenchmarkResult: + """Result of a single agent run against one LLM backend.""" + + agent_name: str + provider: str + model: str + output: AgentOutput | None + elapsed_seconds: float + error: str | None = None + + +@dataclass +class BenchmarkReport: + """Aggregated results from benchmarking one or more agents across backends.""" + + results: list[BenchmarkResult] = field(default_factory=list) + + def summary(self) -> list[dict[str, Any]]: + """Return a list of dicts summarising each result for easy comparison.""" + rows: list[dict[str, Any]] = [] + for r in self.results: + row: dict[str, Any] = { + "agent": r.agent_name, + "provider": r.provider, + "model": r.model, + "elapsed_s": round(r.elapsed_seconds, 2), + } + if r.error: + row["error"] = r.error + elif r.output: + row["rating"] = r.output.rating + row["confidence"] = r.output.confidence + row["thesis_len"] = len(r.output.thesis) + row["risk_factors"] = len(r.output.risk_factors) + rows.append(row) + return rows + + +@dataclass +class LLMBackend: + """Describes an LLM backend to benchmark against.""" + + provider: str + model: str + base_url: str | None = None + kwargs: dict[str, Any] = field(default_factory=dict) + + +def _make_llm(backend: LLMBackend) -> Any: + """Create a LangChain LLM from a backend spec.""" + client = create_llm_client( + provider=backend.provider, + model=backend.model, + base_url=backend.base_url, + **backend.kwargs, + ) + return client.get_llm() + + +def benchmark_agent( + agent_cls: type[BaseAgent], + agent_input: AgentInput, + backends: list[LLMBackend], +) -> BenchmarkReport: + """Run *agent_cls* with *agent_input* across each backend and collect results. + + Args: + agent_cls: A ``BaseAgent`` subclass whose ``__init__`` accepts a single + ``llm`` positional argument. + agent_input: The standardized input to feed every agent instance. + backends: LLM backends to compare. + + Returns: + A :class:`BenchmarkReport` with one :class:`BenchmarkResult` per backend. + """ + report = BenchmarkReport() + for backend in backends: + t0 = time.monotonic() + try: + llm = _make_llm(backend) + agent = agent_cls(llm) + output = agent.analyze(agent_input) + elapsed = time.monotonic() - t0 + report.results.append( + BenchmarkResult( + agent_name=agent.name, + provider=backend.provider, + model=backend.model, + output=output, + elapsed_seconds=elapsed, + ) + ) + except Exception as exc: # noqa: BLE001 + elapsed = time.monotonic() - t0 + report.results.append( + BenchmarkResult( + agent_name=agent_cls.name if hasattr(agent_cls, "name") else agent_cls.__name__, + provider=backend.provider, + model=backend.model, + output=None, + elapsed_seconds=elapsed, + error=str(exc), + ) + ) + return report + + +def benchmark_agents( + agent_classes: list[type[BaseAgent]], + agent_input: AgentInput, + backends: list[LLMBackend], +) -> BenchmarkReport: + """Run multiple agent types across multiple backends. + + Convenience wrapper that calls :func:`benchmark_agent` for each class and + merges the results into a single report. + """ + merged = BenchmarkReport() + for cls in agent_classes: + report = benchmark_agent(cls, agent_input, backends) + merged.results.extend(report.results) + return merged diff --git a/tradingagents/agents/registry.py b/tradingagents/agents/registry.py new file mode 100644 index 00000000..97ec2729 --- /dev/null +++ b/tradingagents/agents/registry.py @@ -0,0 +1,63 @@ +"""AgentRegistry for pluggable agent discovery.""" + +from __future__ import annotations + +from typing import Callable + +from .base_agent import BaseAgent + + +class AgentRegistry: + """Registry that maps agent names to their factory callables. + + Usage:: + + registry = AgentRegistry() + registry.register("fundamentals", FundamentalsAgent, llm=my_llm) + agent = registry.get("fundamentals") + output = agent.analyze(agent_input) + + Agents can be registered either as a pre-built instance or as a class + (with optional ``**kwargs`` forwarded to the constructor on first access). + """ + + def __init__(self) -> None: + self._factories: dict[str, Callable[[], BaseAgent]] = {} + self._instances: dict[str, BaseAgent] = {} + + def register( + self, + name: str, + agent: type[BaseAgent] | BaseAgent, + **kwargs, + ) -> None: + """Register an agent class or instance under *name*. + + If *agent* is a class, ``kwargs`` are forwarded to its constructor + when :meth:`get` is called. If it is already an instance, it is + stored directly. + """ + if isinstance(agent, BaseAgent): + self._instances[name] = agent + elif isinstance(agent, type) and issubclass(agent, BaseAgent): + self._factories[name] = lambda: agent(**kwargs) + else: + raise TypeError(f"Expected BaseAgent subclass or instance, got {type(agent)}") + + def get(self, name: str) -> BaseAgent: + """Return the agent registered under *name*, instantiating lazily if needed.""" + if name not in self._instances: + if name not in self._factories: + raise KeyError(f"No agent registered under '{name}'") + self._instances[name] = self._factories.pop(name)() + return self._instances[name] + + def list(self) -> list[str]: + """Return sorted list of all registered agent names.""" + return sorted({*self._factories, *self._instances}) + + def __contains__(self, name: str) -> bool: + return name in self._factories or name in self._instances + + def __len__(self) -> int: + return len(self.list()) diff --git a/tradingagents/agents/utils/schemas.py b/tradingagents/agents/utils/schemas.py new file mode 100644 index 00000000..dc8f967f --- /dev/null +++ b/tradingagents/agents/utils/schemas.py @@ -0,0 +1,36 @@ +"""Standardized input/output schemas for the generic agent interface.""" + +from __future__ import annotations + +from typing import Literal + +from pydantic import BaseModel, Field + + +class AgentInput(BaseModel): + """Standardized input contract for any trading agent.""" + + ticker: str + date: str + context: dict[str, str] = Field( + default_factory=dict, + description="Optional context keyed by: market_data, news, fundamentals, sentiment, technical_indicators", + ) + + +class PriceTargets(BaseModel): + """Entry, target, and stop-loss price levels.""" + + entry: float + target: float + stop_loss: float + + +class AgentOutput(BaseModel): + """Standardized output contract for any trading agent.""" + + rating: Literal["BUY", "OVERWEIGHT", "HOLD", "UNDERWEIGHT", "SELL"] + confidence: float = Field(ge=0.0, le=1.0) + price_targets: PriceTargets | None = None + thesis: str + risk_factors: list[str] = Field(default_factory=list)