From 5f0a52f8e6d8e6ae7c6a5e6e044dbe00a8b7e684 Mon Sep 17 00:00:00 2001 From: Ahmet Guzererler Date: Wed, 25 Mar 2026 19:04:36 +0100 Subject: [PATCH] feat: enhance data persistence with DualReportStore for local and MongoDB storage; update report store creation logic --- agent_os/backend/routes/runs.py | 108 +++++++++-- agent_os/backend/services/langgraph_engine.py | 40 ++-- .../agents/portfolio/pm_decision_agent.py | 24 ++- tradingagents/graph/portfolio_graph.py | 5 +- tradingagents/observability.py | 50 ++++- tradingagents/portfolio/dual_report_store.py | 175 ++++++++++++++++++ tradingagents/portfolio/store_factory.py | 28 +-- 7 files changed, 376 insertions(+), 54 deletions(-) create mode 100644 tradingagents/portfolio/dual_report_store.py diff --git a/agent_os/backend/routes/runs.py b/agent_os/backend/routes/runs.py index a2719506..5836f3f9 100644 --- a/agent_os/backend/routes/runs.py +++ b/agent_os/backend/routes/runs.py @@ -3,6 +3,7 @@ from typing import Dict, Any, List, AsyncGenerator import logging import uuid import time +import os from agent_os.backend.store import runs from agent_os.backend.dependencies import get_current_user from agent_os.backend.services.langgraph_engine import LangGraphEngine, NODE_TO_PHASE @@ -269,30 +270,99 @@ async def reset_portfolio_stage( return {"deleted": deleted, "date": date, "portfolio_id": portfolio_id} +def _get_mongo_col(): + """Return the run_events collection if MongoDB is configured.""" + uri = os.getenv("TRADINGAGENTS_MONGO_URI") + db_name = os.getenv("TRADINGAGENTS_MONGO_DB", "tradingagents") + if uri: + try: + from pymongo import MongoClient + client = MongoClient(uri) + return client[db_name]["run_events"] + except Exception: + logger.warning("Failed to connect to MongoDB for historical events") + return None + + @router.get("/") async def list_runs(user: dict = Depends(get_current_user)): # Filter by user in production - return list(runs.values()) + all_runs = dict(runs) + + # Supplement with historical metadata from MongoDB if available + col = _get_mongo_col() + if col is not None: + try: + # Fetch unique run_ids from the last 7 days (simplified) + # In a real app, we'd have a separate 'runs' collection for metadata. + # Here we use the events collection and group by run_id. + pipeline = [ + {"$match": {"type": "log", "agent": "SYSTEM"}}, # Filter for start logs + {"$sort": {"ts": -1}}, + {"$group": { + "_id": "$run_id", + "id": {"$first": "$run_id"}, + "type": {"$first": "$type"}, + "created_at": {"$first": "$ts"}, + # Status is harder to get from events without a dedicated meta doc + }}, + {"$limit": 50} + ] + for doc in col.aggregate(pipeline): + rid = doc["id"] + if rid not in all_runs: + all_runs[rid] = { + "id": rid, + "type": doc.get("type", "unknown"), + "status": "historical", + "created_at": doc.get("created_at", 0), + "user_id": "anonymous", + } + except Exception: + logger.warning("Failed to fetch historical runs from MongoDB") + + return list(all_runs.values()) @router.get("/{run_id}") async def get_run_status(run_id: str, user: dict = Depends(get_current_user)): - if run_id not in runs: - raise HTTPException(status_code=404, detail="Run not found") - run = runs[run_id] - # Lazy-load events from disk if they were not kept in memory - if ( - not run.get("events") - and run.get("status") in ("completed", "failed") - ): + if run_id in runs: + run = runs[run_id] + # Lazy-load events from disk if they were not kept in memory + if ( + not run.get("events") + and run.get("status") in ("completed", "failed") + ): + try: + from tradingagents.portfolio.store_factory import create_report_store + short_rid = run.get("short_rid") or run_id[:8] + store = create_report_store(run_id=short_rid) + date = (run.get("params") or {}).get("date", "") + if date: + events = store.load_run_events(date) + if events: + run["events"] = events + except Exception: + logger.warning("Failed to lazy-load events for run=%s", run_id) + return run + + # Not in memory — try MongoDB + col = _get_mongo_col() + if col is not None: try: - from tradingagents.portfolio.store_factory import create_report_store - short_rid = run.get("short_rid") or run_id[:8] - store = create_report_store(run_id=short_rid) - date = (run.get("params") or {}).get("date", "") - if date: - events = store.load_run_events(date) - if events: - run["events"] = events + cursor = col.find({"run_id": run_id}).sort("ts", 1) + events = list(cursor) + if events: + # Remove MongoDB _id for JSON serialization + for e in events: + e.pop("_id", None) + return { + "id": run_id, + "status": "historical", + "events": events, + "type": events[0].get("type", "unknown") if events else "unknown", + "created_at": events[0].get("ts", 0) if events else 0, + } except Exception: - logger.warning("Failed to lazy-load events for run=%s", run_id) - return run + logger.warning("Failed to fetch historical run %s from MongoDB", run_id) + + raise HTTPException(status_code=404, detail="Run not found") diff --git a/agent_os/backend/services/langgraph_engine.py b/agent_os/backend/services/langgraph_engine.py index d3212a37..2d4d5fe2 100644 --- a/agent_os/backend/services/langgraph_engine.py +++ b/agent_os/backend/services/langgraph_engine.py @@ -175,9 +175,11 @@ class LangGraphEngine: # Run logger lifecycle # ------------------------------------------------------------------ - def _start_run_logger(self, run_id: str) -> RunLogger: + def _start_run_logger(self, run_id: str, flow_id: str | None = None) -> RunLogger: """Create and register a ``RunLogger`` for the given run.""" - rl = RunLogger() + uri = self.config.get("mongo_uri") + db = self.config.get("mongo_db") or "tradingagents" + rl = RunLogger(run_id=run_id, mongo_uri=uri, mongo_db=db, flow_id=flow_id) self._run_loggers[run_id] = rl set_run_logger(rl) return rl @@ -209,7 +211,8 @@ class LangGraphEngine: short_rid = generate_run_id() store = create_report_store(run_id=short_rid) - rl = self._start_run_logger(run_id) + flow_id = params.get("flow_id") + rl = self._start_run_logger(run_id, flow_id=flow_id) scan_config = {**self.config} if params.get("max_tickers"): scan_config["max_auto_tickers"] = int(params["max_tickers"]) @@ -330,11 +333,11 @@ class LangGraphEngine: short_rid = generate_run_id() store = create_report_store(run_id=short_rid) - rl = self._start_run_logger(run_id) + flow_id = params.get("flow_id") + rl = self._start_run_logger(run_id, flow_id=flow_id) + + logger.info("Starting PIPELINE run=%s ticker=%s date=%s rid=%s", run_id, ticker, date, short_rid) - logger.info( - "Starting PIPELINE run=%s ticker=%s date=%s rid=%s", run_id, ticker, date, short_rid - ) yield self._system_log(f"Starting analysis pipeline for {ticker} on {date}") graph_wrapper = TradingAgentsGraph( @@ -471,7 +474,8 @@ class LangGraphEngine: # A reader store with no run_id resolves to the latest run for loading reader_store = create_report_store() - rl = self._start_run_logger(run_id) + flow_id = params.get("flow_id") + rl = self._start_run_logger(run_id, flow_id=flow_id) logger.info( "Starting PORTFOLIO run=%s portfolio=%s date=%s rid=%s", @@ -818,23 +822,27 @@ class LangGraphEngine: """Run the full auto pipeline: scan → pipeline → portfolio.""" date = params.get("date", time.strftime("%Y-%m-%d")) force = params.get("force", False) + flow_id = params.get("flow_id") or str(uuid.uuid4()) # Use a reader store (no run_id) for skip-if-exists checks. # Each sub-phase (run_scan, run_pipeline, run_portfolio) creates # its own writer store with a fresh run_id internally. store = create_report_store() - self._start_run_logger(run_id) # auto-run's own logger; sub-phases create their own + self._start_run_logger(run_id, flow_id=flow_id) # auto-run's own logger; sub-phases create their own - logger.info("Starting AUTO run=%s date=%s force=%s", run_id, date, force) - yield self._system_log(f"Starting full auto workflow for {date} (force={force})") + logger.info("Starting AUTO run=%s flow=%s date=%s force=%s", run_id, flow_id, date, force) + yield self._system_log(f"Starting full auto workflow for {date} (force={force}, flow={flow_id})") # Phase 1: Market scan yield self._system_log("Phase 1/3: Running market scan…") if not force and store.load_scan(date): yield self._system_log(f"Phase 1: Macro scan for {date} already exists, skipping.") else: - async for evt in self.run_scan(f"{run_id}_scan", {"date": date}): + scan_params = {"date": date, "flow_id": flow_id} + if params.get("max_tickers"): + scan_params["max_tickers"] = params["max_tickers"] + async for evt in self.run_scan(f"{run_id}_scan", scan_params): yield evt # Phase 2: Pipeline analysis — get tickers from scan report + portfolio holdings @@ -917,7 +925,8 @@ class LangGraphEngine: ) try: async for evt in self.run_pipeline( - f"{run_id}_pipeline_{ticker}", {"ticker": ticker, "date": date} + f"{run_id}_pipeline_{ticker}", + {"ticker": ticker, "date": date, "flow_id": flow_id}, ): await pipeline_queue.put(evt) except Exception as exc: @@ -943,7 +952,7 @@ class LangGraphEngine: try: async for evt in self.run_pipeline( f"{run_id}_fallback_{ticker}", - {"ticker": ticker, "date": date}, + {"ticker": ticker, "date": date, "flow_id": flow_id}, ): await pipeline_queue.put(evt) except Exception as fallback_exc: @@ -993,6 +1002,7 @@ class LangGraphEngine: # Phase 3: Portfolio management yield self._system_log("Phase 3/3: Running portfolio manager…") portfolio_params = {k: v for k, v in params.items() if k != "ticker"} + portfolio_params["flow_id"] = flow_id portfolio_id = params.get("portfolio_id", "main_portfolio") # Check if portfolio stage is fully complete (execution result exists) @@ -1018,7 +1028,7 @@ class LangGraphEngine: yield evt logger.info("Completed AUTO run=%s", run_id) - self._finish_run_logger(run_id, get_daily_dir(date)) + self._finish_run_logger(run_id, get_daily_dir(date, run_id=flow_id[:8])) # ------------------------------------------------------------------ # Report helpers diff --git a/tradingagents/agents/portfolio/pm_decision_agent.py b/tradingagents/agents/portfolio/pm_decision_agent.py index c83ec672..7feb6820 100644 --- a/tradingagents/agents/portfolio/pm_decision_agent.py +++ b/tradingagents/agents/portfolio/pm_decision_agent.py @@ -18,15 +18,23 @@ from tradingagents.agents.utils.json_utils import extract_json logger = logging.getLogger(__name__) -def create_pm_decision_agent(llm): +def create_pm_decision_agent(llm, config: dict | None = None): """Create a PM decision agent node. Args: llm: A LangChain chat model instance (deep_think recommended). + config: Portfolio configuration dictionary containing constraints. Returns: A node function ``pm_decision_node(state)`` compatible with LangGraph. """ + cfg = config or {} + constraints_str = ( + f"- Max position size: {cfg.get('max_position_pct', 0.15):.0%}\n" + f"- Max sector exposure: {cfg.get('max_sector_pct', 0.35):.0%}\n" + f"- Minimum cash reserve: {cfg.get('min_cash_pct', 0.05):.0%}\n" + f"- Max total positions: {cfg.get('max_positions', 15)}\n" + ) def pm_decision_node(state): analysis_date = state.get("analysis_date") or "" @@ -35,7 +43,10 @@ def create_pm_decision_agent(llm): holding_reviews_str = state.get("holding_reviews") or "{}" prioritized_candidates_str = state.get("prioritized_candidates") or "[]" - context = f"""## Portfolio Data + context = f"""## Portfolio Constraints +{constraints_str} + +## Portfolio Data {portfolio_data_str} ## Risk Metrics @@ -50,8 +61,13 @@ def create_pm_decision_agent(llm): system_message = ( "You are a portfolio manager making final investment decisions. " - "Given the risk metrics, holding reviews, and prioritized investment candidates, " + "Given the constraints, risk metrics, holding reviews, and prioritized investment candidates, " "produce a structured JSON investment decision. " + "## CONSTRAINTS COMPLIANCE:\n" + "You MUST ensure your suggested buys and position sizes adhere to the portfolio constraints. " + "If a high-conviction candidate would exceed the max position size or sector limit, " + "YOU MUST adjust the suggested 'shares' downward to fit within the limit. " + "Do not suggest buys that you know will be rejected by the risk engine.\n\n" "Consider: reducing risk where metrics are poor, acting on SELL recommendations, " "and adding positions in high-conviction candidates that pass constraints. " "For every BUY you MUST set a stop_loss price (maximum acceptable loss level, " @@ -69,7 +85,7 @@ def create_pm_decision_agent(llm): ' "risk_summary": "..."\n' "}\n\n" "IMPORTANT: Output ONLY valid JSON. Start your response with '{' and end with '}'. " - "Do NOT use markdown code fences. Do NOT include any explanation before or after the JSON.\n\n" + "Do NOT use markdown code fences. Do NOT include any explanation or preamble before or after the JSON.\n\n" f"{context}" ) diff --git a/tradingagents/graph/portfolio_graph.py b/tradingagents/graph/portfolio_graph.py index 7d91af84..583bbc34 100644 --- a/tradingagents/graph/portfolio_graph.py +++ b/tradingagents/graph/portfolio_graph.py @@ -48,12 +48,13 @@ class PortfolioGraph: mid_llm = self._create_llm("mid_think") deep_llm = self._create_llm("deep_think") + portfolio_config = self._get_portfolio_config() + agents = { "review_holdings": create_holding_reviewer(mid_llm), - "pm_decision": create_pm_decision_agent(deep_llm), + "pm_decision": create_pm_decision_agent(deep_llm, config=portfolio_config), } - portfolio_config = self._get_portfolio_config() setup = PortfolioGraphSetup(agents, repo=self._repo, config=portfolio_config) self.graph = setup.setup_graph() diff --git a/tradingagents/observability.py b/tradingagents/observability.py index 3cb90a7f..3f7e8e3e 100644 --- a/tradingagents/observability.py +++ b/tradingagents/observability.py @@ -60,11 +60,29 @@ class RunLogger: events: Thread-safe list of all recorded events. """ - def __init__(self) -> None: + def __init__( + self, + run_id: str | None = None, + mongo_uri: str | None = None, + mongo_db: str | None = "tradingagents", + flow_id: str | None = None, + ) -> None: self._lock = threading.Lock() self.events: list[_Event] = [] self.callback = _LLMCallbackHandler(self) self._start = time.time() + self.run_id = run_id + self.flow_id = flow_id + self._mongo_col = None + + if mongo_uri and run_id: + try: + from pymongo import MongoClient + client = MongoClient(mongo_uri) + self._mongo_col = client[mongo_db]["run_events"] + _py_logger.info("RunLogger: persisting events to MongoDB (run_id=%s, flow_id=%s)", run_id, flow_id) + except Exception as exc: + _py_logger.warning("RunLogger: MongoDB connection failed: %s", exc) # -- public helpers to record events from non-callback code ---------------- @@ -197,6 +215,16 @@ class RunLogger: self.events.append(evt) _py_logger.debug("%s | %s", evt.kind, json.dumps(evt.data)) + if self._mongo_col is not None and self.run_id: + try: + doc = evt.to_dict() + doc["run_id"] = self.run_id + if self.flow_id: + doc["flow_id"] = self.flow_id + self._mongo_col.insert_one(doc) + except Exception as exc: + _py_logger.warning("RunLogger: MongoDB insert failed: %s", exc) + # ────────────────────────────────────────────────────────────────────────────── # LangChain callback handler — captures LLM call details @@ -225,11 +253,24 @@ class _LLMCallbackHandler(BaseCallbackHandler): model = _extract_model(serialized, kwargs) agent = kwargs.get("name") or serialized.get("name") or _extract_graph_node(kwargs) key = str(run_id) if run_id else str(id(messages)) + + # Capture prompt content + prompt = "" + try: + if messages and isinstance(messages[0], list): + # batched messages + prompt = "\n\n".join([str(m.content) if hasattr(m, "content") else str(m) for m in messages[0]]) + else: + prompt = "\n\n".join([str(m.content) if hasattr(m, "content") else str(m) for m in messages]) + except Exception: + pass + with self._lock: self._inflight[key] = { "model": model, "agent": agent or "", "t0": time.time(), + "prompt": prompt, } # -- legacy LLM start (completion-style) ----------------------------------- @@ -250,6 +291,7 @@ class _LLMCallbackHandler(BaseCallbackHandler): "model": model, "agent": agent or "", "t0": time.time(), + "prompt": prompts[0] if prompts else "", } # -- LLM end --------------------------------------------------------------- @@ -262,10 +304,13 @@ class _LLMCallbackHandler(BaseCallbackHandler): tokens_in = 0 tokens_out = 0 model_from_response = "" + full_response = "" try: generation = response.generations[0][0] + full_response = generation.text if hasattr(generation, "text") else str(generation) if hasattr(generation, "message"): msg = generation.message + full_response = msg.content if hasattr(msg, "content") else full_response if isinstance(msg, AIMessage) and hasattr(msg, "usage_metadata") and msg.usage_metadata: tokens_in = msg.usage_metadata.get("input_tokens", 0) tokens_out = msg.usage_metadata.get("output_tokens", 0) @@ -277,6 +322,7 @@ class _LLMCallbackHandler(BaseCallbackHandler): model = model_from_response or (meta["model"] if meta else "unknown") agent = meta["agent"] if meta else "" duration_ms = (time.time() - meta["t0"]) * 1000 if meta else 0 + prompt = meta["prompt"] if meta else "" evt = _Event( kind="llm", @@ -287,6 +333,8 @@ class _LLMCallbackHandler(BaseCallbackHandler): "tokens_in": tokens_in, "tokens_out": tokens_out, "duration_ms": round(duration_ms, 1), + "prompt": prompt, + "response": full_response, }, ) self._rl._append(evt) diff --git a/tradingagents/portfolio/dual_report_store.py b/tradingagents/portfolio/dual_report_store.py new file mode 100644 index 00000000..3f953ace --- /dev/null +++ b/tradingagents/portfolio/dual_report_store.py @@ -0,0 +1,175 @@ +"""Dual report store that persists to both local filesystem and MongoDB. + +Delegates all save_* calls to both a :class:`ReportStore` and a +:class:`MongoReportStore`. Load methods prioritize the MongoDB store if +available, otherwise fall back to the filesystem. +""" + +from __future__ import annotations + +from typing import Any, TYPE_CHECKING + +if TYPE_CHECKING: + from pathlib import Path + from tradingagents.portfolio.report_store import ReportStore + from tradingagents.portfolio.mongo_report_store import MongoReportStore + + +class DualReportStore: + """Report store that writes to two backends simultaneously.""" + + def __init__(self, local_store: ReportStore, mongo_store: MongoReportStore) -> None: + self._local = local_store + self._mongo = mongo_store + + @property + def run_id(self) -> str | None: + return self._local.run_id + + # ------------------------------------------------------------------ + # Macro Scan + # ------------------------------------------------------------------ + + def save_scan(self, date: str, data: dict[str, Any]) -> Any: + # local returns Path, mongo returns str (_id) + local_result = self._local.save_scan(date, data) + self._mongo.save_scan(date, data) + return local_result + + def load_scan(self, date: str) -> dict[str, Any] | None: + return self._mongo.load_scan(date) or self._local.load_scan(date) + + # ------------------------------------------------------------------ + # Per-Ticker Analysis + # ------------------------------------------------------------------ + + def save_analysis(self, date: str, ticker: str, data: dict[str, Any]) -> Any: + local_result = self._local.save_analysis(date, ticker, data) + self._mongo.save_analysis(date, ticker, data) + return local_result + + def load_analysis(self, date: str, ticker: str) -> dict[str, Any] | None: + return self._mongo.load_analysis(date, ticker) or self._local.load_analysis(date, ticker) + + # ------------------------------------------------------------------ + # Holding Reviews + # ------------------------------------------------------------------ + + def save_holding_review(self, date: str, ticker: str, data: dict[str, Any]) -> Any: + local_result = self._local.save_holding_review(date, ticker, data) + self._mongo.save_holding_review(date, ticker, data) + return local_result + + def load_holding_review(self, date: str, ticker: str) -> dict[str, Any] | None: + return self._mongo.load_holding_review(date, ticker) or self._local.load_holding_review(date, ticker) + + # ------------------------------------------------------------------ + # Risk Metrics + # ------------------------------------------------------------------ + + def save_risk_metrics(self, date: str, portfolio_id: str, data: dict[str, Any]) -> Any: + local_result = self._local.save_risk_metrics(date, portfolio_id, data) + self._mongo.save_risk_metrics(date, portfolio_id, data) + return local_result + + def load_risk_metrics(self, date: str, portfolio_id: str) -> dict[str, Any] | None: + return self._mongo.load_risk_metrics(date, portfolio_id) or self._local.load_risk_metrics(date, portfolio_id) + + # ------------------------------------------------------------------ + # PM Decisions + # ------------------------------------------------------------------ + + def save_pm_decision( + self, + date: str, + portfolio_id: str, + data: dict[str, Any], + markdown: str | None = None, + ) -> Any: + local_result = self._local.save_pm_decision(date, portfolio_id, data, markdown=markdown) + self._mongo.save_pm_decision(date, portfolio_id, data, markdown=markdown) + return local_result + + def load_pm_decision(self, date: str, portfolio_id: str) -> dict[str, Any] | None: + return self._mongo.load_pm_decision(date, portfolio_id) or self._local.load_pm_decision(date, portfolio_id) + + # ------------------------------------------------------------------ + # Execution Results + # ------------------------------------------------------------------ + + def save_execution_result(self, date: str, portfolio_id: str, data: dict[str, Any]) -> Any: + local_result = self._local.save_execution_result(date, portfolio_id, data) + self._mongo.save_execution_result(date, portfolio_id, data) + return local_result + + def load_execution_result(self, date: str, portfolio_id: str) -> dict[str, Any] | None: + return self._mongo.load_execution_result(date, portfolio_id) or self._local.load_execution_result(date, portfolio_id) + + # ------------------------------------------------------------------ + # Run Meta / Events persistence + # ------------------------------------------------------------------ + + def save_run_meta(self, date: str, data: dict[str, Any]) -> Any: + local_result = self._local.save_run_meta(date, data) + self._mongo.save_run_meta(date, data) + return local_result + + def load_run_meta(self, date: str) -> dict[str, Any] | None: + return self._mongo.load_run_meta(date) or self._local.load_run_meta(date) + + def save_run_events(self, date: str, events: list[dict[str, Any]]) -> Any: + local_result = self._local.save_run_events(date, events) + self._mongo.save_run_events(date, events) + return local_result + + def load_run_events(self, date: str) -> list[dict[str, Any]]: + mongo_events = self._mongo.load_run_events(date) + if mongo_events: + return mongo_events + return self._local.load_run_events(date) + + def list_run_metas(self) -> list[dict[str, Any]]: + mongo_metas = self._mongo.list_run_metas() + if mongo_metas: + return mongo_metas + return self._local.list_run_metas() + + # ------------------------------------------------------------------ + # Analyst / Trader Checkpoints + # ------------------------------------------------------------------ + + def save_analysts_checkpoint(self, date: str, ticker: str, data: dict[str, Any]) -> Any: + local_result = self._local.save_analysts_checkpoint(date, ticker, data) + self._mongo.save_analysts_checkpoint(date, ticker, data) + return local_result + + def load_analysts_checkpoint(self, date: str, ticker: str) -> dict[str, Any] | None: + return self._mongo.load_analysts_checkpoint(date, ticker) or self._local.load_analysts_checkpoint(date, ticker) + + def save_trader_checkpoint(self, date: str, ticker: str, data: dict[str, Any]) -> Any: + local_result = self._local.save_trader_checkpoint(date, ticker, data) + self._mongo.save_trader_checkpoint(date, ticker, data) + return local_result + + def load_trader_checkpoint(self, date: str, ticker: str) -> dict[str, Any] | None: + return self._mongo.load_trader_checkpoint(date, ticker) or self._local.load_trader_checkpoint(date, ticker) + + # ------------------------------------------------------------------ + # Utility + # ------------------------------------------------------------------ + + def clear_portfolio_stage(self, date: str, portfolio_id: str) -> list[str]: + local_deleted = self._local.clear_portfolio_stage(date, portfolio_id) + self._mongo.clear_portfolio_stage(date, portfolio_id) + return local_deleted + + def list_pm_decisions(self, portfolio_id: str) -> list[Any]: + # Mongo returns dicts, Local returns Paths. Prefer Mongo for rich data. + mongo_results = self._mongo.list_pm_decisions(portfolio_id) + if mongo_results: + return mongo_results + return self._local.list_pm_decisions(portfolio_id) + + def list_analyses_for_date(self, date: str) -> list[str]: + # Both return list[str] + return list(set(self._mongo.list_analyses_for_date(date)) | set(self._local.list_analyses_for_date(date))) diff --git a/tradingagents/portfolio/store_factory.py b/tradingagents/portfolio/store_factory.py index 858c1a4f..34810f98 100644 --- a/tradingagents/portfolio/store_factory.py +++ b/tradingagents/portfolio/store_factory.py @@ -18,6 +18,7 @@ import os from typing import Union from tradingagents.portfolio.report_store import ReportStore +from tradingagents.portfolio.dual_report_store import DualReportStore logger = logging.getLogger(__name__) @@ -28,13 +29,13 @@ def create_report_store( base_dir: str | None = None, mongo_uri: str | None = None, mongo_db: str | None = None, -) -> Union[ReportStore, "MongoReportStore"]: # noqa: F821 +) -> Union[ReportStore, "MongoReportStore", DualReportStore]: # noqa: F821 """Create and return the appropriate report store. Resolution order for the backend: - 1. If *mongo_uri* is passed explicitly, use MongoDB. - 2. If ``TRADINGAGENTS_MONGO_URI`` env var is set, use MongoDB. + 1. If *mongo_uri* is passed explicitly, use DualReportStore. + 2. If ``TRADINGAGENTS_MONGO_URI`` env var is set, use DualReportStore. 3. Fall back to the filesystem :class:`ReportStore`. Args: @@ -44,32 +45,33 @@ def create_report_store( mongo_db: MongoDB database name (default ``"tradingagents"``). Returns: - A store instance (either ``ReportStore`` or ``MongoReportStore``). + A store instance (either ``ReportStore`` or ``DualReportStore``). """ uri = mongo_uri or os.getenv("TRADINGAGENTS_MONGO_URI", "") db = mongo_db or os.getenv("TRADINGAGENTS_MONGO_DB", "tradingagents") + # Filesystem instance (always created as part of Dual or as standalone) + _base = base_dir or os.getenv("PORTFOLIO_DATA_DIR") or os.getenv( + "TRADINGAGENTS_REPORTS_DIR", "reports" + ) + local_store = ReportStore(base_dir=_base, run_id=run_id) + if uri: try: from tradingagents.portfolio.mongo_report_store import MongoReportStore - store = MongoReportStore( + mongo_store = MongoReportStore( connection_string=uri, db_name=db, run_id=run_id, ) - # ensure_indexes() is called automatically in __init__ - logger.info("Using MongoDB report store (db=%s, run_id=%s)", db, run_id) - return store + logger.info("Using Dual report store (local + MongoDB db=%s, run_id=%s)", db, run_id) + return DualReportStore(local_store, mongo_store) except Exception: logger.warning( "MongoDB connection failed — falling back to filesystem store", exc_info=True, ) - # Filesystem fallback - _base = base_dir or os.getenv("PORTFOLIO_DATA_DIR") or os.getenv( - "TRADINGAGENTS_REPORTS_DIR", "reports" - ) logger.info("Using filesystem report store (base=%s, run_id=%s)", _base, run_id) - return ReportStore(base_dir=_base, run_id=run_id) + return local_store