diff --git a/.env.example b/.env.example index 1328b838..47ac745d 100644 --- a/.env.example +++ b/.env.example @@ -1,6 +1,23 @@ -# LLM Providers (set the one you use) +# LLM Provider API Keys (set the ones you use) OPENAI_API_KEY= GOOGLE_API_KEY= ANTHROPIC_API_KEY= XAI_API_KEY= OPENROUTER_API_KEY= + +# Data Provider API Keys +ALPHA_VANTAGE_API_KEY= + +# ── Configuration overrides ────────────────────────────────────────── +# Any setting in DEFAULT_CONFIG can be overridden with a +# TRADINGAGENTS_ environment variable. Unset or empty values +# are ignored (the hardcoded default is kept). +# +# Examples: +# TRADINGAGENTS_LLM_PROVIDER=openrouter +# TRADINGAGENTS_QUICK_THINK_LLM=deepseek/deepseek-chat-v3-0324 +# TRADINGAGENTS_DEEP_THINK_LLM=deepseek/deepseek-r1-0528 +# TRADINGAGENTS_BACKEND_URL=https://openrouter.ai/api/v1 +# TRADINGAGENTS_RESULTS_DIR=./my_results +# TRADINGAGENTS_MAX_DEBATE_ROUNDS=2 +# TRADINGAGENTS_VENDOR_SCANNER_DATA=alpha_vantage diff --git a/CLAUDE.md b/CLAUDE.md index c94ee85c..89b44310 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -75,9 +75,66 @@ OpenAI, Anthropic, Google, xAI, OpenRouter, Ollama - LLM tiers configuration - Vendor routing - Debate rounds settings +- All values overridable via `TRADINGAGENTS_` env vars (see `.env.example`) ## Patterns to Follow -- Agent creation: `tradingagents/agents/analysts/news_analyst.py` +- Agent creation (trading): `tradingagents/agents/analysts/news_analyst.py` +- Agent creation (scanner): `tradingagents/agents/scanners/geopolitical_scanner.py` - Tools: `tradingagents/agents/utils/news_data_tools.py` -- Graph setup: `tradingagents/graph/setup.py` +- Scanner tools: `tradingagents/agents/utils/scanner_tools.py` +- Graph setup (trading): `tradingagents/graph/setup.py` +- Graph setup (scanner): `tradingagents/graph/scanner_setup.py` +- Inline tool loop: `tradingagents/agents/utils/tool_runner.py` + +## Critical Patterns (from past mistakes — see MISTAKES.md) + +- **Tool execution**: Trading graph uses `ToolNode` in graph. Scanner agents use `run_tool_loop()` inline. If `bind_tools()` is used, there MUST be a tool execution path. +- **yfinance DataFrames**: `top_companies` has ticker as INDEX, not column. Always check `.index` and `.columns`. +- **yfinance Sector/Industry**: `Sector.overview` has NO performance data. Use ETF proxies for performance. +- **Vendor fallback**: Functions inside `route_to_vendor` must RAISE on failure, not embed errors in return values. Catch `(AlphaVantageError, ConnectionError, TimeoutError)`, not just `RateLimitError`. +- **LangGraph parallel writes**: Any state field written by parallel nodes MUST have a reducer (`Annotated[str, reducer_fn]`). +- **Ollama remote host**: Never hardcode `localhost:11434`. Use configured `base_url`. +- **.env loading**: `load_dotenv()` runs at module level in `default_config.py` — import-order-independent. Check actual env var values when debugging auth. +- **Rate limiter locks**: Never hold a lock during `sleep()` or IO. Release, sleep, re-acquire. +- **Config fallback keys**: `llm_provider` and `backend_url` must always exist at top level — `scanner_graph.py` and `trading_graph.py` use them as fallbacks. + +## Project Tracking Files + +- `DECISIONS.md` — Architecture decision records (vendor strategy, LLM setup, tool execution, env overrides) +- `PROGRESS.md` — Feature progress, what works, TODOs +- `MISTAKES.md` — Past bugs and lessons learned (10 documented mistakes) + +## LLM Configuration + +Per-tier provider overrides in `tradingagents/default_config.py`: +- Each tier (`quick_think`, `mid_think`, `deep_think`) can have its own `_llm_provider` and `_backend_url` +- Falls back to top-level `llm_provider` and `backend_url` when per-tier values are None +- All config values overridable via `TRADINGAGENTS_` env vars +- Keys for LLM providers: `.env` file (e.g., `OPENROUTER_API_KEY`, `ALPHA_VANTAGE_API_KEY`) + +### Env Var Override Convention + +```env +# Pattern: TRADINGAGENTS_=value +TRADINGAGENTS_LLM_PROVIDER=openrouter +TRADINGAGENTS_DEEP_THINK_LLM=deepseek/deepseek-r1-0528 +TRADINGAGENTS_MAX_DEBATE_ROUNDS=3 +TRADINGAGENTS_VENDOR_SCANNER_DATA=alpha_vantage +``` + +Empty or unset vars preserve the hardcoded default. `None`-default fields (like `mid_think_llm`) stay `None` when unset, preserving fallback semantics. + +## Running the Scanner + +```bash +conda activate tradingagents +python -m cli.main scan --date 2026-03-17 +``` + +## Running Tests + +```bash +conda activate tradingagents +pytest tests/ -v +``` diff --git a/DECISIONS.md b/DECISIONS.md new file mode 100644 index 00000000..13fcbbb7 --- /dev/null +++ b/DECISIONS.md @@ -0,0 +1,179 @@ +# Architecture Decisions Log + +## Decision 001: Hybrid LLM Setup (Ollama + OpenRouter) + +**Date**: 2026-03-17 +**Status**: Implemented ✅ + +**Context**: Need cost-effective LLM setup for scanner pipeline with different complexity tiers. + +**Decision**: Use hybrid approach: +- **quick_think + mid_think**: `qwen3.5:27b` via Ollama at `http://192.168.50.76:11434` (local, free) +- **deep_think**: `deepseek/deepseek-r1-0528` via OpenRouter (cloud, paid) + +**Config location**: `tradingagents/default_config.py` — per-tier `_llm_provider` and `_backend_url` keys. + +**Consequence**: Removed top-level `llm_provider` and `backend_url` from config. Each tier must have its own `{tier}_llm_provider` set explicitly. + +--- + +## Decision 002: Data Vendor Fallback Strategy + +**Date**: 2026-03-17 +**Status**: Implemented ✅ + +**Context**: Alpha Vantage free/demo key doesn't support ETF symbols and has strict rate limits. Need reliable data for scanner. + +**Decision**: +- `route_to_vendor()` catches `AlphaVantageError` (base class) to trigger fallback, not just `RateLimitError`. +- AV scanner functions raise `AlphaVantageError` when ALL queries fail (not silently embedding errors in output strings). +- yfinance is the fallback vendor and uses SPDR ETF proxies for sector performance instead of broken `Sector.overview`. + +**Files changed**: +- `tradingagents/dataflows/interface.py` — broadened catch +- `tradingagents/dataflows/alpha_vantage_scanner.py` — raise on total failure +- `tradingagents/dataflows/yfinance_scanner.py` — ETF proxy approach + +--- + +## Decision 003: yfinance Sector Performance via ETF Proxies + +**Date**: 2026-03-17 +**Status**: Implemented ✅ + +**Context**: `yfinance.Sector("technology").overview` returns only metadata (companies_count, market_cap, etc.) — no performance data (oneDay, oneWeek, etc.). + +**Decision**: Use SPDR sector ETFs as proxies: +```python +sector_etfs = { + "Technology": "XLK", "Healthcare": "XLV", "Financials": "XLF", + "Energy": "XLE", "Consumer Discretionary": "XLY", ... +} +``` +Download 6 months of history via `yf.download()` and compute 1-day, 1-week, 1-month, YTD percentage changes from closing prices. + +**File**: `tradingagents/dataflows/yfinance_scanner.py` + +--- + +## Decision 004: Inline Tool Execution Loop for Scanner Agents + +**Date**: 2026-03-17 +**Status**: Implemented ✅ + +**Context**: The existing trading graph uses separate `ToolNode` graph nodes for tool execution (agent → tool_node → agent routing loop). Scanner agents are simpler single-pass nodes — no ToolNode in the graph. When the LLM returned tool_calls, nobody executed them, resulting in empty reports. + +**Decision**: Created `tradingagents/agents/utils/tool_runner.py` with `run_tool_loop()` that runs an inline tool execution loop within each scanner agent node: +1. Invoke chain +2. If tool_calls present → execute tools → append ToolMessages → re-invoke +3. Repeat up to `MAX_TOOL_ROUNDS=5` until LLM produces text response + +**Alternative considered**: Adding ToolNode + conditional routing to scanner_setup.py (like trading graph). Rejected — too complex for the fan-out/fan-in pattern and would require 4 separate tool nodes with routing logic. + +**Files**: +- `tradingagents/agents/utils/tool_runner.py` (new) +- All scanner agents updated to use `run_tool_loop()` + +--- + +## Decision 005: LangGraph State Reducers for Parallel Fan-Out + +**Date**: 2026-03-17 +**Status**: Implemented ✅ + +**Context**: Phase 1 runs 3 scanners in parallel. All write to shared state fields (`sender`, etc.). LangGraph requires reducers for concurrent writes — otherwise raises `INVALID_CONCURRENT_GRAPH_UPDATE`. + +**Decision**: Added `_last_value` reducer to all `ScannerState` fields via `Annotated[str, _last_value]`. + +**File**: `tradingagents/agents/utils/scanner_states.py` + +--- + +## Decision 006: CLI --date Flag for Scanner + +**Date**: 2026-03-17 +**Status**: Implemented ✅ + +**Context**: `python -m cli.main scan` was interactive-only (prompts for date). Needed non-interactive invocation for testing/automation. + +**Decision**: Added `--date` / `-d` option to `scan` command. Falls back to interactive prompt if not provided. + +**File**: `cli/main.py` + +--- + +## Decision 007: .env Loading Strategy + +**Date**: 2026-03-17 +**Status**: Superseded by Decision 008 ⚠️ + +**Context**: `load_dotenv()` loads from CWD. When running from a git worktree, the worktree `.env` may have placeholder values while the main repo `.env` has real keys. + +**Decision**: `cli/main.py` calls `load_dotenv()` (CWD) then `load_dotenv(Path(__file__).parent.parent / ".env")` as fallback. The worktree `.env` was also updated with real API keys. + +**Note for future**: If `.env` issues recur, check which `.env` file is being picked up. The worktree and main repo each have their own `.env`. + +**Update**: Decision 008 moves `load_dotenv()` into `default_config.py` itself, making it import-order-independent. The CLI-level `load_dotenv()` in `main.py` is now defense-in-depth only. + +--- + +## Decision 008: Environment Variable Config Overrides + +**Date**: 2026-03-17 +**Status**: Implemented ✅ + +**Context**: `DEFAULT_CONFIG` hardcoded all values (LLM providers, models, vendor routing, debate rounds). Users had to edit `default_config.py` to change any setting. The `load_dotenv()` call in `cli/main.py` ran *after* `DEFAULT_CONFIG` was already evaluated at import time, so env vars like `TRADINGAGENTS_LLM_PROVIDER` had no effect. This also created a latent bug (Mistake #9): `llm_provider` and `backend_url` were removed from the config but `scanner_graph.py` still referenced them as fallbacks. + +**Decision**: +1. **Module-level `.env` loading**: `default_config.py` calls `load_dotenv()` at the top of the module, before `DEFAULT_CONFIG` is evaluated. Loads from CWD first, then falls back to project root (`Path(__file__).resolve().parent.parent / ".env"`). +2. **`_env()` / `_env_int()` helpers**: Read `TRADINGAGENTS_` from environment. Return the hardcoded default when the env var is unset or empty (preserving `None` semantics for per-tier fallbacks). +3. **Restored top-level keys**: `llm_provider` (default: `"openai"`) and `backend_url` (default: `"https://api.openai.com/v1"`) restored as env-overridable keys. Resolves Mistake #9. +4. **All config keys overridable**: LLM models, providers, backend URLs, debate rounds, data vendor categories — all follow the `TRADINGAGENTS_` pattern. +5. **Explicit dependency**: Added `python-dotenv>=1.0.0` to `pyproject.toml` (was used but undeclared). + +**Naming convention**: `TRADINGAGENTS_` prefix + uppercase config key. Examples: +``` +TRADINGAGENTS_LLM_PROVIDER=openrouter +TRADINGAGENTS_DEEP_THINK_LLM=deepseek/deepseek-r1-0528 +TRADINGAGENTS_MAX_DEBATE_ROUNDS=3 +TRADINGAGENTS_VENDOR_SCANNER_DATA=alpha_vantage +``` + +**Files changed**: +- `tradingagents/default_config.py` — core implementation +- `main.py` — moved `load_dotenv()` before imports (defense-in-depth) +- `pyproject.toml` — added `python-dotenv>=1.0.0` +- `.env.example` — documented all overrides +- `tests/test_env_override.py` — 15 tests + +**Alternative considered**: YAML/TOML config file. Rejected — env vars are simpler, work with Docker/CI, and don't require a new config file format. + +--- + +## Decision 009: Thread-Safe Rate Limiter for Alpha Vantage + +**Date**: 2026-03-17 +**Status**: Implemented ✅ + +**Context**: The Alpha Vantage rate limiter in `alpha_vantage_common.py` initially slept *inside* the lock when re-checking the rate window. This blocked all other threads from making API requests during the sleep period, effectively serializing all AV calls. + +**Decision**: Two-phase rate limiting: +1. **First check**: Acquire lock, check timestamps, release lock, sleep if needed. +2. **Re-check loop**: Acquire lock, re-check timestamps. If still over limit, release lock *before* sleeping, then retry. Only append timestamp and break when under the limit. + +This ensures the lock is never held during `sleep()` calls. + +**File**: `tradingagents/dataflows/alpha_vantage_common.py` + +--- + +## Decision 010: Broader Vendor Fallback Exception Handling + +**Date**: 2026-03-17 +**Status**: Implemented ✅ + +**Context**: `route_to_vendor()` only caught `AlphaVantageError` for fallback. But network issues (`ConnectionError`, `TimeoutError`) from the `requests` library wouldn't trigger fallback — they'd crash the pipeline instead. + +**Decision**: Broadened the catch in `route_to_vendor()` to `(AlphaVantageError, ConnectionError, TimeoutError)`. Similarly, `_make_api_request()` now catches `requests.exceptions.RequestException` as a general fallback and wraps `raise_for_status()` in a try/except to convert HTTP errors to `ThirdPartyError`. + +**Files**: `tradingagents/dataflows/interface.py`, `tradingagents/dataflows/alpha_vantage_common.py` diff --git a/MISTAKES.md b/MISTAKES.md new file mode 100644 index 00000000..b296dc0e --- /dev/null +++ b/MISTAKES.md @@ -0,0 +1,122 @@ +# Mistakes & Lessons Learned + +Documenting bugs and wrong assumptions to avoid repeating them. + +--- + +## Mistake 1: Scanner agents had no tool execution + +**What happened**: All 4 scanner agents (geopolitical, market movers, sector, industry) used `llm.bind_tools(tools)` but only checked `if len(result.tool_calls) == 0: report = result.content`. When the LLM chose to call tools (which it always does when tools are available), nobody executed them. Reports were always empty strings. + +**Root cause**: Copied the pattern from existing analysts (`news_analyst.py`) without realizing that the trading graph has separate `ToolNode` graph nodes that handle tool execution in a routing loop. The scanner graph has no such nodes. + +**Fix**: Created `tool_runner.py` with `run_tool_loop()` that executes tools inline within the agent node. + +**Lesson**: When an LLM has `bind_tools`, there MUST be a tool execution mechanism — either graph-level `ToolNode` routing or inline execution. Always verify the tool execution path exists. + +--- + +## Mistake 2: Assumed yfinance `Sector.overview` has performance data + +**What happened**: Wrote `get_sector_performance_yfinance` using `yf.Sector("technology").overview["oneDay"]` etc. This field doesn't exist — `overview` only returns metadata (companies_count, market_cap, industries_count). + +**Root cause**: Assumed the yfinance Sector API mirrors the Yahoo Finance website which shows performance data. It doesn't. + +**Fix**: Switched to SPDR ETF proxy approach — download ETF prices and compute percentage changes. + +**Lesson**: Always test data source APIs interactively before writing agent code. Run `python -c "import yfinance as yf; print(yf.Sector('technology').overview)"` to see actual data shape. + +--- + +## Mistake 3: yfinance `top_companies` — ticker is the index, not a column + +**What happened**: Used `row.get('symbol')` to get ticker from `top_companies` DataFrame. Always returned N/A. + +**Root cause**: The DataFrame has `index.name = 'symbol'` — tickers are the index, not a column. The actual columns are `['name', 'rating', 'market weight']`. + +**Fix**: Changed to `for symbol, row in top_companies.iterrows()`. + +**Lesson**: Always inspect DataFrame structure with `.head()`, `.columns`, and `.index` before writing access code. + +--- + +## Mistake 4: Hardcoded Ollama localhost URL + +**What happened**: `openai_client.py` had `base_url = "http://localhost:11434/v1"` hardcoded for Ollama provider, ignoring the `self.base_url` config. User's Ollama runs on `192.168.50.76:11434`. + +**Fix**: Changed to `host = self.base_url or "http://localhost:11434"` with `/v1` suffix appended. + +**Lesson**: Never hardcode URLs. Always use the configured value with a sensible default. + +--- + +## Mistake 5: Only caught `RateLimitError` in vendor fallback + +**What happened**: `route_to_vendor()` only caught `RateLimitError`. Alpha Vantage demo key returns "Information" responses (not rate limit errors) and other `AlphaVantageError` subtypes. Fallback to yfinance never triggered. + +**Fix**: Broadened catch to `AlphaVantageError` (base class). + +**Lesson**: Fallback mechanisms should catch the broadest reasonable error class, not just specific subtypes. + +--- + +## Mistake 6: AV scanner functions silently caught errors + +**What happened**: `get_sector_performance_alpha_vantage` and `get_industry_performance_alpha_vantage` caught exceptions internally and embedded error strings in the output (e.g., `"Error: ..."` in the result dict). `route_to_vendor` never saw an exception, so it never fell back to yfinance. + +**Fix**: Made both functions raise `AlphaVantageError` when ALL queries fail, while still tolerating partial failures. + +**Lesson**: Functions used inside `route_to_vendor` MUST raise exceptions on total failure — embedding errors in return values defeats the fallback mechanism. + +--- + +## Mistake 7: LangGraph concurrent write without reducer + +**What happened**: Phase 1 runs 3 scanners in parallel. All write to `sender` (and other shared fields). LangGraph raised `INVALID_CONCURRENT_GRAPH_UPDATE` because `ScannerState` had no reducer for concurrent writes. + +**Fix**: Added `_last_value` reducer via `Annotated[str, _last_value]` to all ScannerState fields. + +**Lesson**: Any LangGraph state field written by parallel nodes MUST have a reducer. Use `Annotated[type, reducer_fn]`. + +--- + +## Mistake 8: .env file had placeholder values in worktree + +**What happened**: Created `.env` in worktree with template values (`your_openrouter_key_here`). User's real keys were only in main repo's `.env`. `load_dotenv()` loaded the worktree placeholder, so OpenRouter returned 401. + +**Root cause**: Created `.env` template during setup without copying real keys. `load_dotenv()` with `override=False` (default) keeps the first value found. + +**Fix**: Updated worktree `.env` with real keys. Also added fallback `load_dotenv()` call for project root. + +**Lesson**: When creating `.env` files, always verify they have real values, not placeholders. When debugging auth errors, first check `os.environ.get('KEY')` to see what value is actually loaded. + +--- + +## Mistake 9: Removed top-level `llm_provider` but code still references it + +**What happened**: Removed `llm_provider` from `default_config.py` (since we have per-tier providers). But `scanner_graph.py` line 78 does `self.config.get(f"{tier}_llm_provider") or self.config["llm_provider"]` — would crash if per-tier provider is ever None. + +**Status**: ✅ RESOLVED in PR #9. Top-level `llm_provider` (default: `"openai"`) and `backend_url` (default: `"https://api.openai.com/v1"`) restored as env-overridable config keys. Per-tier providers safely fall back to these when `None`. + +**Lesson**: Always preserve fallback keys that downstream code depends on. When refactoring config, grep for all references before removing keys. + +--- + +## Mistake 10: Rate limiter held lock during sleep + +**What happened**: The Alpha Vantage rate limiter's re-check path in `_rate_limited_request()` called `_time.sleep(extra_sleep)` while holding `_rate_lock`. This blocked all other threads from making API requests during the sleep period, effectively serializing all AV calls even though the pipeline runs parallel scanner agents. + +**Root cause**: Initial implementation only had one lock section. When the re-check-after-sleep pattern was added to prevent race conditions, the sleep was left inside the `with _rate_lock:` block. + +**Fix**: Restructured the re-check as a `while True` loop that releases the lock before sleeping: +```python +while True: + with _rate_lock: + if len(_call_timestamps) < _RATE_LIMIT: + _call_timestamps.append(_time.time()) + break + extra_sleep = 60 - (now - _call_timestamps[0]) + 0.1 + _time.sleep(extra_sleep) # ← outside lock +``` + +**Lesson**: Never hold a lock during a sleep/IO operation. Always release the lock, perform the blocking operation, then re-acquire. diff --git a/PROGRESS.md b/PROGRESS.md new file mode 100644 index 00000000..229aaab0 --- /dev/null +++ b/PROGRESS.md @@ -0,0 +1,108 @@ +# Scanner Pipeline — Progress Tracker + +## Milestone: End-to-End Scanner ✅ COMPLETE + +The 3-phase scanner pipeline runs successfully from `python -m cli.main scan --date 2026-03-17`. + +### What Works + +| Component | Status | Notes | +|-----------|--------|-------| +| Phase 1: Geopolitical Scanner | ✅ | Ollama/qwen3.5:27b, uses `get_topic_news` | +| Phase 1: Market Movers Scanner | ✅ | Ollama/qwen3.5:27b, uses `get_market_movers` + `get_market_indices` | +| Phase 1: Sector Scanner | ✅ | Ollama/qwen3.5:27b, uses `get_sector_performance` (SPDR ETF proxies) | +| Phase 2: Industry Deep Dive | ✅ | Ollama/qwen3.5:27b, uses `get_industry_performance` + `get_topic_news` | +| Phase 3: Macro Synthesis | ✅ | OpenRouter/DeepSeek R1, pure LLM synthesis (no tools) | +| Parallel fan-out (Phase 1) | ✅ | LangGraph with `_last_value` reducers | +| Tool execution loop | ✅ | `run_tool_loop()` in `tool_runner.py` | +| Data vendor fallback | ✅ | AV → yfinance fallback on `AlphaVantageError`, `ConnectionError`, `TimeoutError` | +| CLI `--date` flag | ✅ | `python -m cli.main scan --date YYYY-MM-DD` | +| .env loading | ✅ | `load_dotenv()` at module level in `default_config.py` — import-order-independent | +| Env var config overrides | ✅ | All `DEFAULT_CONFIG` keys overridable via `TRADINGAGENTS_` env vars | +| Tests (38 total) | ✅ | 14 original + 9 scanner fallback + 15 env override tests | + +### Output Quality (Sample Run 2026-03-17) + +| Report | Size | Content | +|--------|------|---------| +| geopolitical_report | 6,295 chars | Iran conflict, energy risks, central bank signals | +| market_movers_report | 6,211 chars | Top gainers/losers, volume anomalies, index trends | +| sector_performance_report | 8,747 chars | Sector rotation analysis with ranked table | +| industry_deep_dive_report | — | Ran but was sparse (Phase 1 reports were the primary context) | +| macro_scan_summary | 10,309 chars | Full synthesis with stock picks and JSON structure | + +### Files Created/Modified + +**New files:** +- `tradingagents/agents/utils/tool_runner.py` — inline tool execution loop +- `tradingagents/agents/utils/scanner_states.py` — ScannerState with reducers +- `tradingagents/agents/utils/scanner_tools.py` — LangChain tool wrappers for scanner data +- `tradingagents/agents/scanners/` — all 5 scanner agent modules +- `tradingagents/graph/scanner_graph.py` — ScannerGraph orchestrator +- `tradingagents/graph/scanner_setup.py` — LangGraph workflow setup +- `tradingagents/dataflows/yfinance_scanner.py` — yfinance data for scanner +- `tradingagents/dataflows/alpha_vantage_scanner.py` — Alpha Vantage data for scanner +- `tradingagents/pipeline/macro_bridge.py` — scan → filter → per-ticker analysis bridge +- `tests/test_scanner_fallback.py` — 9 fallback tests +- `tests/test_env_override.py` — 15 env override tests + +**Modified files:** +- `tradingagents/default_config.py` — env var overrides via `_env()`/`_env_int()` helpers, `load_dotenv()` at module level, restored top-level `llm_provider` and `backend_url` keys +- `tradingagents/llm_clients/openai_client.py` — Ollama remote host support +- `tradingagents/dataflows/interface.py` — broadened fallback catch to `(AlphaVantageError, ConnectionError, TimeoutError)` +- `tradingagents/dataflows/alpha_vantage_common.py` — thread-safe rate limiter (sleep outside lock), broader `RequestException` catch, wrapped `raise_for_status` +- `tradingagents/graph/scanner_graph.py` — debug mode fix (stream for debug, invoke for result) +- `tradingagents/pipeline/macro_bridge.py` — `get_running_loop()` over deprecated `get_event_loop()` +- `cli/main.py` — `scan` command with `--date` flag, `try/except` in `run_pipeline`, `.env` loading fix +- `main.py` — `load_dotenv()` before tradingagents imports +- `pyproject.toml` — `python-dotenv>=1.0.0` dependency declared +- `.env.example` — documented all `TRADINGAGENTS_*` overrides and `ALPHA_VANTAGE_API_KEY` + +--- + +## Milestone: Env Var Config Overrides ✅ COMPLETE (PR #9) + +All `DEFAULT_CONFIG` values are now overridable via `TRADINGAGENTS_` environment variables without code changes. This resolves the latent bug from Mistake #9 (missing top-level `llm_provider`). + +### What Changed + +| Component | Detail | +|-----------|--------| +| `default_config.py` | `load_dotenv()` at module level + `_env()`/`_env_int()` helpers | +| Top-level fallback keys | Restored `llm_provider` and `backend_url` (defaults: `"openai"`, `"https://api.openai.com/v1"`) | +| Per-tier overrides | All `None` by default — fall back to top-level when not set via env | +| Integer config keys | `max_debate_rounds`, `max_risk_discuss_rounds`, `max_recur_limit` use `_env_int()` | +| Data vendor keys | `data_vendors.*` overridable via `TRADINGAGENTS_VENDOR_` | +| `.env.example` | Complete reference of all overridable settings | +| `python-dotenv` | Added to `pyproject.toml` as explicit dependency | +| Tests | 15 new tests in `tests/test_env_override.py` | + +--- + +## TODOs / Future Work + +### High Priority + +- [ ] **Industry Deep Dive quality**: Phase 2 report was sparse in test run. The LLM receives Phase 1 reports as context but may not call tools effectively. Consider: pre-fetching industry data and injecting it directly, or tuning the prompt to be more directive about which sectors to drill into. + +- [ ] **Macro Synthesis JSON parsing**: The `macro_scan_summary` should be valid JSON but DeepSeek R1 sometimes wraps it in markdown code blocks or adds preamble text. The CLI tries `json.loads(summary)` to build a watchlist table — this may fail. Add robust JSON extraction (strip markdown fences, find first `{`). + +- [ ] **`pipeline` command**: `cli/main.py` has a `run_pipeline()` placeholder that chains scan → filter → per-ticker deep dive. Not yet implemented. + +### Medium Priority + +- [ ] **Scanner report persistence**: Reports are saved to `results/macro_scan/{date}/` as `.md` files. Verify this works and add JSON output option. + +- [ ] **Rate limiting for parallel tool calls**: Phase 1 runs 3 agents in parallel, each calling tools. If tools hit the same API (e.g., Google News), they may get rate-limited. Consider adding delays or a shared rate limiter. + +- [ ] **Ollama model validation**: Before running the pipeline, validate that the configured model exists on the Ollama server (call `/api/tags` endpoint). Currently a 404 error is only caught at first LLM call. + +- [ ] **Test coverage for scanner agents**: Current tests cover data layer (yfinance/AV fallback) but not the agent nodes themselves. Add integration tests that mock the LLM and verify tool loop behavior. + +### Low Priority + +- [ ] **Configurable MAX_TOOL_ROUNDS**: Currently hardcoded to 5 in `tool_runner.py`. Could be made configurable via `DEFAULT_CONFIG`. + +- [ ] **Streaming output**: Scanner currently runs with `Live(Spinner(...))` — no intermediate output. Could stream phase completions to the console. + +- [x] ~~**Remove top-level `llm_provider` references**~~: Resolved in PR #9 — `llm_provider` and `backend_url` restored as top-level keys with `"openai"` / `"https://api.openai.com/v1"` defaults. Per-tier providers fall back to these when `None`. diff --git a/cli/main.py b/cli/main.py index 0d78c9f3..d9a9d023 100644 --- a/cli/main.py +++ b/cli/main.py @@ -1,13 +1,16 @@ from typing import Optional import datetime +import json import typer from pathlib import Path from functools import wraps from rich.console import Console from dotenv import load_dotenv -# Load environment variables from .env file +# Load environment variables from .env file. +# Checks CWD first, then falls back to project root (relative to this script). load_dotenv() +load_dotenv(Path(__file__).resolve().parent.parent / ".env") from rich.panel import Panel from rich.spinner import Spinner from rich.live import Live @@ -27,13 +30,7 @@ from tradingagents.graph.trading_graph import TradingAgentsGraph from tradingagents.default_config import DEFAULT_CONFIG from cli.models import AnalystType from cli.utils import * -from tradingagents.agents.utils.scanner_tools import ( - get_market_movers, - get_market_indices, - get_sector_performance, - get_industry_performance, - get_topic_news, -) +from tradingagents.graph.scanner_graph import ScannerGraph from cli.announcements import fetch_announcements, display_announcements from cli.stats_handler import StatsCallbackHandler @@ -1178,67 +1175,161 @@ def run_analysis(): display_complete_report(final_state) -def _is_scanner_error(result: str) -> bool: - """Return True when *result* indicates an error or missing data from a scanner tool.""" - error_prefixes = ( - "Error", - "No data", - "No quotes", - "No movers", - "No news", - "No industry", - "Invalid", - "Alpha Vantage", - ) - return any(result.startswith(prefix) for prefix in error_prefixes) - - -def _invoke_and_save(tool, args: dict, save_dir: Path, filename: str, label: str) -> str: - """Invoke a scanner tool, print a preview, and save the result if it is valid.""" - result = tool.invoke(args) - if not _is_scanner_error(result): - (save_dir / filename).write_text(result) - console.print(result[:500] + "..." if len(result) > 500 else result) - return result - - -def run_scan(): +def run_scan(date: Optional[str] = None): + """Run the 3-phase LLM scanner pipeline via ScannerGraph.""" console.print(Panel("[bold green]Global Macro Scanner[/bold green]", border_style="green")) - default_date = datetime.datetime.now().strftime("%Y-%m-%d") - scan_date = typer.prompt("Scan date (YYYY-MM-DD)", default=default_date) - console.print(f"[cyan]Scanning market data for {scan_date}...[/cyan]") + if date: + scan_date = date + else: + default_date = datetime.datetime.now().strftime("%Y-%m-%d") + scan_date = typer.prompt("Scan date (YYYY-MM-DD)", default=default_date) # Prepare save directory save_dir = Path("results/macro_scan") / scan_date save_dir.mkdir(parents=True, exist_ok=True) - # Call scanner tools - console.print("[bold]1. Market Movers[/bold]") - _invoke_and_save(get_market_movers, {"category": "day_gainers"}, save_dir, "market_movers.txt", "Market Movers") + console.print(f"[cyan]Running 3-phase macro scanner for {scan_date}...[/cyan]") + console.print("[dim]Phase 1: Geopolitical + Market Movers + Sector scans (parallel)[/dim]") + console.print("[dim]Phase 2: Industry Deep Dive[/dim]") + console.print("[dim]Phase 3: Macro Synthesis → stocks to investigate[/dim]\n") - console.print("[bold]2. Market Indices[/bold]") - _invoke_and_save(get_market_indices, {}, save_dir, "market_indices.txt", "Market Indices") + try: + scanner = ScannerGraph(config=DEFAULT_CONFIG.copy()) + with Live(Spinner("dots", text="Scanning..."), console=console, transient=True): + result = scanner.scan(scan_date) + except Exception as e: + console.print(f"[red]Scanner failed: {e}[/red]") + raise typer.Exit(1) - console.print("[bold]3. Sector Performance[/bold]") - _invoke_and_save(get_sector_performance, {}, save_dir, "sector_performance.txt", "Sector Performance") + # Save reports + for key in ["geopolitical_report", "market_movers_report", "sector_performance_report", + "industry_deep_dive_report", "macro_scan_summary"]: + content = result.get(key, "") + if content: + (save_dir / f"{key}.md").write_text(content) - console.print("[bold]4. Industry Performance (Technology)[/bold]") - _invoke_and_save(get_industry_performance, {"sector_key": "technology"}, save_dir, "industry_performance.txt", "Industry Performance") + # Display the final watchlist + summary = result.get("macro_scan_summary", "") + if summary: + console.print(Panel("[bold]Macro Scan Summary[/bold]", border_style="green")) + console.print(Markdown(summary[:3000])) - console.print("[bold]5. Topic News (Market)[/bold]") - _invoke_and_save(get_topic_news, {"topic": "market", "limit": 10}, save_dir, "topic_news.txt", "Topic News") + # Try to parse and show watchlist table + try: + summary_data = json.loads(summary) + stocks = summary_data.get("stocks_to_investigate", []) + if stocks: + table = Table(title="Stocks to Investigate", box=box.ROUNDED) + table.add_column("Ticker", style="cyan bold") + table.add_column("Name") + table.add_column("Sector") + table.add_column("Conviction", style="green") + table.add_column("Thesis") + for s in stocks: + table.add_row( + s.get("ticker", ""), + s.get("name", ""), + s.get("sector", ""), + s.get("conviction", "").upper(), + s.get("thesis_angle", ""), + ) + console.print(table) + except (json.JSONDecodeError, KeyError): + pass # Summary wasn't valid JSON — already printed as markdown - console.print(f"[green]Results saved to {save_dir}[/green]") + + console.print(f"\n[green]Results saved to {save_dir}[/green]") + + +def run_pipeline(): + """Full pipeline: scan -> filter -> per-ticker deep dive.""" + import asyncio + from tradingagents.pipeline.macro_bridge import ( + parse_macro_output, + filter_candidates, + run_all_tickers, + save_results, + ) + + console.print(Panel("[bold green]Macro → TradingAgents Pipeline[/bold green]", border_style="green")) + + macro_output = typer.prompt("Path to macro scan JSON") + macro_path = Path(macro_output) + if not macro_path.exists(): + console.print(f"[red]File not found: {macro_path}[/red]") + raise typer.Exit(1) + + min_conviction = typer.prompt("Minimum conviction (high/medium/low)", default="medium") + tickers_input = typer.prompt("Specific tickers (comma-separated, or blank for all)", default="") + ticker_filter = [t.strip() for t in tickers_input.split(",") if t.strip()] or None + analysis_date = typer.prompt("Analysis date", default=datetime.datetime.now().strftime("%Y-%m-%d")) + dry_run = typer.confirm("Dry run (no API calls)?", default=False) + + # Parse macro output + macro_context, all_candidates = parse_macro_output(macro_path) + candidates = filter_candidates(all_candidates, min_conviction, ticker_filter) + + console.print(f"\n[cyan]Candidates: {len(candidates)} of {len(all_candidates)} stocks passed filter[/cyan]") + + table = Table(title="Selected Stocks", box=box.ROUNDED) + table.add_column("Ticker", style="cyan bold") + table.add_column("Conviction") + table.add_column("Sector") + table.add_column("Name") + for c in candidates: + table.add_row(c.ticker, c.conviction.upper(), c.sector, c.name) + console.print(table) + + if dry_run: + console.print("\n[yellow]Dry run — skipping TradingAgents analysis[/yellow]") + return + + if not candidates: + console.print("[yellow]No candidates passed the filter.[/yellow]") + return + + config = DEFAULT_CONFIG.copy() + output_dir = Path("results/macro_pipeline") + + console.print(f"\n[cyan]Running TradingAgents for {len(candidates)} tickers...[/cyan]") + try: + with Live(Spinner("dots", text="Analyzing..."), console=console, transient=True): + results = asyncio.run( + run_all_tickers(candidates, macro_context, config, analysis_date) + ) + except Exception as e: + console.print(f"[red]Pipeline failed: {e}[/red]") + raise typer.Exit(1) + + save_results(results, macro_context, output_dir) + + successes = [r for r in results if not r.error] + failures = [r for r in results if r.error] + console.print(f"\n[green]Done: {len(successes)} succeeded, {len(failures)} failed[/green]") + console.print(f"Reports saved to: {output_dir.resolve()}") + if failures: + for r in failures: + console.print(f" [red]{r.ticker}: {r.error}[/red]") @app.command() def analyze(): + """Run per-ticker multi-agent analysis.""" run_analysis() @app.command() -def scan(): - run_scan() +def scan( + date: Optional[str] = typer.Option(None, "--date", "-d", help="Scan date in YYYY-MM-DD format (default: today)"), +): + """Run 3-phase macro scanner (geopolitical → sector → synthesis).""" + run_scan(date=date) + + +@app.command() +def pipeline(): + """Full pipeline: macro scan JSON → filter → per-ticker deep dive.""" + run_pipeline() if __name__ == "__main__": diff --git a/main.py b/main.py index 7e8b20e8..be020d0c 100644 --- a/main.py +++ b/main.py @@ -1,11 +1,13 @@ -from tradingagents.graph.trading_graph import TradingAgentsGraph -from tradingagents.default_config import DEFAULT_CONFIG - from dotenv import load_dotenv -# Load environment variables from .env file +# Load environment variables from .env file BEFORE importing any +# tradingagents modules so TRADINGAGENTS_* vars are visible to +# DEFAULT_CONFIG at import time. load_dotenv() +from tradingagents.graph.trading_graph import TradingAgentsGraph +from tradingagents.default_config import DEFAULT_CONFIG + # Create a custom config config = DEFAULT_CONFIG.copy() config["deep_think_llm"] = "gpt-5-mini" # Use a different model diff --git a/pyproject.toml b/pyproject.toml index 9213d7f6..d361508b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ dependencies = [ "langgraph>=0.4.8", "pandas>=2.3.0", "parsel>=1.10.0", + "python-dotenv>=1.0.0", "pytz>=2025.2", "questionary>=2.1.0", "rank-bm25>=0.2.2", diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..b1bed2ce --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,31 @@ +"""Shared fixtures and markers for TradingAgents tests.""" + +import os +import pytest + + +def pytest_configure(config): + config.addinivalue_line("markers", "integration: tests that hit real external APIs") + config.addinivalue_line("markers", "slow: tests that take a long time to run") + + +@pytest.fixture +def av_api_key(): + """Return the Alpha Vantage API key or skip the test.""" + key = os.environ.get("ALPHA_VANTAGE_API_KEY") + if not key: + pytest.skip("ALPHA_VANTAGE_API_KEY not set") + return key + + +@pytest.fixture +def av_config(): + """Return a config dict with Alpha Vantage as the scanner data vendor.""" + from tradingagents.default_config import DEFAULT_CONFIG + + config = DEFAULT_CONFIG.copy() + config["data_vendors"] = { + **config["data_vendors"], + "scanner_data": "alpha_vantage", + } + return config diff --git a/tests/test_alpha_vantage_exceptions.py b/tests/test_alpha_vantage_exceptions.py new file mode 100644 index 00000000..2bf90a4d --- /dev/null +++ b/tests/test_alpha_vantage_exceptions.py @@ -0,0 +1,76 @@ +"""Integration tests for Alpha Vantage exception hierarchy.""" + +import os +import pytest +from unittest.mock import patch + +from tradingagents.dataflows.alpha_vantage_common import ( + AlphaVantageError, + APIKeyInvalidError, + RateLimitError, + AlphaVantageRateLimitError, + ThirdPartyError, + ThirdPartyTimeoutError, + ThirdPartyParseError, + _make_api_request, +) + + +class TestExceptionHierarchy: + """Verify the exception class hierarchy is correct.""" + + def test_all_exceptions_inherit_from_base(self): + assert issubclass(APIKeyInvalidError, AlphaVantageError) + assert issubclass(RateLimitError, AlphaVantageError) + assert issubclass(ThirdPartyError, AlphaVantageError) + assert issubclass(ThirdPartyTimeoutError, AlphaVantageError) + assert issubclass(ThirdPartyParseError, AlphaVantageError) + + def test_rate_limit_alias(self): + """AlphaVantageRateLimitError is an alias for RateLimitError.""" + assert AlphaVantageRateLimitError is RateLimitError + + def test_exceptions_are_catchable_as_base(self): + with pytest.raises(AlphaVantageError): + raise APIKeyInvalidError("bad key") + with pytest.raises(AlphaVantageError): + raise RateLimitError("rate limited") + with pytest.raises(AlphaVantageError): + raise ThirdPartyError("server error") + + +@pytest.mark.integration +class TestMakeApiRequestErrors: + """Test _make_api_request error handling with real HTTP calls.""" + + def test_invalid_api_key(self): + """An invalid API key should raise APIKeyInvalidError or AlphaVantageError.""" + with patch.dict(os.environ, {"ALPHA_VANTAGE_API_KEY": "INVALID_KEY_12345"}): + # AV may return 200 with error in body, or may return a valid demo response + # Either way it should not silently succeed with bad data + try: + result = _make_api_request("TIME_SERIES_DAILY", {"symbol": "IBM"}) + # If it returns something, it should be valid data (demo key behavior) + assert result is not None + except AlphaVantageError: + pass # Expected — any AV error is acceptable here + + def test_timeout_raises_timeout_error(self): + """A timeout should raise ThirdPartyTimeoutError.""" + with patch.dict(os.environ, {"ALPHA_VANTAGE_API_KEY": "demo"}): + with pytest.raises(ThirdPartyTimeoutError): + # Use an impossibly short timeout + _make_api_request( + "TIME_SERIES_DAILY", + {"symbol": "IBM"}, + timeout=0.001, + ) + + def test_valid_request_succeeds(self, av_api_key): + """A valid request with a real key should return data.""" + result = _make_api_request( + "GLOBAL_QUOTE", + {"symbol": "IBM"}, + ) + assert result is not None + assert len(result) > 0 diff --git a/tests/test_alpha_vantage_scanner.py b/tests/test_alpha_vantage_scanner.py new file mode 100644 index 00000000..75bb053a --- /dev/null +++ b/tests/test_alpha_vantage_scanner.py @@ -0,0 +1,92 @@ +"""Integration tests for Alpha Vantage scanner data layer. + +All tests hit the real Alpha Vantage API — no mocks. +Requires ALPHA_VANTAGE_API_KEY environment variable. +""" + +import pytest + +from tradingagents.dataflows.alpha_vantage_scanner import ( + get_market_movers_alpha_vantage, + get_market_indices_alpha_vantage, + get_sector_performance_alpha_vantage, + get_industry_performance_alpha_vantage, + get_topic_news_alpha_vantage, +) + + +@pytest.mark.integration +class TestMarketMovers: + + def test_day_gainers(self, av_api_key): + result = get_market_movers_alpha_vantage("day_gainers") + assert isinstance(result, str) + assert "Market Movers" in result + assert "|" in result # markdown table + + def test_day_losers(self, av_api_key): + result = get_market_movers_alpha_vantage("day_losers") + assert isinstance(result, str) + assert "Market Movers" in result + + def test_most_actives(self, av_api_key): + result = get_market_movers_alpha_vantage("most_actives") + assert isinstance(result, str) + assert "Market Movers" in result + + def test_invalid_category_raises(self, av_api_key): + with pytest.raises(ValueError): + get_market_movers_alpha_vantage("invalid_category") + + +@pytest.mark.integration +class TestMarketIndices: + + def test_returns_markdown_table(self, av_api_key): + result = get_market_indices_alpha_vantage() + assert isinstance(result, str) + assert "Market Indices" in result + assert "|" in result + # Should contain at least some index proxies + assert any(name in result for name in ["S&P 500", "SPY", "Dow", "DIA", "NASDAQ", "QQQ"]) + + +@pytest.mark.integration +class TestSectorPerformance: + + def test_returns_all_sectors(self, av_api_key): + result = get_sector_performance_alpha_vantage() + assert isinstance(result, str) + assert "Sector" in result + assert "|" in result + # Should contain at least some sector names + assert any(s in result for s in ["Technology", "Healthcare", "Energy", "Financials"]) + + +@pytest.mark.integration +class TestIndustryPerformance: + + def test_technology_sector(self, av_api_key): + result = get_industry_performance_alpha_vantage("technology") + assert isinstance(result, str) + assert "|" in result + # Should contain some tech tickers + assert any(t in result for t in ["AAPL", "MSFT", "NVDA", "GOOGL"]) + + def test_invalid_sector_raises(self, av_api_key): + with pytest.raises(ValueError): + get_industry_performance_alpha_vantage("nonexistent_sector") + + +@pytest.mark.integration +class TestTopicNews: + + def test_market_news(self, av_api_key): + result = get_topic_news_alpha_vantage("market", limit=5) + assert isinstance(result, str) + assert "News" in result + + def test_technology_news(self, av_api_key): + result = get_topic_news_alpha_vantage("technology", limit=3) + assert isinstance(result, str) + assert len(result) > 50 # Should have some content diff --git a/tests/test_env_override.py b/tests/test_env_override.py new file mode 100644 index 00000000..1bf4e54b --- /dev/null +++ b/tests/test_env_override.py @@ -0,0 +1,108 @@ +"""Tests that TRADINGAGENTS_* environment variables override DEFAULT_CONFIG.""" + +import importlib +import os +from unittest.mock import patch + +import pytest + + +class TestEnvOverridesDefaults: + """Verify that setting TRADINGAGENTS_ env vars changes DEFAULT_CONFIG.""" + + def _reload_config(self): + """Force-reimport default_config so the module-level dict is rebuilt.""" + import tradingagents.default_config as mod + + importlib.reload(mod) + return mod.DEFAULT_CONFIG + + def test_llm_provider_override(self): + with patch.dict(os.environ, {"TRADINGAGENTS_LLM_PROVIDER": "openrouter"}): + cfg = self._reload_config() + assert cfg["llm_provider"] == "openrouter" + + def test_backend_url_override(self): + with patch.dict(os.environ, {"TRADINGAGENTS_BACKEND_URL": "http://localhost:1234"}): + cfg = self._reload_config() + assert cfg["backend_url"] == "http://localhost:1234" + + def test_deep_think_llm_override(self): + with patch.dict(os.environ, {"TRADINGAGENTS_DEEP_THINK_LLM": "deepseek/deepseek-r1"}): + cfg = self._reload_config() + assert cfg["deep_think_llm"] == "deepseek/deepseek-r1" + + def test_quick_think_llm_override(self): + with patch.dict(os.environ, {"TRADINGAGENTS_QUICK_THINK_LLM": "gpt-4o-mini"}): + cfg = self._reload_config() + assert cfg["quick_think_llm"] == "gpt-4o-mini" + + def test_mid_think_llm_none_by_default(self): + """mid_think_llm defaults to None (falls back to quick_think_llm).""" + with patch.dict(os.environ, {}, clear=False): + # Remove the env var if it happens to be set + os.environ.pop("TRADINGAGENTS_MID_THINK_LLM", None) + cfg = self._reload_config() + assert cfg["mid_think_llm"] is None + + def test_mid_think_llm_override(self): + with patch.dict(os.environ, {"TRADINGAGENTS_MID_THINK_LLM": "gpt-4o"}): + cfg = self._reload_config() + assert cfg["mid_think_llm"] == "gpt-4o" + + def test_empty_env_var_keeps_default(self): + """An empty string is treated the same as unset (keeps the default).""" + with patch.dict(os.environ, {"TRADINGAGENTS_LLM_PROVIDER": ""}): + cfg = self._reload_config() + assert cfg["llm_provider"] == "openai" + + def test_empty_env_var_keeps_none_default(self): + """An empty string for a None-default field stays None.""" + with patch.dict(os.environ, {"TRADINGAGENTS_DEEP_THINK_LLM_PROVIDER": ""}): + cfg = self._reload_config() + assert cfg["deep_think_llm_provider"] is None + + def test_per_tier_provider_override(self): + with patch.dict(os.environ, {"TRADINGAGENTS_DEEP_THINK_LLM_PROVIDER": "anthropic"}): + cfg = self._reload_config() + assert cfg["deep_think_llm_provider"] == "anthropic" + + def test_per_tier_backend_url_override(self): + with patch.dict(os.environ, {"TRADINGAGENTS_MID_THINK_BACKEND_URL": "http://my-ollama:11434"}): + cfg = self._reload_config() + assert cfg["mid_think_backend_url"] == "http://my-ollama:11434" + + def test_max_debate_rounds_int(self): + with patch.dict(os.environ, {"TRADINGAGENTS_MAX_DEBATE_ROUNDS": "3"}): + cfg = self._reload_config() + assert cfg["max_debate_rounds"] == 3 + + def test_max_debate_rounds_bad_value(self): + """Non-numeric string falls back to hardcoded default.""" + with patch.dict(os.environ, {"TRADINGAGENTS_MAX_DEBATE_ROUNDS": "abc"}): + cfg = self._reload_config() + assert cfg["max_debate_rounds"] == 1 + + def test_results_dir_override(self): + with patch.dict(os.environ, {"TRADINGAGENTS_RESULTS_DIR": "/tmp/my_results"}): + cfg = self._reload_config() + assert cfg["results_dir"] == "/tmp/my_results" + + def test_vendor_scanner_data_override(self): + with patch.dict(os.environ, {"TRADINGAGENTS_VENDOR_SCANNER_DATA": "alpha_vantage"}): + cfg = self._reload_config() + assert cfg["data_vendors"]["scanner_data"] == "alpha_vantage" + + def test_defaults_unchanged_when_no_env_set(self): + """Without any TRADINGAGENTS_* vars, defaults are the original hardcoded values.""" + # Clear all TRADINGAGENTS_ vars + env_clean = {k: v for k, v in os.environ.items() if not k.startswith("TRADINGAGENTS_")} + with patch.dict(os.environ, env_clean, clear=True): + cfg = self._reload_config() + assert cfg["llm_provider"] == "openai" + assert cfg["deep_think_llm"] == "gpt-5.2" + assert cfg["mid_think_llm"] is None + assert cfg["quick_think_llm"] == "gpt-5-mini" + assert cfg["backend_url"] == "https://api.openai.com/v1" + assert cfg["max_debate_rounds"] == 1 + assert cfg["data_vendors"]["scanner_data"] == "yfinance" diff --git a/tests/test_macro_bridge.py b/tests/test_macro_bridge.py new file mode 100644 index 00000000..0c8650e3 --- /dev/null +++ b/tests/test_macro_bridge.py @@ -0,0 +1,214 @@ +"""Tests for the macro bridge module — JSON parsing, filtering, and report rendering.""" + +import json +import tempfile +from pathlib import Path + +import pytest + + +EXAMPLE_MACRO_JSON = { + "timeframe": "1 month", + "region": "Global", + "executive_summary": "Test summary", + "macro_context": { + "economic_cycle": "Late expansion", + "central_bank_stance": "Fed on hold", + "geopolitical_risks": ["US-China tensions"], + "key_indicators": [ + {"name": "10Y UST", "status": "4.45%", "signal": "neutral"} + ], + }, + "key_themes": [ + { + "theme": "AI infrastructure", + "description": "Hyperscaler capex elevated", + "conviction": "high", + "timeframe": "3-6 months", + "supporting_factors": ["NVDA revenue"], + } + ], + "sector_opportunities": [], + "stocks_to_investigate": [ + { + "ticker": "NVDA", + "name": "NVIDIA Corporation", + "sector": "Technology — Semiconductors", + "rationale": "AI accelerator dominance", + "thesis_angle": "growth", + "conviction": "high", + "key_catalysts": ["Blackwell ramp"], + "risks": ["export controls"], + }, + { + "ticker": "LMT", + "name": "Lockheed Martin", + "sector": "Defense", + "rationale": "F-35 backlog", + "thesis_angle": "catalyst", + "conviction": "medium", + "key_catalysts": ["NATO orders"], + "risks": ["budget risk"], + }, + { + "ticker": "XYZ", + "name": "Low Conv Corp", + "sector": "Other", + "rationale": "Speculative", + "thesis_angle": "momentum", + "conviction": "low", + "key_catalysts": [], + "risks": [], + }, + ], + "risk_factors": ["Higher for longer"], +} + + +@pytest.fixture +def macro_json_file(tmp_path): + path = tmp_path / "macro_output.json" + path.write_text(json.dumps(EXAMPLE_MACRO_JSON)) + return path + + +class TestParseMacroOutput: + + def test_parses_context_and_candidates(self, macro_json_file): + from tradingagents.pipeline.macro_bridge import parse_macro_output + + ctx, candidates = parse_macro_output(macro_json_file) + assert ctx.economic_cycle == "Late expansion" + assert ctx.executive_summary == "Test summary" + assert len(candidates) == 3 + assert candidates[0].ticker == "NVDA" + assert candidates[0].conviction == "high" + + def test_missing_fields_default_gracefully(self, tmp_path): + from tradingagents.pipeline.macro_bridge import parse_macro_output + + minimal = {"stocks_to_investigate": [{"ticker": "TEST"}]} + path = tmp_path / "minimal.json" + path.write_text(json.dumps(minimal)) + ctx, candidates = parse_macro_output(path) + assert len(candidates) == 1 + assert candidates[0].ticker == "TEST" + assert candidates[0].conviction == "medium" # default + + +class TestFilterCandidates: + + def test_filter_high_conviction(self, macro_json_file): + from tradingagents.pipeline.macro_bridge import ( + parse_macro_output, + filter_candidates, + ) + + _, candidates = parse_macro_output(macro_json_file) + filtered = filter_candidates(candidates, "high", None) + assert len(filtered) == 1 + assert filtered[0].ticker == "NVDA" + + def test_filter_medium_conviction(self, macro_json_file): + from tradingagents.pipeline.macro_bridge import ( + parse_macro_output, + filter_candidates, + ) + + _, candidates = parse_macro_output(macro_json_file) + filtered = filter_candidates(candidates, "medium", None) + assert len(filtered) == 2 + tickers = {c.ticker for c in filtered} + assert tickers == {"NVDA", "LMT"} + + def test_filter_by_ticker(self, macro_json_file): + from tradingagents.pipeline.macro_bridge import ( + parse_macro_output, + filter_candidates, + ) + + _, candidates = parse_macro_output(macro_json_file) + filtered = filter_candidates(candidates, "low", ["LMT"]) + assert len(filtered) == 1 + assert filtered[0].ticker == "LMT" + + def test_sorted_by_conviction_desc(self, macro_json_file): + from tradingagents.pipeline.macro_bridge import ( + parse_macro_output, + filter_candidates, + ) + + _, candidates = parse_macro_output(macro_json_file) + filtered = filter_candidates(candidates, "low", None) + assert filtered[0].conviction == "high" + assert filtered[-1].conviction == "low" + + +class TestReportRendering: + + def test_render_ticker_report(self, macro_json_file): + from tradingagents.pipeline.macro_bridge import ( + parse_macro_output, + TickerResult, + render_ticker_report, + ) + + ctx, candidates = parse_macro_output(macro_json_file) + result = TickerResult( + ticker="NVDA", + candidate=candidates[0], + macro_context=ctx, + analysis_date="2026-03-17", + final_trade_decision="BUY", + ) + report = render_ticker_report(result) + assert "NVDA" in report + assert "NVIDIA" in report + assert "BUY" in report + assert "Macro" in report + + def test_render_combined_summary(self, macro_json_file): + from tradingagents.pipeline.macro_bridge import ( + parse_macro_output, + TickerResult, + render_combined_summary, + ) + + ctx, candidates = parse_macro_output(macro_json_file) + results = [ + TickerResult( + ticker=c.ticker, + candidate=c, + macro_context=ctx, + analysis_date="2026-03-17", + final_trade_decision="HOLD", + ) + for c in candidates[:2] + ] + summary = render_combined_summary(results, ctx) + assert "NVDA" in summary + assert "LMT" in summary + assert "Summary" in summary + + def test_save_results(self, macro_json_file, tmp_path): + from tradingagents.pipeline.macro_bridge import ( + parse_macro_output, + TickerResult, + save_results, + ) + + ctx, candidates = parse_macro_output(macro_json_file) + results = [ + TickerResult( + ticker="NVDA", + candidate=candidates[0], + macro_context=ctx, + analysis_date="2026-03-17", + final_trade_decision="BUY", + ) + ] + output_dir = tmp_path / "output" + save_results(results, ctx, output_dir) + assert (output_dir / "summary.md").exists() + assert (output_dir / "results.json").exists() + assert (output_dir / "NVDA" / "2026-03-17_deep_dive.md").exists() diff --git a/tests/test_scanner_fallback.py b/tests/test_scanner_fallback.py new file mode 100644 index 00000000..134be897 --- /dev/null +++ b/tests/test_scanner_fallback.py @@ -0,0 +1,115 @@ +"""Tests for scanner data functions — yfinance fallback and AV error handling. + +These tests verify: +1. yfinance sector performance returns real data via ETF proxies +2. yfinance industry performance uses DataFrame index for ticker symbols +3. AV scanner functions raise AlphaVantageError when all data fails (enabling fallback) +4. route_to_vendor falls back from AV to yfinance on AlphaVantageError +""" + +import os +import pytest +from unittest.mock import patch + +from tradingagents.dataflows.yfinance_scanner import ( + get_sector_performance_yfinance, + get_industry_performance_yfinance, +) +from tradingagents.dataflows.alpha_vantage_common import AlphaVantageError +from tradingagents.dataflows.alpha_vantage_scanner import ( + get_sector_performance_alpha_vantage, + get_industry_performance_alpha_vantage, +) + + +class TestYfinanceSectorPerformance: + """Verify yfinance sector performance uses ETF proxies and returns real data.""" + + def test_returns_all_11_sectors(self): + result = get_sector_performance_yfinance() + assert "| Sector |" in result + # Check all 11 GICS sectors are present + for sector in [ + "Technology", "Healthcare", "Financials", "Energy", + "Consumer Discretionary", "Consumer Staples", "Industrials", + "Materials", "Real Estate", "Utilities", "Communication Services", + ]: + assert sector in result, f"Missing sector: {sector}" + + def test_returns_numeric_percentages(self): + result = get_sector_performance_yfinance() + lines = result.strip().split("\n") + # Skip header lines (first 4: title, date, column headers, separator) + data_lines = [l for l in lines if l.startswith("| ") and "Sector" not in l and "---" not in l] + assert len(data_lines) == 11, f"Expected 11 data rows, got {len(data_lines)}" + + for line in data_lines: + cols = [c.strip() for c in line.split("|")[1:-1]] + # cols: [sector_name, 1-day, 1-week, 1-month, ytd] + assert len(cols) == 5, f"Expected 5 columns, got {len(cols)} in: {line}" + # 1-day should be a percentage like "+1.45%" or "-0.31%" + day_pct = cols[1] + assert "%" in day_pct or day_pct == "N/A", f"Bad 1-day value: {day_pct}" + # Should NOT contain "Error:" + assert "Error:" not in day_pct, f"Error in 1-day for {cols[0]}: {day_pct}" + + +class TestYfinanceIndustryPerformance: + """Verify yfinance industry performance uses index for ticker symbols.""" + + def test_returns_real_symbols(self): + result = get_industry_performance_yfinance("technology") + assert "| Company |" in result or "| Company " in result + # Should contain actual tickers, not N/A + assert "NVDA" in result or "AAPL" in result or "MSFT" in result, \ + f"No real tickers found in result: {result[:300]}" + + def test_no_na_symbols(self): + result = get_industry_performance_yfinance("technology") + lines = result.strip().split("\n") + data_lines = [l for l in lines if l.startswith("| ") and "Company" not in l and "---" not in l] + for line in data_lines: + cols = [c.strip() for c in line.split("|")[1:-1]] + # Symbol column (index 1) should not be N/A + assert cols[1] != "N/A", f"Symbol is N/A in line: {line}" + + def test_healthcare_sector(self): + result = get_industry_performance_yfinance("healthcare") + assert "Industry Performance: Healthcare" in result + + +class TestAlphaVantageFailoverRaise: + """Verify AV scanner functions raise when all data fails (enabling fallback).""" + + def test_sector_perf_raises_on_total_failure(self): + """When every GLOBAL_QUOTE call fails, the function should raise.""" + with patch.dict(os.environ, {"ALPHA_VANTAGE_API_KEY": "demo"}): + with pytest.raises(AlphaVantageError, match="All .* sector queries failed"): + get_sector_performance_alpha_vantage() + + def test_industry_perf_raises_on_total_failure(self): + """When every ticker quote fails, the function should raise.""" + with patch.dict(os.environ, {"ALPHA_VANTAGE_API_KEY": "demo"}): + with pytest.raises(AlphaVantageError, match="All .* ticker queries failed"): + get_industry_performance_alpha_vantage("technology") + + +class TestRouteToVendorFallback: + """Verify route_to_vendor falls back from AV to yfinance.""" + + def test_sector_perf_falls_back_to_yfinance(self): + with patch.dict(os.environ, {"ALPHA_VANTAGE_API_KEY": "demo"}): + from tradingagents.dataflows.interface import route_to_vendor + result = route_to_vendor("get_sector_performance") + # Should get yfinance data (no "Alpha Vantage" in header) + assert "Sector Performance Overview" in result + # Should have actual percentage data, not all errors + assert "Error:" not in result or result.count("Error:") < 3 + + def test_industry_perf_falls_back_to_yfinance(self): + with patch.dict(os.environ, {"ALPHA_VANTAGE_API_KEY": "demo"}): + from tradingagents.dataflows.interface import route_to_vendor + result = route_to_vendor("get_industry_performance", "technology") + assert "Industry Performance" in result + # Should contain real ticker symbols + assert "N/A" not in result or result.count("N/A") < 5 diff --git a/tests/test_scanner_routing.py b/tests/test_scanner_routing.py new file mode 100644 index 00000000..4a4b1aec --- /dev/null +++ b/tests/test_scanner_routing.py @@ -0,0 +1,67 @@ +"""Integration tests for scanner vendor routing. + +Verifies that when config says scanner_data=alpha_vantage, +scanner tools route to Alpha Vantage implementations. +""" + +import pytest +from tradingagents.dataflows.interface import route_to_vendor, get_vendor +from tradingagents.dataflows.config import set_config + + +@pytest.mark.integration +class TestScannerRouting: + + def setup_method(self): + """Set config to use alpha_vantage for scanner_data.""" + from tradingagents.default_config import DEFAULT_CONFIG + + config = DEFAULT_CONFIG.copy() + config["data_vendors"]["scanner_data"] = "alpha_vantage" + set_config(config) + + def test_vendor_resolves_to_alpha_vantage(self): + vendor = get_vendor("scanner_data") + assert vendor == "alpha_vantage" + + def test_market_movers_routes_to_av(self, av_api_key): + result = route_to_vendor("get_market_movers", "day_gainers") + assert isinstance(result, str) + assert "Market Movers" in result + + def test_market_indices_routes_to_av(self, av_api_key): + result = route_to_vendor("get_market_indices") + assert isinstance(result, str) + assert "Market Indices" in result or "Index" in result + + def test_sector_performance_routes_to_av(self, av_api_key): + result = route_to_vendor("get_sector_performance") + assert isinstance(result, str) + assert "Sector" in result + + def test_industry_performance_routes_to_av(self, av_api_key): + result = route_to_vendor("get_industry_performance", "technology") + assert isinstance(result, str) + assert "|" in result + + def test_topic_news_routes_to_av(self, av_api_key): + result = route_to_vendor("get_topic_news", "market", limit=3) + assert isinstance(result, str) + assert "News" in result + + +class TestFallbackRouting: + + def setup_method(self): + """Set config to use yfinance as fallback.""" + from tradingagents.default_config import DEFAULT_CONFIG + + config = DEFAULT_CONFIG.copy() + config["data_vendors"]["scanner_data"] = "yfinance" + set_config(config) + + def test_yfinance_fallback_works(self): + """When configured for yfinance, scanner tools should use yfinance.""" + result = route_to_vendor("get_market_movers", "day_gainers") + assert isinstance(result, str) + assert "Market Movers" in result diff --git a/tradingagents/agents/scanners/__init__.py b/tradingagents/agents/scanners/__init__.py new file mode 100644 index 00000000..1279e61e --- /dev/null +++ b/tradingagents/agents/scanners/__init__.py @@ -0,0 +1,5 @@ +from .geopolitical_scanner import create_geopolitical_scanner +from .market_movers_scanner import create_market_movers_scanner +from .sector_scanner import create_sector_scanner +from .industry_deep_dive import create_industry_deep_dive +from .macro_synthesis import create_macro_synthesis diff --git a/tradingagents/agents/scanners/geopolitical_scanner.py b/tradingagents/agents/scanners/geopolitical_scanner.py new file mode 100644 index 00000000..afa5d3ce --- /dev/null +++ b/tradingagents/agents/scanners/geopolitical_scanner.py @@ -0,0 +1,53 @@ +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder +from tradingagents.agents.utils.agent_utils import get_topic_news +from tradingagents.agents.utils.tool_runner import run_tool_loop + + +def create_geopolitical_scanner(llm): + def geopolitical_scanner_node(state): + scan_date = state["scan_date"] + + tools = [get_topic_news] + + system_message = ( + "You are a geopolitical analyst scanning global news for risks and opportunities affecting financial markets. " + "Use get_topic_news to search for news on: geopolitics, trade policy, sanctions, central bank decisions, " + "energy markets, and military conflicts. Analyze the results and write a concise report covering: " + "(1) Major geopolitical events and their market impact, " + "(2) Central bank policy signals, " + "(3) Trade/sanctions developments, " + "(4) Energy and commodity supply risks. " + "Include a risk assessment table at the end." + ) + + prompt = ChatPromptTemplate.from_messages( + [ + ( + "system", + "You are a helpful AI assistant, collaborating with other assistants." + " Use the provided tools to progress towards answering the question." + " If you are unable to fully answer, that's OK; another assistant with different tools" + " will help where you left off. Execute what you can to make progress." + " You have access to the following tools: {tool_names}.\n{system_message}" + " For your reference, the current date is {current_date}.", + ), + MessagesPlaceholder(variable_name="messages"), + ] + ) + + prompt = prompt.partial(system_message=system_message) + prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools])) + prompt = prompt.partial(current_date=scan_date) + + chain = prompt | llm.bind_tools(tools) + result = run_tool_loop(chain, state["messages"], tools) + + report = result.content or "" + + return { + "messages": [result], + "geopolitical_report": report, + "sender": "geopolitical_scanner", + } + + return geopolitical_scanner_node diff --git a/tradingagents/agents/scanners/industry_deep_dive.py b/tradingagents/agents/scanners/industry_deep_dive.py new file mode 100644 index 00000000..bfe84b6b --- /dev/null +++ b/tradingagents/agents/scanners/industry_deep_dive.py @@ -0,0 +1,68 @@ +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder +from tradingagents.agents.utils.agent_utils import get_industry_performance, get_topic_news +from tradingagents.agents.utils.tool_runner import run_tool_loop + + +def create_industry_deep_dive(llm): + def industry_deep_dive_node(state): + scan_date = state["scan_date"] + + tools = [get_industry_performance, get_topic_news] + + # Inject Phase 1 context so the LLM can decide which sectors to drill into + phase1_context = f"""## Phase 1 Scanner Reports (for your reference) + +### Geopolitical Report: +{state.get("geopolitical_report", "Not available")} + +### Market Movers Report: +{state.get("market_movers_report", "Not available")} + +### Sector Performance Report: +{state.get("sector_performance_report", "Not available")} +""" + + system_message = ( + "You are a senior research analyst performing an industry deep dive. " + "You have received reports from three parallel scanners (geopolitical, market movers, sector performance). " + "Review these reports and identify the 2-3 most promising sectors/industries to investigate further. " + "Use get_industry_performance to drill into those sectors and get_topic_news for sector-specific news. " + "Write a detailed report covering: " + "(1) Why these industries were selected, " + "(2) Top companies within each industry and their recent performance, " + "(3) Industry-specific catalysts and risks, " + "(4) Cross-references between geopolitical events and sector opportunities." + f"\n\n{phase1_context}" + ) + + prompt = ChatPromptTemplate.from_messages( + [ + ( + "system", + "You are a helpful AI assistant, collaborating with other assistants." + " Use the provided tools to progress towards answering the question." + " If you are unable to fully answer, that's OK; another assistant with different tools" + " will help where you left off. Execute what you can to make progress." + " You have access to the following tools: {tool_names}.\n{system_message}" + " For your reference, the current date is {current_date}.", + ), + MessagesPlaceholder(variable_name="messages"), + ] + ) + + prompt = prompt.partial(system_message=system_message) + prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools])) + prompt = prompt.partial(current_date=scan_date) + + chain = prompt | llm.bind_tools(tools) + result = run_tool_loop(chain, state["messages"], tools) + + report = result.content or "" + + return { + "messages": [result], + "industry_deep_dive_report": report, + "sender": "industry_deep_dive", + } + + return industry_deep_dive_node diff --git a/tradingagents/agents/scanners/macro_synthesis.py b/tradingagents/agents/scanners/macro_synthesis.py new file mode 100644 index 00000000..9876a927 --- /dev/null +++ b/tradingagents/agents/scanners/macro_synthesis.py @@ -0,0 +1,74 @@ +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder + + +def create_macro_synthesis(llm): + def macro_synthesis_node(state): + scan_date = state["scan_date"] + + # Inject all previous reports for synthesis — no tools, pure LLM reasoning + all_reports_context = f"""## All Scanner and Research Reports + +### Geopolitical Report: +{state.get("geopolitical_report", "Not available")} + +### Market Movers Report: +{state.get("market_movers_report", "Not available")} + +### Sector Performance Report: +{state.get("sector_performance_report", "Not available")} + +### Industry Deep Dive Report: +{state.get("industry_deep_dive_report", "Not available")} +""" + + system_message = ( + "You are a macro strategist synthesizing all scanner and research reports into a final investment thesis. " + "You have received: geopolitical analysis, market movers analysis, sector performance analysis, " + "and industry deep dive analysis. " + "Synthesize these into a structured output with: " + "(1) Executive summary of the macro environment, " + "(2) Top macro themes with conviction levels, " + "(3) A list of 8-10 specific stocks worth investigating with ticker, name, sector, rationale, " + "thesis_angle (growth/value/catalyst/turnaround/defensive/momentum), conviction (high/medium/low), " + "key_catalysts, and risks. " + "Output your response as valid JSON matching this schema:\n" + "{\n" + ' "timeframe": "1 month",\n' + ' "executive_summary": "...",\n' + ' "macro_context": { "economic_cycle": "...", "central_bank_stance": "...", "geopolitical_risks": [...] },\n' + ' "key_themes": [{ "theme": "...", "description": "...", "conviction": "high|medium|low", "timeframe": "..." }],\n' + ' "stocks_to_investigate": [{ "ticker": "...", "name": "...", "sector": "...", "rationale": "...", ' + '"thesis_angle": "...", "conviction": "high|medium|low", "key_catalysts": [...], "risks": [...] }],\n' + ' "risk_factors": ["..."]\n' + "}" + f"\n\n{all_reports_context}" + ) + + prompt = ChatPromptTemplate.from_messages( + [ + ( + "system", + "You are a helpful AI assistant, collaborating with other assistants." + " You have access to the following tools: {tool_names}.\n{system_message}" + " For your reference, the current date is {current_date}.", + ), + MessagesPlaceholder(variable_name="messages"), + ] + ) + + prompt = prompt.partial(system_message=system_message) + prompt = prompt.partial(tool_names="none") + prompt = prompt.partial(current_date=scan_date) + + chain = prompt | llm + result = chain.invoke(state["messages"]) + + report = result.content + + return { + "messages": [result], + "macro_scan_summary": report, + "sender": "macro_synthesis", + } + + return macro_synthesis_node diff --git a/tradingagents/agents/scanners/market_movers_scanner.py b/tradingagents/agents/scanners/market_movers_scanner.py new file mode 100644 index 00000000..219a5adf --- /dev/null +++ b/tradingagents/agents/scanners/market_movers_scanner.py @@ -0,0 +1,54 @@ +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder +from tradingagents.agents.utils.agent_utils import get_market_movers, get_market_indices +from tradingagents.agents.utils.tool_runner import run_tool_loop + + +def create_market_movers_scanner(llm): + def market_movers_scanner_node(state): + scan_date = state["scan_date"] + + tools = [get_market_movers, get_market_indices] + + system_message = ( + "You are a market analyst scanning for unusual activity and momentum signals. " + "Use get_market_movers to fetch today's top gainers, losers, and most active stocks. " + "Use get_market_indices to check major index performance. " + "Analyze the results and write a report covering: " + "(1) Unusual movers and potential catalysts, " + "(2) Volume anomalies, " + "(3) Index trends and breadth, " + "(4) Sector concentration in movers. " + "Include a summary table of the most significant moves." + ) + + prompt = ChatPromptTemplate.from_messages( + [ + ( + "system", + "You are a helpful AI assistant, collaborating with other assistants." + " Use the provided tools to progress towards answering the question." + " If you are unable to fully answer, that's OK; another assistant with different tools" + " will help where you left off. Execute what you can to make progress." + " You have access to the following tools: {tool_names}.\n{system_message}" + " For your reference, the current date is {current_date}.", + ), + MessagesPlaceholder(variable_name="messages"), + ] + ) + + prompt = prompt.partial(system_message=system_message) + prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools])) + prompt = prompt.partial(current_date=scan_date) + + chain = prompt | llm.bind_tools(tools) + result = run_tool_loop(chain, state["messages"], tools) + + report = result.content or "" + + return { + "messages": [result], + "market_movers_report": report, + "sender": "market_movers_scanner", + } + + return market_movers_scanner_node diff --git a/tradingagents/agents/scanners/sector_scanner.py b/tradingagents/agents/scanners/sector_scanner.py new file mode 100644 index 00000000..f66782af --- /dev/null +++ b/tradingagents/agents/scanners/sector_scanner.py @@ -0,0 +1,53 @@ +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder +from tradingagents.agents.utils.agent_utils import get_sector_performance +from tradingagents.agents.utils.tool_runner import run_tool_loop + + +def create_sector_scanner(llm): + def sector_scanner_node(state): + scan_date = state["scan_date"] + + tools = [get_sector_performance] + + system_message = ( + "You are a sector rotation analyst. " + "Use get_sector_performance to analyze all 11 GICS sectors. " + "Write a report covering: " + "(1) Sector momentum rankings (1-day, 1-week, 1-month, YTD), " + "(2) Sector rotation signals (money flowing from/to which sectors), " + "(3) Defensive vs cyclical positioning, " + "(4) Sectors showing acceleration or deceleration. " + "Include a ranked performance table." + ) + + prompt = ChatPromptTemplate.from_messages( + [ + ( + "system", + "You are a helpful AI assistant, collaborating with other assistants." + " Use the provided tools to progress towards answering the question." + " If you are unable to fully answer, that's OK; another assistant with different tools" + " will help where you left off. Execute what you can to make progress." + " You have access to the following tools: {tool_names}.\n{system_message}" + " For your reference, the current date is {current_date}.", + ), + MessagesPlaceholder(variable_name="messages"), + ] + ) + + prompt = prompt.partial(system_message=system_message) + prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools])) + prompt = prompt.partial(current_date=scan_date) + + chain = prompt | llm.bind_tools(tools) + result = run_tool_loop(chain, state["messages"], tools) + + report = result.content or "" + + return { + "messages": [result], + "sector_performance_report": report, + "sender": "sector_scanner", + } + + return sector_scanner_node diff --git a/tradingagents/agents/utils/scanner_states.py b/tradingagents/agents/utils/scanner_states.py index 9d9e3c9c..07795a6a 100644 --- a/tradingagents/agents/utils/scanner_states.py +++ b/tradingagents/agents/utils/scanner_states.py @@ -1,47 +1,42 @@ """State definitions for the Global Macro Scanner graph.""" +import operator from typing import Annotated from langgraph.graph import MessagesState +def _last_value(existing: str, new: str) -> str: + """Reducer that keeps the last written value (for concurrent writes).""" + return new + + class ScannerState(MessagesState): """ State for the macro scanner workflow. - + The scanner discovers interesting stocks through multiple phases: - Phase 1: Parallel scanners (geopolitical, market movers, sectors) - Phase 2: Industry deep dive (cross-references phase 1 outputs) - Phase 3: Macro synthesis (produces final top-10 watchlist) + + Fields written by parallel nodes use _last_value reducer to allow + concurrent updates without LangGraph raising INVALID_CONCURRENT_GRAPH_UPDATE. + Each parallel node writes to its own dedicated field, so no data is lost. """ - + # Input - scan_date: Annotated[str, "Date of the scan in YYYY-MM-DD format"] - - # Phase 1: Parallel scanner outputs - geopolitical_report: Annotated[ - str, - "Report from Geopolitical Scanner analyzing global news, geopolitical events, and macro trends" - ] - market_movers_report: Annotated[ - str, - "Report from Market Movers Scanner analyzing top gainers, losers, most active stocks, and index performance" - ] - sector_performance_report: Annotated[ - str, - "Report from Sector Scanner analyzing all 11 GICS sectors performance and trends" - ] - + scan_date: str + + # Phase 1: Parallel scanner outputs — each written by exactly one node + geopolitical_report: Annotated[str, _last_value] + market_movers_report: Annotated[str, _last_value] + sector_performance_report: Annotated[str, _last_value] + # Phase 2: Deep dive output - industry_deep_dive_report: Annotated[ - str, - "Report from Industry Deep Dive agent analyzing specific industries within top performing sectors" - ] - + industry_deep_dive_report: Annotated[str, _last_value] + # Phase 3: Final output - macro_scan_summary: Annotated[ - str, - "Final macro scan summary with top-10 stock watchlist and market overview" - ] - - # Optional: Sender tracking (for debugging/logging) - sender: Annotated[str, "Agent that sent the current message"] = "" + macro_scan_summary: Annotated[str, _last_value] + + # Sender tracking — written by every node, needs reducer for parallel writes + sender: Annotated[str, _last_value] diff --git a/tradingagents/agents/utils/tool_runner.py b/tradingagents/agents/utils/tool_runner.py new file mode 100644 index 00000000..3c07d5a4 --- /dev/null +++ b/tradingagents/agents/utils/tool_runner.py @@ -0,0 +1,67 @@ +"""Utility for running an LLM tool-calling loop within a single graph node. + +The existing trading-graph agents rely on separate ToolNode graph nodes for +tool execution. Scanner agents are simpler — they run in a single node per +phase — so they need an inline tool-execution loop. +""" + +from __future__ import annotations + +from typing import Any, List + +from langchain_core.messages import AIMessage, ToolMessage + + +# Most LLM tool-calling patterns resolve within 2-3 rounds; +# 5 provides headroom for complex scenarios while preventing runaway loops. +MAX_TOOL_ROUNDS = 5 + + +def run_tool_loop( + chain, + messages: List[Any], + tools: List[Any], + max_rounds: int = MAX_TOOL_ROUNDS, +) -> AIMessage: + """Invoke *chain* in a loop, executing any tool calls until the LLM + produces a final text response (i.e. no more tool_calls). + + Args: + chain: A LangChain runnable (prompt | llm.bind_tools(tools)). + messages: The initial list of messages to send. + tools: List of LangChain tool objects (must match the tools bound to the LLM). + max_rounds: Maximum number of tool-calling rounds before forcing a stop. + + Returns: + The final AIMessage with a text ``content`` (report). + """ + tool_map = {t.name: t for t in tools} + current_messages = list(messages) + + for _ in range(max_rounds): + result: AIMessage = chain.invoke(current_messages) + current_messages.append(result) + + if not result.tool_calls: + return result + + # Execute each requested tool call and append ToolMessages + for tc in result.tool_calls: + tool_name = tc["name"] + tool_args = tc["args"] + tool_fn = tool_map.get(tool_name) + if tool_fn is None: + tool_output = f"Error: unknown tool '{tool_name}'" + else: + try: + tool_output = tool_fn.invoke(tool_args) + except Exception as e: + tool_output = f"Error calling {tool_name}: {e}" + + current_messages.append( + ToolMessage(content=str(tool_output), tool_call_id=tc["id"]) + ) + + # If we exhausted max_rounds, return the last AIMessage + # (it may have tool_calls but we treat the content as the report) + return result diff --git a/tradingagents/dataflows/alpha_vantage_common.py b/tradingagents/dataflows/alpha_vantage_common.py index 409ff29e..5aaa27a5 100644 --- a/tradingagents/dataflows/alpha_vantage_common.py +++ b/tradingagents/dataflows/alpha_vantage_common.py @@ -2,6 +2,8 @@ import os import requests import pandas as pd import json +import threading +import time as _time from datetime import datetime from io import StringIO @@ -35,47 +37,170 @@ def format_datetime_for_api(date_input) -> str: else: raise ValueError(f"Date must be string or datetime object, got {type(date_input)}") -class AlphaVantageRateLimitError(Exception): - """Exception raised when Alpha Vantage API rate limit is exceeded.""" +# ─── Exception hierarchy ───────────────────────────────────────────────────── + +class AlphaVantageError(Exception): + """Base exception for all Alpha Vantage API errors.""" pass -def _make_api_request(function_name: str, params: dict) -> dict | str: - """Helper function to make API requests and handle responses. - + +class APIKeyInvalidError(AlphaVantageError): + """Raised when the API key is invalid or missing (401-equivalent).""" + pass + + +class RateLimitError(AlphaVantageError): + """Raised when the API rate limit is exceeded (429-equivalent).""" + pass + + +# Keep old name as alias so existing imports don't break +AlphaVantageRateLimitError = RateLimitError + + +class ThirdPartyError(AlphaVantageError): + """Raised on server-side errors (5xx status codes).""" + pass + + +class ThirdPartyTimeoutError(AlphaVantageError): + """Raised when the request times out.""" + pass + + +class ThirdPartyParseError(AlphaVantageError): + """Raised when the response cannot be parsed (malformed JSON/CSV).""" + pass + + +# ─── Rate-limited request helper ───────────────────────────────────────────── + + +_rate_lock = threading.Lock() +_call_timestamps: list[float] = [] +_RATE_LIMIT = 75 # calls per minute (Alpha Vantage premium) + + +def _rate_limited_request(function_name: str, params: dict, timeout: int = 30) -> dict | str: + """Make an API request with rate limiting (75 calls/min for premium key).""" + sleep_time = 0.0 + with _rate_lock: + now = _time.time() + # Remove timestamps older than 60 seconds + _call_timestamps[:] = [t for t in _call_timestamps if now - t < 60] + if len(_call_timestamps) >= _RATE_LIMIT: + sleep_time = 60 - (now - _call_timestamps[0]) + 0.1 + + # Sleep outside the lock to avoid blocking other threads + if sleep_time > 0: + _time.sleep(sleep_time) + + # Re-check and register under lock to avoid races where multiple + # threads calculate similar sleep times and then all fire at once. + while True: + with _rate_lock: + now = _time.time() + _call_timestamps[:] = [t for t in _call_timestamps if now - t < 60] + if len(_call_timestamps) >= _RATE_LIMIT: + # Another thread filled the window while we slept — wait again + extra_sleep = 60 - (now - _call_timestamps[0]) + 0.1 + else: + _call_timestamps.append(_time.time()) + break + # Sleep outside the lock to avoid blocking other threads + _time.sleep(extra_sleep) + + + return _make_api_request(function_name, params, timeout=timeout) + + +# ─── Core API request ──────────────────────────────────────────────────────── + +def _make_api_request(function_name: str, params: dict, timeout: int = 30) -> dict | str: + """Make an Alpha Vantage API request with proper error handling. + + Returns the response text (JSON string or CSV). + Raises: - AlphaVantageRateLimitError: When API rate limit is exceeded + APIKeyInvalidError: Invalid or missing API key. + RateLimitError: Rate limit exceeded. + ThirdPartyError: Server-side error (5xx). + ThirdPartyTimeoutError: Request timed out. + ThirdPartyParseError: Response could not be parsed. """ - # Create a copy of params to avoid modifying the original api_params = params.copy() api_params.update({ "function": function_name, "apikey": get_api_key(), "source": "trading_agents", }) - - # Handle entitlement parameter if present in params or global variable + + # Handle entitlement parameter current_entitlement = globals().get('_current_entitlement') entitlement = api_params.get("entitlement") or current_entitlement - if entitlement: api_params["entitlement"] = entitlement - elif "entitlement" in api_params: - # Remove entitlement if it's None or empty + else: api_params.pop("entitlement", None) - - response = requests.get(API_BASE_URL, params=api_params) - response.raise_for_status() + + try: + response = requests.get(API_BASE_URL, params=api_params, timeout=timeout) + except requests.exceptions.Timeout: + raise ThirdPartyTimeoutError( + f"Request timed out: function={function_name}, params={params}" + ) + except requests.exceptions.ConnectionError as exc: + raise ThirdPartyError(f"Connection error: function={function_name}, error={exc}") + except requests.exceptions.RequestException as exc: + raise ThirdPartyError(f"Request failed: function={function_name}, error={exc}") + + # HTTP-level errors + if response.status_code == 401: + raise APIKeyInvalidError( + f"Invalid API key: status={response.status_code}, body={response.text[:200]}" + ) + if response.status_code == 429: + raise RateLimitError( + f"Rate limit exceeded: status={response.status_code}, body={response.text[:200]}" + ) + if response.status_code >= 500: + raise ThirdPartyError( + f"Server error: status={response.status_code}, function={function_name}, " + f"body={response.text[:200]}" + ) + try: + response.raise_for_status() + except requests.exceptions.HTTPError as exc: + raise ThirdPartyError( + f"HTTP error: status={response.status_code}, function={function_name}, " + f"body={response.text[:200]}" + ) from exc response_text = response.text - - # Check if response is JSON (error responses are typically JSON) + + # Check for AV-specific error patterns in JSON body try: response_json = json.loads(response_text) - # Check for rate limit error + + if "Error Message" in response_json: + msg = response_json["Error Message"] + if "invalid" in msg.lower() and "apikey" in msg.lower(): + raise APIKeyInvalidError(f"Alpha Vantage: {msg}") + raise AlphaVantageError(f"Alpha Vantage API error: {msg}") + if "Information" in response_json: - info_message = response_json["Information"] - if "rate limit" in info_message.lower() or "api key" in info_message.lower(): - raise AlphaVantageRateLimitError(f"Alpha Vantage rate limit exceeded: {info_message}") + info = response_json["Information"] + info_lower = info.lower() + if "rate limit" in info_lower or "call frequency" in info_lower: + raise RateLimitError(f"Alpha Vantage rate limit: {info}") + if "invalid" in info_lower and "api" in info_lower: + raise APIKeyInvalidError(f"Alpha Vantage: {info}") + + if "Note" in response_json: + note = response_json["Note"] + if "api call frequency" in note.lower() or "rate limit" in note.lower(): + raise RateLimitError(f"Alpha Vantage rate limit: {note}") + except json.JSONDecodeError: # Response is not JSON (likely CSV data), which is normal pass diff --git a/tradingagents/dataflows/alpha_vantage_scanner.py b/tradingagents/dataflows/alpha_vantage_scanner.py index 032dc863..63933bb3 100644 --- a/tradingagents/dataflows/alpha_vantage_scanner.py +++ b/tradingagents/dataflows/alpha_vantage_scanner.py @@ -1,94 +1,614 @@ -"""Alpha Vantage-based scanner data fetching (fallback for market movers only).""" +"""Alpha Vantage-based scanner data fetching for market-wide analysis.""" -from typing import Annotated -from datetime import datetime import json -from .alpha_vantage_common import _make_api_request +from datetime import datetime, date +from typing import Annotated +from .alpha_vantage_common import ( + _rate_limited_request, + AlphaVantageError, + RateLimitError, + ThirdPartyParseError, +) + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + +_CATEGORY_KEY_MAP = { + "day_gainers": "top_gainers", + "day_losers": "top_losers", + "most_actives": "most_actively_traded", +} + +# ETF proxies for the 11 GICS sectors +_SECTOR_ETFS: dict[str, str] = { + "Technology": "XLK", + "Healthcare": "XLV", + "Financials": "XLF", + "Energy": "XLE", + "Consumer Discretionary": "XLY", + "Consumer Staples": "XLP", + "Industrials": "XLI", + "Materials": "XLB", + "Real Estate": "XLRE", + "Utilities": "XLU", + "Communication Services": "XLC", +} + +# Representative large-cap tickers per sector (normalized keys: lowercase + dashes) +_SECTOR_TICKERS: dict[str, list[str]] = { + "technology": ["AAPL", "MSFT", "NVDA", "GOOGL", "META", "AVGO", "ADBE", "CRM", "AMD", "INTC"], + "healthcare": ["UNH", "JNJ", "LLY", "PFE", "ABT", "MRK", "TMO", "ABBV", "DHR", "AMGN"], + "financials": ["JPM", "BAC", "WFC", "GS", "MS", "BLK", "SCHW", "AXP", "C", "USB"], + "energy": ["XOM", "CVX", "COP", "SLB", "EOG", "MPC", "PSX", "VLO", "OXY", "HES"], + "consumer-discretionary": ["AMZN", "TSLA", "HD", "MCD", "NKE", "SBUX", "LOW", "TJX", "BKNG", "CMG"], + "consumer-staples": ["PG", "KO", "PEP", "COST", "WMT", "PM", "MDLZ", "CL", "KHC", "GIS"], + "industrials": ["CAT", "HON", "UNP", "UPS", "BA", "RTX", "DE", "LMT", "GE", "MMM"], + "materials": ["LIN", "APD", "SHW", "ECL", "FCX", "NEM", "NUE", "DOW", "DD", "PPG"], + "real-estate": ["PLD", "AMT", "CCI", "EQIX", "SPG", "PSA", "O", "WELL", "DLR", "AVB"], + "utilities": ["NEE", "DUK", "SO", "D", "AEP", "SRE", "EXC", "XEL", "WEC", "ED"], + "communication-services": ["META", "GOOGL", "NFLX", "DIS", "CMCSA", "T", "VZ", "CHTR", "TMUS", "EA"], +} + +_TOPIC_MAP: dict[str, str] = { + "market": "financial_markets", + "technology": "technology", + "tech": "technology", + "finance": "finance", + "financial": "finance", + "earnings": "earnings", + "ipo": "ipo", + "mergers": "mergers_and_acquisitions", + "m&a": "mergers_and_acquisitions", + "economy": "economy_macro", + "macro": "economy_macro", + "energy": "energy_transportation", + "real estate": "real_estate", + "realestate": "real_estate", + "healthcare": "life_sciences", + "pharma": "life_sciences", + "manufacturing": "manufacturing", + "crypto": "blockchain", + "blockchain": "blockchain", + "retail": "retail_wholesale", + "fiscal": "economy_fiscal", + "monetary": "economy_monetary", +} + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + +def _parse_json(text: str, context: str) -> dict: + """Parse a JSON string, raising ThirdPartyParseError on failure. -def get_market_movers_alpha_vantage( - category: Annotated[str, "Category: 'day_gainers', 'day_losers', or 'most_actives'"] -) -> str: - """ - Get market movers using Alpha Vantage TOP_GAINERS_LOSERS endpoint (fallback). - Args: - category: One of 'day_gainers', 'day_losers', or 'most_actives' - + text: Raw response text from the API. + context: Human-readable label for error messages (e.g. function + symbol). + Returns: - Formatted string containing top market movers + Parsed JSON as a dict. + + Raises: + ThirdPartyParseError: When the text is not valid JSON. """ try: - # Alpha Vantage only supports top_gainers_losers endpoint - # It doesn't have 'most_actives' directly - if category not in ['day_gainers', 'day_losers', 'most_actives']: - return f"Invalid category '{category}'. Must be one of: day_gainers, day_losers, most_actives" - - if category == 'most_actives': - return "Alpha Vantage does not support 'most_actives' category. Please use yfinance instead." - - # Make API request for TOP_GAINERS_LOSERS endpoint - response = _make_api_request("TOP_GAINERS_LOSERS", {}) - if isinstance(response, dict): - data = response - else: - data = json.loads(response) - - if "Error Message" in data: - return f"Error from Alpha Vantage: {data['Error Message']}" - - if "Note" in data: - return f"Alpha Vantage API limit reached: {data['Note']}" - - # Map category to Alpha Vantage response key - if category == 'day_gainers': - key = 'top_gainers' - elif category == 'day_losers': - key = 'top_losers' - else: - return f"Unsupported category: {category}" - - if key not in data: - return f"No data found for {category}" - - movers = data[key] - - if not movers: - return f"No movers found for {category}" - - # Format the output - header = f"# Market Movers: {category.replace('_', ' ').title()} (Alpha Vantage)\n" - header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n" - - result_str = header - result_str += "| Symbol | Price | Change % | Volume |\n" - result_str += "|--------|-------|----------|--------|\n" - - for mover in movers[:15]: # Top 15 - symbol = mover.get('ticker', 'N/A') - price = mover.get('price', 'N/A') - change_pct = mover.get('change_percentage', 'N/A') - volume = mover.get('volume', 'N/A') - - # Format numbers - if isinstance(price, str): - try: - price = f"${float(price):.2f}" - except ValueError: - price = "N/A" - if isinstance(change_pct, str): - change_pct = change_pct.rstrip('%') # Remove % if present - if isinstance(change_pct, (int, float)): - change_pct = f"{float(change_pct):.2f}%" - if isinstance(volume, (int, str)): - try: - volume = f"{int(volume):,}" - except ValueError: - volume = "N/A" - - result_str += f"| {symbol} | {price} | {change_pct} | {volume} |\n" - - return result_str - - except Exception as e: - return f"Error fetching market movers from Alpha Vantage for {category}: {str(e)}" + return json.loads(text) + except json.JSONDecodeError as exc: + raise ThirdPartyParseError( + f"Failed to parse JSON response for {context}: {exc}" + ) from exc + + +def _fetch_global_quote(symbol: str) -> dict: + """Fetch a single GLOBAL_QUOTE entry for a symbol. + + Args: + symbol: Ticker symbol (e.g. "SPY"). + + Returns: + The inner "Global Quote" dict from the API response. + + Raises: + AlphaVantageError: On API-level errors. + ThirdPartyParseError: On malformed JSON. + KeyError: When the expected "Global Quote" key is absent. + """ + text = _rate_limited_request("GLOBAL_QUOTE", {"symbol": symbol}) + data = _parse_json(text, f"GLOBAL_QUOTE/{symbol}") + if "Global Quote" not in data: + raise AlphaVantageError( + f"GLOBAL_QUOTE response for {symbol} missing 'Global Quote' key. " + f"Keys present: {list(data.keys())}" + ) + return data["Global Quote"] + + +def _fetch_daily_closes(symbol: str) -> list[tuple[date, float]]: + """Fetch up to 100 days of daily close prices for a symbol. + + Args: + symbol: Ticker symbol (e.g. "XLK"). + + Returns: + List of (date, close_price) tuples, sorted ascending by date. + + Raises: + AlphaVantageError: On API-level errors or missing data key. + ThirdPartyParseError: On malformed JSON. + """ + text = _rate_limited_request( + "TIME_SERIES_DAILY", + {"symbol": symbol, "outputsize": "compact"}, + ) + data = _parse_json(text, f"TIME_SERIES_DAILY/{symbol}") + + ts_key = "Time Series (Daily)" + if ts_key not in data: + raise AlphaVantageError( + f"TIME_SERIES_DAILY response for {symbol} missing '{ts_key}' key. " + f"Keys present: {list(data.keys())}" + ) + + entries: list[tuple[date, float]] = [] + for date_str, ohlcv in data[ts_key].items(): + try: + close = float(ohlcv["4. close"]) + day = datetime.strptime(date_str, "%Y-%m-%d").date() + entries.append((day, close)) + except (KeyError, ValueError): + # Skip malformed individual entries rather than failing entirely + continue + + entries.sort(key=lambda x: x[0]) # ascending + return entries + + +def _pct_change(closes: list[tuple[date, float]], days_back: int) -> float | None: + """Compute percentage change from `days_back` trading days ago to today. + + Args: + closes: Ascending list of (date, close) pairs. + days_back: How many entries back to use as the base. + + Returns: + Percentage change as a float, or None when there is insufficient data. + """ + if len(closes) < days_back + 1: + return None + base = closes[-(days_back + 1)][1] + current = closes[-1][1] + if base == 0: + return None + return (current - base) / base * 100 + + +def _ytd_pct_change(closes: list[tuple[date, float]]) -> float | None: + """Compute year-to-date percentage change. + + Args: + closes: Ascending list of (date, close) pairs. + + Returns: + YTD percentage change, or None when the prior year-end close is not + available in the provided data. + """ + if not closes: + return None + + current_year = closes[-1][0].year + # Find the last close from the prior calendar year + prior_year_closes = [c for c in closes if c[0].year < current_year] + if not prior_year_closes: + return None + + base = prior_year_closes[-1][1] + current = closes[-1][1] + if base == 0: + return None + return (current - base) / base * 100 + + +def _fmt_pct(value: float | None) -> str: + """Format an optional float as a percentage string. + + Args: + value: The percentage value, or None. + + Returns: + String like "+1.23%" or "N/A". + """ + if value is None: + return "N/A" + return f"{value:+.2f}%" + + +def _now_str() -> str: + return datetime.now().strftime("%Y-%m-%d %H:%M:%S") + + +# --------------------------------------------------------------------------- +# Public scanner functions +# --------------------------------------------------------------------------- + +def get_market_movers_alpha_vantage( + category: Annotated[str, "Category: 'day_gainers', 'day_losers', or 'most_actives'"], +) -> str: + """Get market movers using the Alpha Vantage TOP_GAINERS_LOSERS endpoint. + + Args: + category: One of 'day_gainers', 'day_losers', or 'most_actives'. + + Returns: + Markdown table of the top 15 movers with Symbol, Price, Change %, Volume. + + Raises: + ValueError: When an unsupported category is requested. + AlphaVantageError: On API-level errors. + ThirdPartyParseError: On malformed JSON. + """ + if category not in _CATEGORY_KEY_MAP: + raise ValueError( + f"Invalid category '{category}'. " + f"Must be one of: {list(_CATEGORY_KEY_MAP.keys())}" + ) + + text = _rate_limited_request("TOP_GAINERS_LOSERS", {}) + data = _parse_json(text, "TOP_GAINERS_LOSERS") + + response_key = _CATEGORY_KEY_MAP[category] + if response_key not in data: + raise AlphaVantageError( + f"TOP_GAINERS_LOSERS response missing expected key '{response_key}'. " + f"Keys present: {list(data.keys())}" + ) + + movers: list[dict] = data[response_key] + # A 200 response with an empty list is a valid (genuinely empty) market state + # — we report it as-is rather than raising. + + header = ( + f"# Market Movers: {category.replace('_', ' ').title()} (Alpha Vantage)\n" + f"# Data retrieved on: {_now_str()}\n\n" + ) + result = header + result += "| Symbol | Price | Change % | Volume |\n" + result += "|--------|-------|----------|--------|\n" + + for mover in movers[:15]: + symbol = mover.get("ticker", "N/A") + + raw_price = mover.get("price", "N/A") + try: + price = f"${float(raw_price):.2f}" + except (ValueError, TypeError): + price = str(raw_price) + + raw_change = mover.get("change_percentage", "N/A") + # AV returns values like "3.45%" — normalise to a consistent display + try: + change_pct = f"{float(str(raw_change).rstrip('%')):.2f}%" + except (ValueError, TypeError): + change_pct = str(raw_change) + + raw_volume = mover.get("volume", "N/A") + try: + volume = f"{int(raw_volume):,}" + except (ValueError, TypeError): + volume = str(raw_volume) + + result += f"| {symbol} | {price} | {change_pct} | {volume} |\n" + + return result + + +def get_market_indices_alpha_vantage() -> str: + """Get major market index levels via ETF proxies and the VIX index. + + Uses GLOBAL_QUOTE for each proxy: SPY (S&P 500), DIA (Dow Jones), + QQQ (NASDAQ), IWM (Russell 2000), and VIX (CBOE Volatility Index). + + Returns: + Markdown table with Index, Price, Change, Change %. + + Raises: + AlphaVantageError: On API-level errors. + ThirdPartyParseError: On malformed JSON. + """ + # ETF proxies — keyed by display name + proxies: list[tuple[str, str]] = [ + ("S&P 500 (SPY)", "SPY"), + ("Dow Jones (DIA)", "DIA"), + ("NASDAQ (QQQ)", "QQQ"), + ("Russell 2000 (IWM)", "IWM"), + ] + + header = ( + f"# Major Market Indices (Alpha Vantage)\n" + f"# Data retrieved on: {_now_str()}\n\n" + ) + result = header + result += "| Index | Price | Change | Change % |\n" + result += "|-------|-------|--------|----------|\n" + + for display_name, symbol in proxies: + try: + quote = _fetch_global_quote(symbol) + price = quote.get("05. price", "N/A") + change = quote.get("09. change", "N/A") + change_pct = quote.get("10. change percent", "N/A") + + try: + price = f"${float(price):.2f}" + except (ValueError, TypeError): + pass + + try: + change = f"{float(change):+.2f}" + except (ValueError, TypeError): + pass + + # AV returns "change percent" as "1.23%" — keep as-is if it has the sign, + # otherwise add a + prefix for positive values. + change_pct = str(change_pct).strip() + + result += f"| {display_name} | {price} | {change} | {change_pct} |\n" + + except (AlphaVantageError, ThirdPartyParseError, RateLimitError) as exc: + result += f"| {display_name} | Error | - | {exc!s:.40} |\n" + + # VIX — try "VIX" first, fall back to "^VIX" + vix_symbol = None + vix_quote: dict | None = None + for candidate in ("VIX", "^VIX"): + try: + vix_quote = _fetch_global_quote(candidate) + vix_symbol = candidate + break + except (AlphaVantageError, ThirdPartyParseError, RateLimitError): + continue + + if vix_quote is not None: + price = vix_quote.get("05. price", "N/A") + change = vix_quote.get("09. change", "N/A") + change_pct = vix_quote.get("10. change percent", "N/A") + try: + price = f"{float(price):.2f}" + except (ValueError, TypeError): + pass + try: + change = f"{float(change):+.2f}" + except (ValueError, TypeError): + pass + result += f"| VIX ({vix_symbol}) | {price} | {change} | {change_pct} |\n" + else: + result += "| VIX | Unavailable | - | - |\n" + + return result + + +def get_sector_performance_alpha_vantage() -> str: + """Get daily and multi-period performance for the 11 GICS sectors via SPDR ETFs. + + Makes one GLOBAL_QUOTE call and one TIME_SERIES_DAILY call per ETF (22+ total). + Uses _rate_limited_request throughout to stay within the 75 calls/min limit. + + Returns: + Markdown table with Sector, 1-Day %, 1-Week %, 1-Month %, YTD %. + + Raises: + AlphaVantageError: On API-level errors. + ThirdPartyParseError: On malformed JSON. + """ + header = ( + f"# Sector Performance Overview (Alpha Vantage)\n" + f"# Data retrieved on: {_now_str()}\n\n" + ) + result = header + result += "| Sector | 1-Day % | 1-Week % | 1-Month % | YTD % |\n" + result += "|--------|---------|----------|-----------|-------|\n" + + success_count = 0 + last_error = None + + for sector_name, etf in _SECTOR_ETFS.items(): + try: + # Daily change from GLOBAL_QUOTE (most recent data) + quote = _fetch_global_quote(etf) + raw_day_pct = quote.get("10. change percent", "N/A") + try: + # AV returns "1.23%" — strip % and reformat with sign + day_pct_str = f"{float(str(raw_day_pct).rstrip('%')):+.2f}%" + except (ValueError, TypeError): + day_pct_str = str(raw_day_pct) + + # Multi-period returns from daily close series + closes = _fetch_daily_closes(etf) + week_pct_str = _fmt_pct(_pct_change(closes, 5)) + month_pct_str = _fmt_pct(_pct_change(closes, 21)) + ytd_pct_str = _fmt_pct(_ytd_pct_change(closes)) + success_count += 1 + + except (AlphaVantageError, ThirdPartyParseError, RateLimitError) as exc: + last_error = exc + day_pct_str = week_pct_str = month_pct_str = ytd_pct_str = ( + f"Error: {exc!s:.30}" + ) + + result += ( + f"| {sector_name} | {day_pct_str} | {week_pct_str} | " + f"{month_pct_str} | {ytd_pct_str} |\n" + ) + + # If ALL sectors failed, raise so route_to_vendor can fall back + if success_count == 0 and last_error is not None: + raise AlphaVantageError( + f"All {len(_SECTOR_ETFS)} sector queries failed. Last error: {last_error}" + ) + + return result + + +def get_industry_performance_alpha_vantage( + sector_key: Annotated[str, "Sector key (e.g., 'technology', 'healthcare')"], +) -> str: + """Get price and daily change % for representative tickers in a sector. + + Args: + sector_key: Sector identifier — case-insensitive, spaces converted to dashes + (e.g., 'Technology', 'consumer-discretionary'). + + Returns: + Markdown table with Symbol, Price, Change %, sorted by Change % descending. + + Raises: + ValueError: When the normalised sector_key is not recognised. + AlphaVantageError: On API-level errors. + ThirdPartyParseError: On malformed JSON. + """ + normalised = sector_key.lower().replace(" ", "-") + if normalised not in _SECTOR_TICKERS: + raise ValueError( + f"Unknown sector '{sector_key}'. " + f"Valid keys: {list(_SECTOR_TICKERS.keys())}" + ) + + tickers = _SECTOR_TICKERS[normalised] + + rows: list[tuple[str, str, float | None, str]] = [] # (symbol, price_str, raw_change_float, change_str) + errors: list[str] = [] + + for symbol in tickers: + try: + quote = _fetch_global_quote(symbol) + raw_price = quote.get("05. price", "N/A") + raw_change = quote.get("10. change percent", "N/A") + + try: + price_str = f"${float(raw_price):.2f}" + except (ValueError, TypeError): + price_str = str(raw_price) + + try: + raw_change_float = float(str(raw_change).rstrip("%")) + change_str = f"{raw_change_float:+.2f}%" + except (ValueError, TypeError): + raw_change_float = None + change_str = str(raw_change) + + rows.append((symbol, price_str, raw_change_float, change_str)) + + except (AlphaVantageError, ThirdPartyParseError, RateLimitError) as exc: + errors.append(f"{symbol}: {exc!s:.60}") + + # Sort by change % descending; put rows without a numeric value last + rows.sort(key=lambda r: r[2] if r[2] is not None else float("-inf"), reverse=True) + + sector_title = normalised.replace("-", " ").title() + header = ( + f"# Industry Performance: {sector_title} (Alpha Vantage)\n" + f"# Data retrieved on: {_now_str()}\n\n" + ) + result = header + result += "| Symbol | Price | Change % |\n" + result += "|--------|-------|----------|\n" + + for symbol, price_str, _, change_str in rows: + result += f"| {symbol} | {price_str} | {change_str} |\n" + + # If ALL tickers failed, raise so route_to_vendor can fall back + if not rows and errors: + raise AlphaVantageError( + f"All {len(tickers)} ticker queries failed for sector '{sector_key}'. " + f"Last error: {errors[-1]}" + ) + + if errors: + result += "\n**Fetch errors:**\n" + for err in errors: + result += f"- {err}\n" + + return result + + +def get_topic_news_alpha_vantage( + topic: Annotated[str, "News topic (e.g., 'earnings', 'technology', 'market')"], + limit: Annotated[int, "Maximum number of articles to return"] = 10, +) -> str: + """Fetch topic-based news and sentiment via Alpha Vantage NEWS_SENTIMENT. + + Args: + topic: A topic string. Known topics are mapped to AV topic values; + unknown topics are passed through as-is. + limit: Maximum number of articles to return (default 10). + + Returns: + Markdown list of articles with title, summary, source, link, and + overall sentiment score. + + Raises: + AlphaVantageError: On API-level errors. + ThirdPartyParseError: On malformed JSON. + """ + av_topic = _TOPIC_MAP.get(topic.lower(), topic) + + params = { + "topics": av_topic, + "limit": str(limit), + "sort": "LATEST", + } + + text = _rate_limited_request("NEWS_SENTIMENT", params) + data = _parse_json(text, f"NEWS_SENTIMENT/{topic}") + + if "feed" not in data: + raise AlphaVantageError( + f"NEWS_SENTIMENT response missing 'feed' key for topic '{topic}'. " + f"Keys present: {list(data.keys())}" + ) + + articles: list[dict] = data["feed"] + + header = ( + f"# News for Topic: {topic} (Alpha Vantage)\n" + f"# Data retrieved on: {_now_str()}\n\n" + ) + result = header + + if not articles: + result += "_No articles found for this topic._\n" + return result + + for article in articles[:limit]: + title = article.get("title", "No title") + summary = article.get("summary", "") + source = article.get("source", "Unknown") + url = article.get("url", "") + sentiment_score = article.get("overall_sentiment_score") + published = article.get("time_published", "") + + # Format publication timestamp: "20240315T130000" → "2024-03-15 13:00" + if published and len(published) >= 13: + try: + dt = datetime.strptime(published[:15], "%Y%m%dT%H%M%S") + published = dt.strftime("%Y-%m-%d %H:%M") + except ValueError: + pass # keep raw value if unparseable + + sentiment_str = ( + f"{sentiment_score:.4f}" if isinstance(sentiment_score, float) else "N/A" + ) + + result += f"### {title}\n" + result += f"**Source:** {source}" + if published: + result += f" | **Published:** {published}" + result += f" | **Sentiment:** {sentiment_str}\n" + if summary: + result += f"{summary}\n" + if url: + result += f"**Link:** {url}\n" + result += "\n" + + return result diff --git a/tradingagents/dataflows/interface.py b/tradingagents/dataflows/interface.py index 22e57a6e..adddb290 100644 --- a/tradingagents/dataflows/interface.py +++ b/tradingagents/dataflows/interface.py @@ -1,4 +1,3 @@ -import logging from typing import Annotated # Import from vendor-specific modules @@ -30,8 +29,14 @@ from .alpha_vantage import ( get_news as get_alpha_vantage_news, get_global_news as get_alpha_vantage_global_news, ) -from .alpha_vantage_scanner import get_market_movers_alpha_vantage -from .alpha_vantage_common import AlphaVantageRateLimitError +from .alpha_vantage_scanner import ( + get_market_movers_alpha_vantage, + get_market_indices_alpha_vantage, + get_sector_performance_alpha_vantage, + get_industry_performance_alpha_vantage, + get_topic_news_alpha_vantage, +) +from .alpha_vantage_common import AlphaVantageError, AlphaVantageRateLimitError, RateLimitError # Configuration and routing logic from .config import get_config @@ -132,15 +137,19 @@ VENDOR_METHODS = { "alpha_vantage": get_market_movers_alpha_vantage, }, "get_market_indices": { + "alpha_vantage": get_market_indices_alpha_vantage, "yfinance": get_market_indices_yfinance, }, "get_sector_performance": { + "alpha_vantage": get_sector_performance_alpha_vantage, "yfinance": get_sector_performance_yfinance, }, "get_industry_performance": { + "alpha_vantage": get_industry_performance_alpha_vantage, "yfinance": get_industry_performance_yfinance, }, "get_topic_news": { + "alpha_vantage": get_topic_news_alpha_vantage, "yfinance": get_topic_news_yfinance, }, } @@ -192,8 +201,7 @@ def route_to_vendor(method: str, *args, **kwargs): try: return impl_func(*args, **kwargs) - except (AlphaVantageRateLimitError, ConnectionError, TimeoutError) as e: - logging.warning(f"Vendor '{vendor}' failed for '{method}': {e}, trying next...") - continue + except (AlphaVantageError, ConnectionError, TimeoutError): + continue # Any AV error or connection/timeout triggers fallback to next vendor raise RuntimeError(f"No available vendor for '{method}'") \ No newline at end of file diff --git a/tradingagents/dataflows/yfinance_scanner.py b/tradingagents/dataflows/yfinance_scanner.py index 34f54d41..d4649ab8 100644 --- a/tradingagents/dataflows/yfinance_scanner.py +++ b/tradingagents/dataflows/yfinance_scanner.py @@ -10,54 +10,52 @@ def get_market_movers_yfinance( ) -> str: """ Get market movers using yfinance Screener. - + Args: category: One of 'day_gainers', 'day_losers', or 'most_actives' - + Returns: Formatted string containing top market movers """ try: + # Map category to yfinance screener predefined screener screener_keys = { - "day_gainers": "day_gainers", - "day_losers": "day_losers", - "most_actives": "most_actives" + "day_gainers": "DAY_GAINERS", + "day_losers": "DAY_LOSERS", + "most_actives": "MOST_ACTIVES" } - + if category not in screener_keys: return f"Invalid category '{category}'. Must be one of: {list(screener_keys.keys())}" - - screener = yf.Screener() - data = screener.get_screeners([screener_keys[category]], count=25) - - if not data or screener_keys[category] not in data: + + # Use yfinance screener module's screen function + data = yf.screener.screen(screener_keys[category], count=25) + + if not data or not isinstance(data, dict) or 'quotes' not in data: return f"No data found for {category}" - - movers = data[screener_keys[category]] - - if not movers or 'quotes' not in movers: - return f"No movers found for {category}" - - quotes = movers['quotes'] - + + quotes = data['quotes'] + if not quotes: return f"No quotes found for {category}" - + + # Format the output header = f"# Market Movers: {category.replace('_', ' ').title()}\n" header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n" - + result_str = header result_str += "| Symbol | Name | Price | Change % | Volume | Market Cap |\n" result_str += "|--------|------|-------|----------|--------|------------|\n" - - for quote in quotes[:15]: + + for quote in quotes[:15]: # Top 15 symbol = quote.get('symbol', 'N/A') name = quote.get('shortName', quote.get('longName', 'N/A')) price = quote.get('regularMarketPrice', 'N/A') change_pct = quote.get('regularMarketChangePercent', 'N/A') volume = quote.get('regularMarketVolume', 'N/A') market_cap = quote.get('marketCap', 'N/A') - + + # Format numbers if isinstance(price, (int, float)): price = f"${price:.2f}" if isinstance(change_pct, (int, float)): @@ -66,11 +64,11 @@ def get_market_movers_yfinance( volume = f"{volume:,.0f}" if isinstance(market_cap, (int, float)): market_cap = f"${market_cap:,.0f}" - + result_str += f"| {symbol} | {name[:30]} | {price} | {change_pct} | {volume} | {market_cap} |\n" - + return result_str - + except Exception as e: return f"Error fetching market movers for {category}: {str(e)}" @@ -78,11 +76,12 @@ def get_market_movers_yfinance( def get_market_indices_yfinance() -> str: """ Get major market indices data. - + Returns: Formatted string containing index values and daily changes """ try: + # Major market indices indices = { "^GSPC": "S&P 500", "^DJI": "Dow Jones", @@ -90,120 +89,143 @@ def get_market_indices_yfinance() -> str: "^VIX": "VIX (Volatility Index)", "^RUT": "Russell 2000" } - - header = "# Major Market Indices\n" + + header = f"# Major Market Indices\n" header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n" - + result_str = header result_str += "| Index | Current Price | Change | Change % | 52W High | 52W Low |\n" result_str += "|-------|---------------|--------|----------|----------|----------|\n" - - # Batch download historical price data to avoid N+1 calls. - # yf.download() always returns multi-level columns when multiple symbols - # are requested (group_by="ticker"), so we access hist_batch[symbol]. + + # Batch-download 1-day history for all symbols in a single request symbols = list(indices.keys()) - hist_batch = yf.download( - symbols, - period="2d", - group_by="ticker", - progress=False, - auto_adjust=True, - ) + indices_history = yf.download(symbols, period="2d", auto_adjust=True, progress=False, threads=True) for symbol, name in indices.items(): try: ticker = yf.Ticker(symbol) - info = ticker.info + # fast_info is a lightweight cached property (no extra HTTP call) + fast = ticker.fast_info - # Extract per-symbol slice from the batched result. - # With multiple symbols and group_by="ticker", the columns are - # a MultiIndex keyed by symbol. + # Extract history for this symbol from the batch download try: - hist = hist_batch[symbol].dropna() + if len(symbols) > 1: + closes = indices_history["Close"][symbol].dropna() + else: + closes = indices_history["Close"].dropna() except KeyError: - hist = ticker.history(period="1d") + closes = None - if hist.empty: - result_str += f"| {name} | No data | - | - | - | - |\n" + if closes is None or len(closes) == 0: + result_str += f"| {name} | N/A | - | - | - | - |\n" continue - current_price = hist['Close'].iloc[-1] - prev_close = info.get('previousClose', current_price) + current_price = closes.iloc[-1] + prev_close = closes.iloc[-2] if len(closes) >= 2 else fast.previous_close + if prev_close is None or prev_close == 0: + prev_close = current_price + change = current_price - prev_close change_pct = (change / prev_close * 100) if prev_close else 0 - high_52w = info.get('fiftyTwoWeekHigh', 'N/A') - low_52w = info.get('fiftyTwoWeekLow', 'N/A') + high_52w = fast.year_high + low_52w = fast.year_low + # Format numbers current_str = f"{current_price:.2f}" change_str = f"{change:+.2f}" change_pct_str = f"{change_pct:+.2f}%" high_str = f"{high_52w:.2f}" if isinstance(high_52w, (int, float)) else str(high_52w) low_str = f"{low_52w:.2f}" if isinstance(low_52w, (int, float)) else str(low_52w) - + result_str += f"| {name} | {current_str} | {change_str} | {change_pct_str} | {high_str} | {low_str} |\n" - + except Exception as e: - result_str += f"| {name} | Error: {str(e)[:40]} | - | - | - | - |\n" - + result_str += f"| {name} | Error: {str(e)} | - | - | - | - |\n" + return result_str - + except Exception as e: return f"Error fetching market indices: {str(e)}" def get_sector_performance_yfinance() -> str: """ - Get sector-level performance overview using yfinance Sector data. + Get sector-level performance overview using SPDR sector ETFs. + + yfinance Sector.overview lacks performance data, so we use + sector ETFs (XLK, XLV, etc.) with yf.download() to compute + 1-day, 1-week, 1-month, and YTD returns. Returns: Formatted string containing sector performance data """ - try: - sector_keys = [ - "communication-services", - "consumer-cyclical", - "consumer-defensive", - "energy", - "financial-services", - "healthcare", - "industrials", - "basic-materials", - "real-estate", - "technology", - "utilities" - ] + # Map GICS sectors to SPDR ETF tickers + sector_etfs = { + "Technology": "XLK", + "Healthcare": "XLV", + "Financials": "XLF", + "Energy": "XLE", + "Consumer Discretionary": "XLY", + "Consumer Staples": "XLP", + "Industrials": "XLI", + "Materials": "XLB", + "Real Estate": "XLRE", + "Utilities": "XLU", + "Communication Services": "XLC", + } - header = "# Sector Performance Overview\n" + try: + symbols = list(sector_etfs.values()) + # Download ~6 months of data to cover YTD, 1-month, 1-week + hist = yf.download(symbols, period="6mo", auto_adjust=True, progress=False, threads=True) + + header = f"# Sector Performance Overview\n" header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n" result_str = header result_str += "| Sector | 1-Day % | 1-Week % | 1-Month % | YTD % |\n" result_str += "|--------|---------|----------|-----------|-------|\n" - for sector_key in sector_keys: + for sector_name, etf in sector_etfs.items(): try: - sector = yf.Sector(sector_key) - overview = sector.overview + # Extract close prices for this ETF + if len(symbols) > 1: + closes = hist["Close"][etf].dropna() + else: + closes = hist["Close"].dropna() - if overview is None or overview.empty: + if closes.empty or len(closes) < 2: + result_str += f"| {sector_name} | N/A | N/A | N/A | N/A |\n" continue - sector_name = sector_key.replace("-", " ").title() - day_return = overview.get('oneDay', {}).get('percentChange', 'N/A') - week_return = overview.get('oneWeek', {}).get('percentChange', 'N/A') - month_return = overview.get('oneMonth', {}).get('percentChange', 'N/A') - ytd_return = overview.get('ytd', {}).get('percentChange', 'N/A') + current = closes.iloc[-1] + prev = closes.iloc[-2] - day_str = f"{day_return:.2f}%" if isinstance(day_return, (int, float)) else str(day_return) - week_str = f"{week_return:.2f}%" if isinstance(week_return, (int, float)) else str(week_return) - month_str = f"{month_return:.2f}%" if isinstance(month_return, (int, float)) else str(month_return) - ytd_str = f"{ytd_return:.2f}%" if isinstance(ytd_return, (int, float)) else str(ytd_return) + # 1-day + day_pct = (current - prev) / prev * 100 if prev else 0 + + # 1-week (~5 trading days) + week_pct = _safe_pct(closes, 5) + # 1-month (~21 trading days) + month_pct = _safe_pct(closes, 21) + # YTD: first close of current year vs now + current_year = closes.index[-1].year + year_closes = closes[closes.index.year == current_year] + if len(year_closes) > 0 and year_closes.iloc[0] != 0: + ytd_pct = (current - year_closes.iloc[0]) / year_closes.iloc[0] * 100 + else: + ytd_pct = None + + day_str = f"{day_pct:+.2f}%" + week_str = f"{week_pct:+.2f}%" if week_pct is not None else "N/A" + month_str = f"{month_pct:+.2f}%" if month_pct is not None else "N/A" + ytd_str = f"{ytd_pct:+.2f}%" if ytd_pct is not None else "N/A" result_str += f"| {sector_name} | {day_str} | {week_str} | {month_str} | {ytd_str} |\n" except Exception as e: - result_str += f"| {sector_key.replace('-', ' ').title()} | Error: {str(e)[:20]} | - | - | - |\n" + result_str += f"| {sector_name} | Error: {str(e)[:30]} | - | - | - |\n" return result_str @@ -211,53 +233,60 @@ def get_sector_performance_yfinance() -> str: return f"Error fetching sector performance: {str(e)}" +def _safe_pct(closes, days_back: int) -> float | None: + """Compute percentage change from days_back trading days ago.""" + if len(closes) < days_back + 1: + return None + base = closes.iloc[-(days_back + 1)] + current = closes.iloc[-1] + if base == 0: + return None + return (current - base) / base * 100 + + def get_industry_performance_yfinance( sector_key: Annotated[str, "Sector key (e.g., 'technology', 'healthcare')"] ) -> str: """ Get industry-level drill-down within a sector. - + Args: sector_key: Sector identifier (e.g., 'technology', 'healthcare') - + Returns: Formatted string containing industry performance data within the sector """ try: + # Normalize sector key to yfinance format sector_key = sector_key.lower().replace(" ", "-") - + sector = yf.Sector(sector_key) top_companies = sector.top_companies - + if top_companies is None or top_companies.empty: return f"No industry data found for sector '{sector_key}'" - + header = f"# Industry Performance: {sector_key.replace('-', ' ').title()}\n" header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n" - + result_str = header - result_str += "| Company | Symbol | Industry | Market Cap | Change % |\n" - result_str += "|---------|--------|----------|------------|----------|\n" - - for idx, row in top_companies.head(20).iterrows(): - symbol = row.get('symbol', 'N/A') + result_str += "| Company | Symbol | Rating | Market Weight |\n" + result_str += "|---------|--------|--------|---------------|\n" + + # top_companies has ticker as the DataFrame index (index.name == 'symbol') + # Columns: name, rating, market weight + for symbol, row in top_companies.head(20).iterrows(): name = row.get('name', 'N/A') - industry = row.get('industry', 'N/A') - market_cap = row.get('marketCap', 'N/A') - change_pct = row.get('regularMarketChangePercent', 'N/A') + rating = row.get('rating', 'N/A') + market_weight = row.get('market weight', None) - if isinstance(market_cap, (int, float)): - market_cap = f"${market_cap:,.0f}" - if isinstance(change_pct, (int, float)): - change_pct = f"{change_pct:.2f}%" - - name_short = name[:30] if isinstance(name, str) else name - industry_short = industry[:25] if isinstance(industry, str) else industry - - result_str += f"| {name_short} | {symbol} | {industry_short} | {market_cap} | {change_pct} |\n" + name_short = name[:30] if isinstance(name, str) else str(name) + weight_str = f"{market_weight:.2%}" if isinstance(market_weight, (int, float)) else "N/A" + result_str += f"| {name_short} | {symbol} | {rating} | {weight_str} |\n" + return result_str - + except Exception as e: return f"Error fetching industry performance for sector '{sector_key}': {str(e)}" @@ -268,11 +297,11 @@ def get_topic_news_yfinance( ) -> str: """ Search news by arbitrary topic using yfinance Search. - + Args: topic: Search query/topic limit: Maximum number of articles to return - + Returns: Formatted string containing news articles for the topic """ @@ -282,23 +311,25 @@ def get_topic_news_yfinance( news_count=limit, enable_fuzzy_query=True, ) - + if not search.news: return f"No news found for topic '{topic}'" - + header = f"# News for Topic: {topic}\n" header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n" - + result_str = header - + for article in search.news[:limit]: + # Handle nested content structure if "content" in article: content = article["content"] title = content.get("title", "No title") summary = content.get("summary", "") provider = content.get("provider", {}) publisher = provider.get("displayName", "Unknown") - + + # Get URL url_obj = content.get("canonicalUrl") or content.get("clickThroughUrl") or {} link = url_obj.get("url", "") else: @@ -306,16 +337,15 @@ def get_topic_news_yfinance( summary = article.get("summary", "") publisher = article.get("publisher", "Unknown") link = article.get("link", "") - + result_str += f"### {title} (source: {publisher})\n" if summary: result_str += f"{summary}\n" if link: result_str += f"Link: {link}\n" result_str += "\n" - + return result_str - + except Exception as e: return f"Error fetching news for topic '{topic}': {str(e)}" - diff --git a/tradingagents/default_config.py b/tradingagents/default_config.py index 7e24e801..e42787b1 100644 --- a/tradingagents/default_config.py +++ b/tradingagents/default_config.py @@ -1,47 +1,83 @@ import os +from pathlib import Path + +from dotenv import load_dotenv + +# Load .env so that TRADINGAGENTS_* variables are available before +# DEFAULT_CONFIG is evaluated. CWD is checked first, then the project +# root (two levels up from this file). load_dotenv never overwrites +# variables that are already present in the environment. +load_dotenv() +load_dotenv(Path(__file__).resolve().parent.parent / ".env") + + +def _env(key: str, default=None): + """Read ``TRADINGAGENTS_`` from the environment. + + Returns *default* when the variable is unset **or** empty, so that + ``TRADINGAGENTS_MID_THINK_LLM=`` in a ``.env`` file is treated the + same as not setting it at all (preserving the ``None`` semantics for + "fall back to the parent setting"). + """ + val = os.getenv(f"TRADINGAGENTS_{key.upper()}") + if not val: # None or "" + return default + return val + + +def _env_int(key: str, default=None): + """Like :func:`_env` but coerces the value to ``int``.""" + val = _env(key) + if val is None: + return default + try: + return int(val) + except (ValueError, TypeError): + return default + DEFAULT_CONFIG = { "project_dir": os.path.abspath(os.path.join(os.path.dirname(__file__), ".")), - "results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", "./results"), + "results_dir": _env("RESULTS_DIR", "./results"), "data_cache_dir": os.path.join( os.path.abspath(os.path.join(os.path.dirname(__file__), ".")), "dataflows/data_cache", ), - # LLM settings - "llm_provider": "openai", - "deep_think_llm": "gpt-5.2", - "mid_think_llm": None, # falls back to quick_think_llm when None - "quick_think_llm": "gpt-5-mini", - "backend_url": "https://api.openai.com/v1", + # LLM settings — all overridable via TRADINGAGENTS_ env vars + "llm_provider": _env("LLM_PROVIDER", "openai"), + "deep_think_llm": _env("DEEP_THINK_LLM", "gpt-5.2"), + "mid_think_llm": _env("MID_THINK_LLM"), # falls back to quick_think_llm when None + "quick_think_llm": _env("QUICK_THINK_LLM", "gpt-5-mini"), + "backend_url": _env("BACKEND_URL", "https://api.openai.com/v1"), # Per-role provider overrides (fall back to llm_provider / backend_url when None) - "deep_think_llm_provider": None, # e.g. "google", "anthropic", "openai" - "deep_think_backend_url": None, # override backend URL for deep-think model - "mid_think_llm_provider": None, # e.g. "ollama" - "mid_think_backend_url": None, # override backend URL for mid-think model - "quick_think_llm_provider": None, # e.g. "openai", "ollama" - "quick_think_backend_url": None, # override backend URL for quick-think model + "deep_think_llm_provider": _env("DEEP_THINK_LLM_PROVIDER"), # e.g. "google", "anthropic", "openrouter" + "deep_think_backend_url": _env("DEEP_THINK_BACKEND_URL"), # override backend URL for deep-think model + "mid_think_llm_provider": _env("MID_THINK_LLM_PROVIDER"), # e.g. "ollama" + "mid_think_backend_url": _env("MID_THINK_BACKEND_URL"), # override backend URL for mid-think model + "quick_think_llm_provider": _env("QUICK_THINK_LLM_PROVIDER"), # e.g. "openai", "ollama" + "quick_think_backend_url": _env("QUICK_THINK_BACKEND_URL"), # override backend URL for quick-think model # Provider-specific thinking configuration (applies to all roles unless overridden) - "google_thinking_level": None, # "high", "minimal", etc. - "openai_reasoning_effort": None, # "medium", "high", "low" + "google_thinking_level": _env("GOOGLE_THINKING_LEVEL"), # "high", "minimal", etc. + "openai_reasoning_effort": _env("OPENAI_REASONING_EFFORT"), # "medium", "high", "low" # Per-role provider-specific thinking configuration - "deep_think_google_thinking_level": None, - "deep_think_openai_reasoning_effort": None, - "mid_think_google_thinking_level": None, - "mid_think_openai_reasoning_effort": None, - "quick_think_google_thinking_level": None, - "quick_think_openai_reasoning_effort": None, + "deep_think_google_thinking_level": _env("DEEP_THINK_GOOGLE_THINKING_LEVEL"), + "deep_think_openai_reasoning_effort": _env("DEEP_THINK_OPENAI_REASONING_EFFORT"), + "mid_think_google_thinking_level": _env("MID_THINK_GOOGLE_THINKING_LEVEL"), + "mid_think_openai_reasoning_effort": _env("MID_THINK_OPENAI_REASONING_EFFORT"), + "quick_think_google_thinking_level": _env("QUICK_THINK_GOOGLE_THINKING_LEVEL"), + "quick_think_openai_reasoning_effort": _env("QUICK_THINK_OPENAI_REASONING_EFFORT"), # Debate and discussion settings - "max_debate_rounds": 1, - "max_risk_discuss_rounds": 1, - "max_recur_limit": 100, + "max_debate_rounds": _env_int("MAX_DEBATE_ROUNDS", 1), + "max_risk_discuss_rounds": _env_int("MAX_RISK_DISCUSS_ROUNDS", 1), + "max_recur_limit": _env_int("MAX_RECUR_LIMIT", 100), # Data vendor configuration # Category-level configuration (default for all tools in category) "data_vendors": { - "core_stock_apis": "yfinance", # Options: alpha_vantage, yfinance - "technical_indicators": "yfinance", # Options: alpha_vantage, yfinance - "fundamental_data": "yfinance", # Options: alpha_vantage, yfinance - "news_data": "yfinance", # Options: alpha_vantage, yfinance - "scanner_data": "yfinance", # Options: yfinance (primary), alpha_vantage (fallback for movers only) + "core_stock_apis": _env("VENDOR_CORE_STOCK_APIS", "yfinance"), + "technical_indicators": _env("VENDOR_TECHNICAL_INDICATORS", "yfinance"), + "fundamental_data": _env("VENDOR_FUNDAMENTAL_DATA", "yfinance"), + "news_data": _env("VENDOR_NEWS_DATA", "yfinance"), + "scanner_data": _env("VENDOR_SCANNER_DATA", "yfinance"), }, # Tool-level configuration (takes precedence over category-level) "tool_vendors": { diff --git a/tradingagents/graph/scanner_graph.py b/tradingagents/graph/scanner_graph.py index 2115c28e..9bccd0ff 100644 --- a/tradingagents/graph/scanner_graph.py +++ b/tradingagents/graph/scanner_graph.py @@ -1,62 +1,135 @@ -# tradingagents/graph/scanner_graph.py +"""Scanner graph — orchestrates the 3-phase macro scanner pipeline.""" -import datetime -from typing import Any, Dict, Optional +from typing import Any -from tradingagents.dataflows.config import set_config from tradingagents.default_config import DEFAULT_CONFIG - +from tradingagents.llm_clients import create_llm_client +from tradingagents.agents.scanners import ( + create_geopolitical_scanner, + create_market_movers_scanner, + create_sector_scanner, + create_industry_deep_dive, + create_macro_synthesis, +) from .scanner_setup import ScannerGraphSetup -class MacroScannerGraph: - """Orchestrates the Global Macro Scanner workflow. +class ScannerGraph: + """Orchestrates the 3-phase macro scanner pipeline. - The scanner runs three parallel data-collection phases followed by a - synthesis phase: - - Phase 1 (parallel): - - Geopolitical / macro news scanner - - Market movers + index performance scanner - - Sector performance scanner - - Phase 2 (sequential): - - Industry deep dive (technology sector by default) - - Phase 3 (sequential): - - Macro synthesis — combines all outputs into a single summary + Phase 1 (parallel): geopolitical_scanner, market_movers_scanner, sector_scanner + Phase 2: industry_deep_dive (fan-in from Phase 1) + Phase 3: macro_synthesis -> END """ - def __init__(self, config: Optional[Dict[str, Any]] = None): - """Initialise the scanner graph. + def __init__(self, config: dict[str, Any] | None = None, debug: bool = False) -> None: + """Initialize the scanner graph. Args: - config: Optional configuration dictionary. Defaults to - ``DEFAULT_CONFIG`` when not provided. + config: Configuration dictionary. Falls back to DEFAULT_CONFIG when None. + debug: Whether to stream and print intermediate states. """ - self.config = config or DEFAULT_CONFIG - set_config(self.config) + self.config = config or DEFAULT_CONFIG.copy() + self.debug = debug - self.graph_setup = ScannerGraphSetup() - self.graph = self.graph_setup.setup_graph() + quick_llm = self._create_llm("quick_think") + mid_llm = self._create_llm("mid_think") + deep_llm = self._create_llm("deep_think") - def scan(self, scan_date: Optional[str] = None) -> Dict[str, Any]: - """Execute the macro scan and return the final state. + agents = { + "geopolitical_scanner": create_geopolitical_scanner(quick_llm), + "market_movers_scanner": create_market_movers_scanner(quick_llm), + "sector_scanner": create_sector_scanner(quick_llm), + "industry_deep_dive": create_industry_deep_dive(mid_llm), + "macro_synthesis": create_macro_synthesis(deep_llm), + } + + setup = ScannerGraphSetup(agents) + self.graph = setup.setup_graph() + + def _create_llm(self, tier: str) -> Any: + """Create an LLM instance for the given tier. + + Mirrors the provider/model/backend_url resolution logic from + TradingAgentsGraph, including mid_think fallback to quick_think. Args: - scan_date: Date string in ``YYYY-MM-DD`` format. Defaults to - today's date when not provided. + tier: One of "quick_think", "mid_think", or "deep_think". Returns: - Final LangGraph state dictionary containing all scan reports and - the ``macro_scan_summary`` field. + A LangChain-compatible chat model instance. """ - if scan_date is None: - scan_date = datetime.date.today().isoformat() + kwargs = self._get_provider_kwargs(tier) - initial_state = { - "messages": [], + if tier == "mid_think": + model = self.config.get("mid_think_llm") or self.config["quick_think_llm"] + provider = ( + self.config.get("mid_think_llm_provider") + or self.config.get("quick_think_llm_provider") + or self.config["llm_provider"] + ) + backend_url = ( + self.config.get("mid_think_backend_url") + or self.config.get("quick_think_backend_url") + or self.config.get("backend_url") + ) + else: + model = self.config[f"{tier}_llm"] + provider = self.config.get(f"{tier}_llm_provider") or self.config["llm_provider"] + backend_url = self.config.get(f"{tier}_backend_url") or self.config.get("backend_url") + + client = create_llm_client( + provider=provider, + model=model, + base_url=backend_url, + **kwargs, + ) + return client.get_llm() + + def _get_provider_kwargs(self, tier: str) -> dict[str, Any]: + """Resolve provider-specific kwargs (e.g. thinking_level, reasoning_effort). + + Args: + tier: One of "quick_think", "mid_think", or "deep_think". + + Returns: + Dict of extra kwargs to pass to the LLM client constructor. + """ + kwargs: dict[str, Any] = {} + prefix = f"{tier}_" + provider = ( + self.config.get(f"{prefix}llm_provider") or self.config.get("llm_provider", "") + ).lower() + + if provider == "google": + thinking_level = self.config.get(f"{prefix}google_thinking_level") or self.config.get( + "google_thinking_level" + ) + if thinking_level: + kwargs["thinking_level"] = thinking_level + + elif provider in ("openai", "xai", "openrouter", "ollama"): + reasoning_effort = self.config.get( + f"{prefix}openai_reasoning_effort" + ) or self.config.get("openai_reasoning_effort") + if reasoning_effort: + kwargs["reasoning_effort"] = reasoning_effort + + return kwargs + + def scan(self, scan_date: str) -> dict: + """Run the scanner pipeline and return the final state. + + Args: + scan_date: Date string in YYYY-MM-DD format for the scan. + + Returns: + Final LangGraph state dict containing all scanner reports and + the macro_scan_summary produced by the synthesis phase. + """ + initial_state: dict[str, Any] = { "scan_date": scan_date, + "messages": [], "geopolitical_report": "", "market_movers_report": "", "sector_performance_report": "", @@ -65,9 +138,11 @@ class MacroScannerGraph: "sender": "", } - final_state = self.graph.invoke( - initial_state, - {"recursion_limit": self.config.get("max_recur_limit", 100)}, - ) + if self.debug: + # stream() yields partial state updates; use invoke() for the + # full accumulated state and print chunks for debugging only. + for chunk in self.graph.stream(initial_state): + print(f"[scanner debug] chunk keys: {list(chunk.keys())}") + # Fall through to invoke() for the correct accumulated result - return final_state + return self.graph.invoke(initial_state) diff --git a/tradingagents/graph/scanner_setup.py b/tradingagents/graph/scanner_setup.py index 68413b5c..c4f8302b 100644 --- a/tradingagents/graph/scanner_setup.py +++ b/tradingagents/graph/scanner_setup.py @@ -1,78 +1,52 @@ -# tradingagents/graph/scanner_setup.py +"""Setup for the scanner workflow graph.""" + from langgraph.graph import StateGraph, START, END from tradingagents.agents.utils.scanner_states import ScannerState -from tradingagents.dataflows.interface import route_to_vendor - - -def geopolitical_scanner_node(state: ScannerState) -> dict: - """Phase 1: Fetch geopolitical and macro news.""" - result = route_to_vendor("get_topic_news", "geopolitics global economy", 10) - return {"geopolitical_report": result} - - -def market_movers_scanner_node(state: ScannerState) -> dict: - """Phase 1: Fetch market movers and index performance.""" - movers = route_to_vendor("get_market_movers", "day_gainers") - indices = route_to_vendor("get_market_indices") - return {"market_movers_report": movers + "\n\n" + indices} - - -def sector_scanner_node(state: ScannerState) -> dict: - """Phase 1: Fetch sector performance overview.""" - result = route_to_vendor("get_sector_performance") - return {"sector_performance_report": result} - - -def industry_deep_dive_node(state: ScannerState) -> dict: - """Phase 2: Drill down into the technology sector as a representative example.""" - result = route_to_vendor("get_industry_performance", "technology") - return {"industry_deep_dive_report": result} - - -def macro_synthesis_node(state: ScannerState) -> dict: - """Phase 3: Combine all scanner outputs into a final summary.""" - parts = [ - state.get("geopolitical_report", ""), - state.get("market_movers_report", ""), - state.get("sector_performance_report", ""), - state.get("industry_deep_dive_report", ""), - ] - summary = "\n\n---\n\n".join(p for p in parts if p) - return {"macro_scan_summary": summary} class ScannerGraphSetup: - """Handles the setup and configuration of the scanner graph.""" + """Sets up the 3-phase scanner graph with LLM agent nodes. + + Phase 1: geopolitical_scanner, market_movers_scanner, sector_scanner (parallel fan-out) + Phase 2: industry_deep_dive (fan-in from all three Phase 1 nodes) + Phase 3: macro_synthesis -> END + """ + + def __init__(self, agents: dict) -> None: + """ + Args: + agents: Dict mapping node names to agent node functions: + - geopolitical_scanner + - market_movers_scanner + - sector_scanner + - industry_deep_dive + - macro_synthesis + """ + self.agents = agents def setup_graph(self): - """Set up and compile the scanner workflow graph.""" + """Build and compile the scanner workflow graph. + + Returns: + A compiled LangGraph graph ready to invoke. + """ workflow = StateGraph(ScannerState) - # Phase 1: parallel scanners - workflow.add_node("geopolitical_scanner", geopolitical_scanner_node) - workflow.add_node("market_movers_scanner", market_movers_scanner_node) - workflow.add_node("sector_scanner", sector_scanner_node) + for name, node_fn in self.agents.items(): + workflow.add_node(name, node_fn) - # Phase 2: industry deep dive - workflow.add_node("industry_deep_dive", industry_deep_dive_node) - - # Phase 3: macro synthesis - workflow.add_node("macro_synthesis", macro_synthesis_node) - - # Fan-out from START to 3 parallel scanners + # Phase 1: parallel fan-out from START workflow.add_edge(START, "geopolitical_scanner") workflow.add_edge(START, "market_movers_scanner") workflow.add_edge(START, "sector_scanner") - # Fan-in: LangGraph's StateGraph guarantees that industry_deep_dive - # only executes after ALL three predecessor nodes have completed and - # their state updates have been merged. + # Fan-in: all three Phase 1 nodes must complete before Phase 2 workflow.add_edge("geopolitical_scanner", "industry_deep_dive") workflow.add_edge("market_movers_scanner", "industry_deep_dive") workflow.add_edge("sector_scanner", "industry_deep_dive") - # Sequential: deep dive → synthesis → end + # Phase 2 -> Phase 3 -> END workflow.add_edge("industry_deep_dive", "macro_synthesis") workflow.add_edge("macro_synthesis", END) diff --git a/tradingagents/llm_clients/openai_client.py b/tradingagents/llm_clients/openai_client.py index 7011895f..1076dacf 100644 --- a/tradingagents/llm_clients/openai_client.py +++ b/tradingagents/llm_clients/openai_client.py @@ -56,7 +56,11 @@ class OpenAIClient(BaseLLMClient): if api_key: llm_kwargs["api_key"] = api_key elif self.provider == "ollama": - llm_kwargs["base_url"] = "http://localhost:11434/v1" + host = self.base_url or "http://localhost:11434" + # Ensure the URL ends with /v1 for OpenAI-compatible endpoint + if not host.rstrip("/").endswith("/v1"): + host = host.rstrip("/") + "/v1" + llm_kwargs["base_url"] = host llm_kwargs["api_key"] = "ollama" # Ollama doesn't require auth elif self.base_url: llm_kwargs["base_url"] = self.base_url diff --git a/tradingagents/pipeline/__init__.py b/tradingagents/pipeline/__init__.py new file mode 100644 index 00000000..902b97c4 --- /dev/null +++ b/tradingagents/pipeline/__init__.py @@ -0,0 +1 @@ +# Macro bridge pipeline — connects scanner output to per-ticker analysis diff --git a/tradingagents/pipeline/macro_bridge.py b/tradingagents/pipeline/macro_bridge.py new file mode 100644 index 00000000..53d7a1fa --- /dev/null +++ b/tradingagents/pipeline/macro_bridge.py @@ -0,0 +1,518 @@ +"""Bridge between macro scanner output and TradingAgents per-ticker analysis.""" + +from __future__ import annotations + +import asyncio +import json +import logging +from dataclasses import dataclass +from datetime import datetime +from pathlib import Path +from typing import Literal + +logger = logging.getLogger(__name__) + +ConvictionLevel = Literal["high", "medium", "low"] + +CONVICTION_RANK: dict[str, int] = {"high": 3, "medium": 2, "low": 1} + + +@dataclass +class MacroContext: + """Macro-level context from scanner output.""" + + economic_cycle: str + central_bank_stance: str + geopolitical_risks: list[str] + key_themes: list[dict] # [{theme, description, conviction, timeframe}] + executive_summary: str + risk_factors: list[str] + timeframe: str = "1 month" + region: str = "Global" + + +@dataclass +class StockCandidate: + """A stock surfaced by the macro scanner.""" + + ticker: str + name: str + sector: str + rationale: str + thesis_angle: str # growth | value | catalyst | turnaround | defensive | momentum + conviction: ConvictionLevel + key_catalysts: list[str] + risks: list[str] + macro_theme: str = "" # which macro theme this stock is linked to + + +@dataclass +class TickerResult: + """TradingAgents output for one ticker, enriched with macro context.""" + + ticker: str + candidate: StockCandidate + macro_context: MacroContext + analysis_date: str + + # TradingAgents reports (populated after propagate()) + market_report: str = "" + sentiment_report: str = "" + news_report: str = "" + fundamentals_report: str = "" + investment_debate: str = "" + trader_investment_plan: str = "" + risk_debate: str = "" + final_trade_decision: str = "" + + error: str | None = None + + +# ─── Parsing ────────────────────────────────────────────────────────────────── + + +def parse_macro_output(path: Path) -> tuple[MacroContext, list[StockCandidate]]: + """Parse the JSON output from the Macro Intelligence Agent. + + Args: + path: Path to the JSON file produced by the macro scanner. + + Returns: + Tuple of (MacroContext, list of StockCandidate). + """ + with path.open() as f: + data = json.load(f) + + ctx_raw = data.get("macro_context", {}) + macro_context = MacroContext( + economic_cycle=ctx_raw.get("economic_cycle", ""), + central_bank_stance=ctx_raw.get("central_bank_stance", ""), + geopolitical_risks=ctx_raw.get("geopolitical_risks", []), + key_themes=data.get("key_themes", []), + executive_summary=data.get("executive_summary", ""), + risk_factors=data.get("risk_factors", []), + timeframe=data.get("timeframe", "1 month"), + region=data.get("region", "Global"), + ) + + candidates: list[StockCandidate] = [] + for s in data.get("stocks_to_investigate", []): + theme = _match_theme(s.get("sector", ""), data.get("key_themes", [])) + candidates.append( + StockCandidate( + ticker=s["ticker"].upper(), + name=s.get("name", s["ticker"]), + sector=s.get("sector", ""), + rationale=s.get("rationale", ""), + thesis_angle=s.get("thesis_angle", ""), + conviction=s.get("conviction", "medium"), + key_catalysts=s.get("key_catalysts", []), + risks=s.get("risks", []), + macro_theme=theme, + ) + ) + + return macro_context, candidates + + +def _match_theme(sector: str, themes: list[dict]) -> str: + """Return the macro theme name most likely linked to this sector. + + Args: + sector: Sector name for a stock candidate. + themes: List of macro theme dicts from the scanner output. + + Returns: + The matched theme name, or the first theme name, or empty string. + """ + sector_lower = sector.lower() + for t in themes: + desc = (t.get("description", "") + t.get("theme", "")).lower() + if sector_lower in desc or any(w in desc for w in sector_lower.split()): + return t.get("theme", "") + return themes[0].get("theme", "") if themes else "" + + +# ─── Core pipeline ──────────────────────────────────────────────────────────── + + +def filter_candidates( + candidates: list[StockCandidate], + min_conviction: ConvictionLevel, + ticker_filter: list[str] | None, +) -> list[StockCandidate]: + """Filter by conviction level and optional explicit ticker list. + + Args: + candidates: All stock candidates from the macro scanner. + min_conviction: Minimum conviction threshold ("high", "medium", or "low"). + ticker_filter: Optional list of tickers to restrict to. + + Returns: + Filtered and sorted list (high conviction first, then alphabetically). + """ + min_rank = CONVICTION_RANK[min_conviction] + filtered = [c for c in candidates if CONVICTION_RANK[c.conviction] >= min_rank] + if ticker_filter: + tickers_upper = {t.upper() for t in ticker_filter} + filtered = [c for c in filtered if c.ticker in tickers_upper] + filtered.sort(key=lambda c: (-CONVICTION_RANK[c.conviction], c.ticker)) + return filtered + + +def run_ticker_analysis( + candidate: StockCandidate, + macro_context: MacroContext, + config: dict, + analysis_date: str, +) -> TickerResult: + """Run the full TradingAgents pipeline for one ticker. + + NOTE: TradingAgentsGraph is synchronous — call this from a thread pool + when running multiple tickers concurrently (see run_all_tickers). + + Args: + candidate: The stock candidate to analyse. + macro_context: Macro context to embed in the result. + config: TradingAgents configuration dict. + analysis_date: Date string in YYYY-MM-DD format. + + Returns: + TickerResult with all report fields populated, or error set on failure. + """ + result = TickerResult( + ticker=candidate.ticker, + candidate=candidate, + macro_context=macro_context, + analysis_date=analysis_date, + ) + + logger.info("Starting analysis for %s on %s", candidate.ticker, analysis_date) + + try: + from tradingagents.graph.trading_graph import TradingAgentsGraph + + ta = TradingAgentsGraph(debug=False, config=config) + final_state, decision = ta.propagate(candidate.ticker, analysis_date) + + result.market_report = final_state.get("market_report", "") + result.sentiment_report = final_state.get("sentiment_report", "") + result.news_report = final_state.get("news_report", "") + result.fundamentals_report = final_state.get("fundamentals_report", "") + result.investment_debate = str(final_state.get("investment_debate_state", "")) + result.trader_investment_plan = final_state.get("trader_investment_plan", "") + result.risk_debate = str(final_state.get("risk_debate_state", "")) + result.final_trade_decision = decision + + logger.info( + "Analysis complete for %s: %s", candidate.ticker, str(decision)[:120] + ) + + except Exception as exc: + logger.error("Analysis failed for %s: %s", candidate.ticker, exc, exc_info=True) + result.error = str(exc) + + return result + + +async def run_all_tickers( + candidates: list[StockCandidate], + macro_context: MacroContext, + config: dict, + analysis_date: str, + max_concurrent: int = 2, +) -> list[TickerResult]: + """Run TradingAgents for every candidate with controlled concurrency. + + max_concurrent=2 is conservative — each run makes many API calls. + Increase only if your data vendor plan supports higher rate limits. + + Args: + candidates: Filtered stock candidates to analyse. + macro_context: Macro context shared across all tickers. + config: TradingAgents configuration dict. + analysis_date: Date string in YYYY-MM-DD format. + max_concurrent: Maximum number of tickers to process in parallel. + + Returns: + List of TickerResult in completion order. + """ + semaphore = asyncio.Semaphore(max_concurrent) + + async def _run_one(candidate: StockCandidate) -> TickerResult: + async with semaphore: + loop = asyncio.get_running_loop() + # TradingAgentsGraph is synchronous — run it in a thread pool + return await loop.run_in_executor( + None, + run_ticker_analysis, + candidate, + macro_context, + config, + analysis_date, + ) + + tasks = [_run_one(c) for c in candidates] + results = await asyncio.gather(*tasks) + return list(results) + + +# ─── Reporting ──────────────────────────────────────────────────────────────── + + +def _macro_preamble(ctx: MacroContext) -> str: + """Render the macro context block shared across all reports.""" + themes_text = "\n".join( + f" - **{t['theme']}** ({t.get('conviction', '?')} conviction): {t.get('description', '')}" + for t in ctx.key_themes[:5] + ) + risks_text = "\n".join(f" - {r}" for r in ctx.risk_factors[:5]) + return f"""## Macro context (from Macro Intelligence Agent) + +**Horizon:** {ctx.timeframe} | **Region:** {ctx.region} + +**Economic cycle:** {ctx.economic_cycle} + +**Central bank stance:** {ctx.central_bank_stance} + +**Key macro themes:** +{themes_text} + +**Geopolitical risks:** {', '.join(ctx.geopolitical_risks)} + +**Macro risk factors:** +{risks_text} + +**Executive summary:** {ctx.executive_summary} + +--- +""" + + +def render_ticker_report(result: TickerResult) -> str: + """Render a single ticker's full Markdown report. + + Args: + result: Completed TickerResult (may contain an error). + + Returns: + Markdown string with the full analysis or failure notice. + """ + c = result.candidate + header = f"""# {c.ticker} — {c.name} +**Sector:** {c.sector} | **Thesis:** {c.thesis_angle} | **Conviction:** {c.conviction.upper()} +**Analysis date:** {result.analysis_date} + +### Macro rationale (why this stock was surfaced) +{c.rationale} + +**Macro theme alignment:** {c.macro_theme} +**Key catalysts:** {', '.join(c.key_catalysts)} +**Macro-level risks:** {', '.join(c.risks)} + +--- +""" + if result.error: + return header + f"## Analysis failed\n```\n{result.error}\n```\n" + + return ( + header + + _macro_preamble(result.macro_context) + + f"## Market analysis\n{result.market_report}\n\n" + + f"## Fundamentals analysis\n{result.fundamentals_report}\n\n" + + f"## News analysis\n{result.news_report}\n\n" + + f"## Sentiment analysis\n{result.sentiment_report}\n\n" + + f"## Research team debate (Bull vs Bear)\n{result.investment_debate}\n\n" + + f"## Trader investment plan\n{result.trader_investment_plan}\n\n" + + f"## Risk management assessment\n{result.risk_debate}\n\n" + + f"## Final trade decision\n{result.final_trade_decision}\n" + ) + + +def render_combined_summary( + results: list[TickerResult], + macro_context: MacroContext, +) -> str: + """Render a single summary Markdown combining all tickers. + + Args: + results: All completed TickerResults. + macro_context: Shared macro context for the preamble. + + Returns: + Markdown string with overview table and per-ticker decisions. + """ + now = datetime.now().strftime("%Y-%m-%d %H:%M") + lines = [ + "# Macro-Driven Deep Dive Summary", + f"Generated: {now}\n", + _macro_preamble(macro_context), + "## Results overview\n", + "| Ticker | Name | Conviction | Sector | Decision |", + "|--------|------|-----------|--------|---------|", + ] + + for r in results: + decision_preview = ( + "ERROR" + if r.error + else str(r.final_trade_decision)[:60].replace("\n", " ") + ) + lines.append( + f"| {r.ticker} | {r.candidate.name} " + f"| {r.candidate.conviction.upper()} " + f"| {r.candidate.sector} " + f"| {decision_preview} |" + ) + + lines.append("\n---\n") + for r in results: + lines.append(f"## {r.ticker} — final decision\n") + if r.error: + lines.append(f"Analysis failed: {r.error}\n") + else: + lines.append(f"**Macro rationale:** {r.candidate.rationale}\n\n") + lines.append(r.final_trade_decision or "_No decision generated._") + lines.append("\n\n---\n") + + return "\n".join(lines) + + +def save_results( + results: list[TickerResult], + macro_context: MacroContext, + output_dir: Path, +) -> None: + """Save per-ticker Markdown reports, a combined summary, and a JSON index. + + Args: + results: All completed TickerResults. + macro_context: Shared macro context used in reports. + output_dir: Directory to write all output files into. + """ + output_dir.mkdir(parents=True, exist_ok=True) + + for result in results: + ticker_dir = output_dir / result.ticker + ticker_dir.mkdir(exist_ok=True) + report_path = ticker_dir / f"{result.analysis_date}_deep_dive.md" + report_path.write_text(render_ticker_report(result)) + logger.info("Saved report: %s", report_path) + + summary_path = output_dir / "summary.md" + summary_path.write_text(render_combined_summary(results, macro_context)) + logger.info("Saved summary: %s", summary_path) + + # Machine-readable index for downstream tooling + json_path = output_dir / "results.json" + json_path.write_text( + json.dumps( + [ + { + "ticker": r.ticker, + "name": r.candidate.name, + "sector": r.candidate.sector, + "conviction": r.candidate.conviction, + "thesis_angle": r.candidate.thesis_angle, + "analysis_date": r.analysis_date, + "final_trade_decision": r.final_trade_decision, + "error": r.error, + } + for r in results + ], + indent=2, + ) + ) + logger.info("Saved JSON index: %s", json_path) + + +# ─── Facade ─────────────────────────────────────────────────────────────────── + + +class MacroBridge: + """Facade for the macro scanner → TradingAgents pipeline. + + Provides a single entry point for CLI and programmatic use without + exposing the individual pipeline functions. + """ + + def __init__(self, config: dict) -> None: + """ + Args: + config: TradingAgents configuration dict (built by the caller/CLI). + """ + self.config = config + + def load(self, path: Path) -> tuple[MacroContext, list[StockCandidate]]: + """Parse macro scanner JSON output. + + Args: + path: Path to the macro scanner JSON file. + + Returns: + Tuple of (MacroContext, all StockCandidates). + """ + return parse_macro_output(path) + + def filter( + self, + candidates: list[StockCandidate], + min_conviction: ConvictionLevel = "medium", + ticker_filter: list[str] | None = None, + ) -> list[StockCandidate]: + """Filter and sort stock candidates. + + Args: + candidates: All candidates from load(). + min_conviction: Minimum conviction threshold. + ticker_filter: Optional explicit ticker whitelist. + + Returns: + Filtered and sorted candidate list. + """ + return filter_candidates(candidates, min_conviction, ticker_filter) + + def run( + self, + candidates: list[StockCandidate], + macro_context: MacroContext, + analysis_date: str, + max_concurrent: int = 2, + ) -> list[TickerResult]: + """Run the full TradingAgents pipeline for all candidates. + + Blocks until all tickers are complete. + + Args: + candidates: Filtered candidates to analyse. + macro_context: Macro context for enriching results. + analysis_date: Date string in YYYY-MM-DD format. + max_concurrent: Maximum parallel tickers. + + Returns: + List of TickerResult. + """ + return asyncio.run( + run_all_tickers( + candidates=candidates, + macro_context=macro_context, + config=self.config, + analysis_date=analysis_date, + max_concurrent=max_concurrent, + ) + ) + + def save( + self, + results: list[TickerResult], + macro_context: MacroContext, + output_dir: Path, + ) -> None: + """Save results to disk. + + Args: + results: Completed TickerResults. + macro_context: Shared macro context. + output_dir: Target directory for all output files. + """ + save_results(results, macro_context, output_dir)