feat: enhance data persistence with DualReportStore for local and MongoDB storage; update report store creation logic
This commit is contained in:
parent
7c02e0c76c
commit
5f0a52f8e6
|
|
@ -3,6 +3,7 @@ from typing import Dict, Any, List, AsyncGenerator
|
|||
import logging
|
||||
import uuid
|
||||
import time
|
||||
import os
|
||||
from agent_os.backend.store import runs
|
||||
from agent_os.backend.dependencies import get_current_user
|
||||
from agent_os.backend.services.langgraph_engine import LangGraphEngine, NODE_TO_PHASE
|
||||
|
|
@ -269,15 +270,62 @@ async def reset_portfolio_stage(
|
|||
return {"deleted": deleted, "date": date, "portfolio_id": portfolio_id}
|
||||
|
||||
|
||||
def _get_mongo_col():
|
||||
"""Return the run_events collection if MongoDB is configured."""
|
||||
uri = os.getenv("TRADINGAGENTS_MONGO_URI")
|
||||
db_name = os.getenv("TRADINGAGENTS_MONGO_DB", "tradingagents")
|
||||
if uri:
|
||||
try:
|
||||
from pymongo import MongoClient
|
||||
client = MongoClient(uri)
|
||||
return client[db_name]["run_events"]
|
||||
except Exception:
|
||||
logger.warning("Failed to connect to MongoDB for historical events")
|
||||
return None
|
||||
|
||||
|
||||
@router.get("/")
|
||||
async def list_runs(user: dict = Depends(get_current_user)):
|
||||
# Filter by user in production
|
||||
return list(runs.values())
|
||||
all_runs = dict(runs)
|
||||
|
||||
# Supplement with historical metadata from MongoDB if available
|
||||
col = _get_mongo_col()
|
||||
if col is not None:
|
||||
try:
|
||||
# Fetch unique run_ids from the last 7 days (simplified)
|
||||
# In a real app, we'd have a separate 'runs' collection for metadata.
|
||||
# Here we use the events collection and group by run_id.
|
||||
pipeline = [
|
||||
{"$match": {"type": "log", "agent": "SYSTEM"}}, # Filter for start logs
|
||||
{"$sort": {"ts": -1}},
|
||||
{"$group": {
|
||||
"_id": "$run_id",
|
||||
"id": {"$first": "$run_id"},
|
||||
"type": {"$first": "$type"},
|
||||
"created_at": {"$first": "$ts"},
|
||||
# Status is harder to get from events without a dedicated meta doc
|
||||
}},
|
||||
{"$limit": 50}
|
||||
]
|
||||
for doc in col.aggregate(pipeline):
|
||||
rid = doc["id"]
|
||||
if rid not in all_runs:
|
||||
all_runs[rid] = {
|
||||
"id": rid,
|
||||
"type": doc.get("type", "unknown"),
|
||||
"status": "historical",
|
||||
"created_at": doc.get("created_at", 0),
|
||||
"user_id": "anonymous",
|
||||
}
|
||||
except Exception:
|
||||
logger.warning("Failed to fetch historical runs from MongoDB")
|
||||
|
||||
return list(all_runs.values())
|
||||
|
||||
@router.get("/{run_id}")
|
||||
async def get_run_status(run_id: str, user: dict = Depends(get_current_user)):
|
||||
if run_id not in runs:
|
||||
raise HTTPException(status_code=404, detail="Run not found")
|
||||
if run_id in runs:
|
||||
run = runs[run_id]
|
||||
# Lazy-load events from disk if they were not kept in memory
|
||||
if (
|
||||
|
|
@ -296,3 +344,25 @@ async def get_run_status(run_id: str, user: dict = Depends(get_current_user)):
|
|||
except Exception:
|
||||
logger.warning("Failed to lazy-load events for run=%s", run_id)
|
||||
return run
|
||||
|
||||
# Not in memory — try MongoDB
|
||||
col = _get_mongo_col()
|
||||
if col is not None:
|
||||
try:
|
||||
cursor = col.find({"run_id": run_id}).sort("ts", 1)
|
||||
events = list(cursor)
|
||||
if events:
|
||||
# Remove MongoDB _id for JSON serialization
|
||||
for e in events:
|
||||
e.pop("_id", None)
|
||||
return {
|
||||
"id": run_id,
|
||||
"status": "historical",
|
||||
"events": events,
|
||||
"type": events[0].get("type", "unknown") if events else "unknown",
|
||||
"created_at": events[0].get("ts", 0) if events else 0,
|
||||
}
|
||||
except Exception:
|
||||
logger.warning("Failed to fetch historical run %s from MongoDB", run_id)
|
||||
|
||||
raise HTTPException(status_code=404, detail="Run not found")
|
||||
|
|
|
|||
|
|
@ -175,9 +175,11 @@ class LangGraphEngine:
|
|||
# Run logger lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _start_run_logger(self, run_id: str) -> RunLogger:
|
||||
def _start_run_logger(self, run_id: str, flow_id: str | None = None) -> RunLogger:
|
||||
"""Create and register a ``RunLogger`` for the given run."""
|
||||
rl = RunLogger()
|
||||
uri = self.config.get("mongo_uri")
|
||||
db = self.config.get("mongo_db") or "tradingagents"
|
||||
rl = RunLogger(run_id=run_id, mongo_uri=uri, mongo_db=db, flow_id=flow_id)
|
||||
self._run_loggers[run_id] = rl
|
||||
set_run_logger(rl)
|
||||
return rl
|
||||
|
|
@ -209,7 +211,8 @@ class LangGraphEngine:
|
|||
short_rid = generate_run_id()
|
||||
store = create_report_store(run_id=short_rid)
|
||||
|
||||
rl = self._start_run_logger(run_id)
|
||||
flow_id = params.get("flow_id")
|
||||
rl = self._start_run_logger(run_id, flow_id=flow_id)
|
||||
scan_config = {**self.config}
|
||||
if params.get("max_tickers"):
|
||||
scan_config["max_auto_tickers"] = int(params["max_tickers"])
|
||||
|
|
@ -330,11 +333,11 @@ class LangGraphEngine:
|
|||
short_rid = generate_run_id()
|
||||
store = create_report_store(run_id=short_rid)
|
||||
|
||||
rl = self._start_run_logger(run_id)
|
||||
flow_id = params.get("flow_id")
|
||||
rl = self._start_run_logger(run_id, flow_id=flow_id)
|
||||
|
||||
logger.info("Starting PIPELINE run=%s ticker=%s date=%s rid=%s", run_id, ticker, date, short_rid)
|
||||
|
||||
logger.info(
|
||||
"Starting PIPELINE run=%s ticker=%s date=%s rid=%s", run_id, ticker, date, short_rid
|
||||
)
|
||||
yield self._system_log(f"Starting analysis pipeline for {ticker} on {date}")
|
||||
|
||||
graph_wrapper = TradingAgentsGraph(
|
||||
|
|
@ -471,7 +474,8 @@ class LangGraphEngine:
|
|||
# A reader store with no run_id resolves to the latest run for loading
|
||||
reader_store = create_report_store()
|
||||
|
||||
rl = self._start_run_logger(run_id)
|
||||
flow_id = params.get("flow_id")
|
||||
rl = self._start_run_logger(run_id, flow_id=flow_id)
|
||||
|
||||
logger.info(
|
||||
"Starting PORTFOLIO run=%s portfolio=%s date=%s rid=%s",
|
||||
|
|
@ -818,23 +822,27 @@ class LangGraphEngine:
|
|||
"""Run the full auto pipeline: scan → pipeline → portfolio."""
|
||||
date = params.get("date", time.strftime("%Y-%m-%d"))
|
||||
force = params.get("force", False)
|
||||
flow_id = params.get("flow_id") or str(uuid.uuid4())
|
||||
|
||||
# Use a reader store (no run_id) for skip-if-exists checks.
|
||||
# Each sub-phase (run_scan, run_pipeline, run_portfolio) creates
|
||||
# its own writer store with a fresh run_id internally.
|
||||
store = create_report_store()
|
||||
|
||||
self._start_run_logger(run_id) # auto-run's own logger; sub-phases create their own
|
||||
self._start_run_logger(run_id, flow_id=flow_id) # auto-run's own logger; sub-phases create their own
|
||||
|
||||
logger.info("Starting AUTO run=%s date=%s force=%s", run_id, date, force)
|
||||
yield self._system_log(f"Starting full auto workflow for {date} (force={force})")
|
||||
logger.info("Starting AUTO run=%s flow=%s date=%s force=%s", run_id, flow_id, date, force)
|
||||
yield self._system_log(f"Starting full auto workflow for {date} (force={force}, flow={flow_id})")
|
||||
|
||||
# Phase 1: Market scan
|
||||
yield self._system_log("Phase 1/3: Running market scan…")
|
||||
if not force and store.load_scan(date):
|
||||
yield self._system_log(f"Phase 1: Macro scan for {date} already exists, skipping.")
|
||||
else:
|
||||
async for evt in self.run_scan(f"{run_id}_scan", {"date": date}):
|
||||
scan_params = {"date": date, "flow_id": flow_id}
|
||||
if params.get("max_tickers"):
|
||||
scan_params["max_tickers"] = params["max_tickers"]
|
||||
async for evt in self.run_scan(f"{run_id}_scan", scan_params):
|
||||
yield evt
|
||||
|
||||
# Phase 2: Pipeline analysis — get tickers from scan report + portfolio holdings
|
||||
|
|
@ -917,7 +925,8 @@ class LangGraphEngine:
|
|||
)
|
||||
try:
|
||||
async for evt in self.run_pipeline(
|
||||
f"{run_id}_pipeline_{ticker}", {"ticker": ticker, "date": date}
|
||||
f"{run_id}_pipeline_{ticker}",
|
||||
{"ticker": ticker, "date": date, "flow_id": flow_id},
|
||||
):
|
||||
await pipeline_queue.put(evt)
|
||||
except Exception as exc:
|
||||
|
|
@ -943,7 +952,7 @@ class LangGraphEngine:
|
|||
try:
|
||||
async for evt in self.run_pipeline(
|
||||
f"{run_id}_fallback_{ticker}",
|
||||
{"ticker": ticker, "date": date},
|
||||
{"ticker": ticker, "date": date, "flow_id": flow_id},
|
||||
):
|
||||
await pipeline_queue.put(evt)
|
||||
except Exception as fallback_exc:
|
||||
|
|
@ -993,6 +1002,7 @@ class LangGraphEngine:
|
|||
# Phase 3: Portfolio management
|
||||
yield self._system_log("Phase 3/3: Running portfolio manager…")
|
||||
portfolio_params = {k: v for k, v in params.items() if k != "ticker"}
|
||||
portfolio_params["flow_id"] = flow_id
|
||||
portfolio_id = params.get("portfolio_id", "main_portfolio")
|
||||
|
||||
# Check if portfolio stage is fully complete (execution result exists)
|
||||
|
|
@ -1018,7 +1028,7 @@ class LangGraphEngine:
|
|||
yield evt
|
||||
|
||||
logger.info("Completed AUTO run=%s", run_id)
|
||||
self._finish_run_logger(run_id, get_daily_dir(date))
|
||||
self._finish_run_logger(run_id, get_daily_dir(date, run_id=flow_id[:8]))
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Report helpers
|
||||
|
|
|
|||
|
|
@ -18,15 +18,23 @@ from tradingagents.agents.utils.json_utils import extract_json
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_pm_decision_agent(llm):
|
||||
def create_pm_decision_agent(llm, config: dict | None = None):
|
||||
"""Create a PM decision agent node.
|
||||
|
||||
Args:
|
||||
llm: A LangChain chat model instance (deep_think recommended).
|
||||
config: Portfolio configuration dictionary containing constraints.
|
||||
|
||||
Returns:
|
||||
A node function ``pm_decision_node(state)`` compatible with LangGraph.
|
||||
"""
|
||||
cfg = config or {}
|
||||
constraints_str = (
|
||||
f"- Max position size: {cfg.get('max_position_pct', 0.15):.0%}\n"
|
||||
f"- Max sector exposure: {cfg.get('max_sector_pct', 0.35):.0%}\n"
|
||||
f"- Minimum cash reserve: {cfg.get('min_cash_pct', 0.05):.0%}\n"
|
||||
f"- Max total positions: {cfg.get('max_positions', 15)}\n"
|
||||
)
|
||||
|
||||
def pm_decision_node(state):
|
||||
analysis_date = state.get("analysis_date") or ""
|
||||
|
|
@ -35,7 +43,10 @@ def create_pm_decision_agent(llm):
|
|||
holding_reviews_str = state.get("holding_reviews") or "{}"
|
||||
prioritized_candidates_str = state.get("prioritized_candidates") or "[]"
|
||||
|
||||
context = f"""## Portfolio Data
|
||||
context = f"""## Portfolio Constraints
|
||||
{constraints_str}
|
||||
|
||||
## Portfolio Data
|
||||
{portfolio_data_str}
|
||||
|
||||
## Risk Metrics
|
||||
|
|
@ -50,8 +61,13 @@ def create_pm_decision_agent(llm):
|
|||
|
||||
system_message = (
|
||||
"You are a portfolio manager making final investment decisions. "
|
||||
"Given the risk metrics, holding reviews, and prioritized investment candidates, "
|
||||
"Given the constraints, risk metrics, holding reviews, and prioritized investment candidates, "
|
||||
"produce a structured JSON investment decision. "
|
||||
"## CONSTRAINTS COMPLIANCE:\n"
|
||||
"You MUST ensure your suggested buys and position sizes adhere to the portfolio constraints. "
|
||||
"If a high-conviction candidate would exceed the max position size or sector limit, "
|
||||
"YOU MUST adjust the suggested 'shares' downward to fit within the limit. "
|
||||
"Do not suggest buys that you know will be rejected by the risk engine.\n\n"
|
||||
"Consider: reducing risk where metrics are poor, acting on SELL recommendations, "
|
||||
"and adding positions in high-conviction candidates that pass constraints. "
|
||||
"For every BUY you MUST set a stop_loss price (maximum acceptable loss level, "
|
||||
|
|
@ -69,7 +85,7 @@ def create_pm_decision_agent(llm):
|
|||
' "risk_summary": "..."\n'
|
||||
"}\n\n"
|
||||
"IMPORTANT: Output ONLY valid JSON. Start your response with '{' and end with '}'. "
|
||||
"Do NOT use markdown code fences. Do NOT include any explanation before or after the JSON.\n\n"
|
||||
"Do NOT use markdown code fences. Do NOT include any explanation or preamble before or after the JSON.\n\n"
|
||||
f"{context}"
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -48,12 +48,13 @@ class PortfolioGraph:
|
|||
mid_llm = self._create_llm("mid_think")
|
||||
deep_llm = self._create_llm("deep_think")
|
||||
|
||||
portfolio_config = self._get_portfolio_config()
|
||||
|
||||
agents = {
|
||||
"review_holdings": create_holding_reviewer(mid_llm),
|
||||
"pm_decision": create_pm_decision_agent(deep_llm),
|
||||
"pm_decision": create_pm_decision_agent(deep_llm, config=portfolio_config),
|
||||
}
|
||||
|
||||
portfolio_config = self._get_portfolio_config()
|
||||
setup = PortfolioGraphSetup(agents, repo=self._repo, config=portfolio_config)
|
||||
self.graph = setup.setup_graph()
|
||||
|
||||
|
|
|
|||
|
|
@ -60,11 +60,29 @@ class RunLogger:
|
|||
events: Thread-safe list of all recorded events.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
run_id: str | None = None,
|
||||
mongo_uri: str | None = None,
|
||||
mongo_db: str | None = "tradingagents",
|
||||
flow_id: str | None = None,
|
||||
) -> None:
|
||||
self._lock = threading.Lock()
|
||||
self.events: list[_Event] = []
|
||||
self.callback = _LLMCallbackHandler(self)
|
||||
self._start = time.time()
|
||||
self.run_id = run_id
|
||||
self.flow_id = flow_id
|
||||
self._mongo_col = None
|
||||
|
||||
if mongo_uri and run_id:
|
||||
try:
|
||||
from pymongo import MongoClient
|
||||
client = MongoClient(mongo_uri)
|
||||
self._mongo_col = client[mongo_db]["run_events"]
|
||||
_py_logger.info("RunLogger: persisting events to MongoDB (run_id=%s, flow_id=%s)", run_id, flow_id)
|
||||
except Exception as exc:
|
||||
_py_logger.warning("RunLogger: MongoDB connection failed: %s", exc)
|
||||
|
||||
# -- public helpers to record events from non-callback code ----------------
|
||||
|
||||
|
|
@ -197,6 +215,16 @@ class RunLogger:
|
|||
self.events.append(evt)
|
||||
_py_logger.debug("%s | %s", evt.kind, json.dumps(evt.data))
|
||||
|
||||
if self._mongo_col is not None and self.run_id:
|
||||
try:
|
||||
doc = evt.to_dict()
|
||||
doc["run_id"] = self.run_id
|
||||
if self.flow_id:
|
||||
doc["flow_id"] = self.flow_id
|
||||
self._mongo_col.insert_one(doc)
|
||||
except Exception as exc:
|
||||
_py_logger.warning("RunLogger: MongoDB insert failed: %s", exc)
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
# LangChain callback handler — captures LLM call details
|
||||
|
|
@ -225,11 +253,24 @@ class _LLMCallbackHandler(BaseCallbackHandler):
|
|||
model = _extract_model(serialized, kwargs)
|
||||
agent = kwargs.get("name") or serialized.get("name") or _extract_graph_node(kwargs)
|
||||
key = str(run_id) if run_id else str(id(messages))
|
||||
|
||||
# Capture prompt content
|
||||
prompt = ""
|
||||
try:
|
||||
if messages and isinstance(messages[0], list):
|
||||
# batched messages
|
||||
prompt = "\n\n".join([str(m.content) if hasattr(m, "content") else str(m) for m in messages[0]])
|
||||
else:
|
||||
prompt = "\n\n".join([str(m.content) if hasattr(m, "content") else str(m) for m in messages])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
with self._lock:
|
||||
self._inflight[key] = {
|
||||
"model": model,
|
||||
"agent": agent or "",
|
||||
"t0": time.time(),
|
||||
"prompt": prompt,
|
||||
}
|
||||
|
||||
# -- legacy LLM start (completion-style) -----------------------------------
|
||||
|
|
@ -250,6 +291,7 @@ class _LLMCallbackHandler(BaseCallbackHandler):
|
|||
"model": model,
|
||||
"agent": agent or "",
|
||||
"t0": time.time(),
|
||||
"prompt": prompts[0] if prompts else "",
|
||||
}
|
||||
|
||||
# -- LLM end ---------------------------------------------------------------
|
||||
|
|
@ -262,10 +304,13 @@ class _LLMCallbackHandler(BaseCallbackHandler):
|
|||
tokens_in = 0
|
||||
tokens_out = 0
|
||||
model_from_response = ""
|
||||
full_response = ""
|
||||
try:
|
||||
generation = response.generations[0][0]
|
||||
full_response = generation.text if hasattr(generation, "text") else str(generation)
|
||||
if hasattr(generation, "message"):
|
||||
msg = generation.message
|
||||
full_response = msg.content if hasattr(msg, "content") else full_response
|
||||
if isinstance(msg, AIMessage) and hasattr(msg, "usage_metadata") and msg.usage_metadata:
|
||||
tokens_in = msg.usage_metadata.get("input_tokens", 0)
|
||||
tokens_out = msg.usage_metadata.get("output_tokens", 0)
|
||||
|
|
@ -277,6 +322,7 @@ class _LLMCallbackHandler(BaseCallbackHandler):
|
|||
model = model_from_response or (meta["model"] if meta else "unknown")
|
||||
agent = meta["agent"] if meta else ""
|
||||
duration_ms = (time.time() - meta["t0"]) * 1000 if meta else 0
|
||||
prompt = meta["prompt"] if meta else ""
|
||||
|
||||
evt = _Event(
|
||||
kind="llm",
|
||||
|
|
@ -287,6 +333,8 @@ class _LLMCallbackHandler(BaseCallbackHandler):
|
|||
"tokens_in": tokens_in,
|
||||
"tokens_out": tokens_out,
|
||||
"duration_ms": round(duration_ms, 1),
|
||||
"prompt": prompt,
|
||||
"response": full_response,
|
||||
},
|
||||
)
|
||||
self._rl._append(evt)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,175 @@
|
|||
"""Dual report store that persists to both local filesystem and MongoDB.
|
||||
|
||||
Delegates all save_* calls to both a :class:`ReportStore` and a
|
||||
:class:`MongoReportStore`. Load methods prioritize the MongoDB store if
|
||||
available, otherwise fall back to the filesystem.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
from tradingagents.portfolio.report_store import ReportStore
|
||||
from tradingagents.portfolio.mongo_report_store import MongoReportStore
|
||||
|
||||
|
||||
class DualReportStore:
|
||||
"""Report store that writes to two backends simultaneously."""
|
||||
|
||||
def __init__(self, local_store: ReportStore, mongo_store: MongoReportStore) -> None:
|
||||
self._local = local_store
|
||||
self._mongo = mongo_store
|
||||
|
||||
@property
|
||||
def run_id(self) -> str | None:
|
||||
return self._local.run_id
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Macro Scan
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def save_scan(self, date: str, data: dict[str, Any]) -> Any:
|
||||
# local returns Path, mongo returns str (_id)
|
||||
local_result = self._local.save_scan(date, data)
|
||||
self._mongo.save_scan(date, data)
|
||||
return local_result
|
||||
|
||||
def load_scan(self, date: str) -> dict[str, Any] | None:
|
||||
return self._mongo.load_scan(date) or self._local.load_scan(date)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Per-Ticker Analysis
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def save_analysis(self, date: str, ticker: str, data: dict[str, Any]) -> Any:
|
||||
local_result = self._local.save_analysis(date, ticker, data)
|
||||
self._mongo.save_analysis(date, ticker, data)
|
||||
return local_result
|
||||
|
||||
def load_analysis(self, date: str, ticker: str) -> dict[str, Any] | None:
|
||||
return self._mongo.load_analysis(date, ticker) or self._local.load_analysis(date, ticker)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Holding Reviews
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def save_holding_review(self, date: str, ticker: str, data: dict[str, Any]) -> Any:
|
||||
local_result = self._local.save_holding_review(date, ticker, data)
|
||||
self._mongo.save_holding_review(date, ticker, data)
|
||||
return local_result
|
||||
|
||||
def load_holding_review(self, date: str, ticker: str) -> dict[str, Any] | None:
|
||||
return self._mongo.load_holding_review(date, ticker) or self._local.load_holding_review(date, ticker)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Risk Metrics
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def save_risk_metrics(self, date: str, portfolio_id: str, data: dict[str, Any]) -> Any:
|
||||
local_result = self._local.save_risk_metrics(date, portfolio_id, data)
|
||||
self._mongo.save_risk_metrics(date, portfolio_id, data)
|
||||
return local_result
|
||||
|
||||
def load_risk_metrics(self, date: str, portfolio_id: str) -> dict[str, Any] | None:
|
||||
return self._mongo.load_risk_metrics(date, portfolio_id) or self._local.load_risk_metrics(date, portfolio_id)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# PM Decisions
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def save_pm_decision(
|
||||
self,
|
||||
date: str,
|
||||
portfolio_id: str,
|
||||
data: dict[str, Any],
|
||||
markdown: str | None = None,
|
||||
) -> Any:
|
||||
local_result = self._local.save_pm_decision(date, portfolio_id, data, markdown=markdown)
|
||||
self._mongo.save_pm_decision(date, portfolio_id, data, markdown=markdown)
|
||||
return local_result
|
||||
|
||||
def load_pm_decision(self, date: str, portfolio_id: str) -> dict[str, Any] | None:
|
||||
return self._mongo.load_pm_decision(date, portfolio_id) or self._local.load_pm_decision(date, portfolio_id)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Execution Results
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def save_execution_result(self, date: str, portfolio_id: str, data: dict[str, Any]) -> Any:
|
||||
local_result = self._local.save_execution_result(date, portfolio_id, data)
|
||||
self._mongo.save_execution_result(date, portfolio_id, data)
|
||||
return local_result
|
||||
|
||||
def load_execution_result(self, date: str, portfolio_id: str) -> dict[str, Any] | None:
|
||||
return self._mongo.load_execution_result(date, portfolio_id) or self._local.load_execution_result(date, portfolio_id)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Run Meta / Events persistence
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def save_run_meta(self, date: str, data: dict[str, Any]) -> Any:
|
||||
local_result = self._local.save_run_meta(date, data)
|
||||
self._mongo.save_run_meta(date, data)
|
||||
return local_result
|
||||
|
||||
def load_run_meta(self, date: str) -> dict[str, Any] | None:
|
||||
return self._mongo.load_run_meta(date) or self._local.load_run_meta(date)
|
||||
|
||||
def save_run_events(self, date: str, events: list[dict[str, Any]]) -> Any:
|
||||
local_result = self._local.save_run_events(date, events)
|
||||
self._mongo.save_run_events(date, events)
|
||||
return local_result
|
||||
|
||||
def load_run_events(self, date: str) -> list[dict[str, Any]]:
|
||||
mongo_events = self._mongo.load_run_events(date)
|
||||
if mongo_events:
|
||||
return mongo_events
|
||||
return self._local.load_run_events(date)
|
||||
|
||||
def list_run_metas(self) -> list[dict[str, Any]]:
|
||||
mongo_metas = self._mongo.list_run_metas()
|
||||
if mongo_metas:
|
||||
return mongo_metas
|
||||
return self._local.list_run_metas()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Analyst / Trader Checkpoints
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def save_analysts_checkpoint(self, date: str, ticker: str, data: dict[str, Any]) -> Any:
|
||||
local_result = self._local.save_analysts_checkpoint(date, ticker, data)
|
||||
self._mongo.save_analysts_checkpoint(date, ticker, data)
|
||||
return local_result
|
||||
|
||||
def load_analysts_checkpoint(self, date: str, ticker: str) -> dict[str, Any] | None:
|
||||
return self._mongo.load_analysts_checkpoint(date, ticker) or self._local.load_analysts_checkpoint(date, ticker)
|
||||
|
||||
def save_trader_checkpoint(self, date: str, ticker: str, data: dict[str, Any]) -> Any:
|
||||
local_result = self._local.save_trader_checkpoint(date, ticker, data)
|
||||
self._mongo.save_trader_checkpoint(date, ticker, data)
|
||||
return local_result
|
||||
|
||||
def load_trader_checkpoint(self, date: str, ticker: str) -> dict[str, Any] | None:
|
||||
return self._mongo.load_trader_checkpoint(date, ticker) or self._local.load_trader_checkpoint(date, ticker)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Utility
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def clear_portfolio_stage(self, date: str, portfolio_id: str) -> list[str]:
|
||||
local_deleted = self._local.clear_portfolio_stage(date, portfolio_id)
|
||||
self._mongo.clear_portfolio_stage(date, portfolio_id)
|
||||
return local_deleted
|
||||
|
||||
def list_pm_decisions(self, portfolio_id: str) -> list[Any]:
|
||||
# Mongo returns dicts, Local returns Paths. Prefer Mongo for rich data.
|
||||
mongo_results = self._mongo.list_pm_decisions(portfolio_id)
|
||||
if mongo_results:
|
||||
return mongo_results
|
||||
return self._local.list_pm_decisions(portfolio_id)
|
||||
|
||||
def list_analyses_for_date(self, date: str) -> list[str]:
|
||||
# Both return list[str]
|
||||
return list(set(self._mongo.list_analyses_for_date(date)) | set(self._local.list_analyses_for_date(date)))
|
||||
|
|
@ -18,6 +18,7 @@ import os
|
|||
from typing import Union
|
||||
|
||||
from tradingagents.portfolio.report_store import ReportStore
|
||||
from tradingagents.portfolio.dual_report_store import DualReportStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -28,13 +29,13 @@ def create_report_store(
|
|||
base_dir: str | None = None,
|
||||
mongo_uri: str | None = None,
|
||||
mongo_db: str | None = None,
|
||||
) -> Union[ReportStore, "MongoReportStore"]: # noqa: F821
|
||||
) -> Union[ReportStore, "MongoReportStore", DualReportStore]: # noqa: F821
|
||||
"""Create and return the appropriate report store.
|
||||
|
||||
Resolution order for the backend:
|
||||
|
||||
1. If *mongo_uri* is passed explicitly, use MongoDB.
|
||||
2. If ``TRADINGAGENTS_MONGO_URI`` env var is set, use MongoDB.
|
||||
1. If *mongo_uri* is passed explicitly, use DualReportStore.
|
||||
2. If ``TRADINGAGENTS_MONGO_URI`` env var is set, use DualReportStore.
|
||||
3. Fall back to the filesystem :class:`ReportStore`.
|
||||
|
||||
Args:
|
||||
|
|
@ -44,32 +45,33 @@ def create_report_store(
|
|||
mongo_db: MongoDB database name (default ``"tradingagents"``).
|
||||
|
||||
Returns:
|
||||
A store instance (either ``ReportStore`` or ``MongoReportStore``).
|
||||
A store instance (either ``ReportStore`` or ``DualReportStore``).
|
||||
"""
|
||||
uri = mongo_uri or os.getenv("TRADINGAGENTS_MONGO_URI", "")
|
||||
db = mongo_db or os.getenv("TRADINGAGENTS_MONGO_DB", "tradingagents")
|
||||
|
||||
# Filesystem instance (always created as part of Dual or as standalone)
|
||||
_base = base_dir or os.getenv("PORTFOLIO_DATA_DIR") or os.getenv(
|
||||
"TRADINGAGENTS_REPORTS_DIR", "reports"
|
||||
)
|
||||
local_store = ReportStore(base_dir=_base, run_id=run_id)
|
||||
|
||||
if uri:
|
||||
try:
|
||||
from tradingagents.portfolio.mongo_report_store import MongoReportStore
|
||||
|
||||
store = MongoReportStore(
|
||||
mongo_store = MongoReportStore(
|
||||
connection_string=uri,
|
||||
db_name=db,
|
||||
run_id=run_id,
|
||||
)
|
||||
# ensure_indexes() is called automatically in __init__
|
||||
logger.info("Using MongoDB report store (db=%s, run_id=%s)", db, run_id)
|
||||
return store
|
||||
logger.info("Using Dual report store (local + MongoDB db=%s, run_id=%s)", db, run_id)
|
||||
return DualReportStore(local_store, mongo_store)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"MongoDB connection failed — falling back to filesystem store",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# Filesystem fallback
|
||||
_base = base_dir or os.getenv("PORTFOLIO_DATA_DIR") or os.getenv(
|
||||
"TRADINGAGENTS_REPORTS_DIR", "reports"
|
||||
)
|
||||
logger.info("Using filesystem report store (base=%s, run_id=%s)", _base, run_id)
|
||||
return ReportStore(base_dir=_base, run_id=run_id)
|
||||
return local_store
|
||||
|
|
|
|||
Loading…
Reference in New Issue