feat: generic AI trading agent interface (#264)
Standardized input/output contract for all agents, enabling pluggable composition and benchmarking. - AgentInput/AgentOutput Pydantic schemas (5-tier rating, confidence, targets) - BaseAgent abstract class with analyze(input) -> output contract - AgentRegistry for pluggable agent discovery - Existing analysts refactored to implement BaseAgent - Agent benchmarking: compare outputs across different LLM backends Closes #264
This commit is contained in:
parent
fa4d01c23a
commit
18ac47f5d6
|
|
@ -1,11 +1,21 @@
|
|||
from .base_agent import BaseAgent
|
||||
from .benchmark import BenchmarkReport, BenchmarkResult, LLMBackend, benchmark_agent, benchmark_agents
|
||||
from .registry import AgentRegistry
|
||||
from .utils.agent_utils import create_msg_delete
|
||||
from .utils.agent_states import AgentState, InvestDebateState, RiskDebateState
|
||||
from .utils.memory import FinancialSituationMemory
|
||||
from .utils.schemas import AgentInput, AgentOutput, PriceTargets
|
||||
|
||||
from .analysts.fundamentals_analyst import create_fundamentals_analyst
|
||||
from .analysts.market_analyst import create_market_analyst
|
||||
from .analysts.news_analyst import create_news_analyst
|
||||
from .analysts.social_media_analyst import create_social_media_analyst
|
||||
from .analysts.base_analysts import (
|
||||
FundamentalsAgent,
|
||||
SentimentAgent,
|
||||
NewsAgent,
|
||||
TechnicalAgent,
|
||||
)
|
||||
|
||||
from .researchers.bear_researcher import create_bear_researcher
|
||||
from .researchers.bull_researcher import create_bull_researcher
|
||||
|
|
@ -20,8 +30,22 @@ from .managers.portfolio_manager import create_portfolio_manager
|
|||
from .trader.trader import create_trader
|
||||
|
||||
__all__ = [
|
||||
"AgentRegistry",
|
||||
"BaseAgent",
|
||||
"BenchmarkReport",
|
||||
"BenchmarkResult",
|
||||
"LLMBackend",
|
||||
"benchmark_agent",
|
||||
"benchmark_agents",
|
||||
"FundamentalsAgent",
|
||||
"SentimentAgent",
|
||||
"NewsAgent",
|
||||
"TechnicalAgent",
|
||||
"FinancialSituationMemory",
|
||||
"AgentState",
|
||||
"AgentInput",
|
||||
"AgentOutput",
|
||||
"PriceTargets",
|
||||
"create_msg_delete",
|
||||
"InvestDebateState",
|
||||
"RiskDebateState",
|
||||
|
|
|
|||
|
|
@ -0,0 +1,109 @@
|
|||
"""BaseAgent implementations for the four analyst types.
|
||||
|
||||
Each class wraps the existing analyst logic behind the standardized
|
||||
``BaseAgent.analyze(AgentInput) -> AgentOutput`` contract while the
|
||||
original ``create_*`` factory functions remain unchanged for LangGraph
|
||||
node compatibility.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from tradingagents.agents.base_agent import BaseAgent
|
||||
from tradingagents.agents.utils.schemas import AgentInput, AgentOutput
|
||||
|
||||
# Shared prompt that asks the LLM to return a JSON matching AgentOutput.
|
||||
_STRUCTURED_SUFFIX = (
|
||||
"\n\nAfter your analysis, provide a final JSON object with these exact keys:\n"
|
||||
'- "rating": one of "BUY", "OVERWEIGHT", "HOLD", "UNDERWEIGHT", "SELL"\n'
|
||||
'- "confidence": float 0.0-1.0\n'
|
||||
'- "thesis": one-paragraph summary\n'
|
||||
'- "risk_factors": list of strings\n'
|
||||
"Return ONLY the JSON object, no other text."
|
||||
)
|
||||
|
||||
|
||||
def _invoke_structured(llm, role_prompt: str, agent_input: AgentInput) -> AgentOutput:
|
||||
"""Ask *llm* to produce an ``AgentOutput`` via structured output."""
|
||||
full_prompt = (
|
||||
f"{role_prompt}\n\n"
|
||||
f"Ticker: {agent_input.ticker}\n"
|
||||
f"Date: {agent_input.date}\n"
|
||||
)
|
||||
if agent_input.context:
|
||||
for k, v in agent_input.context.items():
|
||||
full_prompt += f"\n--- {k} ---\n{v}\n"
|
||||
|
||||
full_prompt += _STRUCTURED_SUFFIX
|
||||
|
||||
structured_llm = llm.with_structured_output(AgentOutput)
|
||||
return structured_llm.invoke([HumanMessage(content=full_prompt)])
|
||||
|
||||
|
||||
class FundamentalsAgent(BaseAgent):
|
||||
"""Standardized fundamentals analyst."""
|
||||
|
||||
name: str = "fundamentals_analyst"
|
||||
|
||||
def __init__(self, llm) -> None:
|
||||
self.llm = llm
|
||||
|
||||
def analyze(self, agent_input: AgentInput) -> AgentOutput:
|
||||
return _invoke_structured(
|
||||
self.llm,
|
||||
"You are a fundamentals analyst. Evaluate the company's financial health "
|
||||
"using balance sheets, cash flow, income statements, and key ratios.",
|
||||
agent_input,
|
||||
)
|
||||
|
||||
|
||||
class SentimentAgent(BaseAgent):
|
||||
"""Standardized sentiment / social-media analyst."""
|
||||
|
||||
name: str = "sentiment_analyst"
|
||||
|
||||
def __init__(self, llm) -> None:
|
||||
self.llm = llm
|
||||
|
||||
def analyze(self, agent_input: AgentInput) -> AgentOutput:
|
||||
return _invoke_structured(
|
||||
self.llm,
|
||||
"You are a sentiment analyst. Evaluate public sentiment from social media, "
|
||||
"news headlines, and community discussions about the company.",
|
||||
agent_input,
|
||||
)
|
||||
|
||||
|
||||
class NewsAgent(BaseAgent):
|
||||
"""Standardized news analyst."""
|
||||
|
||||
name: str = "news_analyst"
|
||||
|
||||
def __init__(self, llm) -> None:
|
||||
self.llm = llm
|
||||
|
||||
def analyze(self, agent_input: AgentInput) -> AgentOutput:
|
||||
return _invoke_structured(
|
||||
self.llm,
|
||||
"You are a news analyst. Evaluate recent news, macroeconomic events, "
|
||||
"and geopolitical developments relevant to the company.",
|
||||
agent_input,
|
||||
)
|
||||
|
||||
|
||||
class TechnicalAgent(BaseAgent):
|
||||
"""Standardized technical / market analyst."""
|
||||
|
||||
name: str = "technical_analyst"
|
||||
|
||||
def __init__(self, llm) -> None:
|
||||
self.llm = llm
|
||||
|
||||
def analyze(self, agent_input: AgentInput) -> AgentOutput:
|
||||
return _invoke_structured(
|
||||
self.llm,
|
||||
"You are a technical analyst. Evaluate price action, volume, moving averages, "
|
||||
"MACD, RSI, Bollinger Bands, and other technical indicators.",
|
||||
agent_input,
|
||||
)
|
||||
|
|
@ -0,0 +1,23 @@
|
|||
"""Abstract base class for trading agents with a standardized analyze contract."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from .utils.schemas import AgentInput, AgentOutput
|
||||
|
||||
|
||||
class BaseAgent(ABC):
|
||||
"""Base class all trading agents must implement.
|
||||
|
||||
Subclasses provide ``analyze`` which accepts an :class:`AgentInput` and
|
||||
returns an :class:`AgentOutput`, ensuring a uniform contract across every
|
||||
agent in the system.
|
||||
"""
|
||||
|
||||
name: str = "unnamed_agent"
|
||||
|
||||
@abstractmethod
|
||||
def analyze(self, agent_input: AgentInput) -> AgentOutput:
|
||||
"""Run analysis and return a standardized output."""
|
||||
...
|
||||
|
|
@ -0,0 +1,136 @@
|
|||
"""Agent benchmarking: compare outputs across different LLM backends."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from tradingagents.agents.base_agent import BaseAgent
|
||||
from tradingagents.agents.utils.schemas import AgentInput, AgentOutput
|
||||
from tradingagents.llm_clients import create_llm_client
|
||||
|
||||
|
||||
@dataclass
|
||||
class BenchmarkResult:
|
||||
"""Result of a single agent run against one LLM backend."""
|
||||
|
||||
agent_name: str
|
||||
provider: str
|
||||
model: str
|
||||
output: AgentOutput | None
|
||||
elapsed_seconds: float
|
||||
error: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class BenchmarkReport:
|
||||
"""Aggregated results from benchmarking one or more agents across backends."""
|
||||
|
||||
results: list[BenchmarkResult] = field(default_factory=list)
|
||||
|
||||
def summary(self) -> list[dict[str, Any]]:
|
||||
"""Return a list of dicts summarising each result for easy comparison."""
|
||||
rows: list[dict[str, Any]] = []
|
||||
for r in self.results:
|
||||
row: dict[str, Any] = {
|
||||
"agent": r.agent_name,
|
||||
"provider": r.provider,
|
||||
"model": r.model,
|
||||
"elapsed_s": round(r.elapsed_seconds, 2),
|
||||
}
|
||||
if r.error:
|
||||
row["error"] = r.error
|
||||
elif r.output:
|
||||
row["rating"] = r.output.rating
|
||||
row["confidence"] = r.output.confidence
|
||||
row["thesis_len"] = len(r.output.thesis)
|
||||
row["risk_factors"] = len(r.output.risk_factors)
|
||||
rows.append(row)
|
||||
return rows
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMBackend:
|
||||
"""Describes an LLM backend to benchmark against."""
|
||||
|
||||
provider: str
|
||||
model: str
|
||||
base_url: str | None = None
|
||||
kwargs: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
def _make_llm(backend: LLMBackend) -> Any:
|
||||
"""Create a LangChain LLM from a backend spec."""
|
||||
client = create_llm_client(
|
||||
provider=backend.provider,
|
||||
model=backend.model,
|
||||
base_url=backend.base_url,
|
||||
**backend.kwargs,
|
||||
)
|
||||
return client.get_llm()
|
||||
|
||||
|
||||
def benchmark_agent(
|
||||
agent_cls: type[BaseAgent],
|
||||
agent_input: AgentInput,
|
||||
backends: list[LLMBackend],
|
||||
) -> BenchmarkReport:
|
||||
"""Run *agent_cls* with *agent_input* across each backend and collect results.
|
||||
|
||||
Args:
|
||||
agent_cls: A ``BaseAgent`` subclass whose ``__init__`` accepts a single
|
||||
``llm`` positional argument.
|
||||
agent_input: The standardized input to feed every agent instance.
|
||||
backends: LLM backends to compare.
|
||||
|
||||
Returns:
|
||||
A :class:`BenchmarkReport` with one :class:`BenchmarkResult` per backend.
|
||||
"""
|
||||
report = BenchmarkReport()
|
||||
for backend in backends:
|
||||
t0 = time.monotonic()
|
||||
try:
|
||||
llm = _make_llm(backend)
|
||||
agent = agent_cls(llm)
|
||||
output = agent.analyze(agent_input)
|
||||
elapsed = time.monotonic() - t0
|
||||
report.results.append(
|
||||
BenchmarkResult(
|
||||
agent_name=agent.name,
|
||||
provider=backend.provider,
|
||||
model=backend.model,
|
||||
output=output,
|
||||
elapsed_seconds=elapsed,
|
||||
)
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
elapsed = time.monotonic() - t0
|
||||
report.results.append(
|
||||
BenchmarkResult(
|
||||
agent_name=agent_cls.name if hasattr(agent_cls, "name") else agent_cls.__name__,
|
||||
provider=backend.provider,
|
||||
model=backend.model,
|
||||
output=None,
|
||||
elapsed_seconds=elapsed,
|
||||
error=str(exc),
|
||||
)
|
||||
)
|
||||
return report
|
||||
|
||||
|
||||
def benchmark_agents(
|
||||
agent_classes: list[type[BaseAgent]],
|
||||
agent_input: AgentInput,
|
||||
backends: list[LLMBackend],
|
||||
) -> BenchmarkReport:
|
||||
"""Run multiple agent types across multiple backends.
|
||||
|
||||
Convenience wrapper that calls :func:`benchmark_agent` for each class and
|
||||
merges the results into a single report.
|
||||
"""
|
||||
merged = BenchmarkReport()
|
||||
for cls in agent_classes:
|
||||
report = benchmark_agent(cls, agent_input, backends)
|
||||
merged.results.extend(report.results)
|
||||
return merged
|
||||
|
|
@ -0,0 +1,63 @@
|
|||
"""AgentRegistry for pluggable agent discovery."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Callable
|
||||
|
||||
from .base_agent import BaseAgent
|
||||
|
||||
|
||||
class AgentRegistry:
|
||||
"""Registry that maps agent names to their factory callables.
|
||||
|
||||
Usage::
|
||||
|
||||
registry = AgentRegistry()
|
||||
registry.register("fundamentals", FundamentalsAgent, llm=my_llm)
|
||||
agent = registry.get("fundamentals")
|
||||
output = agent.analyze(agent_input)
|
||||
|
||||
Agents can be registered either as a pre-built instance or as a class
|
||||
(with optional ``**kwargs`` forwarded to the constructor on first access).
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._factories: dict[str, Callable[[], BaseAgent]] = {}
|
||||
self._instances: dict[str, BaseAgent] = {}
|
||||
|
||||
def register(
|
||||
self,
|
||||
name: str,
|
||||
agent: type[BaseAgent] | BaseAgent,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Register an agent class or instance under *name*.
|
||||
|
||||
If *agent* is a class, ``kwargs`` are forwarded to its constructor
|
||||
when :meth:`get` is called. If it is already an instance, it is
|
||||
stored directly.
|
||||
"""
|
||||
if isinstance(agent, BaseAgent):
|
||||
self._instances[name] = agent
|
||||
elif isinstance(agent, type) and issubclass(agent, BaseAgent):
|
||||
self._factories[name] = lambda: agent(**kwargs)
|
||||
else:
|
||||
raise TypeError(f"Expected BaseAgent subclass or instance, got {type(agent)}")
|
||||
|
||||
def get(self, name: str) -> BaseAgent:
|
||||
"""Return the agent registered under *name*, instantiating lazily if needed."""
|
||||
if name not in self._instances:
|
||||
if name not in self._factories:
|
||||
raise KeyError(f"No agent registered under '{name}'")
|
||||
self._instances[name] = self._factories.pop(name)()
|
||||
return self._instances[name]
|
||||
|
||||
def list(self) -> list[str]:
|
||||
"""Return sorted list of all registered agent names."""
|
||||
return sorted({*self._factories, *self._instances})
|
||||
|
||||
def __contains__(self, name: str) -> bool:
|
||||
return name in self._factories or name in self._instances
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.list())
|
||||
|
|
@ -0,0 +1,36 @@
|
|||
"""Standardized input/output schemas for the generic agent interface."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class AgentInput(BaseModel):
|
||||
"""Standardized input contract for any trading agent."""
|
||||
|
||||
ticker: str
|
||||
date: str
|
||||
context: dict[str, str] = Field(
|
||||
default_factory=dict,
|
||||
description="Optional context keyed by: market_data, news, fundamentals, sentiment, technical_indicators",
|
||||
)
|
||||
|
||||
|
||||
class PriceTargets(BaseModel):
|
||||
"""Entry, target, and stop-loss price levels."""
|
||||
|
||||
entry: float
|
||||
target: float
|
||||
stop_loss: float
|
||||
|
||||
|
||||
class AgentOutput(BaseModel):
|
||||
"""Standardized output contract for any trading agent."""
|
||||
|
||||
rating: Literal["BUY", "OVERWEIGHT", "HOLD", "UNDERWEIGHT", "SELL"]
|
||||
confidence: float = Field(ge=0.0, le=1.0)
|
||||
price_targets: PriceTargets | None = None
|
||||
thesis: str
|
||||
risk_factors: list[str] = Field(default_factory=list)
|
||||
Loading…
Reference in New Issue