diff --git a/PROJECT_HANDOVER.md b/PROJECT_HANDOVER.md index 6212c630..aaceb891 100644 --- a/PROJECT_HANDOVER.md +++ b/PROJECT_HANDOVER.md @@ -82,6 +82,6 @@ python run_ningde.py # 宁德时代 ## API配置 -- API Key: Read from a local environment variable; do not commit secrets +- API Key: 从本地环境变量读取(不要提交到仓库) - Base URL: `https://api.minimaxi.com/anthropic` -- Model: `MiniMax-M2.7-highspeed` \ No newline at end of file +- Model: `MiniMax-M2.7-highspeed` diff --git a/orchestrator/profile_stage_chain.py b/orchestrator/profile_stage_chain.py index 284e88c9..332c701d 100644 --- a/orchestrator/profile_stage_chain.py +++ b/orchestrator/profile_stage_chain.py @@ -1,10 +1,12 @@ from __future__ import annotations +import _thread import argparse import json -import signal +import threading import time from collections import defaultdict +from contextlib import contextmanager from datetime import datetime, timezone from pathlib import Path @@ -58,6 +60,27 @@ class _ProfileTimeout(Exception): pass +@contextmanager +def _overall_timeout_guard(seconds: int): + timed_out = threading.Event() + timer: threading.Timer | None = None + + def interrupt_main() -> None: + timed_out.set() + _thread.interrupt_main() + + if seconds > 0: + timer = threading.Timer(seconds, interrupt_main) + timer.daemon = True + timer.start() + + try: + yield timed_out + finally: + if timer is not None: + timer.cancel() + + def _jsonable(value): if isinstance(value, (str, int, float, bool)) or value is None: return value @@ -121,6 +144,8 @@ def build_trace_payload( if exception_type is not None: payload["exception_type"] = exception_type return payload + + def main() -> None: args = build_parser().parse_args() selected_analysts = [item.strip() for item in args.selected_analysts.split(",") if item.strip()] @@ -151,40 +176,40 @@ def main() -> None: dump_dir.mkdir(parents=True, exist_ok=True) dump_path = dump_dir / f"{args.ticker.replace('/', '_')}_{args.date}_{run_id}.json" - def alarm_handler(signum, frame): - raise _ProfileTimeout(f"profiling timeout after {args.overall_timeout}s") - - signal.signal(signal.SIGALRM, alarm_handler) - signal.alarm(args.overall_timeout) - try: - for event in graph.graph.stream(state, stream_mode="updates", config=config_kwargs): - now = time.monotonic() - nodes = list(event.keys()) - phases = sorted({_PHASE_MAP.get(node, "unknown") for node in nodes}) - llm_kinds = sorted({_LLM_KIND_MAP.get(node, "unknown") for node in nodes}) - delta = round(now - last_at, 3) - research_status, degraded_reason, history_len, response_len = _extract_research_state(event) - entry = { - "run_id": run_id, - "nodes": nodes, - "phases": phases, - "llm_kinds": llm_kinds, - "start_at": round(last_at - started_at, 3), - "end_at": round(now - started_at, 3), - "elapsed_ms": int(delta * 1000), - "selected_analysts": selected_analysts, - "analysis_prompt_style": args.analysis_prompt_style, - "research_status": research_status, - "degraded_reason": degraded_reason, - "history_len": history_len, - "response_len": response_len, - } - node_timings.append(entry) - raw_events.append(_jsonable(event)) - for phase in phases: - phase_totals[phase] += delta - last_at = now + with _overall_timeout_guard(args.overall_timeout) as timed_out: + try: + for event in graph.graph.stream(state, stream_mode="updates", config=config_kwargs): + now = time.monotonic() + nodes = list(event.keys()) + phases = sorted({_PHASE_MAP.get(node, "unknown") for node in nodes}) + llm_kinds = sorted({_LLM_KIND_MAP.get(node, "unknown") for node in nodes}) + delta = round(now - last_at, 3) + research_status, degraded_reason, history_len, response_len = _extract_research_state(event) + entry = { + "run_id": run_id, + "nodes": nodes, + "phases": phases, + "llm_kinds": llm_kinds, + "start_at": round(last_at - started_at, 3), + "end_at": round(now - started_at, 3), + "elapsed_ms": int(delta * 1000), + "selected_analysts": selected_analysts, + "analysis_prompt_style": args.analysis_prompt_style, + "research_status": research_status, + "degraded_reason": degraded_reason, + "history_len": history_len, + "response_len": response_len, + } + node_timings.append(entry) + raw_events.append(_jsonable(event)) + for phase in phases: + phase_totals[phase] += delta + last_at = now + except KeyboardInterrupt: + if timed_out.is_set(): + raise _ProfileTimeout(f"profiling timeout after {args.overall_timeout}s") from None + raise payload = { "status": "ok", @@ -212,8 +237,6 @@ def main() -> None: "dump_path": str(dump_path), "raw_events": raw_events, } - finally: - signal.alarm(0) dump_path.write_text(json.dumps(payload, ensure_ascii=False, indent=2)) print(json.dumps(payload, ensure_ascii=False, indent=2)) diff --git a/orchestrator/quant_runner.py b/orchestrator/quant_runner.py index c7a0a02b..21c54c4c 100644 --- a/orchestrator/quant_runner.py +++ b/orchestrator/quant_runner.py @@ -12,6 +12,7 @@ from orchestrator.config import OrchestratorConfig from orchestrator.contracts.error_taxonomy import ReasonCode from orchestrator.contracts.result_contract import Signal, build_error_signal from orchestrator.market_calendar import is_non_trading_day +from tradingagents.dataflows.stockstats_utils import yf_retry logger = logging.getLogger(__name__) @@ -48,7 +49,15 @@ class QuantRunner: start_str = start_dt.strftime("%Y-%m-%d") end_exclusive = (end_dt + timedelta(days=1)).strftime("%Y-%m-%d") - df = yf.download(ticker, start=start_str, end=end_exclusive, progress=False, auto_adjust=True) + df = yf_retry( + lambda: yf.download( + ticker, + start=start_str, + end=end_exclusive, + progress=False, + auto_adjust=True, + ) + ) if df.empty: logger.warning("No price data for %s between %s and %s", ticker, start_str, date) if is_non_trading_day(ticker, end_dt.date()): diff --git a/orchestrator/tests/test_profile_stage_chain.py b/orchestrator/tests/test_profile_stage_chain.py index b362b747..66b4db52 100644 --- a/orchestrator/tests/test_profile_stage_chain.py +++ b/orchestrator/tests/test_profile_stage_chain.py @@ -1,4 +1,5 @@ import json +from contextlib import contextmanager from datetime import datetime as real_datetime, timezone from pathlib import Path @@ -95,9 +96,13 @@ def test_main_writes_trace_payload_with_research_provenance(monkeypatch, tmp_pat monkeypatch.setattr(profile_stage_chain, "TradingAgentsGraph", _FakeTradingAgentsGraph) monkeypatch.setattr(profile_stage_chain, "Propagator", _FakePropagator) monkeypatch.setattr(profile_stage_chain.time, "monotonic", lambda: next(monotonic_points)) - monkeypatch.setattr(profile_stage_chain.signal, "signal", lambda *args, **kwargs: None) - monkeypatch.setattr(profile_stage_chain.signal, "alarm", lambda *args, **kwargs: None) monkeypatch.setattr(profile_stage_chain, "datetime", _FixedDateTime) + + @contextmanager + def fake_guard(_seconds): + yield profile_stage_chain.threading.Event() + + monkeypatch.setattr(profile_stage_chain, "_overall_timeout_guard", fake_guard) monkeypatch.setattr( "sys.argv", [ @@ -161,3 +166,51 @@ def test_main_writes_trace_payload_with_research_provenance(monkeypatch, tmp_pat dump_path = Path(output["dump_path"]) assert dump_path.exists() assert json.loads(dump_path.read_text()) == output + + +class _KeyboardInterruptGraph: + def __init__(self, *, selected_analysts, config): + self.graph = self + + def stream(self, state, stream_mode, config): + raise KeyboardInterrupt + yield + + +def test_main_reports_cross_platform_timeout(monkeypatch, tmp_path, capsys): + monkeypatch.setattr(profile_stage_chain, "TradingAgentsGraph", _KeyboardInterruptGraph) + monkeypatch.setattr(profile_stage_chain, "Propagator", _FakePropagator) + monkeypatch.setattr(profile_stage_chain, "datetime", _FixedDateTime) + + @contextmanager + def timed_out_guard(seconds): + event = profile_stage_chain.threading.Event() + event.set() + yield event + + monkeypatch.setattr(profile_stage_chain, "_overall_timeout_guard", timed_out_guard) + monkeypatch.setattr( + "sys.argv", + [ + "profile_stage_chain.py", + "--ticker", + "AAPL", + "--date", + "2026-04-11", + "--selected-analysts", + "market,social", + "--analysis-prompt-style", + "balanced", + "--overall-timeout", + "1", + "--dump-dir", + str(tmp_path), + ], + ) + + profile_stage_chain.main() + + output = json.loads(capsys.readouterr().out) + assert output["status"] == "error" + assert output["exception_type"] == "_ProfileTimeout" + assert output["error"] == "profiling timeout after 1s" diff --git a/orchestrator/tests/test_quant_runner.py b/orchestrator/tests/test_quant_runner.py index a6f26551..3b99bddc 100644 --- a/orchestrator/tests/test_quant_runner.py +++ b/orchestrator/tests/test_quant_runner.py @@ -183,3 +183,19 @@ def test_get_signal_marks_partial_data_when_required_columns_missing(runner, mon assert signal.degraded is True assert signal.reason_code == ReasonCode.PARTIAL_DATA.value assert signal.metadata["data_quality"]["state"] == "partial_data" + + +def test_get_signal_uses_yf_retry_wrapper(runner, monkeypatch): + calls = [] + + def fake_retry(func, max_retries=3, base_delay=2.0): + calls.append((max_retries, base_delay)) + return pd.DataFrame() + + monkeypatch.setattr("orchestrator.quant_runner.yf_retry", fake_retry) + monkeypatch.setattr("orchestrator.quant_runner.is_non_trading_day", lambda *_args, **_kwargs: False) + + signal = runner.get_signal("AAPL", "2024-01-02") + + assert calls == [(3, 2.0)] + assert signal.reason_code == ReasonCode.QUANT_NO_DATA.value diff --git a/tradingagents/agents/utils/decision_utils.py b/tradingagents/agents/utils/decision_utils.py new file mode 100644 index 00000000..2209d2ef --- /dev/null +++ b/tradingagents/agents/utils/decision_utils.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +import re +from typing import Any, Iterable + +CANONICAL_RATINGS = ("BUY", "OVERWEIGHT", "HOLD", "UNDERWEIGHT", "SELL") +_RATING_PATTERN = re.compile( + r"\b(BUY|OVERWEIGHT|HOLD|UNDERWEIGHT|SELL)\b", + re.IGNORECASE, +) + + +def extract_rating(text: str) -> str | None: + match = _RATING_PATTERN.search(str(text or "")) + if not match: + return None + return match.group(1).upper() + + +def _normalize_report_text(rating: str, rating_source: str, report_text: str) -> str: + body = str(report_text or "").strip() or "No narrative provided." + return ( + "## Normalized Portfolio Decision\n" + f"- Rating: {rating}\n" + f"- Rating Source: {rating_source}\n\n" + f"{body}" + ) + + +def build_structured_decision( + text: str, + *, + fallback_candidates: Iterable[tuple[str, str]] = (), + default_rating: str = "HOLD", + peer_context_mode: str = "UNSPECIFIED", + context_usage: dict[str, Any] | None = None, +) -> dict[str, Any]: + warnings: list[str] = [] + rating_source = "direct" + rating = extract_rating(text) + source_text = str(text or "") + + if rating is None: + for candidate_name, candidate_text in fallback_candidates: + rating = extract_rating(candidate_text) + if rating is not None: + rating_source = candidate_name + source_text = str(candidate_text or "") + warnings.append(f"rating_inferred_from:{candidate_name}") + break + + if rating is None: + rating = str(default_rating or "HOLD").upper() + rating_source = "default" + warnings.append("rating_defaulted") + + usage = context_usage or {} + hold_subtype = "UNSPECIFIED" if rating == "HOLD" else "N/A" + + return { + "rating": rating, + "hold_subtype": hold_subtype, + "rating_source": rating_source, + "report_text": _normalize_report_text(rating, rating_source, source_text), + "warnings": warnings, + "portfolio_context_used": bool(usage.get("portfolio_context")), + "peer_context_used": bool(usage.get("peer_context")), + "peer_context_mode": str(peer_context_mode or "UNSPECIFIED"), + } diff --git a/tradingagents/agents/utils/subagent_runner.py b/tradingagents/agents/utils/subagent_runner.py new file mode 100644 index 00000000..afecfb5a --- /dev/null +++ b/tradingagents/agents/utils/subagent_runner.py @@ -0,0 +1,99 @@ +from __future__ import annotations + +import time +from concurrent.futures import ThreadPoolExecutor, TimeoutError +from typing import Any + + +def _invoke_dimension(llm, dimension: str, prompt: str) -> dict[str, Any]: + started_at = time.monotonic() + try: + response = llm.invoke(prompt) + content = response.content if hasattr(response, "content") else str(response) + return { + "dimension": dimension, + "content": str(content).strip(), + "ok": True, + "error": None, + "elapsed_s": round(time.monotonic() - started_at, 3), + } + except Exception as exc: + return { + "dimension": dimension, + "content": "", + "ok": False, + "error": str(exc), + "elapsed_s": round(time.monotonic() - started_at, 3), + } + + +def run_parallel_subagents( + *, + llm, + dimension_configs: list[dict[str, Any]], + timeout_per_subagent: float = 25.0, + max_workers: int = 4, +) -> list[dict[str, Any]]: + if not dimension_configs: + return [] + + executor = ThreadPoolExecutor(max_workers=max_workers) + futures = { + executor.submit( + _invoke_dimension, + llm, + config["dimension"], + config["prompt"], + ): config["dimension"] + for config in dimension_configs + } + + results: list[dict[str, Any]] = [] + try: + for future, dimension in futures.items(): + try: + results.append(future.result(timeout=timeout_per_subagent)) + except TimeoutError: + results.append( + { + "dimension": dimension, + "content": "", + "ok": False, + "error": "timeout", + "elapsed_s": round(timeout_per_subagent, 3), + } + ) + finally: + executor.shutdown(wait=False, cancel_futures=True) + + return results + + +def synthesize_subagent_results( + subagent_results: list[dict[str, Any]], + *, + max_chars_per_result: int = 200, +) -> tuple[str, dict[str, Any]]: + lines: list[str] = [] + timings: dict[str, float] = {} + failures: dict[str, str] = {} + + for result in subagent_results: + dimension = str(result.get("dimension") or "unknown") + timings[dimension] = float(result.get("elapsed_s") or 0.0) + + content = str(result.get("content") or "").strip() + if not result.get("ok"): + failure_reason = str(result.get("error") or "unknown error") + failures[dimension] = failure_reason + content = f"[UNAVAILABLE: {failure_reason}]" + + if len(content) > max_chars_per_result: + content = f"{content[:max_chars_per_result - 3]}..." + + lines.append(f"[{dimension.upper()}]\n{content or '[NO OUTPUT]'}") + + return "\n\n".join(lines), { + "subagent_timings": timings, + "failures": failures, + } diff --git a/tradingagents/dataflows/stockstats_utils.py b/tradingagents/dataflows/stockstats_utils.py index 63d5ddf6..06e191ba 100644 --- a/tradingagents/dataflows/stockstats_utils.py +++ b/tradingagents/dataflows/stockstats_utils.py @@ -1,5 +1,6 @@ import time import logging +import threading import pandas as pd import yfinance as yf @@ -11,6 +12,16 @@ import os from .config import get_config logger = logging.getLogger(__name__) +_fallback_session_local = threading.local() + + +def _get_fallback_session() -> requests.Session: + session = getattr(_fallback_session_local, "session", None) + if session is None: + session = requests.Session() + session.trust_env = False + _fallback_session_local.session = session + return session def _symbol_to_tencent_code(symbol: str) -> str: @@ -24,8 +35,7 @@ def _symbol_to_tencent_code(symbol: str) -> str: def _fetch_tencent_ohlcv(symbol: str, start_date: str, end_date: str) -> pd.DataFrame: """Fallback daily OHLCV fetch for A-shares via Tencent.""" - session = requests.Session() - session.trust_env = False + session = _get_fallback_session() response = session.get( "https://web.ifzq.gtimg.cn/appstock/app/fqkline/get", params={ @@ -72,8 +82,7 @@ def _symbol_to_eastmoney_secid(symbol: str) -> str: def _fetch_eastmoney_ohlcv(symbol: str, start_date: str, end_date: str) -> pd.DataFrame: """Fallback daily OHLCV fetch for A-shares via Eastmoney.""" - session = requests.Session() - session.trust_env = False + session = _get_fallback_session() url = "https://push2his.eastmoney.com/api/qt/stock/kline/get" response = session.get( url, diff --git a/tradingagents/tests/test_stockstats_utils.py b/tradingagents/tests/test_stockstats_utils.py new file mode 100644 index 00000000..3e2e168b --- /dev/null +++ b/tradingagents/tests/test_stockstats_utils.py @@ -0,0 +1,22 @@ +import threading + +from tradingagents.dataflows import stockstats_utils + + +def test_get_fallback_session_reuses_session_in_same_thread(monkeypatch): + created = [] + + class FakeSession: + def __init__(self): + self.trust_env = True + created.append(self) + + monkeypatch.setattr(stockstats_utils, "_fallback_session_local", threading.local()) + monkeypatch.setattr(stockstats_utils.requests, "Session", FakeSession) + + first = stockstats_utils._get_fallback_session() + second = stockstats_utils._get_fallback_session() + + assert first is second + assert len(created) == 1 + assert first.trust_env is False diff --git a/web_dashboard/backend/api/portfolio.py b/web_dashboard/backend/api/portfolio.py index 25594686..05e1c9a7 100644 --- a/web_dashboard/backend/api/portfolio.py +++ b/web_dashboard/backend/api/portfolio.py @@ -2,8 +2,8 @@ Portfolio API — 自选股、持仓、每日建议 """ import asyncio -import fcntl import json +import os import uuid from datetime import datetime from pathlib import Path @@ -11,6 +11,34 @@ from typing import Optional import yfinance +try: + import fcntl +except ImportError: # pragma: no cover - exercised on Windows + import msvcrt + + class _FcntlCompat: + LOCK_SH = 1 + LOCK_EX = 2 + LOCK_UN = 8 + + @staticmethod + def flock(fd: int, operation: int) -> None: + os.lseek(fd, 0, os.SEEK_SET) + if operation == _FcntlCompat.LOCK_UN: + try: + msvcrt.locking(fd, msvcrt.LK_UNLCK, 1) + except OSError: + return + return + + if os.fstat(fd).st_size == 0: + os.write(fd, b"\0") + os.lseek(fd, 0, os.SEEK_SET) + + msvcrt.locking(fd, msvcrt.LK_LOCK, 1) + + fcntl = _FcntlCompat() + # Data directory DATA_DIR = Path(__file__).parent.parent.parent / "data" DATA_DIR.mkdir(parents=True, exist_ok=True) @@ -153,7 +181,7 @@ def _fetch_price(ticker: str) -> float | None: async def _fetch_price_throttled(ticker: str) -> float | None: """Fetch price with semaphore throttling.""" async with _yfinance_semaphore: - return _fetch_price(ticker) + return await asyncio.to_thread(_fetch_price, ticker) async def get_positions(account: Optional[str] = None) -> list: diff --git a/web_dashboard/backend/tests/test_portfolio_api.py b/web_dashboard/backend/tests/test_portfolio_api.py index e6c00d22..a1780a4b 100644 --- a/web_dashboard/backend/tests/test_portfolio_api.py +++ b/web_dashboard/backend/tests/test_portfolio_api.py @@ -1,12 +1,9 @@ """ Tests for portfolio API — covers critical security and correctness fixes. """ +import asyncio import json -import os -import tempfile -import pytest from pathlib import Path -from unittest.mock import patch class TestRemovePositionMassDeletion: @@ -261,3 +258,28 @@ class TestConstants: assert "MAX_CONCURRENT_YFINANCE_REQUESTS" in content assert "asyncio.Semaphore(MAX_CONCURRENT_YFINANCE_REQUESTS)" in content + + def test_portfolio_locking_has_windows_fallback(self): + portfolio_path = Path(__file__).parent.parent / "api" / "portfolio.py" + content = portfolio_path.read_text() + + assert "except ImportError" in content + assert "msvcrt" in content + + +class TestAsyncPriceFetch: + def test_fetch_price_throttled_uses_worker_thread(self, monkeypatch): + from api import portfolio + + calls = [] + + async def fake_to_thread(func, *args): + calls.append((func, args)) + return 321.0 + + monkeypatch.setattr(portfolio.asyncio, "to_thread", fake_to_thread) + + result = asyncio.run(portfolio._fetch_price_throttled("AAPL")) + + assert result == 321.0 + assert calls == [(portfolio._fetch_price, ("AAPL",))]