feat: enhance data persistence with DualReportStore for local and MongoDB storage; update report store creation logic
This commit is contained in:
parent
7c02e0c76c
commit
5f0a52f8e6
|
|
@ -3,6 +3,7 @@ from typing import Dict, Any, List, AsyncGenerator
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
import time
|
import time
|
||||||
|
import os
|
||||||
from agent_os.backend.store import runs
|
from agent_os.backend.store import runs
|
||||||
from agent_os.backend.dependencies import get_current_user
|
from agent_os.backend.dependencies import get_current_user
|
||||||
from agent_os.backend.services.langgraph_engine import LangGraphEngine, NODE_TO_PHASE
|
from agent_os.backend.services.langgraph_engine import LangGraphEngine, NODE_TO_PHASE
|
||||||
|
|
@ -269,15 +270,62 @@ async def reset_portfolio_stage(
|
||||||
return {"deleted": deleted, "date": date, "portfolio_id": portfolio_id}
|
return {"deleted": deleted, "date": date, "portfolio_id": portfolio_id}
|
||||||
|
|
||||||
|
|
||||||
|
def _get_mongo_col():
|
||||||
|
"""Return the run_events collection if MongoDB is configured."""
|
||||||
|
uri = os.getenv("TRADINGAGENTS_MONGO_URI")
|
||||||
|
db_name = os.getenv("TRADINGAGENTS_MONGO_DB", "tradingagents")
|
||||||
|
if uri:
|
||||||
|
try:
|
||||||
|
from pymongo import MongoClient
|
||||||
|
client = MongoClient(uri)
|
||||||
|
return client[db_name]["run_events"]
|
||||||
|
except Exception:
|
||||||
|
logger.warning("Failed to connect to MongoDB for historical events")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
@router.get("/")
|
@router.get("/")
|
||||||
async def list_runs(user: dict = Depends(get_current_user)):
|
async def list_runs(user: dict = Depends(get_current_user)):
|
||||||
# Filter by user in production
|
# Filter by user in production
|
||||||
return list(runs.values())
|
all_runs = dict(runs)
|
||||||
|
|
||||||
|
# Supplement with historical metadata from MongoDB if available
|
||||||
|
col = _get_mongo_col()
|
||||||
|
if col is not None:
|
||||||
|
try:
|
||||||
|
# Fetch unique run_ids from the last 7 days (simplified)
|
||||||
|
# In a real app, we'd have a separate 'runs' collection for metadata.
|
||||||
|
# Here we use the events collection and group by run_id.
|
||||||
|
pipeline = [
|
||||||
|
{"$match": {"type": "log", "agent": "SYSTEM"}}, # Filter for start logs
|
||||||
|
{"$sort": {"ts": -1}},
|
||||||
|
{"$group": {
|
||||||
|
"_id": "$run_id",
|
||||||
|
"id": {"$first": "$run_id"},
|
||||||
|
"type": {"$first": "$type"},
|
||||||
|
"created_at": {"$first": "$ts"},
|
||||||
|
# Status is harder to get from events without a dedicated meta doc
|
||||||
|
}},
|
||||||
|
{"$limit": 50}
|
||||||
|
]
|
||||||
|
for doc in col.aggregate(pipeline):
|
||||||
|
rid = doc["id"]
|
||||||
|
if rid not in all_runs:
|
||||||
|
all_runs[rid] = {
|
||||||
|
"id": rid,
|
||||||
|
"type": doc.get("type", "unknown"),
|
||||||
|
"status": "historical",
|
||||||
|
"created_at": doc.get("created_at", 0),
|
||||||
|
"user_id": "anonymous",
|
||||||
|
}
|
||||||
|
except Exception:
|
||||||
|
logger.warning("Failed to fetch historical runs from MongoDB")
|
||||||
|
|
||||||
|
return list(all_runs.values())
|
||||||
|
|
||||||
@router.get("/{run_id}")
|
@router.get("/{run_id}")
|
||||||
async def get_run_status(run_id: str, user: dict = Depends(get_current_user)):
|
async def get_run_status(run_id: str, user: dict = Depends(get_current_user)):
|
||||||
if run_id not in runs:
|
if run_id in runs:
|
||||||
raise HTTPException(status_code=404, detail="Run not found")
|
|
||||||
run = runs[run_id]
|
run = runs[run_id]
|
||||||
# Lazy-load events from disk if they were not kept in memory
|
# Lazy-load events from disk if they were not kept in memory
|
||||||
if (
|
if (
|
||||||
|
|
@ -296,3 +344,25 @@ async def get_run_status(run_id: str, user: dict = Depends(get_current_user)):
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.warning("Failed to lazy-load events for run=%s", run_id)
|
logger.warning("Failed to lazy-load events for run=%s", run_id)
|
||||||
return run
|
return run
|
||||||
|
|
||||||
|
# Not in memory — try MongoDB
|
||||||
|
col = _get_mongo_col()
|
||||||
|
if col is not None:
|
||||||
|
try:
|
||||||
|
cursor = col.find({"run_id": run_id}).sort("ts", 1)
|
||||||
|
events = list(cursor)
|
||||||
|
if events:
|
||||||
|
# Remove MongoDB _id for JSON serialization
|
||||||
|
for e in events:
|
||||||
|
e.pop("_id", None)
|
||||||
|
return {
|
||||||
|
"id": run_id,
|
||||||
|
"status": "historical",
|
||||||
|
"events": events,
|
||||||
|
"type": events[0].get("type", "unknown") if events else "unknown",
|
||||||
|
"created_at": events[0].get("ts", 0) if events else 0,
|
||||||
|
}
|
||||||
|
except Exception:
|
||||||
|
logger.warning("Failed to fetch historical run %s from MongoDB", run_id)
|
||||||
|
|
||||||
|
raise HTTPException(status_code=404, detail="Run not found")
|
||||||
|
|
|
||||||
|
|
@ -175,9 +175,11 @@ class LangGraphEngine:
|
||||||
# Run logger lifecycle
|
# Run logger lifecycle
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
def _start_run_logger(self, run_id: str) -> RunLogger:
|
def _start_run_logger(self, run_id: str, flow_id: str | None = None) -> RunLogger:
|
||||||
"""Create and register a ``RunLogger`` for the given run."""
|
"""Create and register a ``RunLogger`` for the given run."""
|
||||||
rl = RunLogger()
|
uri = self.config.get("mongo_uri")
|
||||||
|
db = self.config.get("mongo_db") or "tradingagents"
|
||||||
|
rl = RunLogger(run_id=run_id, mongo_uri=uri, mongo_db=db, flow_id=flow_id)
|
||||||
self._run_loggers[run_id] = rl
|
self._run_loggers[run_id] = rl
|
||||||
set_run_logger(rl)
|
set_run_logger(rl)
|
||||||
return rl
|
return rl
|
||||||
|
|
@ -209,7 +211,8 @@ class LangGraphEngine:
|
||||||
short_rid = generate_run_id()
|
short_rid = generate_run_id()
|
||||||
store = create_report_store(run_id=short_rid)
|
store = create_report_store(run_id=short_rid)
|
||||||
|
|
||||||
rl = self._start_run_logger(run_id)
|
flow_id = params.get("flow_id")
|
||||||
|
rl = self._start_run_logger(run_id, flow_id=flow_id)
|
||||||
scan_config = {**self.config}
|
scan_config = {**self.config}
|
||||||
if params.get("max_tickers"):
|
if params.get("max_tickers"):
|
||||||
scan_config["max_auto_tickers"] = int(params["max_tickers"])
|
scan_config["max_auto_tickers"] = int(params["max_tickers"])
|
||||||
|
|
@ -330,11 +333,11 @@ class LangGraphEngine:
|
||||||
short_rid = generate_run_id()
|
short_rid = generate_run_id()
|
||||||
store = create_report_store(run_id=short_rid)
|
store = create_report_store(run_id=short_rid)
|
||||||
|
|
||||||
rl = self._start_run_logger(run_id)
|
flow_id = params.get("flow_id")
|
||||||
|
rl = self._start_run_logger(run_id, flow_id=flow_id)
|
||||||
|
|
||||||
|
logger.info("Starting PIPELINE run=%s ticker=%s date=%s rid=%s", run_id, ticker, date, short_rid)
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"Starting PIPELINE run=%s ticker=%s date=%s rid=%s", run_id, ticker, date, short_rid
|
|
||||||
)
|
|
||||||
yield self._system_log(f"Starting analysis pipeline for {ticker} on {date}")
|
yield self._system_log(f"Starting analysis pipeline for {ticker} on {date}")
|
||||||
|
|
||||||
graph_wrapper = TradingAgentsGraph(
|
graph_wrapper = TradingAgentsGraph(
|
||||||
|
|
@ -471,7 +474,8 @@ class LangGraphEngine:
|
||||||
# A reader store with no run_id resolves to the latest run for loading
|
# A reader store with no run_id resolves to the latest run for loading
|
||||||
reader_store = create_report_store()
|
reader_store = create_report_store()
|
||||||
|
|
||||||
rl = self._start_run_logger(run_id)
|
flow_id = params.get("flow_id")
|
||||||
|
rl = self._start_run_logger(run_id, flow_id=flow_id)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Starting PORTFOLIO run=%s portfolio=%s date=%s rid=%s",
|
"Starting PORTFOLIO run=%s portfolio=%s date=%s rid=%s",
|
||||||
|
|
@ -818,23 +822,27 @@ class LangGraphEngine:
|
||||||
"""Run the full auto pipeline: scan → pipeline → portfolio."""
|
"""Run the full auto pipeline: scan → pipeline → portfolio."""
|
||||||
date = params.get("date", time.strftime("%Y-%m-%d"))
|
date = params.get("date", time.strftime("%Y-%m-%d"))
|
||||||
force = params.get("force", False)
|
force = params.get("force", False)
|
||||||
|
flow_id = params.get("flow_id") or str(uuid.uuid4())
|
||||||
|
|
||||||
# Use a reader store (no run_id) for skip-if-exists checks.
|
# Use a reader store (no run_id) for skip-if-exists checks.
|
||||||
# Each sub-phase (run_scan, run_pipeline, run_portfolio) creates
|
# Each sub-phase (run_scan, run_pipeline, run_portfolio) creates
|
||||||
# its own writer store with a fresh run_id internally.
|
# its own writer store with a fresh run_id internally.
|
||||||
store = create_report_store()
|
store = create_report_store()
|
||||||
|
|
||||||
self._start_run_logger(run_id) # auto-run's own logger; sub-phases create their own
|
self._start_run_logger(run_id, flow_id=flow_id) # auto-run's own logger; sub-phases create their own
|
||||||
|
|
||||||
logger.info("Starting AUTO run=%s date=%s force=%s", run_id, date, force)
|
logger.info("Starting AUTO run=%s flow=%s date=%s force=%s", run_id, flow_id, date, force)
|
||||||
yield self._system_log(f"Starting full auto workflow for {date} (force={force})")
|
yield self._system_log(f"Starting full auto workflow for {date} (force={force}, flow={flow_id})")
|
||||||
|
|
||||||
# Phase 1: Market scan
|
# Phase 1: Market scan
|
||||||
yield self._system_log("Phase 1/3: Running market scan…")
|
yield self._system_log("Phase 1/3: Running market scan…")
|
||||||
if not force and store.load_scan(date):
|
if not force and store.load_scan(date):
|
||||||
yield self._system_log(f"Phase 1: Macro scan for {date} already exists, skipping.")
|
yield self._system_log(f"Phase 1: Macro scan for {date} already exists, skipping.")
|
||||||
else:
|
else:
|
||||||
async for evt in self.run_scan(f"{run_id}_scan", {"date": date}):
|
scan_params = {"date": date, "flow_id": flow_id}
|
||||||
|
if params.get("max_tickers"):
|
||||||
|
scan_params["max_tickers"] = params["max_tickers"]
|
||||||
|
async for evt in self.run_scan(f"{run_id}_scan", scan_params):
|
||||||
yield evt
|
yield evt
|
||||||
|
|
||||||
# Phase 2: Pipeline analysis — get tickers from scan report + portfolio holdings
|
# Phase 2: Pipeline analysis — get tickers from scan report + portfolio holdings
|
||||||
|
|
@ -917,7 +925,8 @@ class LangGraphEngine:
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
async for evt in self.run_pipeline(
|
async for evt in self.run_pipeline(
|
||||||
f"{run_id}_pipeline_{ticker}", {"ticker": ticker, "date": date}
|
f"{run_id}_pipeline_{ticker}",
|
||||||
|
{"ticker": ticker, "date": date, "flow_id": flow_id},
|
||||||
):
|
):
|
||||||
await pipeline_queue.put(evt)
|
await pipeline_queue.put(evt)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
|
|
@ -943,7 +952,7 @@ class LangGraphEngine:
|
||||||
try:
|
try:
|
||||||
async for evt in self.run_pipeline(
|
async for evt in self.run_pipeline(
|
||||||
f"{run_id}_fallback_{ticker}",
|
f"{run_id}_fallback_{ticker}",
|
||||||
{"ticker": ticker, "date": date},
|
{"ticker": ticker, "date": date, "flow_id": flow_id},
|
||||||
):
|
):
|
||||||
await pipeline_queue.put(evt)
|
await pipeline_queue.put(evt)
|
||||||
except Exception as fallback_exc:
|
except Exception as fallback_exc:
|
||||||
|
|
@ -993,6 +1002,7 @@ class LangGraphEngine:
|
||||||
# Phase 3: Portfolio management
|
# Phase 3: Portfolio management
|
||||||
yield self._system_log("Phase 3/3: Running portfolio manager…")
|
yield self._system_log("Phase 3/3: Running portfolio manager…")
|
||||||
portfolio_params = {k: v for k, v in params.items() if k != "ticker"}
|
portfolio_params = {k: v for k, v in params.items() if k != "ticker"}
|
||||||
|
portfolio_params["flow_id"] = flow_id
|
||||||
portfolio_id = params.get("portfolio_id", "main_portfolio")
|
portfolio_id = params.get("portfolio_id", "main_portfolio")
|
||||||
|
|
||||||
# Check if portfolio stage is fully complete (execution result exists)
|
# Check if portfolio stage is fully complete (execution result exists)
|
||||||
|
|
@ -1018,7 +1028,7 @@ class LangGraphEngine:
|
||||||
yield evt
|
yield evt
|
||||||
|
|
||||||
logger.info("Completed AUTO run=%s", run_id)
|
logger.info("Completed AUTO run=%s", run_id)
|
||||||
self._finish_run_logger(run_id, get_daily_dir(date))
|
self._finish_run_logger(run_id, get_daily_dir(date, run_id=flow_id[:8]))
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Report helpers
|
# Report helpers
|
||||||
|
|
|
||||||
|
|
@ -18,15 +18,23 @@ from tradingagents.agents.utils.json_utils import extract_json
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def create_pm_decision_agent(llm):
|
def create_pm_decision_agent(llm, config: dict | None = None):
|
||||||
"""Create a PM decision agent node.
|
"""Create a PM decision agent node.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
llm: A LangChain chat model instance (deep_think recommended).
|
llm: A LangChain chat model instance (deep_think recommended).
|
||||||
|
config: Portfolio configuration dictionary containing constraints.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A node function ``pm_decision_node(state)`` compatible with LangGraph.
|
A node function ``pm_decision_node(state)`` compatible with LangGraph.
|
||||||
"""
|
"""
|
||||||
|
cfg = config or {}
|
||||||
|
constraints_str = (
|
||||||
|
f"- Max position size: {cfg.get('max_position_pct', 0.15):.0%}\n"
|
||||||
|
f"- Max sector exposure: {cfg.get('max_sector_pct', 0.35):.0%}\n"
|
||||||
|
f"- Minimum cash reserve: {cfg.get('min_cash_pct', 0.05):.0%}\n"
|
||||||
|
f"- Max total positions: {cfg.get('max_positions', 15)}\n"
|
||||||
|
)
|
||||||
|
|
||||||
def pm_decision_node(state):
|
def pm_decision_node(state):
|
||||||
analysis_date = state.get("analysis_date") or ""
|
analysis_date = state.get("analysis_date") or ""
|
||||||
|
|
@ -35,7 +43,10 @@ def create_pm_decision_agent(llm):
|
||||||
holding_reviews_str = state.get("holding_reviews") or "{}"
|
holding_reviews_str = state.get("holding_reviews") or "{}"
|
||||||
prioritized_candidates_str = state.get("prioritized_candidates") or "[]"
|
prioritized_candidates_str = state.get("prioritized_candidates") or "[]"
|
||||||
|
|
||||||
context = f"""## Portfolio Data
|
context = f"""## Portfolio Constraints
|
||||||
|
{constraints_str}
|
||||||
|
|
||||||
|
## Portfolio Data
|
||||||
{portfolio_data_str}
|
{portfolio_data_str}
|
||||||
|
|
||||||
## Risk Metrics
|
## Risk Metrics
|
||||||
|
|
@ -50,8 +61,13 @@ def create_pm_decision_agent(llm):
|
||||||
|
|
||||||
system_message = (
|
system_message = (
|
||||||
"You are a portfolio manager making final investment decisions. "
|
"You are a portfolio manager making final investment decisions. "
|
||||||
"Given the risk metrics, holding reviews, and prioritized investment candidates, "
|
"Given the constraints, risk metrics, holding reviews, and prioritized investment candidates, "
|
||||||
"produce a structured JSON investment decision. "
|
"produce a structured JSON investment decision. "
|
||||||
|
"## CONSTRAINTS COMPLIANCE:\n"
|
||||||
|
"You MUST ensure your suggested buys and position sizes adhere to the portfolio constraints. "
|
||||||
|
"If a high-conviction candidate would exceed the max position size or sector limit, "
|
||||||
|
"YOU MUST adjust the suggested 'shares' downward to fit within the limit. "
|
||||||
|
"Do not suggest buys that you know will be rejected by the risk engine.\n\n"
|
||||||
"Consider: reducing risk where metrics are poor, acting on SELL recommendations, "
|
"Consider: reducing risk where metrics are poor, acting on SELL recommendations, "
|
||||||
"and adding positions in high-conviction candidates that pass constraints. "
|
"and adding positions in high-conviction candidates that pass constraints. "
|
||||||
"For every BUY you MUST set a stop_loss price (maximum acceptable loss level, "
|
"For every BUY you MUST set a stop_loss price (maximum acceptable loss level, "
|
||||||
|
|
@ -69,7 +85,7 @@ def create_pm_decision_agent(llm):
|
||||||
' "risk_summary": "..."\n'
|
' "risk_summary": "..."\n'
|
||||||
"}\n\n"
|
"}\n\n"
|
||||||
"IMPORTANT: Output ONLY valid JSON. Start your response with '{' and end with '}'. "
|
"IMPORTANT: Output ONLY valid JSON. Start your response with '{' and end with '}'. "
|
||||||
"Do NOT use markdown code fences. Do NOT include any explanation before or after the JSON.\n\n"
|
"Do NOT use markdown code fences. Do NOT include any explanation or preamble before or after the JSON.\n\n"
|
||||||
f"{context}"
|
f"{context}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -48,12 +48,13 @@ class PortfolioGraph:
|
||||||
mid_llm = self._create_llm("mid_think")
|
mid_llm = self._create_llm("mid_think")
|
||||||
deep_llm = self._create_llm("deep_think")
|
deep_llm = self._create_llm("deep_think")
|
||||||
|
|
||||||
|
portfolio_config = self._get_portfolio_config()
|
||||||
|
|
||||||
agents = {
|
agents = {
|
||||||
"review_holdings": create_holding_reviewer(mid_llm),
|
"review_holdings": create_holding_reviewer(mid_llm),
|
||||||
"pm_decision": create_pm_decision_agent(deep_llm),
|
"pm_decision": create_pm_decision_agent(deep_llm, config=portfolio_config),
|
||||||
}
|
}
|
||||||
|
|
||||||
portfolio_config = self._get_portfolio_config()
|
|
||||||
setup = PortfolioGraphSetup(agents, repo=self._repo, config=portfolio_config)
|
setup = PortfolioGraphSetup(agents, repo=self._repo, config=portfolio_config)
|
||||||
self.graph = setup.setup_graph()
|
self.graph = setup.setup_graph()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -60,11 +60,29 @@ class RunLogger:
|
||||||
events: Thread-safe list of all recorded events.
|
events: Thread-safe list of all recorded events.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
run_id: str | None = None,
|
||||||
|
mongo_uri: str | None = None,
|
||||||
|
mongo_db: str | None = "tradingagents",
|
||||||
|
flow_id: str | None = None,
|
||||||
|
) -> None:
|
||||||
self._lock = threading.Lock()
|
self._lock = threading.Lock()
|
||||||
self.events: list[_Event] = []
|
self.events: list[_Event] = []
|
||||||
self.callback = _LLMCallbackHandler(self)
|
self.callback = _LLMCallbackHandler(self)
|
||||||
self._start = time.time()
|
self._start = time.time()
|
||||||
|
self.run_id = run_id
|
||||||
|
self.flow_id = flow_id
|
||||||
|
self._mongo_col = None
|
||||||
|
|
||||||
|
if mongo_uri and run_id:
|
||||||
|
try:
|
||||||
|
from pymongo import MongoClient
|
||||||
|
client = MongoClient(mongo_uri)
|
||||||
|
self._mongo_col = client[mongo_db]["run_events"]
|
||||||
|
_py_logger.info("RunLogger: persisting events to MongoDB (run_id=%s, flow_id=%s)", run_id, flow_id)
|
||||||
|
except Exception as exc:
|
||||||
|
_py_logger.warning("RunLogger: MongoDB connection failed: %s", exc)
|
||||||
|
|
||||||
# -- public helpers to record events from non-callback code ----------------
|
# -- public helpers to record events from non-callback code ----------------
|
||||||
|
|
||||||
|
|
@ -197,6 +215,16 @@ class RunLogger:
|
||||||
self.events.append(evt)
|
self.events.append(evt)
|
||||||
_py_logger.debug("%s | %s", evt.kind, json.dumps(evt.data))
|
_py_logger.debug("%s | %s", evt.kind, json.dumps(evt.data))
|
||||||
|
|
||||||
|
if self._mongo_col is not None and self.run_id:
|
||||||
|
try:
|
||||||
|
doc = evt.to_dict()
|
||||||
|
doc["run_id"] = self.run_id
|
||||||
|
if self.flow_id:
|
||||||
|
doc["flow_id"] = self.flow_id
|
||||||
|
self._mongo_col.insert_one(doc)
|
||||||
|
except Exception as exc:
|
||||||
|
_py_logger.warning("RunLogger: MongoDB insert failed: %s", exc)
|
||||||
|
|
||||||
|
|
||||||
# ──────────────────────────────────────────────────────────────────────────────
|
# ──────────────────────────────────────────────────────────────────────────────
|
||||||
# LangChain callback handler — captures LLM call details
|
# LangChain callback handler — captures LLM call details
|
||||||
|
|
@ -225,11 +253,24 @@ class _LLMCallbackHandler(BaseCallbackHandler):
|
||||||
model = _extract_model(serialized, kwargs)
|
model = _extract_model(serialized, kwargs)
|
||||||
agent = kwargs.get("name") or serialized.get("name") or _extract_graph_node(kwargs)
|
agent = kwargs.get("name") or serialized.get("name") or _extract_graph_node(kwargs)
|
||||||
key = str(run_id) if run_id else str(id(messages))
|
key = str(run_id) if run_id else str(id(messages))
|
||||||
|
|
||||||
|
# Capture prompt content
|
||||||
|
prompt = ""
|
||||||
|
try:
|
||||||
|
if messages and isinstance(messages[0], list):
|
||||||
|
# batched messages
|
||||||
|
prompt = "\n\n".join([str(m.content) if hasattr(m, "content") else str(m) for m in messages[0]])
|
||||||
|
else:
|
||||||
|
prompt = "\n\n".join([str(m.content) if hasattr(m, "content") else str(m) for m in messages])
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._inflight[key] = {
|
self._inflight[key] = {
|
||||||
"model": model,
|
"model": model,
|
||||||
"agent": agent or "",
|
"agent": agent or "",
|
||||||
"t0": time.time(),
|
"t0": time.time(),
|
||||||
|
"prompt": prompt,
|
||||||
}
|
}
|
||||||
|
|
||||||
# -- legacy LLM start (completion-style) -----------------------------------
|
# -- legacy LLM start (completion-style) -----------------------------------
|
||||||
|
|
@ -250,6 +291,7 @@ class _LLMCallbackHandler(BaseCallbackHandler):
|
||||||
"model": model,
|
"model": model,
|
||||||
"agent": agent or "",
|
"agent": agent or "",
|
||||||
"t0": time.time(),
|
"t0": time.time(),
|
||||||
|
"prompt": prompts[0] if prompts else "",
|
||||||
}
|
}
|
||||||
|
|
||||||
# -- LLM end ---------------------------------------------------------------
|
# -- LLM end ---------------------------------------------------------------
|
||||||
|
|
@ -262,10 +304,13 @@ class _LLMCallbackHandler(BaseCallbackHandler):
|
||||||
tokens_in = 0
|
tokens_in = 0
|
||||||
tokens_out = 0
|
tokens_out = 0
|
||||||
model_from_response = ""
|
model_from_response = ""
|
||||||
|
full_response = ""
|
||||||
try:
|
try:
|
||||||
generation = response.generations[0][0]
|
generation = response.generations[0][0]
|
||||||
|
full_response = generation.text if hasattr(generation, "text") else str(generation)
|
||||||
if hasattr(generation, "message"):
|
if hasattr(generation, "message"):
|
||||||
msg = generation.message
|
msg = generation.message
|
||||||
|
full_response = msg.content if hasattr(msg, "content") else full_response
|
||||||
if isinstance(msg, AIMessage) and hasattr(msg, "usage_metadata") and msg.usage_metadata:
|
if isinstance(msg, AIMessage) and hasattr(msg, "usage_metadata") and msg.usage_metadata:
|
||||||
tokens_in = msg.usage_metadata.get("input_tokens", 0)
|
tokens_in = msg.usage_metadata.get("input_tokens", 0)
|
||||||
tokens_out = msg.usage_metadata.get("output_tokens", 0)
|
tokens_out = msg.usage_metadata.get("output_tokens", 0)
|
||||||
|
|
@ -277,6 +322,7 @@ class _LLMCallbackHandler(BaseCallbackHandler):
|
||||||
model = model_from_response or (meta["model"] if meta else "unknown")
|
model = model_from_response or (meta["model"] if meta else "unknown")
|
||||||
agent = meta["agent"] if meta else ""
|
agent = meta["agent"] if meta else ""
|
||||||
duration_ms = (time.time() - meta["t0"]) * 1000 if meta else 0
|
duration_ms = (time.time() - meta["t0"]) * 1000 if meta else 0
|
||||||
|
prompt = meta["prompt"] if meta else ""
|
||||||
|
|
||||||
evt = _Event(
|
evt = _Event(
|
||||||
kind="llm",
|
kind="llm",
|
||||||
|
|
@ -287,6 +333,8 @@ class _LLMCallbackHandler(BaseCallbackHandler):
|
||||||
"tokens_in": tokens_in,
|
"tokens_in": tokens_in,
|
||||||
"tokens_out": tokens_out,
|
"tokens_out": tokens_out,
|
||||||
"duration_ms": round(duration_ms, 1),
|
"duration_ms": round(duration_ms, 1),
|
||||||
|
"prompt": prompt,
|
||||||
|
"response": full_response,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
self._rl._append(evt)
|
self._rl._append(evt)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,175 @@
|
||||||
|
"""Dual report store that persists to both local filesystem and MongoDB.
|
||||||
|
|
||||||
|
Delegates all save_* calls to both a :class:`ReportStore` and a
|
||||||
|
:class:`MongoReportStore`. Load methods prioritize the MongoDB store if
|
||||||
|
available, otherwise fall back to the filesystem.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, TYPE_CHECKING
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from pathlib import Path
|
||||||
|
from tradingagents.portfolio.report_store import ReportStore
|
||||||
|
from tradingagents.portfolio.mongo_report_store import MongoReportStore
|
||||||
|
|
||||||
|
|
||||||
|
class DualReportStore:
|
||||||
|
"""Report store that writes to two backends simultaneously."""
|
||||||
|
|
||||||
|
def __init__(self, local_store: ReportStore, mongo_store: MongoReportStore) -> None:
|
||||||
|
self._local = local_store
|
||||||
|
self._mongo = mongo_store
|
||||||
|
|
||||||
|
@property
|
||||||
|
def run_id(self) -> str | None:
|
||||||
|
return self._local.run_id
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Macro Scan
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def save_scan(self, date: str, data: dict[str, Any]) -> Any:
|
||||||
|
# local returns Path, mongo returns str (_id)
|
||||||
|
local_result = self._local.save_scan(date, data)
|
||||||
|
self._mongo.save_scan(date, data)
|
||||||
|
return local_result
|
||||||
|
|
||||||
|
def load_scan(self, date: str) -> dict[str, Any] | None:
|
||||||
|
return self._mongo.load_scan(date) or self._local.load_scan(date)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Per-Ticker Analysis
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def save_analysis(self, date: str, ticker: str, data: dict[str, Any]) -> Any:
|
||||||
|
local_result = self._local.save_analysis(date, ticker, data)
|
||||||
|
self._mongo.save_analysis(date, ticker, data)
|
||||||
|
return local_result
|
||||||
|
|
||||||
|
def load_analysis(self, date: str, ticker: str) -> dict[str, Any] | None:
|
||||||
|
return self._mongo.load_analysis(date, ticker) or self._local.load_analysis(date, ticker)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Holding Reviews
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def save_holding_review(self, date: str, ticker: str, data: dict[str, Any]) -> Any:
|
||||||
|
local_result = self._local.save_holding_review(date, ticker, data)
|
||||||
|
self._mongo.save_holding_review(date, ticker, data)
|
||||||
|
return local_result
|
||||||
|
|
||||||
|
def load_holding_review(self, date: str, ticker: str) -> dict[str, Any] | None:
|
||||||
|
return self._mongo.load_holding_review(date, ticker) or self._local.load_holding_review(date, ticker)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Risk Metrics
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def save_risk_metrics(self, date: str, portfolio_id: str, data: dict[str, Any]) -> Any:
|
||||||
|
local_result = self._local.save_risk_metrics(date, portfolio_id, data)
|
||||||
|
self._mongo.save_risk_metrics(date, portfolio_id, data)
|
||||||
|
return local_result
|
||||||
|
|
||||||
|
def load_risk_metrics(self, date: str, portfolio_id: str) -> dict[str, Any] | None:
|
||||||
|
return self._mongo.load_risk_metrics(date, portfolio_id) or self._local.load_risk_metrics(date, portfolio_id)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# PM Decisions
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def save_pm_decision(
|
||||||
|
self,
|
||||||
|
date: str,
|
||||||
|
portfolio_id: str,
|
||||||
|
data: dict[str, Any],
|
||||||
|
markdown: str | None = None,
|
||||||
|
) -> Any:
|
||||||
|
local_result = self._local.save_pm_decision(date, portfolio_id, data, markdown=markdown)
|
||||||
|
self._mongo.save_pm_decision(date, portfolio_id, data, markdown=markdown)
|
||||||
|
return local_result
|
||||||
|
|
||||||
|
def load_pm_decision(self, date: str, portfolio_id: str) -> dict[str, Any] | None:
|
||||||
|
return self._mongo.load_pm_decision(date, portfolio_id) or self._local.load_pm_decision(date, portfolio_id)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Execution Results
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def save_execution_result(self, date: str, portfolio_id: str, data: dict[str, Any]) -> Any:
|
||||||
|
local_result = self._local.save_execution_result(date, portfolio_id, data)
|
||||||
|
self._mongo.save_execution_result(date, portfolio_id, data)
|
||||||
|
return local_result
|
||||||
|
|
||||||
|
def load_execution_result(self, date: str, portfolio_id: str) -> dict[str, Any] | None:
|
||||||
|
return self._mongo.load_execution_result(date, portfolio_id) or self._local.load_execution_result(date, portfolio_id)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Run Meta / Events persistence
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def save_run_meta(self, date: str, data: dict[str, Any]) -> Any:
|
||||||
|
local_result = self._local.save_run_meta(date, data)
|
||||||
|
self._mongo.save_run_meta(date, data)
|
||||||
|
return local_result
|
||||||
|
|
||||||
|
def load_run_meta(self, date: str) -> dict[str, Any] | None:
|
||||||
|
return self._mongo.load_run_meta(date) or self._local.load_run_meta(date)
|
||||||
|
|
||||||
|
def save_run_events(self, date: str, events: list[dict[str, Any]]) -> Any:
|
||||||
|
local_result = self._local.save_run_events(date, events)
|
||||||
|
self._mongo.save_run_events(date, events)
|
||||||
|
return local_result
|
||||||
|
|
||||||
|
def load_run_events(self, date: str) -> list[dict[str, Any]]:
|
||||||
|
mongo_events = self._mongo.load_run_events(date)
|
||||||
|
if mongo_events:
|
||||||
|
return mongo_events
|
||||||
|
return self._local.load_run_events(date)
|
||||||
|
|
||||||
|
def list_run_metas(self) -> list[dict[str, Any]]:
|
||||||
|
mongo_metas = self._mongo.list_run_metas()
|
||||||
|
if mongo_metas:
|
||||||
|
return mongo_metas
|
||||||
|
return self._local.list_run_metas()
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Analyst / Trader Checkpoints
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def save_analysts_checkpoint(self, date: str, ticker: str, data: dict[str, Any]) -> Any:
|
||||||
|
local_result = self._local.save_analysts_checkpoint(date, ticker, data)
|
||||||
|
self._mongo.save_analysts_checkpoint(date, ticker, data)
|
||||||
|
return local_result
|
||||||
|
|
||||||
|
def load_analysts_checkpoint(self, date: str, ticker: str) -> dict[str, Any] | None:
|
||||||
|
return self._mongo.load_analysts_checkpoint(date, ticker) or self._local.load_analysts_checkpoint(date, ticker)
|
||||||
|
|
||||||
|
def save_trader_checkpoint(self, date: str, ticker: str, data: dict[str, Any]) -> Any:
|
||||||
|
local_result = self._local.save_trader_checkpoint(date, ticker, data)
|
||||||
|
self._mongo.save_trader_checkpoint(date, ticker, data)
|
||||||
|
return local_result
|
||||||
|
|
||||||
|
def load_trader_checkpoint(self, date: str, ticker: str) -> dict[str, Any] | None:
|
||||||
|
return self._mongo.load_trader_checkpoint(date, ticker) or self._local.load_trader_checkpoint(date, ticker)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Utility
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def clear_portfolio_stage(self, date: str, portfolio_id: str) -> list[str]:
|
||||||
|
local_deleted = self._local.clear_portfolio_stage(date, portfolio_id)
|
||||||
|
self._mongo.clear_portfolio_stage(date, portfolio_id)
|
||||||
|
return local_deleted
|
||||||
|
|
||||||
|
def list_pm_decisions(self, portfolio_id: str) -> list[Any]:
|
||||||
|
# Mongo returns dicts, Local returns Paths. Prefer Mongo for rich data.
|
||||||
|
mongo_results = self._mongo.list_pm_decisions(portfolio_id)
|
||||||
|
if mongo_results:
|
||||||
|
return mongo_results
|
||||||
|
return self._local.list_pm_decisions(portfolio_id)
|
||||||
|
|
||||||
|
def list_analyses_for_date(self, date: str) -> list[str]:
|
||||||
|
# Both return list[str]
|
||||||
|
return list(set(self._mongo.list_analyses_for_date(date)) | set(self._local.list_analyses_for_date(date)))
|
||||||
|
|
@ -18,6 +18,7 @@ import os
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
from tradingagents.portfolio.report_store import ReportStore
|
from tradingagents.portfolio.report_store import ReportStore
|
||||||
|
from tradingagents.portfolio.dual_report_store import DualReportStore
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -28,13 +29,13 @@ def create_report_store(
|
||||||
base_dir: str | None = None,
|
base_dir: str | None = None,
|
||||||
mongo_uri: str | None = None,
|
mongo_uri: str | None = None,
|
||||||
mongo_db: str | None = None,
|
mongo_db: str | None = None,
|
||||||
) -> Union[ReportStore, "MongoReportStore"]: # noqa: F821
|
) -> Union[ReportStore, "MongoReportStore", DualReportStore]: # noqa: F821
|
||||||
"""Create and return the appropriate report store.
|
"""Create and return the appropriate report store.
|
||||||
|
|
||||||
Resolution order for the backend:
|
Resolution order for the backend:
|
||||||
|
|
||||||
1. If *mongo_uri* is passed explicitly, use MongoDB.
|
1. If *mongo_uri* is passed explicitly, use DualReportStore.
|
||||||
2. If ``TRADINGAGENTS_MONGO_URI`` env var is set, use MongoDB.
|
2. If ``TRADINGAGENTS_MONGO_URI`` env var is set, use DualReportStore.
|
||||||
3. Fall back to the filesystem :class:`ReportStore`.
|
3. Fall back to the filesystem :class:`ReportStore`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -44,32 +45,33 @@ def create_report_store(
|
||||||
mongo_db: MongoDB database name (default ``"tradingagents"``).
|
mongo_db: MongoDB database name (default ``"tradingagents"``).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A store instance (either ``ReportStore`` or ``MongoReportStore``).
|
A store instance (either ``ReportStore`` or ``DualReportStore``).
|
||||||
"""
|
"""
|
||||||
uri = mongo_uri or os.getenv("TRADINGAGENTS_MONGO_URI", "")
|
uri = mongo_uri or os.getenv("TRADINGAGENTS_MONGO_URI", "")
|
||||||
db = mongo_db or os.getenv("TRADINGAGENTS_MONGO_DB", "tradingagents")
|
db = mongo_db or os.getenv("TRADINGAGENTS_MONGO_DB", "tradingagents")
|
||||||
|
|
||||||
|
# Filesystem instance (always created as part of Dual or as standalone)
|
||||||
|
_base = base_dir or os.getenv("PORTFOLIO_DATA_DIR") or os.getenv(
|
||||||
|
"TRADINGAGENTS_REPORTS_DIR", "reports"
|
||||||
|
)
|
||||||
|
local_store = ReportStore(base_dir=_base, run_id=run_id)
|
||||||
|
|
||||||
if uri:
|
if uri:
|
||||||
try:
|
try:
|
||||||
from tradingagents.portfolio.mongo_report_store import MongoReportStore
|
from tradingagents.portfolio.mongo_report_store import MongoReportStore
|
||||||
|
|
||||||
store = MongoReportStore(
|
mongo_store = MongoReportStore(
|
||||||
connection_string=uri,
|
connection_string=uri,
|
||||||
db_name=db,
|
db_name=db,
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
)
|
)
|
||||||
# ensure_indexes() is called automatically in __init__
|
logger.info("Using Dual report store (local + MongoDB db=%s, run_id=%s)", db, run_id)
|
||||||
logger.info("Using MongoDB report store (db=%s, run_id=%s)", db, run_id)
|
return DualReportStore(local_store, mongo_store)
|
||||||
return store
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"MongoDB connection failed — falling back to filesystem store",
|
"MongoDB connection failed — falling back to filesystem store",
|
||||||
exc_info=True,
|
exc_info=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Filesystem fallback
|
|
||||||
_base = base_dir or os.getenv("PORTFOLIO_DATA_DIR") or os.getenv(
|
|
||||||
"TRADINGAGENTS_REPORTS_DIR", "reports"
|
|
||||||
)
|
|
||||||
logger.info("Using filesystem report store (base=%s, run_id=%s)", _base, run_id)
|
logger.info("Using filesystem report store (base=%s, run_id=%s)", _base, run_id)
|
||||||
return ReportStore(base_dir=_base, run_id=run_id)
|
return local_store
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue