From eb34152d558d560c1efab20726dc9abc979ee362 Mon Sep 17 00:00:00 2001 From: Kevin Bruton Date: Tue, 30 Sep 2025 11:39:07 +0200 Subject: [PATCH] feat: Add WebSockets - add better error handling and retry behaviour --- README.md | 63 +++++ memory_store/chroma.sqlite3 | Bin 163840 -> 163840 bytes .../agents/managers/research_manager.py | 3 +- tradingagents/agents/managers/risk_manager.py | 3 +- .../agents/managers/trade_planner.py | 5 +- .../agents/researchers/bear_researcher.py | 4 +- .../agents/researchers/bull_researcher.py | 4 +- .../agents/risk_mgmt/aggresive_debator.py | 4 +- .../agents/risk_mgmt/conservative_debator.py | 4 +- .../agents/risk_mgmt/neutral_debator.py | 4 +- tradingagents/agents/trader/trader.py | 4 +- tradingagents/agents/utils/safe_llm.py | 92 +++++++ webapp/main.py | 248 ++++++++++++++++-- webapp/templates/index.html | 181 ++++++++++--- 14 files changed, 539 insertions(+), 80 deletions(-) create mode 100644 tradingagents/agents/utils/safe_llm.py diff --git a/README.md b/README.md index 2f65bbb6..1c815244 100644 --- a/README.md +++ b/README.md @@ -181,6 +181,23 @@ If you leave these fields blank or select "No Open Position", the system will ge 5. Enter a company symbol (e.g., `AAPL`) in the configuration form and click "Start Process" to begin the analysis. 6. (Optional) If you have an open position, select Long/Short and enter existing stop-loss / take-profit so the final decision can include management guidance. +### Real-Time Updates (WebSockets) + +The web frontend now uses a WebSocket channel (`/ws`) for real-time status and progress updates instead of relying solely on periodic HTTP polling. + +Benefits: +- Lower latency updates as each agent completes +- Reduced network overhead vs. 2s polling +- Automatic retry with exponential backoff if the socket drops +- Graceful fallback to legacy polling if WebSockets are unavailable + +Client behavior: +- On connection, the server sends an `init` payload with the current execution tree and progress. +- Subsequent incremental updates are sent as `status_update` messages. +- When you click an item, the existing HTMX request still works; alternatively, the client can request content over the socket using `{ "action": "get_content", "item_id": "..." }`. + +If you need to disable WebSockets (e.g., for debugging a proxy), you can block the `/ws` path and the client will automatically revert to polling. + ### Rendered Reports (Markdown Support) Agent-generated reports (analysis summaries, debate histories, plans, and risk assessments) are produced in Markdown. The web frontend now renders these Markdown documents as styled HTML instead of showing raw markup. This includes support for: @@ -191,6 +208,52 @@ Agent-generated reports (analysis summaries, debate histories, plans, and risk a Security: Markdown is sanitized server‑side using `bleach` to strip unsafe tags/attributes while preserving semantic structure. If you need to extend allowed tags (e.g., to permit additional formatting), modify `ALLOWED_TAGS` / `ALLOWED_ATTRIBUTES` in `webapp/main.py`. +### LLM Invocation Reliability (Automatic Retry Layer) + +Many agent nodes perform JSON-heavy LLM calls that can occasionally fail due to transient network issues (timeouts, dropped connections) or incomplete JSON payloads returned by the provider. To reduce user-facing errors and noisy red states in the execution tree, TradingAgents wraps model calls with a lightweight exponential backoff retry helper. + +Core implementation: `safe_invoke_llm` in `tradingagents/agents/utils/safe_llm.py`. + +Default behavior: +- Retries up to 4 attempts (configurable) on a targeted set of transient exceptions. +- Backoff: exponential (base 0.75s) with ±30% jitter, capped at 8s. +- Immediate propagation for non-transient errors (logical / prompt / auth failures aren’t retried). + +Transient exception classes handled: +- `json.JSONDecodeError` (malformed or truncated JSON) +- `httpx.TimeoutException` +- `httpx.ConnectError` +- `httpx.NetworkError` (if available in the installed httpx version) +- Heuristic: any exception message containing both `Expecting value` and `json` (covers provider-specific wrappers) + +Why this matters: +- Prevents single flaky decode from aborting an entire multi-agent debate or risk evaluation phase. +- Smooths over brief provider-side instabilities and network blips without user intervention. +- Reduces false-negative failure attribution in the UI. + +Customization: +You can supply a custom `LLMRetryConfig` if a node needs different resilience parameters: +```python +from tradingagents.agents.utils.safe_llm import safe_invoke_llm, LLMRetryConfig + +cfg = LLMRetryConfig(max_attempts=6, base_delay=0.5, max_delay=10.0, jitter=0.25) +response = safe_invoke_llm(llm, prompt, cfg) +``` + +Disabling or tightening: +- To effectively disable retries for debugging, set `max_attempts=1`. +- For latency-sensitive quick-think paths, you can lower `max_attempts` or `max_delay`. + +Logging & observability (future enhancement): +- Currently, retries are silent except for aggregate timing impact. If you need visibility, wrap `safe_invoke_llm` and add structured logging around each attempt. + +Edge cases not retried: +- Authentication / quota errors +- Deterministic validation failures in downstream parsing +- Prompt formatting errors (these should be fixed at the source) + +If you encounter a failure pattern you believe should be considered transient, you can extend `TRANSIENT_EXCEPTION_TYPES` inside `safe_llm.py`. + ## TradingAgents Package diff --git a/memory_store/chroma.sqlite3 b/memory_store/chroma.sqlite3 index 51ccff21d2d98b421f5b463edccc1e335215332e..3e345e0142d2e350d8a357ca21b0a939e401ca34 100644 GIT binary patch delta 82 zcmZo@;A&{#njp>SI#I@%(RE|O5`AVHzAKa24f478UhqBUyT^Bv@5*LFg%CbTW*cS( ePG)NmVFe;AL4*a6Xnt$g{??9h`&&CEVMhSMz!jwc delta 45 zcmZo@;A&{#njp>SFj2;t(P3l45`AWKzPFRv4e~b|DwOjzzqM dict: @@ -43,8 +44,7 @@ Based on your analysis, provide the stop-loss and take-profit levels in a JSON f The stop-loss level is mandatory. The take-profit level is optional. Do not provide any other information or explanation. ''' - - response = llm.invoke(prompt) + response = safe_invoke_llm(llm, prompt) try: levels = json.loads(response.content) @@ -54,7 +54,6 @@ Do not provide any other information or explanation. stop_loss = None take_profit = None - return { "stop_loss": stop_loss, "take_profit": take_profit, diff --git a/tradingagents/agents/researchers/bear_researcher.py b/tradingagents/agents/researchers/bear_researcher.py index 5af7db39..196c7c82 100644 --- a/tradingagents/agents/researchers/bear_researcher.py +++ b/tradingagents/agents/researchers/bear_researcher.py @@ -1,6 +1,7 @@ from langchain_core.messages import AIMessage import time import json +from tradingagents.agents.utils.safe_llm import safe_invoke_llm def create_bear_researcher(llm, memory): @@ -53,8 +54,7 @@ Last bull argument: {current_response} Reflections from similar situations and lessons learned: {past_memory_str} Use this information to deliver a compelling bear argument, refute the bull's claims, and engage in a dynamic debate that demonstrates the risks and weaknesses of investing in the stock. You must also address reflections and learn from lessons and mistakes you made in the past. """ - - response = llm.invoke(prompt) + response = safe_invoke_llm(llm, prompt) argument = f"Bear Analyst: {response.content}" diff --git a/tradingagents/agents/researchers/bull_researcher.py b/tradingagents/agents/researchers/bull_researcher.py index f3b1098c..64c544f2 100644 --- a/tradingagents/agents/researchers/bull_researcher.py +++ b/tradingagents/agents/researchers/bull_researcher.py @@ -1,6 +1,7 @@ from langchain_core.messages import AIMessage import time import json +from tradingagents.agents.utils.safe_llm import safe_invoke_llm def create_bull_researcher(llm, memory): @@ -51,8 +52,7 @@ Last bear argument: {current_response} Reflections from similar situations and lessons learned: {past_memory_str} Use this information to deliver a compelling bull argument, refute the bear's concerns, and engage in a dynamic debate that demonstrates the strengths of the bull position. You must also address reflections and learn from lessons and mistakes you made in the past. """ - - response = llm.invoke(prompt) + response = safe_invoke_llm(llm, prompt) argument = f"Bull Analyst: {response.content}" diff --git a/tradingagents/agents/risk_mgmt/aggresive_debator.py b/tradingagents/agents/risk_mgmt/aggresive_debator.py index f07599d7..918f5030 100644 --- a/tradingagents/agents/risk_mgmt/aggresive_debator.py +++ b/tradingagents/agents/risk_mgmt/aggresive_debator.py @@ -1,5 +1,6 @@ import time import json +from tradingagents.agents.utils.safe_llm import safe_invoke_llm def create_risky_debator(llm): @@ -41,8 +42,7 @@ Company Fundamentals Report: {fundamentals_report} Here is the current conversation history: {history} Here are the last arguments from the conservative analyst: {current_safe_response} Here are the last arguments from the neutral analyst: {current_neutral_response}. If there are no responses from the other viewpoints, do not halluncinate and just present your point. Engage actively by addressing any specific concerns raised, refuting the weaknesses in their logic, and asserting the benefits of risk-taking to outpace market norms. Maintain a focus on debating and persuading, not just presenting data. Challenge each counterpoint to underscore why a high-risk approach is optimal. Output conversationally as if you are speaking without any special formatting.""" - - response = llm.invoke(prompt) + response = safe_invoke_llm(llm, prompt) argument = f"Risky Analyst: {response.content}" diff --git a/tradingagents/agents/risk_mgmt/conservative_debator.py b/tradingagents/agents/risk_mgmt/conservative_debator.py index a17cc66e..7867743d 100644 --- a/tradingagents/agents/risk_mgmt/conservative_debator.py +++ b/tradingagents/agents/risk_mgmt/conservative_debator.py @@ -1,6 +1,7 @@ from langchain_core.messages import AIMessage import time import json +from tradingagents.agents.utils.safe_llm import safe_invoke_llm def create_safe_debator(llm): @@ -42,8 +43,7 @@ Company Fundamentals Report: {fundamentals_report} Here is the current conversation history: {history} Here is the last response from the risky analyst: {current_risky_response} Here is the last response from the neutral analyst: {current_neutral_response}. If there are no responses from the other viewpoints, do not halluncinate and just present your point. Engage by questioning their optimism and emphasizing the potential downsides they may have overlooked. Address each of their counterpoints to showcase why a conservative stance is ultimately the safest path for the firm's assets. Focus on debating and critiquing their arguments to demonstrate the strength of a low-risk strategy over their approaches. Output conversationally as if you are speaking without any special formatting.""" - - response = llm.invoke(prompt) + response = safe_invoke_llm(llm, prompt) argument = f"Safe Analyst: {response.content}" diff --git a/tradingagents/agents/risk_mgmt/neutral_debator.py b/tradingagents/agents/risk_mgmt/neutral_debator.py index 6bc6eef2..3d2b4452 100644 --- a/tradingagents/agents/risk_mgmt/neutral_debator.py +++ b/tradingagents/agents/risk_mgmt/neutral_debator.py @@ -1,5 +1,6 @@ import time import json +from tradingagents.agents.utils.safe_llm import safe_invoke_llm def create_neutral_debator(llm): @@ -41,8 +42,7 @@ Company Fundamentals Report: {fundamentals_report} Here is the current conversation history: {history} Here is the last response from the risky analyst: {current_risky_response} Here is the last response from the safe analyst: {current_safe_response}. If there are no responses from the other viewpoints, do not halluncinate and just present your point. Engage actively by analyzing both sides critically, addressing weaknesses in the risky and conservative arguments to advocate for a more balanced approach. Challenge each of their points to illustrate why a moderate risk strategy might offer the best of both worlds, providing growth potential while safeguarding against extreme volatility. Focus on debating rather than simply presenting data, aiming to show that a balanced view can lead to the most reliable outcomes. Output conversationally as if you are speaking without any special formatting.""" - - response = llm.invoke(prompt) + response = safe_invoke_llm(llm, prompt) argument = f"Neutral Analyst: {response.content}" diff --git a/tradingagents/agents/trader/trader.py b/tradingagents/agents/trader/trader.py index 9a0aa51d..0db12ab6 100644 --- a/tradingagents/agents/trader/trader.py +++ b/tradingagents/agents/trader/trader.py @@ -1,6 +1,7 @@ import functools import time import json +from tradingagents.agents.utils.safe_llm import safe_invoke_llm def create_trader(llm, memory): @@ -48,8 +49,7 @@ Your output should always be in markdown format.""", }, context, ] - - result = llm.invoke(messages) + result = safe_invoke_llm(llm, messages) return { "messages": [result], diff --git a/tradingagents/agents/utils/safe_llm.py b/tradingagents/agents/utils/safe_llm.py new file mode 100644 index 00000000..bd9220d1 --- /dev/null +++ b/tradingagents/agents/utils/safe_llm.py @@ -0,0 +1,92 @@ +import time +import random +from typing import Any, Callable, Sequence, Union +import json + +# Define which exceptions to treat as transient +try: + import httpx # type: ignore +except Exception: # pragma: no cover + httpx = None # fallback if not installed (but project includes it transitively) + +TRANSIENT_EXCEPTION_TYPES = [] +if httpx: + TRANSIENT_EXCEPTION_TYPES.extend([ + httpx.TimeoutException, + httpx.ConnectError, + httpx.NetworkError if hasattr(httpx, 'NetworkError') else Exception, # broad fallback + ]) + +# Always include JSON decode errors +from json import JSONDecodeError +TRANSIENT_EXCEPTION_TYPES.append(JSONDecodeError) + + +class LLMRetryConfig: + def __init__( + self, + max_attempts: int = 4, + base_delay: float = 0.75, + max_delay: float = 8.0, + jitter: float = 0.3, + ): + self.max_attempts = max_attempts + self.base_delay = base_delay + self.max_delay = max_delay + self.jitter = jitter + + +def _compute_backoff(attempt: int, cfg: LLMRetryConfig) -> float: + # Exponential backoff with jitter + delay = min(cfg.base_delay * (2 ** (attempt - 1)), cfg.max_delay) + if cfg.jitter: + delta = delay * cfg.jitter + delay = random.uniform(delay - delta, delay + delta) + return max(0.05, delay) + + +def safe_invoke_llm(llm: Any, payload: Union[str, Sequence[dict]], cfg: LLMRetryConfig | None = None): + """Invoke an LLM with retries for transient decode/network errors. + + Parameters + ---------- + llm : Any + LangChain-compatible LLM/chat model with an .invoke() method. + payload : str | list + Prompt string or messages sequence. + cfg : LLMRetryConfig | None + Retry configuration (defaults sensible for API use). + + Returns + ------- + result : Any + Model response from final successful attempt. + + Raises + ------ + Exception + The last raised exception if all attempts fail. + """ + if cfg is None: + cfg = LLMRetryConfig() + + attempts = 0 + last_error: Exception | None = None + while attempts < cfg.max_attempts: + attempts += 1 + try: + return llm.invoke(payload) + except Exception as e: # noqa: BLE001 + is_transient = isinstance(e, tuple(TRANSIENT_EXCEPTION_TYPES)) + # Some OpenAI / router errors wrap JSON decode text; heuristic fallback + if not is_transient and 'Expecting value' in str(e) and 'json' in str(e).lower(): + is_transient = True + if attempts >= cfg.max_attempts or not is_transient: + raise + last_error = e + delay = _compute_backoff(attempts, cfg) + time.sleep(delay) + # Should not reach here; safeguard + if last_error: + raise last_error + raise RuntimeError('safe_invoke_llm: exhausted without exception context') diff --git a/webapp/main.py b/webapp/main.py index e8234296..bfa5520f 100644 --- a/webapp/main.py +++ b/webapp/main.py @@ -1,5 +1,5 @@ -from fastapi import FastAPI, Request, Form, BackgroundTasks, HTTPException -from fastapi.responses import HTMLResponse +from fastapi import FastAPI, Request, Form, BackgroundTasks, HTTPException, WebSocket, WebSocketDisconnect +from fastapi.responses import HTMLResponse, JSONResponse from fastapi.staticfiles import StaticFiles import jinja2 import markdown as md @@ -7,6 +7,7 @@ import bleach import os from typing import Dict, Any import threading +import asyncio import time from dotenv import load_dotenv @@ -31,6 +32,60 @@ from tradingagents.graph.trading_graph import TradingAgentsGraph app = FastAPI() +# Main event loop reference (captured at startup) so threads can schedule coroutines +MAIN_EVENT_LOOP: asyncio.AbstractEventLoop | None = None + +@app.on_event("startup") +async def _capture_loop(): + global MAIN_EVENT_LOOP + MAIN_EVENT_LOOP = asyncio.get_event_loop() + +# ============================================== +# WebSocket Connection Management +# ============================================== +class ConnectionManager: + """Tracks active websocket connections and allows broadcast of messages.""" + def __init__(self): + self._connections: set[WebSocket] = set() + self._lock = threading.Lock() + + async def connect(self, websocket: WebSocket): + await websocket.accept() + with self._lock: + self._connections.add(websocket) + + def disconnect_sync(self, websocket: WebSocket): + # Called from sync context in finally blocks + with self._lock: + if websocket in self._connections: + self._connections.remove(websocket) + + async def disconnect(self, websocket: WebSocket): + with self._lock: + if websocket in self._connections: + self._connections.remove(websocket) + try: + await websocket.close() + except Exception: + pass + + async def broadcast_json(self, payload: dict): + """Broadcast JSON payload to all active connections, pruning dead ones.""" + to_remove = [] + with self._lock: + conns = list(self._connections) + for ws in conns: + try: + await ws.send_json(payload) + except Exception: + to_remove.append(ws) + if to_remove: + with self._lock: + for ws in to_remove: + self._connections.discard(ws) + +manager = ConnectionManager() + # In-memory storage for the process state # Using a lock for thread-safe access to app_state app_state_lock = threading.Lock() @@ -80,6 +135,27 @@ def render_markdown(value: str) -> str: jinja_env.filters['markdown'] = render_markdown +async def _broadcast_status_locked_unlocked(): + """Helper to broadcast status updates using existing helper logic.""" + status_updates = {} + def extract_status_info(items): + for item in items: + status_updates[item["id"]] = { + "status": item["status"], + "status_icon": get_status_icon(item["status"]) + } + if item.get("children"): + extract_status_info(item["children"]) + with app_state_lock: + extract_status_info(app_state.get("execution_tree", [])) + payload = { + "type": "status_update", + "status_updates": status_updates, + "overall_progress": app_state.get("overall_progress", 0), + "overall_status": app_state.get("overall_status", "idle") + } + await manager.broadcast_json(payload) + def update_execution_state(state: Dict[str, Any]): """Callback function to update the app_state based on LangGraph's state.""" print(f"📡 Callback received state keys: {list(state.keys())}") @@ -191,6 +267,14 @@ def update_execution_state(state: Dict[str, Any]): print(f"📊 Progress updated: {app_state['overall_progress']}% ({completed_agents}/{total_agents} agents)") + # Fire-and-forget broadcast using main loop even when we're in a worker thread + try: + if MAIN_EVENT_LOOP and not MAIN_EVENT_LOOP.is_closed(): + asyncio.run_coroutine_threadsafe(_broadcast_status_locked_unlocked(), MAIN_EVENT_LOOP) + except Exception as _e: + # Silently ignore broadcast issues; optionally log + pass + def initialize_complete_execution_tree(): """Initialize the complete execution tree with all agents in pending state.""" return [ @@ -333,6 +417,30 @@ def find_agent_in_tree(agent_id: str, tree: list): return agent return None +def mark_agent_error(agent_id: str, error_message: str): + """Mark a specific agent (and its phase) as error with provided message.""" + execution_tree = app_state.get("execution_tree", []) + target_agent = find_agent_in_tree(agent_id, execution_tree) + if not target_agent: + return False + # Mark agent + target_agent["status"] = "error" + target_agent["content"] = f"❌ {target_agent['name']} - Error encountered\n\n{error_message}" + # Mark any children as error for clarity + for child in target_agent.get("children", []): + if child["status"] != "completed": + child["status"] = "error" + if not child.get("content"): + child["content"] = "Error in parent agent" + # Mark containing phase error + for phase in execution_tree: + if phase.get("children") and any(c is target_agent for c in phase["children"]): + phase["status"] = "error" + if not phase.get("content") or "Error" not in phase["content"]: + phase["content"] = f"❌ {phase['name']} - Error in {target_agent['name']}" + break + return True + def find_item_by_id(item_id: str, items: list): """Find an item by ID in a list of items.""" for item in items: @@ -548,24 +656,62 @@ def run_trading_process(company_symbol: str, config: Dict[str, Any]): except Exception as e: import traceback error_detail = traceback.format_exc() + # Attempt to extract agent name from traceback (LangGraph style: "During task with name 'Risk Judge'") + import re + agent_name = None + m = re.search(r"During task with name '([^']+)'", error_detail) + if m: + agent_name = m.group(1) + # Map human-readable agent name to internal agent_id used in tree + name_to_id = { + "Market Analyst": "market_analyst", + "Social Analyst": "social_analyst", + "News Analyst": "news_analyst", + "Fundamentals Analyst": "fundamentals_analyst", + "Bull Researcher": "bull_researcher", + "Bear Researcher": "bear_researcher", + "Research Manager": "research_manager", + "Trade Planner": "trade_planner", + "Trader": "trader", + "Risky Analyst": "risky_analyst", + "Neutral Analyst": "neutral_analyst", + "Safe Analyst": "safe_analyst", + "Risk Judge": "risk_judge" + } + mapped_agent_id = name_to_id.get(agent_name) if agent_name else None with app_state_lock: app_state["overall_status"] = "error" app_state["overall_progress"] = 100 - if app_state["execution_tree"]: + # Mark specific agent if identified; else attach error to root phase (first) + if mapped_agent_id and mark_agent_error(mapped_agent_id, f"Error during execution: {str(e)}"): + pass + elif app_state["execution_tree"]: app_state["execution_tree"][0]["status"] = "error" app_state["execution_tree"][0]["content"] = f"Error during execution: {str(e)}\n\n{error_detail}" - # Add a specific error item to the tree + # Always append detailed error node for inspection app_state["execution_tree"].append({ "id": "error", - "name": "Process Error", + "name": f"Process Error{(' - ' + agent_name) if agent_name else ''}", "status": "error", "content": f"Error during execution: {str(e)}\n\n{error_detail}", "children": [], "timestamp": time.time() }) + # Immediate broadcast of error state + if MAIN_EVENT_LOOP and not MAIN_EVENT_LOOP.is_closed(): + try: + asyncio.run_coroutine_threadsafe(_broadcast_status_locked_unlocked(), MAIN_EVENT_LOOP) + except Exception: + pass finally: with app_state_lock: app_state["process_running"] = False + # Final broadcast to push terminal status + if MAIN_EVENT_LOOP and not MAIN_EVENT_LOOP.is_closed(): + try: + asyncio.run_coroutine_threadsafe(_broadcast_status_locked_unlocked(), MAIN_EVENT_LOOP) + except Exception: + pass @app.get("/", response_class=HTMLResponse) @@ -577,7 +723,7 @@ async def read_root(): @app.post("/start", response_class=HTMLResponse) async def start_process( - background_tasks: BackgroundTasks, + background_tasks: BackgroundTasks, # kept for backward compat; no longer used for long task company_symbol: str = Form(...), llm_provider: str = Form(...), quick_think_llm: str = Form(...), @@ -655,7 +801,9 @@ async def start_process( # Initialize execution tree with complete structure app_state["execution_tree"] = initialize_complete_execution_tree() - background_tasks.add_task(run_trading_process, company_symbol, app_state["config"]) + # Launch heavy propagation in its own daemon thread so FastAPI loop remains responsive for websockets + worker = threading.Thread(target=run_trading_process, args=(company_symbol, app_state["config"]), daemon=True) + worker.start() template = jinja_env.get_template("_partials/left_panel.html") return template.render(tree=app_state["execution_tree"], app_state=app_state) @@ -669,27 +817,75 @@ async def get_status(): @app.get("/status-updates") async def get_status_updates(): - """Return only the status updates as JSON for targeted updates.""" + """Legacy endpoint for polling (kept as fallback).""" + status_updates = {} + def extract_status_info(items): + for item in items: + status_updates[item["id"]] = { + "status": item["status"], + "status_icon": get_status_icon(item["status"]) + } + if item.get("children"): + extract_status_info(item["children"]) with app_state_lock: - status_updates = {} - - def extract_status_info(items, prefix=""): - for item in items: - item_id = item["id"] - status_updates[item_id] = { - "status": item["status"], - "status_icon": get_status_icon(item["status"]) - } - if item.get("children"): - extract_status_info(item["children"]) - - extract_status_info(app_state["execution_tree"]) - - return { + extract_status_info(app_state.get("execution_tree", [])) + return JSONResponse({ "status_updates": status_updates, - "overall_progress": app_state["overall_progress"], - "overall_status": app_state["overall_status"] - } + "overall_progress": app_state.get("overall_progress", 0), + "overall_status": app_state.get("overall_status", "idle") + }) + +@app.websocket("/ws") +async def websocket_endpoint(websocket: WebSocket): + """Primary realtime channel. + Client should expect messages of forms: + {"type": "status_update", ...} - incremental + {"type": "init", execution_tree: [...], overall_progress, overall_status} + {"type": "content", item_id, html} + {"type": "error", message} + Client can send: {"action": "subscribe"} (ignored) or {"action": "get_content", "item_id": id} + """ + await manager.connect(websocket) + try: + # Send initial snapshot + with app_state_lock: + init_payload = { + "type": "init", + "execution_tree_html": jinja_env.get_template("_partials/left_panel.html").render(tree=app_state.get("execution_tree", []), app_state=app_state), + "overall_progress": app_state.get("overall_progress", 0), + "overall_status": app_state.get("overall_status", "idle"), + } + await websocket.send_json(init_payload) + + while True: + data = await websocket.receive_json() + action = data.get("action") + if action == "ping": + await websocket.send_json({"type": "pong"}) + elif action == "get_content": + item_id = data.get("item_id") + if not item_id: + await websocket.send_json({"type": "error", "message": "item_id required"}) + continue + with app_state_lock: + item = find_item_in_tree(item_id, app_state.get("execution_tree", [])) + if item: + html = jinja_env.get_template("_partials/right_panel.html").render(content=item.get("content", "No content available.")) + await websocket.send_json({"type": "content", "item_id": item_id, "html": html}) + else: + await websocket.send_json({"type": "error", "message": f"Item {item_id} not found"}) + else: + # ignore or echo + await websocket.send_json({"type": "ack", "received": action}) + except WebSocketDisconnect: + manager.disconnect_sync(websocket) + except Exception as e: + # Attempt to notify client + try: + await websocket.send_json({"type": "error", "message": str(e)}) + except Exception: + pass + manager.disconnect_sync(websocket) def get_status_icon(status: str) -> str: """Get the status icon for a given status.""" diff --git a/webapp/templates/index.html b/webapp/templates/index.html index 51d8033a..de850dec 100644 --- a/webapp/templates/index.html +++ b/webapp/templates/index.html @@ -94,6 +94,132 @@