feat: generic AI trading agent interface (#264)
Standardized input/output contract for all agents, enabling pluggable composition and benchmarking. - AgentInput/AgentOutput Pydantic schemas (5-tier rating, confidence, targets) - BaseAgent abstract class with analyze(input) -> output contract - AgentRegistry for pluggable agent discovery - Existing analysts refactored to implement BaseAgent - Agent benchmarking: compare outputs across different LLM backends Closes #264
This commit is contained in:
parent
fa4d01c23a
commit
18ac47f5d6
|
|
@ -1,11 +1,21 @@
|
||||||
|
from .base_agent import BaseAgent
|
||||||
|
from .benchmark import BenchmarkReport, BenchmarkResult, LLMBackend, benchmark_agent, benchmark_agents
|
||||||
|
from .registry import AgentRegistry
|
||||||
from .utils.agent_utils import create_msg_delete
|
from .utils.agent_utils import create_msg_delete
|
||||||
from .utils.agent_states import AgentState, InvestDebateState, RiskDebateState
|
from .utils.agent_states import AgentState, InvestDebateState, RiskDebateState
|
||||||
from .utils.memory import FinancialSituationMemory
|
from .utils.memory import FinancialSituationMemory
|
||||||
|
from .utils.schemas import AgentInput, AgentOutput, PriceTargets
|
||||||
|
|
||||||
from .analysts.fundamentals_analyst import create_fundamentals_analyst
|
from .analysts.fundamentals_analyst import create_fundamentals_analyst
|
||||||
from .analysts.market_analyst import create_market_analyst
|
from .analysts.market_analyst import create_market_analyst
|
||||||
from .analysts.news_analyst import create_news_analyst
|
from .analysts.news_analyst import create_news_analyst
|
||||||
from .analysts.social_media_analyst import create_social_media_analyst
|
from .analysts.social_media_analyst import create_social_media_analyst
|
||||||
|
from .analysts.base_analysts import (
|
||||||
|
FundamentalsAgent,
|
||||||
|
SentimentAgent,
|
||||||
|
NewsAgent,
|
||||||
|
TechnicalAgent,
|
||||||
|
)
|
||||||
|
|
||||||
from .researchers.bear_researcher import create_bear_researcher
|
from .researchers.bear_researcher import create_bear_researcher
|
||||||
from .researchers.bull_researcher import create_bull_researcher
|
from .researchers.bull_researcher import create_bull_researcher
|
||||||
|
|
@ -20,8 +30,22 @@ from .managers.portfolio_manager import create_portfolio_manager
|
||||||
from .trader.trader import create_trader
|
from .trader.trader import create_trader
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"AgentRegistry",
|
||||||
|
"BaseAgent",
|
||||||
|
"BenchmarkReport",
|
||||||
|
"BenchmarkResult",
|
||||||
|
"LLMBackend",
|
||||||
|
"benchmark_agent",
|
||||||
|
"benchmark_agents",
|
||||||
|
"FundamentalsAgent",
|
||||||
|
"SentimentAgent",
|
||||||
|
"NewsAgent",
|
||||||
|
"TechnicalAgent",
|
||||||
"FinancialSituationMemory",
|
"FinancialSituationMemory",
|
||||||
"AgentState",
|
"AgentState",
|
||||||
|
"AgentInput",
|
||||||
|
"AgentOutput",
|
||||||
|
"PriceTargets",
|
||||||
"create_msg_delete",
|
"create_msg_delete",
|
||||||
"InvestDebateState",
|
"InvestDebateState",
|
||||||
"RiskDebateState",
|
"RiskDebateState",
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,109 @@
|
||||||
|
"""BaseAgent implementations for the four analyst types.
|
||||||
|
|
||||||
|
Each class wraps the existing analyst logic behind the standardized
|
||||||
|
``BaseAgent.analyze(AgentInput) -> AgentOutput`` contract while the
|
||||||
|
original ``create_*`` factory functions remain unchanged for LangGraph
|
||||||
|
node compatibility.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
|
|
||||||
|
from tradingagents.agents.base_agent import BaseAgent
|
||||||
|
from tradingagents.agents.utils.schemas import AgentInput, AgentOutput
|
||||||
|
|
||||||
|
# Shared prompt that asks the LLM to return a JSON matching AgentOutput.
|
||||||
|
_STRUCTURED_SUFFIX = (
|
||||||
|
"\n\nAfter your analysis, provide a final JSON object with these exact keys:\n"
|
||||||
|
'- "rating": one of "BUY", "OVERWEIGHT", "HOLD", "UNDERWEIGHT", "SELL"\n'
|
||||||
|
'- "confidence": float 0.0-1.0\n'
|
||||||
|
'- "thesis": one-paragraph summary\n'
|
||||||
|
'- "risk_factors": list of strings\n'
|
||||||
|
"Return ONLY the JSON object, no other text."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _invoke_structured(llm, role_prompt: str, agent_input: AgentInput) -> AgentOutput:
|
||||||
|
"""Ask *llm* to produce an ``AgentOutput`` via structured output."""
|
||||||
|
full_prompt = (
|
||||||
|
f"{role_prompt}\n\n"
|
||||||
|
f"Ticker: {agent_input.ticker}\n"
|
||||||
|
f"Date: {agent_input.date}\n"
|
||||||
|
)
|
||||||
|
if agent_input.context:
|
||||||
|
for k, v in agent_input.context.items():
|
||||||
|
full_prompt += f"\n--- {k} ---\n{v}\n"
|
||||||
|
|
||||||
|
full_prompt += _STRUCTURED_SUFFIX
|
||||||
|
|
||||||
|
structured_llm = llm.with_structured_output(AgentOutput)
|
||||||
|
return structured_llm.invoke([HumanMessage(content=full_prompt)])
|
||||||
|
|
||||||
|
|
||||||
|
class FundamentalsAgent(BaseAgent):
|
||||||
|
"""Standardized fundamentals analyst."""
|
||||||
|
|
||||||
|
name: str = "fundamentals_analyst"
|
||||||
|
|
||||||
|
def __init__(self, llm) -> None:
|
||||||
|
self.llm = llm
|
||||||
|
|
||||||
|
def analyze(self, agent_input: AgentInput) -> AgentOutput:
|
||||||
|
return _invoke_structured(
|
||||||
|
self.llm,
|
||||||
|
"You are a fundamentals analyst. Evaluate the company's financial health "
|
||||||
|
"using balance sheets, cash flow, income statements, and key ratios.",
|
||||||
|
agent_input,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SentimentAgent(BaseAgent):
|
||||||
|
"""Standardized sentiment / social-media analyst."""
|
||||||
|
|
||||||
|
name: str = "sentiment_analyst"
|
||||||
|
|
||||||
|
def __init__(self, llm) -> None:
|
||||||
|
self.llm = llm
|
||||||
|
|
||||||
|
def analyze(self, agent_input: AgentInput) -> AgentOutput:
|
||||||
|
return _invoke_structured(
|
||||||
|
self.llm,
|
||||||
|
"You are a sentiment analyst. Evaluate public sentiment from social media, "
|
||||||
|
"news headlines, and community discussions about the company.",
|
||||||
|
agent_input,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class NewsAgent(BaseAgent):
|
||||||
|
"""Standardized news analyst."""
|
||||||
|
|
||||||
|
name: str = "news_analyst"
|
||||||
|
|
||||||
|
def __init__(self, llm) -> None:
|
||||||
|
self.llm = llm
|
||||||
|
|
||||||
|
def analyze(self, agent_input: AgentInput) -> AgentOutput:
|
||||||
|
return _invoke_structured(
|
||||||
|
self.llm,
|
||||||
|
"You are a news analyst. Evaluate recent news, macroeconomic events, "
|
||||||
|
"and geopolitical developments relevant to the company.",
|
||||||
|
agent_input,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TechnicalAgent(BaseAgent):
|
||||||
|
"""Standardized technical / market analyst."""
|
||||||
|
|
||||||
|
name: str = "technical_analyst"
|
||||||
|
|
||||||
|
def __init__(self, llm) -> None:
|
||||||
|
self.llm = llm
|
||||||
|
|
||||||
|
def analyze(self, agent_input: AgentInput) -> AgentOutput:
|
||||||
|
return _invoke_structured(
|
||||||
|
self.llm,
|
||||||
|
"You are a technical analyst. Evaluate price action, volume, moving averages, "
|
||||||
|
"MACD, RSI, Bollinger Bands, and other technical indicators.",
|
||||||
|
agent_input,
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,23 @@
|
||||||
|
"""Abstract base class for trading agents with a standardized analyze contract."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
from .utils.schemas import AgentInput, AgentOutput
|
||||||
|
|
||||||
|
|
||||||
|
class BaseAgent(ABC):
|
||||||
|
"""Base class all trading agents must implement.
|
||||||
|
|
||||||
|
Subclasses provide ``analyze`` which accepts an :class:`AgentInput` and
|
||||||
|
returns an :class:`AgentOutput`, ensuring a uniform contract across every
|
||||||
|
agent in the system.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str = "unnamed_agent"
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def analyze(self, agent_input: AgentInput) -> AgentOutput:
|
||||||
|
"""Run analysis and return a standardized output."""
|
||||||
|
...
|
||||||
|
|
@ -0,0 +1,136 @@
|
||||||
|
"""Agent benchmarking: compare outputs across different LLM backends."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from tradingagents.agents.base_agent import BaseAgent
|
||||||
|
from tradingagents.agents.utils.schemas import AgentInput, AgentOutput
|
||||||
|
from tradingagents.llm_clients import create_llm_client
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BenchmarkResult:
|
||||||
|
"""Result of a single agent run against one LLM backend."""
|
||||||
|
|
||||||
|
agent_name: str
|
||||||
|
provider: str
|
||||||
|
model: str
|
||||||
|
output: AgentOutput | None
|
||||||
|
elapsed_seconds: float
|
||||||
|
error: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BenchmarkReport:
|
||||||
|
"""Aggregated results from benchmarking one or more agents across backends."""
|
||||||
|
|
||||||
|
results: list[BenchmarkResult] = field(default_factory=list)
|
||||||
|
|
||||||
|
def summary(self) -> list[dict[str, Any]]:
|
||||||
|
"""Return a list of dicts summarising each result for easy comparison."""
|
||||||
|
rows: list[dict[str, Any]] = []
|
||||||
|
for r in self.results:
|
||||||
|
row: dict[str, Any] = {
|
||||||
|
"agent": r.agent_name,
|
||||||
|
"provider": r.provider,
|
||||||
|
"model": r.model,
|
||||||
|
"elapsed_s": round(r.elapsed_seconds, 2),
|
||||||
|
}
|
||||||
|
if r.error:
|
||||||
|
row["error"] = r.error
|
||||||
|
elif r.output:
|
||||||
|
row["rating"] = r.output.rating
|
||||||
|
row["confidence"] = r.output.confidence
|
||||||
|
row["thesis_len"] = len(r.output.thesis)
|
||||||
|
row["risk_factors"] = len(r.output.risk_factors)
|
||||||
|
rows.append(row)
|
||||||
|
return rows
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LLMBackend:
|
||||||
|
"""Describes an LLM backend to benchmark against."""
|
||||||
|
|
||||||
|
provider: str
|
||||||
|
model: str
|
||||||
|
base_url: str | None = None
|
||||||
|
kwargs: dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_llm(backend: LLMBackend) -> Any:
|
||||||
|
"""Create a LangChain LLM from a backend spec."""
|
||||||
|
client = create_llm_client(
|
||||||
|
provider=backend.provider,
|
||||||
|
model=backend.model,
|
||||||
|
base_url=backend.base_url,
|
||||||
|
**backend.kwargs,
|
||||||
|
)
|
||||||
|
return client.get_llm()
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_agent(
|
||||||
|
agent_cls: type[BaseAgent],
|
||||||
|
agent_input: AgentInput,
|
||||||
|
backends: list[LLMBackend],
|
||||||
|
) -> BenchmarkReport:
|
||||||
|
"""Run *agent_cls* with *agent_input* across each backend and collect results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_cls: A ``BaseAgent`` subclass whose ``__init__`` accepts a single
|
||||||
|
``llm`` positional argument.
|
||||||
|
agent_input: The standardized input to feed every agent instance.
|
||||||
|
backends: LLM backends to compare.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A :class:`BenchmarkReport` with one :class:`BenchmarkResult` per backend.
|
||||||
|
"""
|
||||||
|
report = BenchmarkReport()
|
||||||
|
for backend in backends:
|
||||||
|
t0 = time.monotonic()
|
||||||
|
try:
|
||||||
|
llm = _make_llm(backend)
|
||||||
|
agent = agent_cls(llm)
|
||||||
|
output = agent.analyze(agent_input)
|
||||||
|
elapsed = time.monotonic() - t0
|
||||||
|
report.results.append(
|
||||||
|
BenchmarkResult(
|
||||||
|
agent_name=agent.name,
|
||||||
|
provider=backend.provider,
|
||||||
|
model=backend.model,
|
||||||
|
output=output,
|
||||||
|
elapsed_seconds=elapsed,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
elapsed = time.monotonic() - t0
|
||||||
|
report.results.append(
|
||||||
|
BenchmarkResult(
|
||||||
|
agent_name=agent_cls.name if hasattr(agent_cls, "name") else agent_cls.__name__,
|
||||||
|
provider=backend.provider,
|
||||||
|
model=backend.model,
|
||||||
|
output=None,
|
||||||
|
elapsed_seconds=elapsed,
|
||||||
|
error=str(exc),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return report
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_agents(
|
||||||
|
agent_classes: list[type[BaseAgent]],
|
||||||
|
agent_input: AgentInput,
|
||||||
|
backends: list[LLMBackend],
|
||||||
|
) -> BenchmarkReport:
|
||||||
|
"""Run multiple agent types across multiple backends.
|
||||||
|
|
||||||
|
Convenience wrapper that calls :func:`benchmark_agent` for each class and
|
||||||
|
merges the results into a single report.
|
||||||
|
"""
|
||||||
|
merged = BenchmarkReport()
|
||||||
|
for cls in agent_classes:
|
||||||
|
report = benchmark_agent(cls, agent_input, backends)
|
||||||
|
merged.results.extend(report.results)
|
||||||
|
return merged
|
||||||
|
|
@ -0,0 +1,63 @@
|
||||||
|
"""AgentRegistry for pluggable agent discovery."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
from .base_agent import BaseAgent
|
||||||
|
|
||||||
|
|
||||||
|
class AgentRegistry:
|
||||||
|
"""Registry that maps agent names to their factory callables.
|
||||||
|
|
||||||
|
Usage::
|
||||||
|
|
||||||
|
registry = AgentRegistry()
|
||||||
|
registry.register("fundamentals", FundamentalsAgent, llm=my_llm)
|
||||||
|
agent = registry.get("fundamentals")
|
||||||
|
output = agent.analyze(agent_input)
|
||||||
|
|
||||||
|
Agents can be registered either as a pre-built instance or as a class
|
||||||
|
(with optional ``**kwargs`` forwarded to the constructor on first access).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._factories: dict[str, Callable[[], BaseAgent]] = {}
|
||||||
|
self._instances: dict[str, BaseAgent] = {}
|
||||||
|
|
||||||
|
def register(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
agent: type[BaseAgent] | BaseAgent,
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
"""Register an agent class or instance under *name*.
|
||||||
|
|
||||||
|
If *agent* is a class, ``kwargs`` are forwarded to its constructor
|
||||||
|
when :meth:`get` is called. If it is already an instance, it is
|
||||||
|
stored directly.
|
||||||
|
"""
|
||||||
|
if isinstance(agent, BaseAgent):
|
||||||
|
self._instances[name] = agent
|
||||||
|
elif isinstance(agent, type) and issubclass(agent, BaseAgent):
|
||||||
|
self._factories[name] = lambda: agent(**kwargs)
|
||||||
|
else:
|
||||||
|
raise TypeError(f"Expected BaseAgent subclass or instance, got {type(agent)}")
|
||||||
|
|
||||||
|
def get(self, name: str) -> BaseAgent:
|
||||||
|
"""Return the agent registered under *name*, instantiating lazily if needed."""
|
||||||
|
if name not in self._instances:
|
||||||
|
if name not in self._factories:
|
||||||
|
raise KeyError(f"No agent registered under '{name}'")
|
||||||
|
self._instances[name] = self._factories.pop(name)()
|
||||||
|
return self._instances[name]
|
||||||
|
|
||||||
|
def list(self) -> list[str]:
|
||||||
|
"""Return sorted list of all registered agent names."""
|
||||||
|
return sorted({*self._factories, *self._instances})
|
||||||
|
|
||||||
|
def __contains__(self, name: str) -> bool:
|
||||||
|
return name in self._factories or name in self._instances
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self.list())
|
||||||
|
|
@ -0,0 +1,36 @@
|
||||||
|
"""Standardized input/output schemas for the generic agent interface."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class AgentInput(BaseModel):
|
||||||
|
"""Standardized input contract for any trading agent."""
|
||||||
|
|
||||||
|
ticker: str
|
||||||
|
date: str
|
||||||
|
context: dict[str, str] = Field(
|
||||||
|
default_factory=dict,
|
||||||
|
description="Optional context keyed by: market_data, news, fundamentals, sentiment, technical_indicators",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PriceTargets(BaseModel):
|
||||||
|
"""Entry, target, and stop-loss price levels."""
|
||||||
|
|
||||||
|
entry: float
|
||||||
|
target: float
|
||||||
|
stop_loss: float
|
||||||
|
|
||||||
|
|
||||||
|
class AgentOutput(BaseModel):
|
||||||
|
"""Standardized output contract for any trading agent."""
|
||||||
|
|
||||||
|
rating: Literal["BUY", "OVERWEIGHT", "HOLD", "UNDERWEIGHT", "SELL"]
|
||||||
|
confidence: float = Field(ge=0.0, le=1.0)
|
||||||
|
price_targets: PriceTargets | None = None
|
||||||
|
thesis: str
|
||||||
|
risk_factors: list[str] = Field(default_factory=list)
|
||||||
Loading…
Reference in New Issue