Stabilize TradingAgents contracts so orchestration and dashboard can converge
This change set introduces a versioned result contract, shared config schema/loading, provider/data adapter seams, and a no-strategy application-service skeleton so the current research graph, orchestrator layer, and dashboard backend stop drifting further apart. It also keeps the earlier MiniMax compatibility and compact-prompt work aligned with the new contract shape and extends regression coverage so degradation, fallback, and service migration remain testable during the next phases. Constraint: Must preserve existing FastAPI entrypoints and fallback behavior while introducing an application-service seam Constraint: Must not turn application service into a new strategy or learning layer Rejected: Full backend rewrite to service-only execution now | too risky before contract and fallback paths stabilize Rejected: Leave provider/data/config logic distributed across scripts and endpoints | continues boundary drift and weakens verification Confidence: high Scope-risk: broad Directive: Keep future application-service changes orchestration-only; move any scoring, signal fusion, or learning logic to orchestrator or tradingagents instead Tested: python -m compileall orchestrator tradingagents web_dashboard/backend Tested: python -m pytest orchestrator/tests/test_signals.py orchestrator/tests/test_llm_runner.py orchestrator/tests/test_quant_runner.py orchestrator/tests/test_contract_v1alpha1.py orchestrator/tests/test_application_service.py orchestrator/tests/test_provider_adapter.py web_dashboard/backend/tests/test_main_api.py web_dashboard/backend/tests/test_portfolio_api.py web_dashboard/backend/tests/test_api_smoke.py web_dashboard/backend/tests/test_services_migration.py -q Not-tested: live MiniMax/provider execution against external services Not-tested: full dashboard/manual websocket flow against a running frontend Not-tested: omx team runtime end-to-end in the primary workspace
This commit is contained in:
parent
5b2d631393
commit
b6e57d01e3
|
|
@ -4,13 +4,20 @@ on:
|
|||
push:
|
||||
branches: [main, feat/**, fix/**]
|
||||
paths:
|
||||
- 'orchestrator/**/*.py'
|
||||
- 'tradingagents/**/*.py'
|
||||
- 'orchestrator/tests/**/*.py'
|
||||
- 'web_dashboard/backend/**/*.py'
|
||||
- 'web_dashboard/frontend/**/*.js'
|
||||
- '.github/workflows/dashboard-tests.yml'
|
||||
pull_request:
|
||||
paths:
|
||||
- 'orchestrator/**/*.py'
|
||||
- 'tradingagents/**/*.py'
|
||||
- 'orchestrator/tests/**/*.py'
|
||||
- 'web_dashboard/backend/**/*.py'
|
||||
- 'web_dashboard/frontend/**/*.js'
|
||||
- '.github/workflows/dashboard-tests.yml'
|
||||
|
||||
jobs:
|
||||
test-backend:
|
||||
|
|
@ -29,6 +36,10 @@ jobs:
|
|||
pip install pytest pytest-asyncio httpx
|
||||
pip install -e . 2>/dev/null || true
|
||||
|
||||
- name: Run orchestrator tests
|
||||
run: |
|
||||
python -m pytest orchestrator/tests/ -v --tb=short
|
||||
|
||||
- name: Run backend tests
|
||||
working-directory: web_dashboard/backend
|
||||
run: |
|
||||
|
|
|
|||
|
|
@ -0,0 +1,195 @@
|
|||
# TradingAgents architecture convergence draft: application boundary
|
||||
|
||||
Status: draft
|
||||
Audience: backend/dashboard/orchestrator maintainers
|
||||
Scope: define the boundary between HTTP/WebSocket delivery, application service orchestration, and the quant+LLM merge kernel
|
||||
|
||||
## 1. Why this document exists
|
||||
|
||||
The current backend mixes three concerns inside `web_dashboard/backend/main.py`:
|
||||
|
||||
1. transport concerns: FastAPI routes, headers, WebSocket sessions, task persistence;
|
||||
2. application orchestration: task lifecycle, stage progress, subprocess wiring, result projection;
|
||||
3. domain execution: `TradingOrchestrator`, `LiveMode`, quant+LLM signal merge.
|
||||
|
||||
For architecture convergence, these concerns should be separated so that:
|
||||
|
||||
- the application service remains a no-strategy orchestration and contract layer;
|
||||
- `orchestrator/` remains the quant+LLM merge kernel;
|
||||
- transport adapters can migrate without re-embedding business rules.
|
||||
|
||||
## 2. Current evidence in repo
|
||||
|
||||
### 2.1 Merge kernel already exists
|
||||
|
||||
- `orchestrator/orchestrator.py` owns quant runner + LLM runner composition.
|
||||
- `orchestrator/signals.py` owns `Signal`, `FinalSignal`, and merge math.
|
||||
- `orchestrator/live_mode.py` owns batch live execution against the orchestrator.
|
||||
|
||||
This is the correct place for quant/LLM merge semantics.
|
||||
|
||||
### 2.2 Backend currently crosses the boundary
|
||||
|
||||
`web_dashboard/backend/main.py` currently also owns:
|
||||
|
||||
- analysis subprocess template creation;
|
||||
- stage-to-progress mapping;
|
||||
- task state persistence in `app.state.task_results` and `data/task_status/*.json`;
|
||||
- conversion from `FinalSignal` to UI-oriented fields such as `decision`, `quant_signal`, `llm_signal`, `confidence`;
|
||||
- report materialization into `results/<ticker>/<date>/complete_report.md`.
|
||||
|
||||
This makes the transport layer hard to replace and makes result contracts implicit.
|
||||
|
||||
## 3. Target boundary
|
||||
|
||||
## 3.1 Layer model
|
||||
|
||||
### Transport adapters
|
||||
|
||||
Examples:
|
||||
|
||||
- FastAPI REST routes
|
||||
- FastAPI WebSocket endpoints
|
||||
- future CLI/Tauri/worker adapters
|
||||
|
||||
Responsibilities:
|
||||
|
||||
- request parsing and auth
|
||||
- response serialization
|
||||
- websocket connection management
|
||||
- mapping application errors to HTTP/WebSocket status
|
||||
|
||||
Non-responsibilities:
|
||||
|
||||
- no strategy logic
|
||||
- no quant/LLM weighting logic
|
||||
- no task-stage business rules beyond rendering application events
|
||||
|
||||
### Application service
|
||||
|
||||
Suggested responsibility set:
|
||||
|
||||
- accept typed command/query inputs from transport
|
||||
- orchestrate analysis execution lifecycle
|
||||
- map domain results into stable result contracts
|
||||
- own task ids, progress events, persistence coordination, and rollback-safe migration switches
|
||||
- decide which backend implementation to call during migration
|
||||
|
||||
Non-responsibilities:
|
||||
|
||||
- no rating-to-signal research logic
|
||||
- no quant/LLM merge math
|
||||
- no provider-specific data acquisition details
|
||||
|
||||
### Domain kernel
|
||||
|
||||
Examples:
|
||||
|
||||
- `TradingOrchestrator`
|
||||
- `SignalMerger`
|
||||
- `QuantRunner`
|
||||
- `LLMRunner`
|
||||
- `TradingAgentsGraph`
|
||||
|
||||
Responsibilities:
|
||||
|
||||
- produce quant signal, LLM signal, merged signal
|
||||
- expose domain-native dataclasses and metadata
|
||||
- degrade gracefully when one lane fails
|
||||
|
||||
## 3.2 Canonical dependency direction
|
||||
|
||||
```text
|
||||
transport adapter -> application service -> domain kernel
|
||||
transport adapter -> application service -> persistence adapter
|
||||
application service -> result contract mapper
|
||||
```
|
||||
|
||||
Forbidden direction:
|
||||
|
||||
```text
|
||||
transport adapter -> domain kernel + ad hoc mapping + ad hoc persistence
|
||||
```
|
||||
|
||||
## 4. Proposed application-service interface
|
||||
|
||||
The application service should expose typed use cases instead of letting routes assemble logic inline.
|
||||
|
||||
## 4.1 Commands / queries
|
||||
|
||||
Suggested surface:
|
||||
|
||||
- `start_analysis(request) -> AnalysisTaskAccepted`
|
||||
- `get_analysis_status(task_id) -> AnalysisTaskStatus`
|
||||
- `cancel_analysis(task_id) -> AnalysisTaskStatus`
|
||||
- `run_live_signals(request) -> LiveSignalBatch`
|
||||
- `list_analysis_tasks() -> AnalysisTaskList`
|
||||
- `get_report(ticker, date) -> HistoricalReport`
|
||||
|
||||
## 4.2 Domain input boundary
|
||||
|
||||
Inputs from transport should already be normalized into application DTOs:
|
||||
|
||||
- ticker
|
||||
- trade date
|
||||
- auth context
|
||||
- provider/config selection
|
||||
- execution mode
|
||||
|
||||
The application service may choose subprocess/backend/orchestrator execution strategy, but it must not redefine domain semantics.
|
||||
|
||||
## 5. Boundary rules for convergence work
|
||||
|
||||
### Rule A: result mapping happens once
|
||||
|
||||
Current code maps `FinalSignal` to dashboard fields inside the analysis subprocess template. That mapping should move behind a single application mapper so REST, WebSocket, export, and persisted task status share one contract.
|
||||
|
||||
### Rule B: stage model belongs to application layer
|
||||
|
||||
Stage names such as `analysts`, `research`, `trading`, `risk`, `portfolio` are delivery/progress concepts, not merge-kernel concepts. Keep them outside `orchestrator/`.
|
||||
|
||||
### Rule C: orchestrator stays contract-light
|
||||
|
||||
`orchestrator/` should continue returning `Signal` / `FinalSignal` and domain metadata. It should not learn about HTTP status, WebSocket payloads, pagination, or UI labels beyond domain rating semantics already present.
|
||||
|
||||
### Rule D: transport only renders contracts
|
||||
|
||||
Routes should call the application service and return the already-shaped DTO/contract. They should not reconstruct `decision`, `quant_signal`, `llm_signal`, or progress math themselves.
|
||||
|
||||
## 6. Suggested module split
|
||||
|
||||
One viable split:
|
||||
|
||||
```text
|
||||
web_dashboard/backend/
|
||||
application/
|
||||
analysis_service.py
|
||||
live_signal_service.py
|
||||
report_service.py
|
||||
contracts.py
|
||||
mappers.py
|
||||
infra/
|
||||
task_store.py
|
||||
subprocess_runner.py
|
||||
report_store.py
|
||||
api/
|
||||
fastapi_routes remain thin
|
||||
```
|
||||
|
||||
This keeps convergence local to backend/application without moving merge logic out of `orchestrator/`.
|
||||
|
||||
## 7. Non-goals
|
||||
|
||||
- Do not move signal merge math into the application service.
|
||||
- Do not turn the application service into a strategy engine.
|
||||
- Do not require frontend-specific field naming inside `orchestrator/`.
|
||||
- Do not block migration on a full rewrite of existing routes.
|
||||
|
||||
## 8. Review checklist
|
||||
|
||||
A change respects this boundary if all are true:
|
||||
|
||||
- route handlers mainly validate/auth/call service/return contract;
|
||||
- application service owns task lifecycle and contract mapping;
|
||||
- `orchestrator/` remains the only owner of merge semantics;
|
||||
- domain dataclasses can still be tested without FastAPI or WebSocket context.
|
||||
|
|
@ -0,0 +1,244 @@
|
|||
# TradingAgents result contract v1alpha1 draft
|
||||
|
||||
Status: draft
|
||||
Audience: backend, desktop, frontend, verification
|
||||
Format: JSON-oriented contract notes with examples
|
||||
|
||||
## 1. Goals
|
||||
|
||||
`result-contract-v1alpha1` defines the stable shapes exchanged across:
|
||||
|
||||
- analysis start/status APIs
|
||||
- websocket progress events
|
||||
- live orchestrator streaming
|
||||
- persisted task state
|
||||
- historical report projection
|
||||
|
||||
The contract should be application-facing, not raw domain dataclasses.
|
||||
|
||||
## 2. Design principles
|
||||
|
||||
- version every externally consumed payload
|
||||
- keep transport-neutral field meanings
|
||||
- allow partial/degraded results when quant or LLM lane fails
|
||||
- distinguish task lifecycle from signal outcome
|
||||
- keep raw domain metadata nested, not smeared across top-level fields
|
||||
|
||||
## 3. Core enums
|
||||
|
||||
## 3.1 Task status
|
||||
|
||||
```json
|
||||
["pending", "running", "completed", "failed", "cancelled"]
|
||||
```
|
||||
|
||||
## 3.2 Stage name
|
||||
|
||||
```json
|
||||
["analysts", "research", "trading", "risk", "portfolio"]
|
||||
```
|
||||
|
||||
## 3.3 Decision rating
|
||||
|
||||
```json
|
||||
["BUY", "OVERWEIGHT", "HOLD", "UNDERWEIGHT", "SELL"]
|
||||
```
|
||||
|
||||
## 4. Canonical envelope
|
||||
|
||||
All application-facing payloads should include:
|
||||
|
||||
```json
|
||||
{
|
||||
"contract_version": "v1alpha1"
|
||||
}
|
||||
```
|
||||
|
||||
Optional transport-specific wrapper fields such as WebSocket `type` may sit outside the contract body.
|
||||
|
||||
## 5. Analysis task contract
|
||||
|
||||
## 5.1 Accepted response
|
||||
|
||||
```json
|
||||
{
|
||||
"contract_version": "v1alpha1",
|
||||
"task_id": "600519.SS_20260413_120000_ab12cd",
|
||||
"ticker": "600519.SS",
|
||||
"date": "2026-04-13",
|
||||
"status": "running"
|
||||
}
|
||||
```
|
||||
|
||||
## 5.2 Status / progress document
|
||||
|
||||
```json
|
||||
{
|
||||
"contract_version": "v1alpha1",
|
||||
"task_id": "600519.SS_20260413_120000_ab12cd",
|
||||
"ticker": "600519.SS",
|
||||
"date": "2026-04-13",
|
||||
"status": "running",
|
||||
"progress": 40,
|
||||
"current_stage": "research",
|
||||
"created_at": "2026-04-13T12:00:00Z",
|
||||
"elapsed_seconds": 18,
|
||||
"stages": [
|
||||
{"name": "analysts", "status": "completed", "completed_at": "12:00:05"},
|
||||
{"name": "research", "status": "running", "completed_at": null},
|
||||
{"name": "trading", "status": "pending", "completed_at": null},
|
||||
{"name": "risk", "status": "pending", "completed_at": null},
|
||||
{"name": "portfolio", "status": "pending", "completed_at": null}
|
||||
],
|
||||
"result": null,
|
||||
"error": null
|
||||
}
|
||||
```
|
||||
|
||||
Notes:
|
||||
|
||||
- `elapsed_seconds` is preferred over the current loosely typed `elapsed`.
|
||||
- stage entries should carry explicit `name`; current positional arrays are fragile.
|
||||
- `result` remains nullable until completion.
|
||||
|
||||
## 5.3 Completed result payload
|
||||
|
||||
```json
|
||||
{
|
||||
"contract_version": "v1alpha1",
|
||||
"task_id": "600519.SS_20260413_120000_ab12cd",
|
||||
"ticker": "600519.SS",
|
||||
"date": "2026-04-13",
|
||||
"status": "completed",
|
||||
"progress": 100,
|
||||
"current_stage": "portfolio",
|
||||
"result": {
|
||||
"decision": "OVERWEIGHT",
|
||||
"confidence": 0.64,
|
||||
"signals": {
|
||||
"merged": {"direction": 1, "rating": "OVERWEIGHT"},
|
||||
"quant": {"direction": 1, "rating": "OVERWEIGHT", "available": true},
|
||||
"llm": {"direction": 1, "rating": "BUY", "available": true}
|
||||
},
|
||||
"degraded": false,
|
||||
"report": {
|
||||
"path": "results/600519.SS/2026-04-13/complete_report.md",
|
||||
"available": true
|
||||
}
|
||||
},
|
||||
"error": null
|
||||
}
|
||||
```
|
||||
|
||||
## 5.4 Failed result payload
|
||||
|
||||
```json
|
||||
{
|
||||
"contract_version": "v1alpha1",
|
||||
"task_id": "600519.SS_20260413_120000_ab12cd",
|
||||
"ticker": "600519.SS",
|
||||
"date": "2026-04-13",
|
||||
"status": "failed",
|
||||
"progress": 60,
|
||||
"current_stage": "trading",
|
||||
"result": null,
|
||||
"error": {
|
||||
"code": "analysis_failed",
|
||||
"message": "both quant and llm signals are None",
|
||||
"retryable": false
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## 6. Live signal batch contract
|
||||
|
||||
This covers `/ws/orchestrator` style responses currently produced by `LiveMode`.
|
||||
|
||||
```json
|
||||
{
|
||||
"contract_version": "v1alpha1",
|
||||
"signals": [
|
||||
{
|
||||
"ticker": "600519.SS",
|
||||
"date": "2026-04-13",
|
||||
"status": "completed",
|
||||
"result": {
|
||||
"direction": 1,
|
||||
"confidence": 0.64,
|
||||
"quant_direction": 1,
|
||||
"llm_direction": 1,
|
||||
"timestamp": "2026-04-13T12:00:11Z"
|
||||
},
|
||||
"error": null
|
||||
},
|
||||
{
|
||||
"ticker": "300750.SZ",
|
||||
"date": "2026-04-13",
|
||||
"status": "failed",
|
||||
"result": null,
|
||||
"error": {
|
||||
"code": "live_signal_failed",
|
||||
"message": "both quant and llm signals are None",
|
||||
"retryable": false
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
## 7. Historical report contract
|
||||
|
||||
```json
|
||||
{
|
||||
"contract_version": "v1alpha1",
|
||||
"ticker": "600519.SS",
|
||||
"date": "2026-04-13",
|
||||
"decision": "OVERWEIGHT",
|
||||
"report": "# TradingAgents ...",
|
||||
"artifacts": {
|
||||
"complete_report": true,
|
||||
"stage_reports": {
|
||||
"analysts": true,
|
||||
"research": true,
|
||||
"trading": true,
|
||||
"risk": true,
|
||||
"portfolio": false
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## 8. Mapping from current implementation
|
||||
|
||||
Current backend fields in `web_dashboard/backend/main.py` map roughly as follows:
|
||||
|
||||
- `decision` -> `result.decision`
|
||||
- `quant_signal` -> `result.signals.quant.rating`
|
||||
- `llm_signal` -> `result.signals.llm.rating`
|
||||
- `confidence` -> `result.confidence`
|
||||
- top-level `error` string -> structured `error`
|
||||
- positional `stages[]` -> named `stages[]`
|
||||
|
||||
## 9. Compatibility notes
|
||||
|
||||
### v1alpha1 tolerances
|
||||
|
||||
Consumers should tolerate:
|
||||
|
||||
- absent `result.signals.quant` when quant path is unavailable
|
||||
- absent `result.signals.llm` when LLM path is unavailable
|
||||
- `result.degraded = true` when only one lane produced a usable signal
|
||||
|
||||
### fields to avoid freezing yet
|
||||
|
||||
Do not freeze these until config-schema work lands:
|
||||
|
||||
- provider-specific configuration echo fields
|
||||
- raw metadata blobs from quant/LLM internals
|
||||
- report summary extraction fields
|
||||
|
||||
## 10. Open review questions
|
||||
|
||||
- Should `rating` remain duplicated with `direction`, or should one be derived client-side?
|
||||
- Should task progress timestamps standardize on RFC 3339 instead of mixed clock-only strings?
|
||||
- Should historical report APIs return extracted summary separately from full markdown?
|
||||
|
|
@ -0,0 +1,188 @@
|
|||
# TradingAgents backend migration and rollback notes draft
|
||||
|
||||
Status: draft
|
||||
Audience: backend/application maintainers
|
||||
Scope: migrate toward application-service boundary and result-contract-v1alpha1 with rollback safety
|
||||
|
||||
## 1. Migration objective
|
||||
|
||||
Move backend delivery code from route-local orchestration to an application-service layer without changing the quant+LLM merge kernel behavior.
|
||||
|
||||
Target outcomes:
|
||||
|
||||
- stable result contract (`v1alpha1`)
|
||||
- thin FastAPI transport
|
||||
- application-owned task lifecycle and mapping
|
||||
- rollback-safe migration using dual-read/dual-write where useful
|
||||
|
||||
## 2. Current coupling hotspots
|
||||
|
||||
Primary hotspot: `web_dashboard/backend/main.py`
|
||||
|
||||
It currently combines:
|
||||
|
||||
- route handlers
|
||||
- task persistence
|
||||
- subprocess creation and monitoring
|
||||
- progress/stage state mutation
|
||||
- result projection into API fields
|
||||
- report export concerns
|
||||
|
||||
This file is the first migration target.
|
||||
|
||||
## 3. Recommended migration sequence
|
||||
|
||||
## Phase 0: contract freeze draft
|
||||
|
||||
Deliverables:
|
||||
|
||||
- agree on `docs/contracts/result-contract-v1alpha1.md`
|
||||
- agree on application boundary in `docs/architecture/application-boundary.md`
|
||||
|
||||
Rollback:
|
||||
|
||||
- none needed; documentation only
|
||||
|
||||
## Phase 1: introduce application service behind existing routes
|
||||
|
||||
Actions:
|
||||
|
||||
- add backend application modules for analysis status, live signals, and report reads
|
||||
- keep existing route URLs unchanged
|
||||
- move mapping logic out of route functions into service/mappers
|
||||
|
||||
Compatibility tactic:
|
||||
|
||||
- routes still return current payload shape if frontend depends on it
|
||||
- internal service also emits `v1alpha1` DTOs for verification comparison
|
||||
|
||||
Rollback:
|
||||
|
||||
- route handlers can call old inline functions directly via feature flag or import switch
|
||||
|
||||
## Phase 2: dual-read for task status
|
||||
|
||||
Why:
|
||||
|
||||
Task status currently lives in memory plus `data/task_status/*.json`. During migration, new service storage and old persisted shape may diverge.
|
||||
|
||||
Recommended strategy:
|
||||
|
||||
- read preference: new application store first
|
||||
- fallback read: legacy JSON task status
|
||||
- compare key fields during shadow period: `status`, `progress`, `current_stage`, `decision`, `error`
|
||||
|
||||
Rollback:
|
||||
|
||||
- switch read preference back to legacy JSON only
|
||||
- leave new store populated for debugging, but non-authoritative
|
||||
|
||||
## Phase 3: dual-write for task results
|
||||
|
||||
Why:
|
||||
|
||||
To avoid breaking status pages and historical tooling during rollout.
|
||||
|
||||
Recommended strategy:
|
||||
|
||||
- authoritative write: new application store
|
||||
- compatibility write: legacy `app.state.task_results` + `data/task_status/*.json`
|
||||
- emit diff logs when new-vs-legacy projections disagree
|
||||
|
||||
Guardrails:
|
||||
|
||||
- dual-write only for application-layer payloads
|
||||
- do not dual-write alternate domain semantics into `orchestrator/`
|
||||
|
||||
Rollback:
|
||||
|
||||
- disable new-store writes
|
||||
- continue legacy writes only
|
||||
|
||||
## Phase 4: websocket and live signal migration
|
||||
|
||||
Actions:
|
||||
|
||||
- make `/ws/analysis/{task_id}` and `/ws/orchestrator` render application contracts
|
||||
- keep websocket wrapper fields stable while migrating internal body shape
|
||||
|
||||
Suggested compatibility step:
|
||||
|
||||
- send legacy event envelope with embedded `contract_version`
|
||||
- update frontend consumers before removing legacy-only fields
|
||||
|
||||
Rollback:
|
||||
|
||||
- restore websocket serializer to legacy shape
|
||||
- keep application service intact behind adapter
|
||||
|
||||
## Phase 5: remove route-local orchestration
|
||||
|
||||
Actions:
|
||||
|
||||
- delete dead inline task mutation helpers from `main.py`
|
||||
- keep routes as thin adapter layer
|
||||
- preserve report retrieval behavior
|
||||
|
||||
Rollback:
|
||||
|
||||
- only safe after shadow metrics show parity
|
||||
- otherwise revert to Phase 3 dual-write mode, not direct deletion
|
||||
|
||||
## 4. Suggested feature flags
|
||||
|
||||
Environment-variable style examples:
|
||||
|
||||
- `TA_APP_SERVICE_ENABLED=1`
|
||||
- `TA_RESULT_CONTRACT_VERSION=v1alpha1`
|
||||
- `TA_TASKSTORE_DUAL_READ=1`
|
||||
- `TA_TASKSTORE_DUAL_WRITE=1`
|
||||
- `TA_WS_V1ALPHA1_ENABLED=0`
|
||||
|
||||
These names are placeholders; exact naming can be chosen during implementation.
|
||||
|
||||
## 5. Verification checkpoints per phase
|
||||
|
||||
For each migration phase, verify:
|
||||
|
||||
- same task ids are returned for the same route behavior
|
||||
- stage transitions remain monotonic
|
||||
- completed tasks persist `decision`, `confidence`, and degraded-path outcomes
|
||||
- failure path still preserves actionable error text
|
||||
- live websocket payloads preserve ticker/date ordering expectations
|
||||
|
||||
## 6. Rollback triggers
|
||||
|
||||
Rollback immediately if any of these happen:
|
||||
|
||||
- task status disappears after backend restart
|
||||
- WebSocket clients stop receiving progress updates
|
||||
- completed analysis loses `decision` or confidence fields
|
||||
- degraded single-lane signals are reclassified incorrectly
|
||||
- report export or historical report retrieval cannot find prior artifacts
|
||||
|
||||
## 7. Explicit non-goals during migration
|
||||
|
||||
- do not rewrite `orchestrator/signals.py` merge math as part of boundary migration
|
||||
- do not rework provider/model selection semantics in the same change set
|
||||
- do not force frontend redesign before contract shadowing proves parity
|
||||
- do not implement a new strategy layer inside the application service
|
||||
|
||||
## 8. Minimal rollback playbook
|
||||
|
||||
If production or local verification fails after migration cutover:
|
||||
|
||||
1. disable application-service read path
|
||||
2. disable dual-write to new store if it corrupts parity checks
|
||||
3. restore legacy route-local serializers
|
||||
4. keep generated comparison logs/artifacts for diff analysis
|
||||
5. re-run backend tests and one end-to-end manual analysis flow
|
||||
|
||||
## 9. Review checklist
|
||||
|
||||
A migration plan is acceptable only if it:
|
||||
|
||||
- preserves orchestrator ownership of quant+LLM merge semantics
|
||||
- introduces feature-flagged cutover points
|
||||
- supports dual-read/dual-write only at application/persistence boundary
|
||||
- provides a one-step rollback path at each release phase
|
||||
|
|
@ -1,5 +1,7 @@
|
|||
from dataclasses import dataclass, field
|
||||
|
||||
from orchestrator.contracts.config_loader import normalize_orchestrator_fields
|
||||
|
||||
|
||||
@dataclass
|
||||
class OrchestratorConfig:
|
||||
|
|
@ -12,3 +14,19 @@ class OrchestratorConfig:
|
|||
cache_dir: str = "orchestrator/cache" # LLM 信号缓存目录
|
||||
llm_solo_penalty: float = 0.7 # LLM 单轨时的置信度折扣
|
||||
quant_solo_penalty: float = 0.8 # Quant 单轨时的置信度折扣
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
normalized = normalize_orchestrator_fields(
|
||||
{
|
||||
"quant_backtest_path": self.quant_backtest_path,
|
||||
"trading_agents_config": self.trading_agents_config,
|
||||
"quant_weight_cap": self.quant_weight_cap,
|
||||
"llm_weight_cap": self.llm_weight_cap,
|
||||
"llm_batch_days": self.llm_batch_days,
|
||||
"cache_dir": self.cache_dir,
|
||||
"llm_solo_penalty": self.llm_solo_penalty,
|
||||
"quant_solo_penalty": self.quant_solo_penalty,
|
||||
}
|
||||
)
|
||||
for key, value in normalized.items():
|
||||
setattr(self, key, value)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,31 @@
|
|||
from orchestrator.contracts.config_loader import (
|
||||
normalize_orchestrator_fields,
|
||||
normalize_trading_agents_config,
|
||||
)
|
||||
from orchestrator.contracts.config_schema import (
|
||||
CONTRACT_VERSION,
|
||||
OrchestratorConfigSchema,
|
||||
build_orchestrator_schema,
|
||||
build_trading_agents_config,
|
||||
)
|
||||
from orchestrator.contracts.error_taxonomy import ReasonCode
|
||||
from orchestrator.contracts.result_contract import (
|
||||
FinalSignal,
|
||||
Signal,
|
||||
build_error_signal,
|
||||
signal_reason_code,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"CONTRACT_VERSION",
|
||||
"FinalSignal",
|
||||
"OrchestratorConfigSchema",
|
||||
"ReasonCode",
|
||||
"Signal",
|
||||
"build_error_signal",
|
||||
"build_orchestrator_schema",
|
||||
"build_trading_agents_config",
|
||||
"normalize_orchestrator_fields",
|
||||
"normalize_trading_agents_config",
|
||||
"signal_reason_code",
|
||||
]
|
||||
|
|
@ -0,0 +1,18 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Mapping, Optional
|
||||
|
||||
from orchestrator.contracts.config_schema import (
|
||||
build_orchestrator_schema,
|
||||
build_trading_agents_config,
|
||||
)
|
||||
|
||||
|
||||
def normalize_trading_agents_config(
|
||||
config: Optional[Mapping[str, Any]],
|
||||
) -> dict[str, Any]:
|
||||
return dict(build_trading_agents_config(config))
|
||||
|
||||
|
||||
def normalize_orchestrator_fields(raw: Mapping[str, Any]) -> dict[str, Any]:
|
||||
return build_orchestrator_schema(raw).to_runtime_fields()
|
||||
|
|
@ -0,0 +1,168 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Mapping, Optional, TypedDict, cast
|
||||
|
||||
from tradingagents.default_config import get_default_config
|
||||
|
||||
|
||||
CONTRACT_VERSION = "v1alpha1"
|
||||
|
||||
|
||||
class TradingAgentsConfigPayload(TypedDict, total=False):
|
||||
project_dir: str
|
||||
results_dir: str
|
||||
data_cache_dir: str
|
||||
llm_provider: str
|
||||
deep_think_llm: str
|
||||
quick_think_llm: str
|
||||
backend_url: str
|
||||
google_thinking_level: Optional[str]
|
||||
openai_reasoning_effort: Optional[str]
|
||||
anthropic_effort: Optional[str]
|
||||
output_language: str
|
||||
max_debate_rounds: int
|
||||
max_risk_discuss_rounds: int
|
||||
max_recur_limit: int
|
||||
data_vendors: dict[str, str]
|
||||
tool_vendors: dict[str, str]
|
||||
selected_analysts: list[str]
|
||||
llm_timeout: float
|
||||
llm_max_retries: int
|
||||
timeout: float
|
||||
max_retries: int
|
||||
use_responses_api: bool
|
||||
|
||||
|
||||
REQUIRED_TRADING_CONFIG_KEYS = (
|
||||
"project_dir",
|
||||
"results_dir",
|
||||
"data_cache_dir",
|
||||
"llm_provider",
|
||||
"deep_think_llm",
|
||||
"quick_think_llm",
|
||||
)
|
||||
|
||||
|
||||
def _validate_probability(name: str, value: Any) -> float:
|
||||
if not isinstance(value, (int, float)):
|
||||
raise TypeError(f"{name} must be a number")
|
||||
if not 0.0 <= float(value) <= 1.0:
|
||||
raise ValueError(f"{name} must be between 0.0 and 1.0")
|
||||
return float(value)
|
||||
|
||||
|
||||
def _validate_positive_int(name: str, value: Any) -> int:
|
||||
if not isinstance(value, int):
|
||||
raise TypeError(f"{name} must be an int")
|
||||
if value <= 0:
|
||||
raise ValueError(f"{name} must be > 0")
|
||||
return value
|
||||
|
||||
|
||||
def _validate_string_map(name: str, value: Any) -> dict[str, str]:
|
||||
if not isinstance(value, Mapping):
|
||||
raise TypeError(f"{name} must be a mapping")
|
||||
normalized = {}
|
||||
for key, item in value.items():
|
||||
if not isinstance(key, str) or not isinstance(item, str):
|
||||
raise TypeError(f"{name} keys and values must be strings")
|
||||
normalized[key] = item
|
||||
return normalized
|
||||
|
||||
|
||||
def build_trading_agents_config(
|
||||
overrides: Optional[Mapping[str, Any]],
|
||||
) -> TradingAgentsConfigPayload:
|
||||
merged: dict[str, Any] = get_default_config()
|
||||
|
||||
if overrides:
|
||||
if not isinstance(overrides, Mapping):
|
||||
raise TypeError("trading_agents_config must be a mapping")
|
||||
for key, value in overrides.items():
|
||||
if (
|
||||
key in ("data_vendors", "tool_vendors")
|
||||
and value is not None
|
||||
):
|
||||
merged[key] = _validate_string_map(key, value)
|
||||
elif key == "selected_analysts" and value is not None:
|
||||
if not isinstance(value, list) or any(
|
||||
not isinstance(item, str) for item in value
|
||||
):
|
||||
raise TypeError("selected_analysts must be a list of strings")
|
||||
merged[key] = list(value)
|
||||
else:
|
||||
merged[key] = value
|
||||
|
||||
for key in REQUIRED_TRADING_CONFIG_KEYS:
|
||||
value = merged.get(key)
|
||||
if not isinstance(value, str) or not value.strip():
|
||||
raise ValueError(f"trading_agents_config.{key} must be a non-empty string")
|
||||
|
||||
merged["data_vendors"] = _validate_string_map("data_vendors", merged["data_vendors"])
|
||||
merged["tool_vendors"] = _validate_string_map("tool_vendors", merged["tool_vendors"])
|
||||
|
||||
return cast(TradingAgentsConfigPayload, merged)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class OrchestratorConfigSchema:
|
||||
quant_backtest_path: str = ""
|
||||
trading_agents_config: TradingAgentsConfigPayload = field(
|
||||
default_factory=lambda: build_trading_agents_config(None)
|
||||
)
|
||||
quant_weight_cap: float = 0.8
|
||||
llm_weight_cap: float = 0.9
|
||||
llm_batch_days: int = 7
|
||||
cache_dir: str = "orchestrator/cache"
|
||||
llm_solo_penalty: float = 0.7
|
||||
quant_solo_penalty: float = 0.8
|
||||
contract_version: str = CONTRACT_VERSION
|
||||
|
||||
def to_runtime_fields(self) -> dict[str, Any]:
|
||||
return {
|
||||
"quant_backtest_path": self.quant_backtest_path,
|
||||
"trading_agents_config": dict(self.trading_agents_config),
|
||||
"quant_weight_cap": self.quant_weight_cap,
|
||||
"llm_weight_cap": self.llm_weight_cap,
|
||||
"llm_batch_days": self.llm_batch_days,
|
||||
"cache_dir": self.cache_dir,
|
||||
"llm_solo_penalty": self.llm_solo_penalty,
|
||||
"quant_solo_penalty": self.quant_solo_penalty,
|
||||
}
|
||||
|
||||
|
||||
def build_orchestrator_schema(raw: Mapping[str, Any]) -> OrchestratorConfigSchema:
|
||||
if not isinstance(raw, Mapping):
|
||||
raise TypeError("orchestrator config must be a mapping")
|
||||
|
||||
quant_backtest_path = raw.get("quant_backtest_path", "")
|
||||
if not isinstance(quant_backtest_path, str):
|
||||
raise TypeError("quant_backtest_path must be a string")
|
||||
|
||||
cache_dir = raw.get("cache_dir", "orchestrator/cache")
|
||||
if not isinstance(cache_dir, str) or not cache_dir.strip():
|
||||
raise ValueError("cache_dir must be a non-empty string")
|
||||
|
||||
return OrchestratorConfigSchema(
|
||||
quant_backtest_path=quant_backtest_path,
|
||||
trading_agents_config=build_trading_agents_config(
|
||||
cast(Optional[Mapping[str, Any]], raw.get("trading_agents_config"))
|
||||
),
|
||||
quant_weight_cap=_validate_probability(
|
||||
"quant_weight_cap", raw.get("quant_weight_cap", 0.8)
|
||||
),
|
||||
llm_weight_cap=_validate_probability(
|
||||
"llm_weight_cap", raw.get("llm_weight_cap", 0.9)
|
||||
),
|
||||
llm_batch_days=_validate_positive_int(
|
||||
"llm_batch_days", raw.get("llm_batch_days", 7)
|
||||
),
|
||||
cache_dir=cache_dir,
|
||||
llm_solo_penalty=_validate_probability(
|
||||
"llm_solo_penalty", raw.get("llm_solo_penalty", 0.7)
|
||||
),
|
||||
quant_solo_penalty=_validate_probability(
|
||||
"quant_solo_penalty", raw.get("quant_solo_penalty", 0.8)
|
||||
),
|
||||
)
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
from enum import Enum
|
||||
|
||||
|
||||
class ReasonCode(str, Enum):
|
||||
CONFIG_INVALID = "config_invalid"
|
||||
QUANT_NOT_CONFIGURED = "quant_not_configured"
|
||||
QUANT_INIT_FAILED = "quant_init_failed"
|
||||
QUANT_SIGNAL_FAILED = "quant_signal_failed"
|
||||
QUANT_NO_DATA = "quant_no_data"
|
||||
LLM_INIT_FAILED = "llm_init_failed"
|
||||
LLM_SIGNAL_FAILED = "llm_signal_failed"
|
||||
LLM_UNKNOWN_RATING = "llm_unknown_rating"
|
||||
BOTH_SIGNALS_UNAVAILABLE = "both_signals_unavailable"
|
||||
|
||||
|
||||
def reason_code_value(value: "ReasonCode | str") -> str:
|
||||
if isinstance(value, ReasonCode):
|
||||
return value.value
|
||||
return value
|
||||
|
|
@ -0,0 +1,99 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Optional
|
||||
|
||||
from orchestrator.contracts.config_schema import CONTRACT_VERSION
|
||||
from orchestrator.contracts.error_taxonomy import reason_code_value
|
||||
|
||||
|
||||
def _normalize_metadata(
|
||||
metadata: Optional[dict[str, Any]],
|
||||
*,
|
||||
reason_code: Optional[str] = None,
|
||||
) -> dict[str, Any]:
|
||||
normalized = dict(metadata or {})
|
||||
normalized.setdefault("contract_version", CONTRACT_VERSION)
|
||||
if reason_code:
|
||||
normalized.setdefault("reason_code", reason_code)
|
||||
return normalized
|
||||
|
||||
|
||||
@dataclass
|
||||
class Signal:
|
||||
ticker: str
|
||||
direction: int
|
||||
confidence: float
|
||||
source: str
|
||||
timestamp: datetime
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
contract_version: str = CONTRACT_VERSION
|
||||
reason_code: Optional[str] = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.reason_code is not None:
|
||||
self.reason_code = reason_code_value(self.reason_code)
|
||||
self.metadata = _normalize_metadata(self.metadata, reason_code=self.reason_code)
|
||||
self.reason_code = self.reason_code or self.metadata.get("reason_code")
|
||||
self.metadata.setdefault("source", self.source)
|
||||
|
||||
@property
|
||||
def degraded(self) -> bool:
|
||||
return self.reason_code is not None or bool(self.metadata.get("error"))
|
||||
|
||||
|
||||
@dataclass
|
||||
class FinalSignal:
|
||||
ticker: str
|
||||
direction: int
|
||||
confidence: float
|
||||
quant_signal: Optional[Signal]
|
||||
llm_signal: Optional[Signal]
|
||||
timestamp: datetime
|
||||
degrade_reason_codes: tuple[str, ...] = ()
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
contract_version: str = CONTRACT_VERSION
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.degrade_reason_codes = tuple(
|
||||
dict.fromkeys(code for code in self.degrade_reason_codes if code)
|
||||
)
|
||||
self.metadata = _normalize_metadata(self.metadata)
|
||||
if self.degrade_reason_codes:
|
||||
self.metadata.setdefault(
|
||||
"degrade_reason_codes",
|
||||
list(self.degrade_reason_codes),
|
||||
)
|
||||
|
||||
@property
|
||||
def degraded(self) -> bool:
|
||||
return bool(self.degrade_reason_codes)
|
||||
|
||||
|
||||
def build_error_signal(
|
||||
*,
|
||||
ticker: str,
|
||||
source: str,
|
||||
reason_code: str,
|
||||
message: str,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
timestamp: Optional[datetime] = None,
|
||||
) -> Signal:
|
||||
payload = dict(metadata or {})
|
||||
payload["error"] = message
|
||||
return Signal(
|
||||
ticker=ticker,
|
||||
direction=0,
|
||||
confidence=0.0,
|
||||
source=source,
|
||||
timestamp=timestamp or datetime.now(timezone.utc),
|
||||
metadata=payload,
|
||||
reason_code=reason_code,
|
||||
)
|
||||
|
||||
|
||||
def signal_reason_code(signal: Optional[Signal]) -> Optional[str]:
|
||||
if signal is None:
|
||||
return None
|
||||
return signal.reason_code or signal.metadata.get("reason_code")
|
||||
|
|
@ -4,7 +4,8 @@ import os
|
|||
from datetime import datetime, timezone
|
||||
|
||||
from orchestrator.config import OrchestratorConfig
|
||||
from orchestrator.signals import Signal
|
||||
from orchestrator.contracts.error_taxonomy import ReasonCode
|
||||
from orchestrator.contracts.result_contract import Signal, build_error_signal
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -21,7 +22,10 @@ class LLMRunner:
|
|||
if self._graph is None:
|
||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||
trading_cfg = self._config.trading_agents_config if self._config.trading_agents_config else None
|
||||
self._graph = TradingAgentsGraph(config=trading_cfg)
|
||||
graph_kwargs = {"config": trading_cfg}
|
||||
if trading_cfg and "selected_analysts" in trading_cfg:
|
||||
graph_kwargs["selected_analysts"] = trading_cfg["selected_analysts"]
|
||||
self._graph = TradingAgentsGraph(**graph_kwargs)
|
||||
return self._graph
|
||||
|
||||
def get_signal(self, ticker: str, date: str) -> Signal:
|
||||
|
|
@ -70,13 +74,11 @@ class LLMRunner:
|
|||
)
|
||||
except Exception as e:
|
||||
logger.error("LLMRunner: propagate failed for %s %s: %s", ticker, date, e)
|
||||
return Signal(
|
||||
return build_error_signal(
|
||||
ticker=ticker,
|
||||
direction=0,
|
||||
confidence=0.0,
|
||||
source="llm",
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
metadata={"error": str(e)},
|
||||
reason_code=ReasonCode.LLM_SIGNAL_FAILED.value,
|
||||
message=str(e),
|
||||
)
|
||||
|
||||
def _map_rating(self, rating: str) -> tuple[int, float]:
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
from orchestrator.config import OrchestratorConfig
|
||||
from orchestrator.contracts.error_taxonomy import ReasonCode
|
||||
from orchestrator.contracts.result_contract import FinalSignal, Signal, signal_reason_code
|
||||
from orchestrator.signals import Signal, FinalSignal, SignalMerger
|
||||
from orchestrator.quant_runner import QuantRunner
|
||||
from orchestrator.llm_runner import LLMRunner
|
||||
|
|
@ -16,6 +17,8 @@ class TradingOrchestrator:
|
|||
self._merger = SignalMerger(config)
|
||||
self._quant: Optional[QuantRunner] = None
|
||||
self._llm: Optional[LLMRunner] = None
|
||||
self._quant_unavailable_reason: Optional[str] = None
|
||||
self._llm_unavailable_reason: Optional[str] = None
|
||||
|
||||
# Initialize runners (quant requires quant_backtest_path)
|
||||
if config.quant_backtest_path:
|
||||
|
|
@ -23,8 +26,15 @@ class TradingOrchestrator:
|
|||
self._quant = QuantRunner(config)
|
||||
except Exception as e:
|
||||
logger.warning("TradingOrchestrator: QuantRunner init failed: %s", e)
|
||||
self._quant_unavailable_reason = ReasonCode.QUANT_INIT_FAILED.value
|
||||
else:
|
||||
self._quant_unavailable_reason = ReasonCode.QUANT_NOT_CONFIGURED.value
|
||||
|
||||
self._llm = LLMRunner(config)
|
||||
try:
|
||||
self._llm = LLMRunner(config)
|
||||
except Exception as e:
|
||||
logger.warning("TradingOrchestrator: LLMRunner init failed: %s", e)
|
||||
self._llm_unavailable_reason = ReasonCode.LLM_INIT_FAILED.value
|
||||
|
||||
def get_combined_signal(self, ticker: str, date: str) -> FinalSignal:
|
||||
"""
|
||||
|
|
@ -36,28 +46,48 @@ class TradingOrchestrator:
|
|||
"""
|
||||
quant_sig: Optional[Signal] = None
|
||||
llm_sig: Optional[Signal] = None
|
||||
degradation_reasons: list[str] = []
|
||||
|
||||
if self._quant is None and self._quant_unavailable_reason:
|
||||
degradation_reasons.append(self._quant_unavailable_reason)
|
||||
if self._llm is None and self._llm_unavailable_reason:
|
||||
degradation_reasons.append(self._llm_unavailable_reason)
|
||||
|
||||
# Get quant signal
|
||||
if self._quant is not None:
|
||||
try:
|
||||
quant_sig = self._quant.get_signal(ticker, date)
|
||||
# Treat error signals (confidence=0, direction=0 with error metadata) as None
|
||||
if quant_sig.metadata.get("error") or quant_sig.metadata.get("reason") == "no_data":
|
||||
if quant_sig.degraded:
|
||||
degradation_reasons.append(
|
||||
signal_reason_code(quant_sig) or ReasonCode.QUANT_SIGNAL_FAILED.value
|
||||
)
|
||||
logger.warning("TradingOrchestrator: quant signal degraded for %s %s", ticker, date)
|
||||
quant_sig = None
|
||||
except Exception as e:
|
||||
logger.error("TradingOrchestrator: quant get_signal failed: %s", e)
|
||||
degradation_reasons.append(ReasonCode.QUANT_SIGNAL_FAILED.value)
|
||||
quant_sig = None
|
||||
|
||||
# Get llm signal
|
||||
try:
|
||||
llm_sig = self._llm.get_signal(ticker, date)
|
||||
if llm_sig.metadata.get("error"):
|
||||
logger.warning("TradingOrchestrator: llm signal degraded for %s %s", ticker, date)
|
||||
if self._llm is not None:
|
||||
try:
|
||||
llm_sig = self._llm.get_signal(ticker, date)
|
||||
if llm_sig.degraded:
|
||||
degradation_reasons.append(
|
||||
signal_reason_code(llm_sig) or ReasonCode.LLM_SIGNAL_FAILED.value
|
||||
)
|
||||
logger.warning("TradingOrchestrator: llm signal degraded for %s %s", ticker, date)
|
||||
llm_sig = None
|
||||
except Exception as e:
|
||||
logger.error("TradingOrchestrator: llm get_signal failed: %s", e)
|
||||
degradation_reasons.append(ReasonCode.LLM_SIGNAL_FAILED.value)
|
||||
llm_sig = None
|
||||
except Exception as e:
|
||||
logger.error("TradingOrchestrator: llm get_signal failed: %s", e)
|
||||
llm_sig = None
|
||||
|
||||
# merge raises ValueError if both None
|
||||
return self._merger.merge(quant_sig, llm_sig)
|
||||
if quant_sig is None and llm_sig is None:
|
||||
degradation_reasons.append(ReasonCode.BOTH_SIGNALS_UNAVAILABLE.value)
|
||||
return self._merger.merge(
|
||||
quant_sig,
|
||||
llm_sig,
|
||||
degradation_reasons=degradation_reasons,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -8,7 +8,8 @@ from typing import Any
|
|||
import yfinance as yf
|
||||
|
||||
from orchestrator.config import OrchestratorConfig
|
||||
from orchestrator.signals import Signal
|
||||
from orchestrator.contracts.error_taxonomy import ReasonCode
|
||||
from orchestrator.contracts.result_contract import Signal, build_error_signal
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -41,13 +42,12 @@ class QuantRunner:
|
|||
df = yf.download(ticker, start=start_str, end=date, progress=False, auto_adjust=True)
|
||||
if df.empty:
|
||||
logger.warning("No price data for %s between %s and %s", ticker, start_str, date)
|
||||
return Signal(
|
||||
return build_error_signal(
|
||||
ticker=ticker,
|
||||
direction=0,
|
||||
confidence=0.0,
|
||||
source="quant",
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
metadata={"reason": "no_data"},
|
||||
reason_code=ReasonCode.QUANT_NO_DATA.value,
|
||||
message=f"no price data between {start_str} and {date}",
|
||||
metadata={"start_date": start_str, "end_date": date},
|
||||
)
|
||||
|
||||
# 标准化列名为小写
|
||||
|
|
|
|||
|
|
@ -1,33 +1,13 @@
|
|||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
from orchestrator.config import OrchestratorConfig
|
||||
from orchestrator.contracts.result_contract import FinalSignal, Signal
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Signal:
|
||||
ticker: str
|
||||
direction: int # +1 买入, -1 卖出, 0 持有
|
||||
confidence: float # 0.0 ~ 1.0
|
||||
source: str # "quant" | "llm"
|
||||
timestamp: datetime
|
||||
metadata: dict = field(default_factory=dict) # 原始输出,用于调试
|
||||
|
||||
|
||||
@dataclass
|
||||
class FinalSignal:
|
||||
ticker: str
|
||||
direction: int # sign(quant_dir×quant_conf + llm_dir×llm_conf),sign(0)→0(HOLD)
|
||||
confidence: float # abs(weighted_sum) / total_conf
|
||||
quant_signal: Optional[Signal]
|
||||
llm_signal: Optional[Signal]
|
||||
timestamp: datetime
|
||||
|
||||
|
||||
def _sign(x: float) -> int:
|
||||
"""Return +1, -1, or 0."""
|
||||
if x > 0:
|
||||
|
|
@ -41,8 +21,14 @@ class SignalMerger:
|
|||
def __init__(self, config: OrchestratorConfig) -> None:
|
||||
self._config = config
|
||||
|
||||
def merge(self, quant: Optional[Signal], llm: Optional[Signal]) -> FinalSignal:
|
||||
def merge(
|
||||
self,
|
||||
quant: Optional[Signal],
|
||||
llm: Optional[Signal],
|
||||
degradation_reasons: Optional[list[str]] = None,
|
||||
) -> FinalSignal:
|
||||
now = datetime.now(timezone.utc)
|
||||
reasons = tuple(dict.fromkeys(code for code in (degradation_reasons or []) if code))
|
||||
|
||||
# 两者均失败
|
||||
if quant is None and llm is None:
|
||||
|
|
@ -60,6 +46,7 @@ class SignalMerger:
|
|||
quant_signal=None,
|
||||
llm_signal=llm,
|
||||
timestamp=now,
|
||||
degrade_reason_codes=reasons,
|
||||
)
|
||||
|
||||
# 只有 Quant(llm 失败)
|
||||
|
|
@ -72,6 +59,7 @@ class SignalMerger:
|
|||
quant_signal=quant,
|
||||
llm_signal=None,
|
||||
timestamp=now,
|
||||
degrade_reason_codes=reasons,
|
||||
)
|
||||
|
||||
# 两者都有:加权合并
|
||||
|
|
@ -98,4 +86,5 @@ class SignalMerger:
|
|||
quant_signal=quant,
|
||||
llm_signal=llm,
|
||||
timestamp=now,
|
||||
degrade_reason_codes=reasons,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,113 @@
|
|||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
|
||||
import orchestrator.orchestrator as orchestrator_module
|
||||
from orchestrator.config import OrchestratorConfig
|
||||
from orchestrator.contracts.error_taxonomy import ReasonCode
|
||||
from orchestrator.signals import Signal
|
||||
|
||||
|
||||
def _signal(
|
||||
source: str,
|
||||
*,
|
||||
direction: int,
|
||||
confidence: float,
|
||||
metadata: dict | None = None,
|
||||
reason_code: str | None = None,
|
||||
) -> Signal:
|
||||
return Signal(
|
||||
ticker="AAPL",
|
||||
direction=direction,
|
||||
confidence=confidence,
|
||||
source=source,
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
metadata=metadata or {},
|
||||
reason_code=reason_code,
|
||||
)
|
||||
|
||||
|
||||
def test_trading_orchestrator_degrades_to_llm_only_when_quant_has_error(monkeypatch):
|
||||
class FakeQuantRunner:
|
||||
def __init__(self, _config):
|
||||
pass
|
||||
|
||||
def get_signal(self, _ticker, _date):
|
||||
return _signal("quant", direction=1, confidence=0.8, metadata={"error": "db unavailable"})
|
||||
|
||||
class FakeLLMRunner:
|
||||
def __init__(self, _config):
|
||||
pass
|
||||
|
||||
def get_signal(self, _ticker, _date):
|
||||
return _signal("llm", direction=-1, confidence=0.9)
|
||||
|
||||
monkeypatch.setattr(orchestrator_module, "QuantRunner", FakeQuantRunner)
|
||||
monkeypatch.setattr(orchestrator_module, "LLMRunner", FakeLLMRunner)
|
||||
|
||||
result = orchestrator_module.TradingOrchestrator(
|
||||
OrchestratorConfig(quant_backtest_path="/tmp/quant")
|
||||
).get_combined_signal("AAPL", "2026-04-11")
|
||||
|
||||
assert result.direction == -1
|
||||
assert result.quant_signal is None
|
||||
assert result.llm_signal is not None
|
||||
assert result.llm_signal.source == "llm"
|
||||
|
||||
|
||||
def test_trading_orchestrator_degrades_to_quant_only_when_llm_has_error(monkeypatch):
|
||||
class FakeQuantRunner:
|
||||
def __init__(self, _config):
|
||||
pass
|
||||
|
||||
def get_signal(self, _ticker, _date):
|
||||
return _signal("quant", direction=1, confidence=0.8)
|
||||
|
||||
class FakeLLMRunner:
|
||||
def __init__(self, _config):
|
||||
pass
|
||||
|
||||
def get_signal(self, _ticker, _date):
|
||||
return _signal("llm", direction=0, confidence=0.0, metadata={"error": "timeout"})
|
||||
|
||||
monkeypatch.setattr(orchestrator_module, "QuantRunner", FakeQuantRunner)
|
||||
monkeypatch.setattr(orchestrator_module, "LLMRunner", FakeLLMRunner)
|
||||
|
||||
result = orchestrator_module.TradingOrchestrator(
|
||||
OrchestratorConfig(quant_backtest_path="/tmp/quant")
|
||||
).get_combined_signal("AAPL", "2026-04-11")
|
||||
|
||||
assert result.direction == 1
|
||||
assert result.quant_signal is not None
|
||||
assert result.quant_signal.source == "quant"
|
||||
assert result.llm_signal is None
|
||||
|
||||
|
||||
def test_trading_orchestrator_raises_when_both_sources_degrade(monkeypatch):
|
||||
class FakeQuantRunner:
|
||||
def __init__(self, _config):
|
||||
pass
|
||||
|
||||
def get_signal(self, _ticker, _date):
|
||||
return _signal(
|
||||
"quant",
|
||||
direction=0,
|
||||
confidence=0.0,
|
||||
metadata={"error": "no data"},
|
||||
reason_code=ReasonCode.QUANT_NO_DATA.value,
|
||||
)
|
||||
|
||||
class FakeLLMRunner:
|
||||
def __init__(self, _config):
|
||||
pass
|
||||
|
||||
def get_signal(self, _ticker, _date):
|
||||
return _signal("llm", direction=0, confidence=0.0, metadata={"error": "timeout"})
|
||||
|
||||
monkeypatch.setattr(orchestrator_module, "QuantRunner", FakeQuantRunner)
|
||||
monkeypatch.setattr(orchestrator_module, "LLMRunner", FakeLLMRunner)
|
||||
|
||||
with pytest.raises(ValueError, match="both quant and llm signals are None"):
|
||||
orchestrator_module.TradingOrchestrator(
|
||||
OrchestratorConfig(quant_backtest_path="/tmp/quant")
|
||||
).get_combined_signal("AAPL", "2026-04-11")
|
||||
|
|
@ -0,0 +1,52 @@
|
|||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from orchestrator.config import OrchestratorConfig
|
||||
from orchestrator.contracts.error_taxonomy import ReasonCode
|
||||
from orchestrator.llm_runner import LLMRunner
|
||||
|
||||
|
||||
class _SuccessfulGraph:
|
||||
def propagate(self, ticker: str, date: str):
|
||||
return {"ticker": ticker, "date": date}, "BUY"
|
||||
|
||||
|
||||
class _FailingGraph:
|
||||
def propagate(self, _ticker: str, _date: str):
|
||||
raise RuntimeError("graph offline")
|
||||
|
||||
|
||||
def test_llm_runner_persists_result_contract_v1alpha1(monkeypatch, tmp_path):
|
||||
runner = LLMRunner(OrchestratorConfig(cache_dir=str(tmp_path)))
|
||||
monkeypatch.setattr(runner, "_get_graph", lambda: _SuccessfulGraph())
|
||||
|
||||
signal = runner.get_signal("BRK/B", "2026-04-11")
|
||||
|
||||
assert signal.ticker == "BRK/B"
|
||||
assert signal.direction == 1
|
||||
assert signal.confidence == 0.9
|
||||
assert signal.source == "llm"
|
||||
assert signal.metadata["rating"] == "BUY"
|
||||
assert signal.metadata["ticker"] == "BRK/B"
|
||||
assert signal.metadata["date"] == "2026-04-11"
|
||||
assert datetime.fromisoformat(signal.metadata["timestamp"])
|
||||
|
||||
cache_path = Path(tmp_path) / "BRK_B_2026-04-11.json"
|
||||
assert cache_path.exists()
|
||||
|
||||
|
||||
def test_llm_runner_returns_error_contract_when_graph_fails(monkeypatch, tmp_path):
|
||||
runner = LLMRunner(OrchestratorConfig(cache_dir=str(tmp_path)))
|
||||
monkeypatch.setattr(runner, "_get_graph", lambda: _FailingGraph())
|
||||
|
||||
signal = runner.get_signal("AAPL", "2026-04-11")
|
||||
|
||||
assert signal.ticker == "AAPL"
|
||||
assert signal.direction == 0
|
||||
assert signal.confidence == 0.0
|
||||
assert signal.source == "llm"
|
||||
assert signal.metadata["error"] == "graph offline"
|
||||
assert signal.metadata["reason_code"] == ReasonCode.LLM_SIGNAL_FAILED.value
|
||||
assert signal.metadata["contract_version"]
|
||||
assert signal.metadata["source"] == "llm"
|
||||
assert not (Path(tmp_path) / "AAPL_2026-04-11.json").exists()
|
||||
|
|
@ -0,0 +1,59 @@
|
|||
import tradingagents.agents.analysts.fundamentals_analyst as fundamentals_module
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class _FakePrompt:
|
||||
def __init__(self):
|
||||
self.partials = {}
|
||||
|
||||
def partial(self, **kwargs):
|
||||
self.partials.update(kwargs)
|
||||
return self
|
||||
|
||||
def __or__(self, _other):
|
||||
return _FakeChain(self)
|
||||
|
||||
|
||||
class _FakeChain:
|
||||
def __init__(self, prompt):
|
||||
self.prompt = prompt
|
||||
|
||||
def invoke(self, _messages):
|
||||
return SimpleNamespace(tool_calls=[], content=self.prompt.partials["system_message"])
|
||||
|
||||
|
||||
class _FakePromptTemplate:
|
||||
last_prompt = None
|
||||
|
||||
@classmethod
|
||||
def from_messages(cls, _messages):
|
||||
cls.last_prompt = _FakePrompt()
|
||||
return cls.last_prompt
|
||||
|
||||
|
||||
class _FakeLLM:
|
||||
def bind_tools(self, _tools):
|
||||
return self
|
||||
|
||||
|
||||
@pytest.mark.parametrize("compact_mode", [True, False])
|
||||
def test_fundamentals_system_message_is_string(monkeypatch, compact_mode):
|
||||
monkeypatch.setattr(fundamentals_module, "ChatPromptTemplate", _FakePromptTemplate)
|
||||
monkeypatch.setattr(fundamentals_module, "use_compact_analysis_prompt", lambda: compact_mode)
|
||||
monkeypatch.setattr(fundamentals_module, "get_language_instruction", lambda: "")
|
||||
|
||||
node = fundamentals_module.create_fundamentals_analyst(_FakeLLM())
|
||||
result = node(
|
||||
{
|
||||
"trade_date": "2026-04-11",
|
||||
"company_of_interest": "600519.SS",
|
||||
"messages": [],
|
||||
}
|
||||
)
|
||||
|
||||
system_message = _FakePromptTemplate.last_prompt.partials["system_message"]
|
||||
|
||||
assert isinstance(system_message, str)
|
||||
assert result["fundamentals_report"] == system_message
|
||||
|
|
@ -1,8 +1,11 @@
|
|||
"""Tests for LLMRunner._map_rating()."""
|
||||
import tempfile
|
||||
"""Tests for LLMRunner."""
|
||||
import sys
|
||||
from types import ModuleType
|
||||
|
||||
import pytest
|
||||
|
||||
from orchestrator.config import OrchestratorConfig
|
||||
from orchestrator.contracts.error_taxonomy import ReasonCode
|
||||
from orchestrator.llm_runner import LLMRunner
|
||||
|
||||
|
||||
|
|
@ -39,3 +42,43 @@ def test_map_rating_lowercase(runner):
|
|||
# Empty string → (0, 0.5)
|
||||
def test_map_rating_empty_string(runner):
|
||||
assert runner._map_rating("") == (0, 0.5)
|
||||
|
||||
|
||||
def test_get_graph_preserves_explicit_empty_selected_analysts(monkeypatch, tmp_path):
|
||||
captured_kwargs = {}
|
||||
|
||||
class FakeTradingAgentsGraph:
|
||||
def __init__(self, **kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
|
||||
fake_module = ModuleType("tradingagents.graph.trading_graph")
|
||||
fake_module.TradingAgentsGraph = FakeTradingAgentsGraph
|
||||
monkeypatch.setitem(sys.modules, "tradingagents.graph.trading_graph", fake_module)
|
||||
|
||||
cfg = OrchestratorConfig(
|
||||
cache_dir=str(tmp_path),
|
||||
trading_agents_config={"selected_analysts": [], "llm_provider": "anthropic"},
|
||||
)
|
||||
|
||||
runner = LLMRunner(cfg)
|
||||
graph = runner._get_graph()
|
||||
|
||||
assert isinstance(graph, FakeTradingAgentsGraph)
|
||||
assert captured_kwargs["config"] == cfg.trading_agents_config
|
||||
assert captured_kwargs["selected_analysts"] == []
|
||||
|
||||
|
||||
def test_get_signal_returns_reason_code_on_propagate_failure(monkeypatch, tmp_path):
|
||||
class BrokenGraph:
|
||||
def propagate(self, ticker, date):
|
||||
raise RuntimeError("graph unavailable")
|
||||
|
||||
cfg = OrchestratorConfig(cache_dir=str(tmp_path))
|
||||
runner = LLMRunner(cfg)
|
||||
monkeypatch.setattr(runner, "_get_graph", lambda: BrokenGraph())
|
||||
|
||||
signal = runner.get_signal("AAPL", "2024-01-02")
|
||||
|
||||
assert signal.degraded is True
|
||||
assert signal.reason_code == ReasonCode.LLM_SIGNAL_FAILED.value
|
||||
assert signal.metadata["error"] == "graph unavailable"
|
||||
|
|
|
|||
|
|
@ -0,0 +1,95 @@
|
|||
import importlib.util
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def _load_factory_module(monkeypatch):
|
||||
package_name = "_lane4_factory_testpkg"
|
||||
package = ModuleType(package_name)
|
||||
package.__path__ = []
|
||||
monkeypatch.setitem(sys.modules, package_name, package)
|
||||
|
||||
base_module = ModuleType(f"{package_name}.base_client")
|
||||
|
||||
class BaseLLMClient:
|
||||
pass
|
||||
|
||||
base_module.BaseLLMClient = BaseLLMClient
|
||||
monkeypatch.setitem(sys.modules, f"{package_name}.base_client", base_module)
|
||||
|
||||
calls = []
|
||||
|
||||
def _register_client(module_suffix: str, class_name: str):
|
||||
module = ModuleType(f"{package_name}.{module_suffix}")
|
||||
|
||||
class Client:
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
calls.append((class_name, args, kwargs))
|
||||
|
||||
setattr(module, class_name, Client)
|
||||
monkeypatch.setitem(sys.modules, module.__name__, module)
|
||||
|
||||
_register_client("openai_client", "OpenAIClient")
|
||||
_register_client("anthropic_client", "AnthropicClient")
|
||||
_register_client("google_client", "GoogleClient")
|
||||
|
||||
factory_path = (
|
||||
Path(__file__).resolve().parents[2]
|
||||
/ "tradingagents"
|
||||
/ "llm_clients"
|
||||
/ "factory.py"
|
||||
)
|
||||
spec = importlib.util.spec_from_file_location(f"{package_name}.factory", factory_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
monkeypatch.setitem(sys.modules, spec.name, module)
|
||||
assert spec.loader is not None
|
||||
spec.loader.exec_module(module)
|
||||
return module, calls
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("provider", "expected_class", "expected_provider"),
|
||||
[
|
||||
("openai", "OpenAIClient", "openai"),
|
||||
("OpenRouter", "OpenAIClient", "openrouter"),
|
||||
("ollama", "OpenAIClient", "ollama"),
|
||||
("xai", "OpenAIClient", "xai"),
|
||||
("anthropic", "AnthropicClient", None),
|
||||
("google", "GoogleClient", None),
|
||||
],
|
||||
)
|
||||
def test_create_llm_client_routes_provider_to_expected_adapter(
|
||||
monkeypatch,
|
||||
provider,
|
||||
expected_class,
|
||||
expected_provider,
|
||||
):
|
||||
factory_module, calls = _load_factory_module(monkeypatch)
|
||||
|
||||
client = factory_module.create_llm_client(
|
||||
provider=provider,
|
||||
model="demo-model",
|
||||
base_url="https://example.test",
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
assert client is not None
|
||||
assert calls[-1][0] == expected_class
|
||||
assert calls[-1][1] == ("demo-model", "https://example.test")
|
||||
if expected_provider is None:
|
||||
assert "provider" not in calls[-1][2]
|
||||
else:
|
||||
assert calls[-1][2]["provider"] == expected_provider
|
||||
assert calls[-1][2]["timeout"] == 30
|
||||
|
||||
|
||||
def test_create_llm_client_rejects_unsupported_provider(monkeypatch):
|
||||
factory_module, _calls = _load_factory_module(monkeypatch)
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported LLM provider"):
|
||||
factory_module.create_llm_client("unknown", "demo-model")
|
||||
|
|
@ -1,11 +1,10 @@
|
|||
"""Tests for QuantRunner._calc_confidence()."""
|
||||
import json
|
||||
import sqlite3
|
||||
import tempfile
|
||||
import os
|
||||
import pytest
|
||||
|
||||
from orchestrator.config import OrchestratorConfig
|
||||
from orchestrator.contracts.error_taxonomy import ReasonCode
|
||||
from orchestrator.quant_runner import QuantRunner
|
||||
|
||||
|
||||
|
|
@ -63,3 +62,15 @@ def test_calc_confidence_clamped_above(runner):
|
|||
def test_calc_confidence_clamped_below(runner):
|
||||
result = runner._calc_confidence(-1.0, 2.0)
|
||||
assert result == pytest.approx(0.0)
|
||||
|
||||
|
||||
def test_get_signal_returns_reason_code_when_no_data(runner, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"orchestrator.quant_runner.yf.download",
|
||||
lambda *args, **kwargs: type("EmptyFrame", (), {"empty": True})(),
|
||||
)
|
||||
|
||||
signal = runner.get_signal("AAPL", "2024-01-02")
|
||||
|
||||
assert signal.degraded is True
|
||||
assert signal.reason_code == ReasonCode.QUANT_NO_DATA.value
|
||||
|
|
|
|||
|
|
@ -54,12 +54,13 @@ def test_merge_quant_only_capped(merger):
|
|||
def test_merge_llm_only(merger):
|
||||
cfg = OrchestratorConfig()
|
||||
l = _make_signal(direction=-1, confidence=0.9, source="llm")
|
||||
result = merger.merge(None, l)
|
||||
result = merger.merge(None, l, degradation_reasons=["quant_signal_failed"])
|
||||
assert result.direction == -1
|
||||
expected_conf = min(0.9 * cfg.llm_solo_penalty, cfg.llm_weight_cap)
|
||||
assert math.isclose(result.confidence, expected_conf)
|
||||
assert result.llm_signal is l
|
||||
assert result.quant_signal is None
|
||||
assert result.degrade_reason_codes == ("quant_signal_failed",)
|
||||
|
||||
|
||||
def test_merge_llm_only_capped(merger):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,29 @@
|
|||
from tradingagents.default_config import DEFAULT_CONFIG
|
||||
from tradingagents.graph.trading_graph import _merge_with_default_config
|
||||
|
||||
|
||||
def test_merge_with_default_config_keeps_required_defaults():
|
||||
merged = _merge_with_default_config({
|
||||
"llm_provider": "anthropic",
|
||||
"backend_url": "https://example.com/api",
|
||||
})
|
||||
|
||||
assert merged["llm_provider"] == "anthropic"
|
||||
assert merged["backend_url"] == "https://example.com/api"
|
||||
assert merged["project_dir"] == DEFAULT_CONFIG["project_dir"]
|
||||
assert merged["results_dir"] == DEFAULT_CONFIG["results_dir"]
|
||||
|
||||
|
||||
def test_merge_with_default_config_merges_nested_vendor_settings():
|
||||
merged = _merge_with_default_config({
|
||||
"data_vendors": {
|
||||
"news_data": "alpha_vantage",
|
||||
},
|
||||
"tool_vendors": {
|
||||
"get_stock_data": "alpha_vantage",
|
||||
},
|
||||
})
|
||||
|
||||
assert merged["data_vendors"]["news_data"] == "alpha_vantage"
|
||||
assert merged["data_vendors"]["core_stock_apis"] == DEFAULT_CONFIG["data_vendors"]["core_stock_apis"]
|
||||
assert merged["tool_vendors"]["get_stock_data"] == "alpha_vantage"
|
||||
|
|
@ -5,10 +5,9 @@ from tradingagents.agents.utils.agent_utils import (
|
|||
get_cashflow,
|
||||
get_fundamentals,
|
||||
get_income_statement,
|
||||
get_insider_transactions,
|
||||
get_language_instruction,
|
||||
use_compact_analysis_prompt,
|
||||
)
|
||||
from tradingagents.dataflows.config import get_config
|
||||
|
||||
|
||||
def create_fundamentals_analyst(llm):
|
||||
|
|
@ -23,12 +22,18 @@ def create_fundamentals_analyst(llm):
|
|||
get_income_statement,
|
||||
]
|
||||
|
||||
system_message = (
|
||||
"You are a researcher tasked with analyzing fundamental information over the past week about a company. Please write a comprehensive report of the company's fundamental information such as financial documents, company profile, basic company financials, and company financial history to gain a full view of the company's fundamental information to inform traders. Make sure to include as much detail as possible. Provide specific, actionable insights with supporting evidence to help traders make informed decisions."
|
||||
+ " Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read."
|
||||
+ " Use the available tools: `get_fundamentals` for comprehensive company analysis, `get_balance_sheet`, `get_cashflow`, and `get_income_statement` for specific financial statements."
|
||||
+ get_language_instruction(),
|
||||
)
|
||||
if use_compact_analysis_prompt():
|
||||
system_message = (
|
||||
"You are a fundamentals analyst. Use `get_fundamentals` first, then only call statement tools if needed. Summarize the company in under 220 words with: business quality, growth/profitability, balance-sheet risk, cash-flow quality, and a trading implication. End with a Markdown table."
|
||||
+ get_language_instruction()
|
||||
)
|
||||
else:
|
||||
system_message = (
|
||||
"You are a researcher tasked with analyzing fundamental information over the past week about a company. Please write a comprehensive report of the company's fundamental information such as financial documents, company profile, basic company financials, and company financial history to gain a full view of the company's fundamental information to inform traders. Make sure to include as much detail as possible. Provide specific, actionable insights with supporting evidence to help traders make informed decisions."
|
||||
+ " Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read."
|
||||
+ " Use the available tools: `get_fundamentals` for comprehensive company analysis, `get_balance_sheet`, `get_cashflow`, and `get_income_statement` for specific financial statements."
|
||||
+ get_language_instruction()
|
||||
)
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from tradingagents.agents.utils.agent_utils import (
|
|||
get_indicators,
|
||||
get_language_instruction,
|
||||
get_stock_data,
|
||||
use_compact_analysis_prompt,
|
||||
)
|
||||
from tradingagents.dataflows.config import get_config
|
||||
|
||||
|
|
@ -19,8 +20,23 @@ def create_market_analyst(llm):
|
|||
get_indicators,
|
||||
]
|
||||
|
||||
system_message = (
|
||||
"""You are a trading assistant tasked with analyzing financial markets. Your role is to select the **most relevant indicators** for a given market condition or trading strategy from the following list. The goal is to choose up to **8 indicators** that provide complementary insights without redundancy. Categories and each category's indicators are:
|
||||
if use_compact_analysis_prompt():
|
||||
system_message = (
|
||||
"""You are a market analyst. First call `get_stock_data`, then call `get_indicators` with 4 to 6 complementary indicators chosen from: `close_10_ema`, `close_50_sma`, `close_200_sma`, `macd`, `macds`, `macdh`, `rsi`, `boll`, `boll_ub`, `boll_lb`, `atr`, `vwma`.
|
||||
|
||||
Pick indicators that cover trend, momentum, volatility, and volume without redundancy. Then produce a concise report with:
|
||||
- market regime
|
||||
- momentum signal
|
||||
- support/resistance or volatility levels
|
||||
- trade implications
|
||||
- risk warnings
|
||||
|
||||
Keep the report under 250 words and end with a Markdown table of the key signals."""
|
||||
+ get_language_instruction()
|
||||
)
|
||||
else:
|
||||
system_message = (
|
||||
"""You are a trading assistant tasked with analyzing financial markets. Your role is to select the **most relevant indicators** for a given market condition or trading strategy from the following list. The goal is to choose up to **8 indicators** that provide complementary insights without redundancy. Categories and each category's indicators are:
|
||||
|
||||
Moving Averages:
|
||||
- close_50_sma: 50 SMA: A medium-term trend indicator. Usage: Identify trend direction and serve as dynamic support/resistance. Tips: It lags price; combine with faster indicators for timely signals.
|
||||
|
|
@ -45,9 +61,9 @@ Volume-Based Indicators:
|
|||
- vwma: VWMA: A moving average weighted by volume. Usage: Confirm trends by integrating price action with volume data. Tips: Watch for skewed results from volume spikes; use in combination with other volume analyses.
|
||||
|
||||
- Select indicators that provide diverse and complementary information. Avoid redundancy (e.g., do not select both rsi and stochrsi). Also briefly explain why they are suitable for the given market context. When you tool call, please use the exact name of the indicators provided above as they are defined parameters, otherwise your call will fail. Please make sure to call get_stock_data first to retrieve the CSV that is needed to generate indicators. Then use get_indicators with the specific indicator names. Write a very detailed and nuanced report of the trends you observe. Provide specific, actionable insights with supporting evidence to help traders make informed decisions."""
|
||||
+ """ Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read."""
|
||||
+ get_language_instruction()
|
||||
)
|
||||
+ """ Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read."""
|
||||
+ get_language_instruction()
|
||||
)
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from tradingagents.agents.utils.agent_utils import (
|
|||
get_global_news,
|
||||
get_language_instruction,
|
||||
get_news,
|
||||
use_compact_analysis_prompt,
|
||||
)
|
||||
from tradingagents.dataflows.config import get_config
|
||||
|
||||
|
|
@ -18,11 +19,17 @@ def create_news_analyst(llm):
|
|||
get_global_news,
|
||||
]
|
||||
|
||||
system_message = (
|
||||
"You are a news researcher tasked with analyzing recent news and trends over the past week. Please write a comprehensive report of the current state of the world that is relevant for trading and macroeconomics. Use the available tools: get_news(query, start_date, end_date) for company-specific or targeted news searches, and get_global_news(curr_date, look_back_days, limit) for broader macroeconomic news. Provide specific, actionable insights with supporting evidence to help traders make informed decisions."
|
||||
+ """ Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read."""
|
||||
+ get_language_instruction()
|
||||
)
|
||||
if use_compact_analysis_prompt():
|
||||
system_message = (
|
||||
"You are a news analyst. Gather only the most relevant recent company and macro news. Summarize in under 180 words with: bullish catalysts, bearish catalysts, macro context, and likely near-term market impact. End with a Markdown table."
|
||||
+ get_language_instruction()
|
||||
)
|
||||
else:
|
||||
system_message = (
|
||||
"You are a news researcher tasked with analyzing recent news and trends over the past week. Please write a comprehensive report of the current state of the world that is relevant for trading and macroeconomics. Use the available tools: get_news(query, start_date, end_date) for company-specific or targeted news searches, and get_global_news(curr_date, look_back_days, limit) for broader macroeconomic news. Provide specific, actionable insights with supporting evidence to help traders make informed decisions."
|
||||
+ """ Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read."""
|
||||
+ get_language_instruction()
|
||||
)
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
|
|
|
|||
|
|
@ -1,5 +1,10 @@
|
|||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from tradingagents.agents.utils.agent_utils import build_instrument_context, get_language_instruction, get_news
|
||||
from tradingagents.agents.utils.agent_utils import (
|
||||
build_instrument_context,
|
||||
get_language_instruction,
|
||||
get_news,
|
||||
use_compact_analysis_prompt,
|
||||
)
|
||||
from tradingagents.dataflows.config import get_config
|
||||
|
||||
|
||||
|
|
@ -12,11 +17,17 @@ def create_social_media_analyst(llm):
|
|||
get_news,
|
||||
]
|
||||
|
||||
system_message = (
|
||||
"You are a social media and company specific news researcher/analyst tasked with analyzing social media posts, recent company news, and public sentiment for a specific company over the past week. You will be given a company's name your objective is to write a comprehensive long report detailing your analysis, insights, and implications for traders and investors on this company's current state after looking at social media and what people are saying about that company, analyzing sentiment data of what people feel each day about the company, and looking at recent company news. Use the get_news(query, start_date, end_date) tool to search for company-specific news and social media discussions. Try to look at all sources possible from social media to sentiment to news. Provide specific, actionable insights with supporting evidence to help traders make informed decisions."
|
||||
+ """ Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read."""
|
||||
+ get_language_instruction()
|
||||
)
|
||||
if use_compact_analysis_prompt():
|
||||
system_message = (
|
||||
"You are a sentiment analyst. Use `get_news` to infer recent company sentiment from news and public discussion. Summarize in under 180 words with: sentiment direction, what is driving it, whether sentiment confirms or contradicts price action, and the trading implication. End with a Markdown table."
|
||||
+ get_language_instruction()
|
||||
)
|
||||
else:
|
||||
system_message = (
|
||||
"You are a social media and company specific news researcher/analyst tasked with analyzing social media posts, recent company news, and public sentiment for a specific company over the past week. You will be given a company's name your objective is to write a comprehensive long report detailing your analysis, insights, and implications for traders and investors on this company's current state after looking at social media and what people are saying about that company, analyzing sentiment data of what people feel each day about the company, and looking at recent company news. Use the get_news(query, start_date, end_date) tool to search for company-specific news and social media discussions. Try to look at all sources possible from social media to sentiment to news. Provide specific, actionable insights with supporting evidence to help traders make informed decisions."
|
||||
+ """ Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read."""
|
||||
+ get_language_instruction()
|
||||
)
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
|
|
|
|||
|
|
@ -1,4 +1,9 @@
|
|||
from tradingagents.agents.utils.agent_utils import build_instrument_context, get_language_instruction
|
||||
from tradingagents.agents.utils.agent_utils import (
|
||||
build_instrument_context,
|
||||
get_language_instruction,
|
||||
truncate_prompt_text,
|
||||
use_compact_analysis_prompt,
|
||||
)
|
||||
|
||||
|
||||
def create_portfolio_manager(llm, memory):
|
||||
|
|
@ -22,7 +27,24 @@ def create_portfolio_manager(llm, memory):
|
|||
for i, rec in enumerate(past_memories, 1):
|
||||
past_memory_str += rec["recommendation"] + "\n\n"
|
||||
|
||||
prompt = f"""As the Portfolio Manager, synthesize the risk analysts' debate and deliver the final trading decision.
|
||||
if use_compact_analysis_prompt():
|
||||
prompt = f"""As the Portfolio Manager, synthesize the risk debate and deliver the final rating.
|
||||
|
||||
{instrument_context}
|
||||
|
||||
Use exactly one rating: Buy / Overweight / Hold / Underweight / Sell.
|
||||
|
||||
Return only:
|
||||
1. Rating
|
||||
2. Executive summary
|
||||
3. Key risks
|
||||
|
||||
Research plan: {truncate_prompt_text(research_plan, 500)}
|
||||
Trader plan: {truncate_prompt_text(trader_plan, 500)}
|
||||
Past lessons: {truncate_prompt_text(past_memory_str, 400)}
|
||||
Risk debate: {truncate_prompt_text(history, 1400)}{get_language_instruction()}"""
|
||||
else:
|
||||
prompt = f"""As the Portfolio Manager, synthesize the risk analysts' debate and deliver the final trading decision.
|
||||
|
||||
{instrument_context}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,8 @@
|
|||
|
||||
from tradingagents.agents.utils.agent_utils import build_instrument_context
|
||||
from tradingagents.agents.utils.agent_utils import (
|
||||
build_instrument_context,
|
||||
truncate_prompt_text,
|
||||
use_compact_analysis_prompt,
|
||||
)
|
||||
|
||||
|
||||
def create_research_manager(llm, memory):
|
||||
|
|
@ -20,7 +23,23 @@ def create_research_manager(llm, memory):
|
|||
for i, rec in enumerate(past_memories, 1):
|
||||
past_memory_str += rec["recommendation"] + "\n\n"
|
||||
|
||||
prompt = f"""As the portfolio manager and debate facilitator, your role is to critically evaluate this round of debate and make a definitive decision: align with the bear analyst, the bull analyst, or choose Hold only if it is strongly justified based on the arguments presented.
|
||||
if use_compact_analysis_prompt():
|
||||
prompt = f"""You are the research manager. Decide Buy, Sell, or Hold based on the debate.
|
||||
|
||||
Return a concise response with:
|
||||
1. Recommendation
|
||||
2. Top reasons
|
||||
3. Simple execution plan
|
||||
|
||||
Past lessons:
|
||||
{truncate_prompt_text(past_memory_str, 400)}
|
||||
|
||||
{instrument_context}
|
||||
|
||||
Debate history:
|
||||
{truncate_prompt_text(history, 1200)}"""
|
||||
else:
|
||||
prompt = f"""As the portfolio manager and debate facilitator, your role is to critically evaluate this round of debate and make a definitive decision: align with the bear analyst, the bull analyst, or choose Hold only if it is strongly justified based on the arguments presented.
|
||||
|
||||
Summarize the key points from both sides concisely, focusing on the most compelling evidence or reasoning. Your recommendation—Buy, Sell, or Hold—must be clear and actionable. Avoid defaulting to Hold simply because both sides have valid points; commit to a stance grounded in the debate's strongest arguments.
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,9 @@
|
|||
|
||||
from tradingagents.agents.utils.agent_utils import (
|
||||
truncate_prompt_text,
|
||||
use_compact_analysis_prompt,
|
||||
)
|
||||
|
||||
|
||||
def create_bear_researcher(llm, memory):
|
||||
def bear_node(state) -> dict:
|
||||
|
|
@ -19,7 +24,21 @@ def create_bear_researcher(llm, memory):
|
|||
for i, rec in enumerate(past_memories, 1):
|
||||
past_memory_str += rec["recommendation"] + "\n\n"
|
||||
|
||||
prompt = f"""You are a Bear Analyst making the case against investing in the stock. Your goal is to present a well-reasoned argument emphasizing risks, challenges, and negative indicators. Leverage the provided research and data to highlight potential downsides and counter bullish arguments effectively.
|
||||
if use_compact_analysis_prompt():
|
||||
prompt = f"""You are a Bear Analyst. Make the strongest concise short case against the stock.
|
||||
|
||||
Use only the highest-signal evidence from the reports below. Address the latest bull point directly. Keep the answer under 220 words and end with a clear stance.
|
||||
|
||||
Market report: {truncate_prompt_text(market_research_report, 800)}
|
||||
Sentiment report: {truncate_prompt_text(sentiment_report, 500)}
|
||||
News report: {truncate_prompt_text(news_report, 500)}
|
||||
Fundamentals report: {truncate_prompt_text(fundamentals_report, 700)}
|
||||
Debate history: {truncate_prompt_text(history, 600)}
|
||||
Last bull argument: {truncate_prompt_text(current_response, 400)}
|
||||
Past lessons: {truncate_prompt_text(past_memory_str, 400)}
|
||||
"""
|
||||
else:
|
||||
prompt = f"""You are a Bear Analyst making the case against investing in the stock. Your goal is to present a well-reasoned argument emphasizing risks, challenges, and negative indicators. Leverage the provided research and data to highlight potential downsides and counter bullish arguments effectively.
|
||||
|
||||
Key points to focus on:
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,9 @@
|
|||
|
||||
from tradingagents.agents.utils.agent_utils import (
|
||||
truncate_prompt_text,
|
||||
use_compact_analysis_prompt,
|
||||
)
|
||||
|
||||
|
||||
def create_bull_researcher(llm, memory):
|
||||
def bull_node(state) -> dict:
|
||||
|
|
@ -19,7 +24,21 @@ def create_bull_researcher(llm, memory):
|
|||
for i, rec in enumerate(past_memories, 1):
|
||||
past_memory_str += rec["recommendation"] + "\n\n"
|
||||
|
||||
prompt = f"""You are a Bull Analyst advocating for investing in the stock. Your task is to build a strong, evidence-based case emphasizing growth potential, competitive advantages, and positive market indicators. Leverage the provided research and data to address concerns and counter bearish arguments effectively.
|
||||
if use_compact_analysis_prompt():
|
||||
prompt = f"""You are a Bull Analyst. Make the strongest concise long case for the stock.
|
||||
|
||||
Use only the highest-signal evidence from the reports below. Address the latest bear point directly. Keep the answer under 220 words and end with a clear stance.
|
||||
|
||||
Market report: {truncate_prompt_text(market_research_report, 800)}
|
||||
Sentiment report: {truncate_prompt_text(sentiment_report, 500)}
|
||||
News report: {truncate_prompt_text(news_report, 500)}
|
||||
Fundamentals report: {truncate_prompt_text(fundamentals_report, 700)}
|
||||
Debate history: {truncate_prompt_text(history, 600)}
|
||||
Last bear argument: {truncate_prompt_text(current_response, 400)}
|
||||
Past lessons: {truncate_prompt_text(past_memory_str, 400)}
|
||||
"""
|
||||
else:
|
||||
prompt = f"""You are a Bull Analyst advocating for investing in the stock. Your task is to build a strong, evidence-based case emphasizing growth potential, competitive advantages, and positive market indicators. Leverage the provided research and data to address concerns and counter bearish arguments effectively.
|
||||
|
||||
Key points to focus on:
|
||||
- Growth Potential: Highlight the company's market opportunities, revenue projections, and scalability.
|
||||
|
|
|
|||
|
|
@ -1,4 +1,9 @@
|
|||
|
||||
from tradingagents.agents.utils.agent_utils import (
|
||||
truncate_prompt_text,
|
||||
use_compact_analysis_prompt,
|
||||
)
|
||||
|
||||
|
||||
def create_aggressive_debator(llm):
|
||||
def aggressive_node(state) -> dict:
|
||||
|
|
@ -16,7 +21,21 @@ def create_aggressive_debator(llm):
|
|||
|
||||
trader_decision = state["trader_investment_plan"]
|
||||
|
||||
prompt = f"""As the Aggressive Risk Analyst, your role is to actively champion high-reward, high-risk opportunities, emphasizing bold strategies and competitive advantages. When evaluating the trader's decision or plan, focus intently on the potential upside, growth potential, and innovative benefits—even when these come with elevated risk. Use the provided market data and sentiment analysis to strengthen your arguments and challenge the opposing views. Specifically, respond directly to each point made by the conservative and neutral analysts, countering with data-driven rebuttals and persuasive reasoning. Highlight where their caution might miss critical opportunities or where their assumptions may be overly conservative. Here is the trader's decision:
|
||||
if use_compact_analysis_prompt():
|
||||
prompt = f"""You are the Aggressive Risk Analyst. Defend upside and attack excessive caution.
|
||||
|
||||
Trader decision: {truncate_prompt_text(trader_decision, 500)}
|
||||
Market report: {truncate_prompt_text(market_research_report, 500)}
|
||||
Sentiment report: {truncate_prompt_text(sentiment_report, 350)}
|
||||
News report: {truncate_prompt_text(news_report, 350)}
|
||||
Fundamentals report: {truncate_prompt_text(fundamentals_report, 450)}
|
||||
Debate history: {truncate_prompt_text(history, 500)}
|
||||
Last conservative: {truncate_prompt_text(current_conservative_response, 300)}
|
||||
Last neutral: {truncate_prompt_text(current_neutral_response, 300)}
|
||||
|
||||
Keep it under 180 words and focus on 2-3 high-upside arguments."""
|
||||
else:
|
||||
prompt = f"""As the Aggressive Risk Analyst, your role is to actively champion high-reward, high-risk opportunities, emphasizing bold strategies and competitive advantages. When evaluating the trader's decision or plan, focus intently on the potential upside, growth potential, and innovative benefits—even when these come with elevated risk. Use the provided market data and sentiment analysis to strengthen your arguments and challenge the opposing views. Specifically, respond directly to each point made by the conservative and neutral analysts, countering with data-driven rebuttals and persuasive reasoning. Highlight where their caution might miss critical opportunities or where their assumptions may be overly conservative. Here is the trader's decision:
|
||||
|
||||
{trader_decision}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,9 @@
|
|||
|
||||
from tradingagents.agents.utils.agent_utils import (
|
||||
truncate_prompt_text,
|
||||
use_compact_analysis_prompt,
|
||||
)
|
||||
|
||||
|
||||
def create_conservative_debator(llm):
|
||||
def conservative_node(state) -> dict:
|
||||
|
|
@ -16,7 +21,21 @@ def create_conservative_debator(llm):
|
|||
|
||||
trader_decision = state["trader_investment_plan"]
|
||||
|
||||
prompt = f"""As the Conservative Risk Analyst, your primary objective is to protect assets, minimize volatility, and ensure steady, reliable growth. You prioritize stability, security, and risk mitigation, carefully assessing potential losses, economic downturns, and market volatility. When evaluating the trader's decision or plan, critically examine high-risk elements, pointing out where the decision may expose the firm to undue risk and where more cautious alternatives could secure long-term gains. Here is the trader's decision:
|
||||
if use_compact_analysis_prompt():
|
||||
prompt = f"""You are the Conservative Risk Analyst. Focus on downside protection and capital preservation.
|
||||
|
||||
Trader decision: {truncate_prompt_text(trader_decision, 500)}
|
||||
Market report: {truncate_prompt_text(market_research_report, 500)}
|
||||
Sentiment report: {truncate_prompt_text(sentiment_report, 350)}
|
||||
News report: {truncate_prompt_text(news_report, 350)}
|
||||
Fundamentals report: {truncate_prompt_text(fundamentals_report, 450)}
|
||||
Debate history: {truncate_prompt_text(history, 500)}
|
||||
Last aggressive: {truncate_prompt_text(current_aggressive_response, 300)}
|
||||
Last neutral: {truncate_prompt_text(current_neutral_response, 300)}
|
||||
|
||||
Keep it under 180 words and focus on 2-3 main risks."""
|
||||
else:
|
||||
prompt = f"""As the Conservative Risk Analyst, your primary objective is to protect assets, minimize volatility, and ensure steady, reliable growth. You prioritize stability, security, and risk mitigation, carefully assessing potential losses, economic downturns, and market volatility. When evaluating the trader's decision or plan, critically examine high-risk elements, pointing out where the decision may expose the firm to undue risk and where more cautious alternatives could secure long-term gains. Here is the trader's decision:
|
||||
|
||||
{trader_decision}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,9 @@
|
|||
|
||||
from tradingagents.agents.utils.agent_utils import (
|
||||
truncate_prompt_text,
|
||||
use_compact_analysis_prompt,
|
||||
)
|
||||
|
||||
|
||||
def create_neutral_debator(llm):
|
||||
def neutral_node(state) -> dict:
|
||||
|
|
@ -16,7 +21,21 @@ def create_neutral_debator(llm):
|
|||
|
||||
trader_decision = state["trader_investment_plan"]
|
||||
|
||||
prompt = f"""As the Neutral Risk Analyst, your role is to provide a balanced perspective, weighing both the potential benefits and risks of the trader's decision or plan. You prioritize a well-rounded approach, evaluating the upsides and downsides while factoring in broader market trends, potential economic shifts, and diversification strategies.Here is the trader's decision:
|
||||
if use_compact_analysis_prompt():
|
||||
prompt = f"""You are the Neutral Risk Analyst. Balance upside and downside and prefer robust execution.
|
||||
|
||||
Trader decision: {truncate_prompt_text(trader_decision, 500)}
|
||||
Market report: {truncate_prompt_text(market_research_report, 500)}
|
||||
Sentiment report: {truncate_prompt_text(sentiment_report, 350)}
|
||||
News report: {truncate_prompt_text(news_report, 350)}
|
||||
Fundamentals report: {truncate_prompt_text(fundamentals_report, 450)}
|
||||
Debate history: {truncate_prompt_text(history, 500)}
|
||||
Last aggressive: {truncate_prompt_text(current_aggressive_response, 300)}
|
||||
Last conservative: {truncate_prompt_text(current_conservative_response, 300)}
|
||||
|
||||
Keep it under 180 words and argue for the most balanced path."""
|
||||
else:
|
||||
prompt = f"""As the Neutral Risk Analyst, your role is to provide a balanced perspective, weighing both the potential benefits and risks of the trader's decision or plan. You prioritize a well-rounded approach, evaluating the upsides and downsides while factoring in broader market trends, potential economic shifts, and diversification strategies.Here is the trader's decision:
|
||||
|
||||
{trader_decision}
|
||||
|
||||
|
|
|
|||
|
|
@ -34,6 +34,27 @@ def get_language_instruction() -> str:
|
|||
return f" Write your entire response in {lang}."
|
||||
|
||||
|
||||
def use_compact_analysis_prompt() -> bool:
|
||||
"""Return whether analysts should use shorter prompts/reports.
|
||||
|
||||
This is helpful for OpenAI-compatible or Anthropic-compatible backends
|
||||
that support the API surface but struggle with the repository's original,
|
||||
very verbose analyst instructions.
|
||||
"""
|
||||
from tradingagents.dataflows.config import get_config
|
||||
|
||||
mode = str(get_config().get("analysis_prompt_style", "standard")).strip().lower()
|
||||
return mode in {"compact", "fast", "minimax"}
|
||||
|
||||
|
||||
def truncate_prompt_text(text: str, max_chars: int = 1200) -> str:
|
||||
"""Trim long reports/history before feeding them into compact prompts."""
|
||||
text = (text or "").strip()
|
||||
if len(text) <= max_chars:
|
||||
return text
|
||||
return text[:max_chars].rstrip() + "\n...[truncated]..."
|
||||
|
||||
|
||||
def build_instrument_context(ticker: str) -> str:
|
||||
"""Describe the exact instrument so agents preserve exchange-qualified tickers."""
|
||||
return (
|
||||
|
|
|
|||
|
|
@ -0,0 +1,8 @@
|
|||
from .interface import DEFAULT_DATAFLOW_ADAPTER, DataflowAdapter, VendorSelection, route_to_vendor
|
||||
|
||||
__all__ = [
|
||||
"DEFAULT_DATAFLOW_ADAPTER",
|
||||
"DataflowAdapter",
|
||||
"VendorSelection",
|
||||
"route_to_vendor",
|
||||
]
|
||||
|
|
@ -9,15 +9,29 @@ def initialize_config():
|
|||
"""Initialize the configuration with default values."""
|
||||
global _config
|
||||
if _config is None:
|
||||
_config = default_config.DEFAULT_CONFIG.copy()
|
||||
_config = default_config.get_default_config()
|
||||
|
||||
|
||||
def _merge_config(base: Dict, overrides: Dict) -> Dict:
|
||||
merged = dict(base)
|
||||
for key, value in overrides.items():
|
||||
if (
|
||||
key in ("data_vendors", "tool_vendors")
|
||||
and isinstance(value, dict)
|
||||
and isinstance(merged.get(key), dict)
|
||||
):
|
||||
merged[key] = {**merged[key], **value}
|
||||
else:
|
||||
merged[key] = value
|
||||
return merged
|
||||
|
||||
|
||||
def set_config(config: Dict):
|
||||
"""Update the configuration with custom values."""
|
||||
global _config
|
||||
if _config is None:
|
||||
_config = default_config.DEFAULT_CONFIG.copy()
|
||||
_config.update(config)
|
||||
_config = default_config.get_default_config()
|
||||
_config = _merge_config(_config, config)
|
||||
|
||||
|
||||
def get_config() -> Dict:
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from typing import Annotated
|
||||
from dataclasses import dataclass
|
||||
from typing import Annotated, Any
|
||||
|
||||
# Import from vendor-specific modules
|
||||
from .y_finance import (
|
||||
|
|
@ -183,32 +184,62 @@ def get_vendor(category: str, method: str = None) -> str:
|
|||
return config.get("data_vendors", {}).get(category, "default")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class VendorSelection:
|
||||
"""Resolved vendor routing metadata for one dataflow method call."""
|
||||
|
||||
method: str
|
||||
category: str
|
||||
configured_vendors: tuple[str, ...]
|
||||
fallback_chain: tuple[str, ...]
|
||||
|
||||
|
||||
class DataflowAdapter:
|
||||
"""Thin adapter boundary over legacy vendor routing logic."""
|
||||
|
||||
def resolve(self, method: str) -> VendorSelection:
|
||||
category = get_category_for_method(method)
|
||||
vendor_config = get_vendor(category, method)
|
||||
configured_vendors = tuple(v.strip() for v in vendor_config.split(",") if v.strip())
|
||||
|
||||
if method not in VENDOR_METHODS:
|
||||
raise ValueError(f"Method '{method}' not supported")
|
||||
|
||||
all_available_vendors = list(VENDOR_METHODS[method].keys())
|
||||
fallback_chain = list(configured_vendors)
|
||||
for vendor in all_available_vendors:
|
||||
if vendor not in fallback_chain:
|
||||
fallback_chain.append(vendor)
|
||||
|
||||
return VendorSelection(
|
||||
method=method,
|
||||
category=category,
|
||||
configured_vendors=configured_vendors,
|
||||
fallback_chain=tuple(fallback_chain),
|
||||
)
|
||||
|
||||
def execute(self, method: str, *args: Any, **kwargs: Any):
|
||||
"""Route the call through the configured vendor chain with legacy fallback behavior."""
|
||||
selection = self.resolve(method)
|
||||
|
||||
for vendor in selection.fallback_chain:
|
||||
if vendor not in VENDOR_METHODS[method]:
|
||||
continue
|
||||
|
||||
vendor_impl = VENDOR_METHODS[method][vendor]
|
||||
impl_func = vendor_impl[0] if isinstance(vendor_impl, list) else vendor_impl
|
||||
|
||||
try:
|
||||
return impl_func(*args, **kwargs)
|
||||
except AlphaVantageRateLimitError:
|
||||
continue # Only rate limits trigger fallback
|
||||
|
||||
raise RuntimeError(f"No available vendor for '{method}'")
|
||||
|
||||
|
||||
DEFAULT_DATAFLOW_ADAPTER = DataflowAdapter()
|
||||
|
||||
|
||||
def route_to_vendor(method: str, *args, **kwargs):
|
||||
"""Route method calls to appropriate vendor implementation with fallback support."""
|
||||
category = get_category_for_method(method)
|
||||
vendor_config = get_vendor(category, method)
|
||||
primary_vendors = [v.strip() for v in vendor_config.split(",")]
|
||||
|
||||
if method not in VENDOR_METHODS:
|
||||
raise ValueError(f"Method '{method}' not supported")
|
||||
|
||||
# Build fallback chain: primary vendors first, then remaining available vendors
|
||||
all_available_vendors = list(VENDOR_METHODS[method].keys())
|
||||
fallback_vendors = primary_vendors.copy()
|
||||
for vendor in all_available_vendors:
|
||||
if vendor not in fallback_vendors:
|
||||
fallback_vendors.append(vendor)
|
||||
|
||||
for vendor in fallback_vendors:
|
||||
if vendor not in VENDOR_METHODS[method]:
|
||||
continue
|
||||
|
||||
vendor_impl = VENDOR_METHODS[method][vendor]
|
||||
impl_func = vendor_impl[0] if isinstance(vendor_impl, list) else vendor_impl
|
||||
|
||||
try:
|
||||
return impl_func(*args, **kwargs)
|
||||
except AlphaVantageRateLimitError:
|
||||
continue # Only rate limits trigger fallback
|
||||
|
||||
raise RuntimeError(f"No available vendor for '{method}'")
|
||||
return DEFAULT_DATAFLOW_ADAPTER.execute(method, *args, **kwargs)
|
||||
|
|
|
|||
|
|
@ -12,20 +12,37 @@ from .config import get_config
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _is_transient_yfinance_error(exc: Exception) -> bool:
|
||||
"""Heuristic for flaky yfinance transport/parser failures."""
|
||||
if isinstance(exc, YFRateLimitError):
|
||||
return True
|
||||
message = str(exc)
|
||||
return isinstance(exc, TypeError) and "'NoneType' object is not subscriptable" in message
|
||||
|
||||
|
||||
def yf_retry(func, max_retries=3, base_delay=2.0):
|
||||
"""Execute a yfinance call with exponential backoff on rate limits.
|
||||
|
||||
yfinance raises YFRateLimitError on HTTP 429 responses but does not
|
||||
retry them internally. This wrapper adds retry logic specifically
|
||||
for rate limits. Other exceptions propagate immediately.
|
||||
for rate limits and observed transient parser failures. Other
|
||||
exceptions propagate immediately.
|
||||
"""
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
return func()
|
||||
except YFRateLimitError:
|
||||
except Exception as exc:
|
||||
if not _is_transient_yfinance_error(exc):
|
||||
raise
|
||||
if attempt < max_retries:
|
||||
delay = base_delay * (2 ** attempt)
|
||||
logger.warning(f"Yahoo Finance rate limited, retrying in {delay:.0f}s (attempt {attempt + 1}/{max_retries})")
|
||||
logger.warning(
|
||||
"Yahoo Finance transient failure (%s), retrying in %.0fs (attempt %s/%s)",
|
||||
exc,
|
||||
delay,
|
||||
attempt + 1,
|
||||
max_retries,
|
||||
)
|
||||
time.sleep(delay)
|
||||
else:
|
||||
raise
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import copy
|
||||
import os
|
||||
|
||||
DEFAULT_CONFIG = {
|
||||
|
|
@ -36,3 +37,7 @@ DEFAULT_CONFIG = {
|
|||
# Example: "get_stock_data": "alpha_vantage", # Override category default
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_default_config():
|
||||
return copy.deepcopy(DEFAULT_CONFIG)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
# TradingAgents/graph/trading_graph.py
|
||||
|
||||
import copy
|
||||
import os
|
||||
from pathlib import Path
|
||||
import json
|
||||
|
|
@ -40,6 +41,30 @@ from .reflection import Reflector
|
|||
from .signal_processing import SignalProcessor
|
||||
|
||||
|
||||
def _merge_with_default_config(config: Optional[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""Merge a partial user config onto DEFAULT_CONFIG.
|
||||
|
||||
Orchestrator callers often override only a few LLM/vendor fields. Without a
|
||||
merge step, required defaults such as ``project_dir`` disappear and the
|
||||
graph fails during initialization.
|
||||
"""
|
||||
merged = copy.deepcopy(DEFAULT_CONFIG)
|
||||
if not config:
|
||||
return merged
|
||||
|
||||
for key, value in config.items():
|
||||
if (
|
||||
key in ("data_vendors", "tool_vendors")
|
||||
and isinstance(value, dict)
|
||||
and isinstance(merged.get(key), dict)
|
||||
):
|
||||
merged[key].update(value)
|
||||
else:
|
||||
merged[key] = value
|
||||
|
||||
return merged
|
||||
|
||||
|
||||
class TradingAgentsGraph:
|
||||
"""Main class that orchestrates the trading agents framework."""
|
||||
|
||||
|
|
@ -59,7 +84,7 @@ class TradingAgentsGraph:
|
|||
callbacks: Optional list of callback handlers (e.g., for tracking LLM/tool stats)
|
||||
"""
|
||||
self.debug = debug
|
||||
self.config = config or DEFAULT_CONFIG
|
||||
self.config = _merge_with_default_config(config)
|
||||
self.callbacks = callbacks or []
|
||||
|
||||
# Update the interface's config
|
||||
|
|
@ -138,6 +163,17 @@ class TradingAgentsGraph:
|
|||
kwargs = {}
|
||||
provider = self.config.get("llm_provider", "").lower()
|
||||
|
||||
common_passthrough = {
|
||||
"timeout": ("llm_timeout", "timeout"),
|
||||
"max_retries": ("llm_max_retries", "max_retries"),
|
||||
}
|
||||
for out_key, config_keys in common_passthrough.items():
|
||||
for config_key in config_keys:
|
||||
value = self.config.get(config_key)
|
||||
if value is not None:
|
||||
kwargs[out_key] = value
|
||||
break
|
||||
|
||||
if provider == "google":
|
||||
thinking_level = self.config.get("google_thinking_level")
|
||||
if thinking_level:
|
||||
|
|
|
|||
|
|
@ -1,4 +1,10 @@
|
|||
from .base_client import BaseLLMClient
|
||||
from .factory import create_llm_client
|
||||
from .factory import ProviderSpec, create_llm_client, get_provider_spec, get_supported_providers
|
||||
|
||||
__all__ = ["BaseLLMClient", "create_llm_client"]
|
||||
__all__ = [
|
||||
"BaseLLMClient",
|
||||
"ProviderSpec",
|
||||
"create_llm_client",
|
||||
"get_provider_spec",
|
||||
"get_supported_providers",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from typing import Optional
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Optional
|
||||
|
||||
from .base_client import BaseLLMClient
|
||||
from .openai_client import OpenAIClient
|
||||
|
|
@ -6,6 +7,63 @@ from .anthropic_client import AnthropicClient
|
|||
from .google_client import GoogleClient
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ProviderSpec:
|
||||
"""Provider registry entry for LLM client creation."""
|
||||
|
||||
canonical_name: str
|
||||
aliases: tuple[str, ...]
|
||||
builder: Callable[..., BaseLLMClient]
|
||||
|
||||
|
||||
_PROVIDER_SPECS: tuple[ProviderSpec, ...] = (
|
||||
ProviderSpec(
|
||||
canonical_name="openai",
|
||||
aliases=("openai", "ollama", "openrouter"),
|
||||
builder=lambda model, base_url=None, **kwargs: OpenAIClient(
|
||||
model,
|
||||
base_url,
|
||||
provider=kwargs.pop("provider", "openai"),
|
||||
**kwargs,
|
||||
),
|
||||
),
|
||||
ProviderSpec(
|
||||
canonical_name="xai",
|
||||
aliases=("xai",),
|
||||
builder=lambda model, base_url=None, **kwargs: OpenAIClient(
|
||||
model,
|
||||
base_url,
|
||||
provider="xai",
|
||||
**kwargs,
|
||||
),
|
||||
),
|
||||
ProviderSpec(
|
||||
canonical_name="anthropic",
|
||||
aliases=("anthropic",),
|
||||
builder=lambda model, base_url=None, **kwargs: AnthropicClient(model, base_url, **kwargs),
|
||||
),
|
||||
ProviderSpec(
|
||||
canonical_name="google",
|
||||
aliases=("google",),
|
||||
builder=lambda model, base_url=None, **kwargs: GoogleClient(model, base_url, **kwargs),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def get_provider_spec(provider: str) -> ProviderSpec:
|
||||
"""Resolve a provider or alias to its canonical registry entry."""
|
||||
provider_lower = provider.lower()
|
||||
for spec in _PROVIDER_SPECS:
|
||||
if provider_lower in spec.aliases:
|
||||
return spec
|
||||
raise ValueError(f"Unsupported LLM provider: {provider}")
|
||||
|
||||
|
||||
def get_supported_providers() -> tuple[str, ...]:
|
||||
"""Return canonical provider names exposed by the registry."""
|
||||
return tuple(spec.canonical_name for spec in _PROVIDER_SPECS)
|
||||
|
||||
|
||||
def create_llm_client(
|
||||
provider: str,
|
||||
model: str,
|
||||
|
|
@ -33,17 +91,8 @@ def create_llm_client(
|
|||
ValueError: If provider is not supported
|
||||
"""
|
||||
provider_lower = provider.lower()
|
||||
|
||||
provider_spec = get_provider_spec(provider_lower)
|
||||
builder_kwargs = dict(kwargs)
|
||||
if provider_lower in ("openai", "ollama", "openrouter"):
|
||||
return OpenAIClient(model, base_url, provider=provider_lower, **kwargs)
|
||||
|
||||
if provider_lower == "xai":
|
||||
return OpenAIClient(model, base_url, provider="xai", **kwargs)
|
||||
|
||||
if provider_lower == "anthropic":
|
||||
return AnthropicClient(model, base_url, **kwargs)
|
||||
|
||||
if provider_lower == "google":
|
||||
return GoogleClient(model, base_url, **kwargs)
|
||||
|
||||
raise ValueError(f"Unsupported LLM provider: {provider}")
|
||||
builder_kwargs["provider"] = provider_lower
|
||||
return provider_spec.builder(model, base_url, **builder_kwargs)
|
||||
|
|
|
|||
|
|
@ -334,3 +334,56 @@ def save_recommendation(date: str, ticker: str, data: dict):
|
|||
date_dir = RECOMMENDATIONS_DIR / date
|
||||
date_dir.mkdir(parents=True, exist_ok=True)
|
||||
(date_dir / f"{ticker}.json").write_text(json.dumps(data, ensure_ascii=False, indent=2))
|
||||
|
||||
|
||||
class LegacyPortfolioGateway:
|
||||
"""Compatibility gateway that exposes the current portfolio API as a service boundary."""
|
||||
|
||||
def get_watchlist(self) -> list:
|
||||
return get_watchlist()
|
||||
|
||||
def add_to_watchlist(self, ticker: str, name: str) -> dict:
|
||||
return add_to_watchlist(ticker, name)
|
||||
|
||||
def remove_from_watchlist(self, ticker: str) -> bool:
|
||||
return remove_from_watchlist(ticker)
|
||||
|
||||
def get_accounts(self) -> dict:
|
||||
return get_accounts()
|
||||
|
||||
def create_account(self, account_name: str) -> dict:
|
||||
return create_account(account_name)
|
||||
|
||||
def delete_account(self, account_name: str) -> bool:
|
||||
return delete_account(account_name)
|
||||
|
||||
async def get_positions(self, account: Optional[str] = None) -> list:
|
||||
return await get_positions(account)
|
||||
|
||||
def add_position(
|
||||
self,
|
||||
ticker: str,
|
||||
shares: float,
|
||||
cost_price: float,
|
||||
purchase_date: Optional[str],
|
||||
notes: str,
|
||||
account: str,
|
||||
) -> dict:
|
||||
return add_position(ticker, shares, cost_price, purchase_date, notes, account)
|
||||
|
||||
def remove_position(self, ticker: str, position_id: str, account: Optional[str]) -> bool:
|
||||
return remove_position(ticker, position_id, account)
|
||||
|
||||
def get_recommendations(self, date: Optional[str] = None, limit: int = DEFAULT_PAGE_SIZE, offset: int = 0) -> dict:
|
||||
return get_recommendations(date, limit, offset)
|
||||
|
||||
def get_recommendation(self, date: str, ticker: str) -> Optional[dict]:
|
||||
return get_recommendation(date, ticker)
|
||||
|
||||
def save_recommendation(self, date: str, ticker: str, data: dict):
|
||||
save_recommendation(date, ticker, data)
|
||||
|
||||
|
||||
def create_legacy_portfolio_gateway() -> LegacyPortfolioGateway:
|
||||
"""Create a gateway instance for service-layer migration."""
|
||||
return LegacyPortfolioGateway()
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ TradingAgents Web Dashboard Backend
|
|||
FastAPI REST API + WebSocket for real-time analysis progress
|
||||
"""
|
||||
import asyncio
|
||||
import fcntl
|
||||
import hmac
|
||||
import json
|
||||
import os
|
||||
|
|
@ -17,12 +16,13 @@ from pathlib import Path
|
|||
from typing import Optional
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect, Query, Header
|
||||
from fastapi import FastAPI, HTTPException, Request, WebSocket, WebSocketDisconnect, Query, Header
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import Response, FileResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from pydantic import BaseModel
|
||||
import os
|
||||
|
||||
from services import AnalysisService, JobService, ResultStore, build_request_context, load_migration_flags
|
||||
|
||||
# Path to TradingAgents repo root
|
||||
REPO_ROOT = Path(__file__).parent.parent.parent
|
||||
|
|
@ -30,6 +30,7 @@ REPO_ROOT = Path(__file__).parent.parent.parent
|
|||
ANALYSIS_PYTHON = Path(sys.executable)
|
||||
# Task state persistence directory
|
||||
TASK_STATUS_DIR = Path(__file__).parent / "data" / "task_status"
|
||||
CONFIG_PATH = Path(__file__).parent / "data" / "config.json"
|
||||
|
||||
|
||||
# ============== Lifespan ==============
|
||||
|
|
@ -40,15 +41,31 @@ async def lifespan(app: FastAPI):
|
|||
app.state.active_connections: dict[str, list[WebSocket]] = {}
|
||||
app.state.task_results: dict[str, dict] = {}
|
||||
app.state.analysis_tasks: dict[str, asyncio.Task] = {}
|
||||
app.state.processes: dict[str, asyncio.subprocess.Process | None] = {}
|
||||
app.state.migration_flags = load_migration_flags()
|
||||
|
||||
portfolio_gateway = create_legacy_portfolio_gateway()
|
||||
app.state.result_store = ResultStore(TASK_STATUS_DIR, portfolio_gateway)
|
||||
app.state.job_service = JobService(
|
||||
task_results=app.state.task_results,
|
||||
analysis_tasks=app.state.analysis_tasks,
|
||||
processes=app.state.processes,
|
||||
persist_task=app.state.result_store.save_task_status,
|
||||
delete_task=app.state.result_store.delete_task_status,
|
||||
)
|
||||
app.state.analysis_service = AnalysisService(
|
||||
analysis_python=ANALYSIS_PYTHON,
|
||||
repo_root=REPO_ROOT,
|
||||
analysis_script_template=ANALYSIS_SCRIPT_TEMPLATE,
|
||||
api_key_resolver=_get_analysis_api_key,
|
||||
result_store=app.state.result_store,
|
||||
job_service=app.state.job_service,
|
||||
retry_count=MAX_RETRY_COUNT,
|
||||
retry_base_delay_secs=RETRY_BASE_DELAY_SECS,
|
||||
)
|
||||
|
||||
# Restore persisted task states from disk
|
||||
TASK_STATUS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
for f in TASK_STATUS_DIR.glob("*.json"):
|
||||
try:
|
||||
data = json.loads(f.read_text())
|
||||
app.state.task_results[data["task_id"]] = data
|
||||
except Exception:
|
||||
pass
|
||||
app.state.job_service.restore_task_results(app.state.result_store.restore_task_results())
|
||||
|
||||
yield
|
||||
|
||||
|
|
@ -89,13 +106,19 @@ async def check_config():
|
|||
"""Check if the app is configured (API key is set).
|
||||
The FastAPI backend receives ANTHROPIC_API_KEY as an env var when spawned by Tauri.
|
||||
"""
|
||||
configured = bool(os.environ.get("ANTHROPIC_API_KEY") or os.environ.get("MINIMAX_API_KEY"))
|
||||
configured = bool(_get_analysis_api_key())
|
||||
return {"configured": configured}
|
||||
|
||||
|
||||
@app.post("/api/config/apikey")
|
||||
async def save_apikey(body: dict = None, api_key: Optional[str] = Header(None)):
|
||||
"""Save API key via Tauri command. Used by the setup wizard."""
|
||||
async def save_apikey(request: Request, body: dict = None, api_key: Optional[str] = Header(None)):
|
||||
"""Persist API key for local desktop/backend use."""
|
||||
if _get_api_key():
|
||||
if not _check_api_key(api_key):
|
||||
_auth_error()
|
||||
elif not _is_local_request(request):
|
||||
raise HTTPException(status_code=403, detail="API key setup is only allowed from localhost")
|
||||
|
||||
if not body or "api_key" not in body:
|
||||
raise HTTPException(status_code=400, detail="api_key is required")
|
||||
|
||||
|
|
@ -104,8 +127,7 @@ async def save_apikey(body: dict = None, api_key: Optional[str] = Header(None)):
|
|||
raise HTTPException(status_code=400, detail="api_key cannot be empty")
|
||||
|
||||
try:
|
||||
result = _tauri_invoke("set_config", {"key": "api_key", "value": apikey})
|
||||
# If we get here without error, the key was saved
|
||||
_persist_analysis_api_key(apikey)
|
||||
return {"ok": True, "saved": True}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to save API key: {e}")
|
||||
|
|
@ -145,6 +167,39 @@ def _auth_error():
|
|||
raise HTTPException(status_code=401, detail="Unauthorized: valid X-API-Key header required")
|
||||
|
||||
|
||||
def _load_saved_config() -> dict:
|
||||
try:
|
||||
if CONFIG_PATH.exists():
|
||||
return json.loads(CONFIG_PATH.read_text())
|
||||
except Exception:
|
||||
pass
|
||||
return {}
|
||||
|
||||
|
||||
def _persist_analysis_api_key(api_key_value: str):
|
||||
global _api_key
|
||||
CONFIG_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||
CONFIG_PATH.write_text(json.dumps({"api_key": api_key_value}, ensure_ascii=False))
|
||||
os.chmod(CONFIG_PATH, 0o600)
|
||||
os.environ["ANTHROPIC_API_KEY"] = api_key_value
|
||||
_api_key = None
|
||||
|
||||
|
||||
def _get_analysis_api_key() -> Optional[str]:
|
||||
return (
|
||||
os.environ.get("ANTHROPIC_API_KEY")
|
||||
or os.environ.get("MINIMAX_API_KEY")
|
||||
or _load_saved_config().get("api_key")
|
||||
)
|
||||
|
||||
|
||||
def _is_local_request(request: Request) -> bool:
|
||||
client = request.client
|
||||
if client is None:
|
||||
return False
|
||||
return client.host in {"127.0.0.1", "::1", "localhost", "testclient"}
|
||||
|
||||
|
||||
def _get_cache_path(mode: str) -> Path:
|
||||
return CACHE_DIR / f"screen_{mode}.json"
|
||||
|
||||
|
|
@ -258,6 +313,7 @@ config = OrchestratorConfig(
|
|||
"max_debate_rounds": 1,
|
||||
"max_risk_discuss_rounds": 1,
|
||||
"project_dir": os.path.join(repo_root, "tradingagents"),
|
||||
"results_dir": os.path.join(repo_root, "results"),
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -267,7 +323,11 @@ orchestrator = TradingOrchestrator(config)
|
|||
|
||||
print("STAGE:trading", flush=True)
|
||||
|
||||
result = orchestrator.get_combined_signal(ticker, date)
|
||||
try:
|
||||
result = orchestrator.get_combined_signal(ticker, date)
|
||||
except ValueError as _e:
|
||||
print("ANALYSIS_ERROR:" + str(_e), file=sys.stderr, flush=True)
|
||||
sys.exit(1)
|
||||
|
||||
print("STAGE:risk", flush=True)
|
||||
|
||||
|
|
@ -334,7 +394,7 @@ async def start_analysis(request: AnalysisRequest, api_key: Optional[str] = Head
|
|||
_auth_error()
|
||||
|
||||
# Validate ANTHROPIC_API_KEY for the analysis subprocess
|
||||
anthropic_key = os.environ.get("ANTHROPIC_API_KEY") or os.environ.get("MINIMAX_API_KEY")
|
||||
anthropic_key = _get_analysis_api_key()
|
||||
if not anthropic_key:
|
||||
raise HTTPException(status_code=500, detail="ANTHROPIC_API_KEY environment variable not set")
|
||||
|
||||
|
|
@ -404,31 +464,6 @@ async def start_analysis(request: AnalysisRequest, api_key: Optional[str] = Head
|
|||
app.state.task_results[task_id]["progress"] = int((idx + 1) / 5 * 100)
|
||||
app.state.task_results[task_id]["current_stage"] = stage_name
|
||||
|
||||
async def monitor_subprocess(task_id: str, proc: asyncio.subprocess.Process, cancel_evt: asyncio.Event):
|
||||
"""Monitor subprocess stdout for stage markers and broadcast progress."""
|
||||
# Set stdout to non-blocking
|
||||
fd = proc.stdout.fileno()
|
||||
fl = fcntl.fcntl(fd, fcntl.GETFL)
|
||||
fcntl.fcntl(fd, fcntl.SETFL, fl | os.O_NONBLOCK)
|
||||
|
||||
while not cancel_evt.is_set():
|
||||
if proc.returncode is not None:
|
||||
break
|
||||
await asyncio.sleep(5)
|
||||
if cancel_evt.is_set():
|
||||
break
|
||||
try:
|
||||
chunk = os.read(fd, 32768)
|
||||
if chunk:
|
||||
for line in chunk.decode().splitlines():
|
||||
if line.startswith("STAGE:"):
|
||||
stage = line.split(":", 1)[1].strip()
|
||||
_update_task_stage(stage)
|
||||
await broadcast_progress(task_id, app.state.task_results[task_id])
|
||||
except (BlockingIOError, OSError):
|
||||
# No data available yet
|
||||
pass
|
||||
|
||||
async def run_analysis():
|
||||
"""Run analysis subprocess and broadcast progress"""
|
||||
try:
|
||||
|
|
@ -450,17 +485,26 @@ async def start_analysis(request: AnalysisRequest, api_key: Optional[str] = Head
|
|||
)
|
||||
app.state.processes[task_id] = proc
|
||||
|
||||
# Start monitor coroutine alongside subprocess
|
||||
monitor_task = asyncio.create_task(monitor_subprocess(task_id, proc, cancel_event))
|
||||
# Read stdout line-by-line for real-time stage updates
|
||||
stdout_lines = []
|
||||
while True:
|
||||
try:
|
||||
line_bytes = await asyncio.wait_for(proc.stdout.readline(), timeout=300.0)
|
||||
except asyncio.TimeoutError:
|
||||
break
|
||||
if not line_bytes:
|
||||
break
|
||||
line = line_bytes.decode(errors="replace").rstrip()
|
||||
stdout_lines.append(line)
|
||||
if line.startswith("STAGE:"):
|
||||
stage = line.split(":", 1)[1].strip()
|
||||
_update_task_stage(stage)
|
||||
await broadcast_progress(task_id, app.state.task_results[task_id])
|
||||
if cancel_event.is_set():
|
||||
break
|
||||
|
||||
stdout, stderr = await proc.communicate()
|
||||
|
||||
# Signal monitor to stop and wait for it
|
||||
cancel_event.set()
|
||||
try:
|
||||
await asyncio.wait_for(monitor_task, timeout=1.0)
|
||||
except asyncio.TimeoutError:
|
||||
monitor_task.cancel()
|
||||
await proc.wait()
|
||||
stderr_bytes = await proc.stderr.read()
|
||||
|
||||
# Clean up script file
|
||||
try:
|
||||
|
|
@ -469,9 +513,9 @@ async def start_analysis(request: AnalysisRequest, api_key: Optional[str] = Head
|
|||
pass
|
||||
|
||||
if proc.returncode == 0:
|
||||
output = stdout.decode()
|
||||
output = "\n".join(stdout_lines)
|
||||
decision = "HOLD"
|
||||
for line in output.splitlines():
|
||||
for line in stdout_lines:
|
||||
if line.startswith("SIGNAL_DETAIL:"):
|
||||
try:
|
||||
detail = json.loads(line.split(":", 1)[1].strip())
|
||||
|
|
@ -492,7 +536,7 @@ async def start_analysis(request: AnalysisRequest, api_key: Optional[str] = Head
|
|||
if not app.state.task_results[task_id]["stages"][i].get("completed_at"):
|
||||
app.state.task_results[task_id]["stages"][i]["completed_at"] = datetime.now().strftime("%H:%M:%S")
|
||||
else:
|
||||
error_msg = stderr.decode()[-1000:] if stderr else "Unknown error"
|
||||
error_msg = stderr_bytes.decode(errors="replace")[-1000:] if stderr_bytes else "Unknown error"
|
||||
app.state.task_results[task_id]["status"] = "failed"
|
||||
app.state.task_results[task_id]["error"] = error_msg
|
||||
|
||||
|
|
@ -896,6 +940,7 @@ async def export_report_pdf(ticker: str, date: str, api_key: Optional[str] = Hea
|
|||
import sys
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
from api.portfolio import (
|
||||
create_legacy_portfolio_gateway,
|
||||
get_watchlist, add_to_watchlist, remove_from_watchlist,
|
||||
get_positions, add_position, remove_position,
|
||||
get_accounts, create_account, delete_account,
|
||||
|
|
@ -968,7 +1013,9 @@ async def delete_account_endpoint(account_name: str, api_key: Optional[str] = He
|
|||
async def list_positions(account: Optional[str] = Query(None), api_key: Optional[str] = Header(None)):
|
||||
if not _check_api_key(api_key):
|
||||
_auth_error()
|
||||
return {"positions": get_positions(account)}
|
||||
if app.state.migration_flags.use_result_store:
|
||||
return {"positions": await app.state.result_store.get_positions(account)}
|
||||
return {"positions": await get_positions(account)}
|
||||
|
||||
|
||||
@app.post("/api/portfolio/positions")
|
||||
|
|
@ -1003,7 +1050,10 @@ async def delete_position(ticker: str, position_id: Optional[str] = Query(None),
|
|||
async def export_positions_csv(account: Optional[str] = Query(None), api_key: Optional[str] = Header(None)):
|
||||
if not _check_api_key(api_key):
|
||||
_auth_error()
|
||||
positions = get_positions(account)
|
||||
if app.state.migration_flags.use_result_store:
|
||||
positions = await app.state.result_store.get_positions(account)
|
||||
else:
|
||||
positions = await get_positions(account)
|
||||
import csv
|
||||
import io
|
||||
output = io.StringIO()
|
||||
|
|
@ -1052,6 +1102,20 @@ async def start_portfolio_analysis(api_key: Optional[str] = Header(None)):
|
|||
date = datetime.now().strftime("%Y-%m-%d")
|
||||
task_id = f"port_{date}_{uuid.uuid4().hex[:6]}"
|
||||
|
||||
if app.state.migration_flags.use_application_services:
|
||||
request_context = build_request_context(api_key=api_key)
|
||||
try:
|
||||
return await app.state.analysis_service.start_portfolio_analysis(
|
||||
task_id=task_id,
|
||||
date=date,
|
||||
request_context=request_context,
|
||||
broadcast_progress=broadcast_progress,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc))
|
||||
except RuntimeError as exc:
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
watchlist = get_watchlist()
|
||||
if not watchlist:
|
||||
raise HTTPException(status_code=400, detail="自选股为空,请先添加股票")
|
||||
|
|
@ -1069,7 +1133,7 @@ async def start_portfolio_analysis(api_key: Optional[str] = Header(None)):
|
|||
"error": None,
|
||||
}
|
||||
|
||||
api_key = os.environ.get("ANTHROPIC_API_KEY") or os.environ.get("MINIMAX_API_KEY")
|
||||
api_key = _get_analysis_api_key()
|
||||
if not api_key:
|
||||
raise HTTPException(status_code=500, detail="ANTHROPIC_API_KEY environment variable not set")
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,15 @@
|
|||
from .analysis_service import AnalysisService
|
||||
from .job_service import JobService
|
||||
from .migration_flags import MigrationFlags, load_migration_flags
|
||||
from .request_context import RequestContext, build_request_context
|
||||
from .result_store import ResultStore
|
||||
|
||||
__all__ = [
|
||||
"AnalysisService",
|
||||
"JobService",
|
||||
"MigrationFlags",
|
||||
"RequestContext",
|
||||
"ResultStore",
|
||||
"build_request_context",
|
||||
"load_migration_flags",
|
||||
]
|
||||
|
|
@ -0,0 +1,211 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Awaitable, Callable, Optional
|
||||
|
||||
from .request_context import RequestContext
|
||||
|
||||
BroadcastFn = Callable[[str, dict], Awaitable[None]]
|
||||
|
||||
|
||||
class AnalysisService:
|
||||
"""Application service that orchestrates backend analysis jobs without owning strategy logic."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
analysis_python: Path,
|
||||
repo_root: Path,
|
||||
analysis_script_template: str,
|
||||
api_key_resolver: Callable[[], Optional[str]],
|
||||
result_store,
|
||||
job_service,
|
||||
retry_count: int = 2,
|
||||
retry_base_delay_secs: int = 1,
|
||||
):
|
||||
self.analysis_python = analysis_python
|
||||
self.repo_root = repo_root
|
||||
self.analysis_script_template = analysis_script_template
|
||||
self.api_key_resolver = api_key_resolver
|
||||
self.result_store = result_store
|
||||
self.job_service = job_service
|
||||
self.retry_count = retry_count
|
||||
self.retry_base_delay_secs = retry_base_delay_secs
|
||||
|
||||
async def start_portfolio_analysis(
|
||||
self,
|
||||
*,
|
||||
task_id: str,
|
||||
date: str,
|
||||
request_context: RequestContext,
|
||||
broadcast_progress: BroadcastFn,
|
||||
) -> dict:
|
||||
del request_context # Reserved for future auditing/auth propagation.
|
||||
watchlist = self.result_store.get_watchlist()
|
||||
if not watchlist:
|
||||
raise ValueError("自选股为空,请先添加股票")
|
||||
|
||||
analysis_api_key = self.api_key_resolver()
|
||||
if not analysis_api_key:
|
||||
raise RuntimeError("ANTHROPIC_API_KEY environment variable not set")
|
||||
|
||||
state = self.job_service.create_portfolio_job(task_id=task_id, total=len(watchlist))
|
||||
await broadcast_progress(task_id, state)
|
||||
|
||||
task = asyncio.create_task(
|
||||
self._run_portfolio_analysis(
|
||||
task_id=task_id,
|
||||
date=date,
|
||||
watchlist=watchlist,
|
||||
analysis_api_key=analysis_api_key,
|
||||
broadcast_progress=broadcast_progress,
|
||||
)
|
||||
)
|
||||
self.job_service.register_background_task(task_id, task)
|
||||
return {
|
||||
"task_id": task_id,
|
||||
"total": len(watchlist),
|
||||
"status": "running",
|
||||
}
|
||||
|
||||
async def _run_portfolio_analysis(
|
||||
self,
|
||||
*,
|
||||
task_id: str,
|
||||
date: str,
|
||||
watchlist: list[dict],
|
||||
analysis_api_key: str,
|
||||
broadcast_progress: BroadcastFn,
|
||||
) -> None:
|
||||
try:
|
||||
for index, stock in enumerate(watchlist):
|
||||
stock = {**stock, "_idx": index}
|
||||
ticker = stock["ticker"]
|
||||
await broadcast_progress(
|
||||
task_id,
|
||||
self.job_service.update_portfolio_progress(task_id, ticker=ticker, completed=index),
|
||||
)
|
||||
|
||||
success, rec = await self._run_single_portfolio_analysis(
|
||||
task_id=task_id,
|
||||
ticker=ticker,
|
||||
stock=stock,
|
||||
date=date,
|
||||
analysis_api_key=analysis_api_key,
|
||||
)
|
||||
if success and rec is not None:
|
||||
self.job_service.append_portfolio_result(task_id, rec)
|
||||
else:
|
||||
self.job_service.mark_portfolio_failure(task_id)
|
||||
|
||||
await broadcast_progress(task_id, self.job_service.task_results[task_id])
|
||||
|
||||
self.job_service.complete_job(task_id)
|
||||
except Exception as exc:
|
||||
self.job_service.fail_job(task_id, str(exc))
|
||||
|
||||
await broadcast_progress(task_id, self.job_service.task_results[task_id])
|
||||
|
||||
async def _run_single_portfolio_analysis(
|
||||
self,
|
||||
*,
|
||||
task_id: str,
|
||||
ticker: str,
|
||||
stock: dict,
|
||||
date: str,
|
||||
analysis_api_key: str,
|
||||
) -> tuple[bool, Optional[dict]]:
|
||||
last_error: Optional[str] = None
|
||||
for attempt in range(self.retry_count + 1):
|
||||
script_path: Optional[Path] = None
|
||||
try:
|
||||
fd, script_path_str = tempfile.mkstemp(
|
||||
suffix=".py",
|
||||
prefix=f"analysis_{task_id}_{stock['_idx']}_",
|
||||
)
|
||||
script_path = Path(script_path_str)
|
||||
os.chmod(script_path, 0o600)
|
||||
with os.fdopen(fd, "w") as handle:
|
||||
handle.write(self.analysis_script_template)
|
||||
|
||||
clean_env = {
|
||||
key: value
|
||||
for key, value in os.environ.items()
|
||||
if not key.startswith(("PYTHON", "CONDA", "VIRTUAL"))
|
||||
}
|
||||
clean_env["ANTHROPIC_API_KEY"] = analysis_api_key
|
||||
clean_env["ANTHROPIC_BASE_URL"] = "https://api.minimaxi.com/anthropic"
|
||||
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
str(self.analysis_python),
|
||||
str(script_path),
|
||||
ticker,
|
||||
date,
|
||||
str(self.repo_root),
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
env=clean_env,
|
||||
)
|
||||
self.job_service.register_process(task_id, proc)
|
||||
stdout, stderr = await proc.communicate()
|
||||
|
||||
if proc.returncode == 0:
|
||||
rec = self._build_recommendation_record(
|
||||
stdout=stdout.decode(),
|
||||
ticker=ticker,
|
||||
stock=stock,
|
||||
date=date,
|
||||
)
|
||||
self.result_store.save_recommendation(date, ticker, rec)
|
||||
return True, rec
|
||||
|
||||
last_error = stderr.decode()[-500:] if stderr else f"exit {proc.returncode}"
|
||||
except Exception as exc:
|
||||
last_error = str(exc)
|
||||
finally:
|
||||
if script_path is not None:
|
||||
try:
|
||||
script_path.unlink()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if attempt < self.retry_count:
|
||||
await asyncio.sleep(self.retry_base_delay_secs ** attempt)
|
||||
|
||||
if last_error:
|
||||
self.job_service.task_results[task_id]["last_error"] = last_error
|
||||
return False, None
|
||||
|
||||
@staticmethod
|
||||
def _build_recommendation_record(*, stdout: str, ticker: str, stock: dict, date: str) -> dict:
|
||||
decision = "HOLD"
|
||||
quant_signal = None
|
||||
llm_signal = None
|
||||
confidence = None
|
||||
for line in stdout.splitlines():
|
||||
if line.startswith("SIGNAL_DETAIL:"):
|
||||
try:
|
||||
detail = json.loads(line.split(":", 1)[1].strip())
|
||||
except Exception:
|
||||
continue
|
||||
quant_signal = detail.get("quant_signal")
|
||||
llm_signal = detail.get("llm_signal")
|
||||
confidence = detail.get("confidence")
|
||||
if line.startswith("ANALYSIS_COMPLETE:"):
|
||||
decision = line.split(":", 1)[1].strip()
|
||||
|
||||
return {
|
||||
"ticker": ticker,
|
||||
"name": stock.get("name", ticker),
|
||||
"analysis_date": date,
|
||||
"decision": decision,
|
||||
"quant_signal": quant_signal,
|
||||
"llm_signal": llm_signal,
|
||||
"confidence": confidence,
|
||||
"created_at": datetime.now().isoformat(),
|
||||
}
|
||||
|
|
@ -0,0 +1,94 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import Any, Callable
|
||||
|
||||
|
||||
class JobService:
|
||||
"""Application-layer job state orchestrator with legacy-compatible payloads."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
task_results: dict[str, dict],
|
||||
analysis_tasks: dict[str, asyncio.Task],
|
||||
processes: dict[str, Any],
|
||||
persist_task: Callable[[str, dict], None],
|
||||
delete_task: Callable[[str], None],
|
||||
):
|
||||
self.task_results = task_results
|
||||
self.analysis_tasks = analysis_tasks
|
||||
self.processes = processes
|
||||
self.persist_task = persist_task
|
||||
self.delete_task = delete_task
|
||||
|
||||
def restore_task_results(self, restored: dict[str, dict]) -> None:
|
||||
self.task_results.update(restored)
|
||||
|
||||
def create_portfolio_job(self, *, task_id: str, total: int) -> dict:
|
||||
state = {
|
||||
"task_id": task_id,
|
||||
"type": "portfolio",
|
||||
"status": "running",
|
||||
"total": total,
|
||||
"completed": 0,
|
||||
"failed": 0,
|
||||
"current_ticker": None,
|
||||
"results": [],
|
||||
"error": None,
|
||||
"created_at": datetime.now().isoformat(),
|
||||
}
|
||||
self.task_results[task_id] = state
|
||||
self.processes.setdefault(task_id, None)
|
||||
return state
|
||||
|
||||
def update_portfolio_progress(self, task_id: str, *, ticker: str, completed: int) -> dict:
|
||||
state = self.task_results[task_id]
|
||||
state["current_ticker"] = ticker
|
||||
state["status"] = "running"
|
||||
state["completed"] = completed
|
||||
return state
|
||||
|
||||
def append_portfolio_result(self, task_id: str, rec: dict) -> dict:
|
||||
state = self.task_results[task_id]
|
||||
state["completed"] += 1
|
||||
state["results"].append(rec)
|
||||
return state
|
||||
|
||||
def mark_portfolio_failure(self, task_id: str) -> dict:
|
||||
state = self.task_results[task_id]
|
||||
state["failed"] += 1
|
||||
return state
|
||||
|
||||
def complete_job(self, task_id: str) -> dict:
|
||||
state = self.task_results[task_id]
|
||||
state["status"] = "completed"
|
||||
state["current_ticker"] = None
|
||||
self.persist_task(task_id, state)
|
||||
return state
|
||||
|
||||
def fail_job(self, task_id: str, error: str) -> dict:
|
||||
state = self.task_results[task_id]
|
||||
state["status"] = "failed"
|
||||
state["error"] = error
|
||||
self.persist_task(task_id, state)
|
||||
return state
|
||||
|
||||
def register_background_task(self, task_id: str, task: asyncio.Task) -> None:
|
||||
self.analysis_tasks[task_id] = task
|
||||
|
||||
def register_process(self, task_id: str, process: Any) -> None:
|
||||
self.processes[task_id] = process
|
||||
|
||||
def cancel_job(self, task_id: str, error: str = "用户取消") -> dict | None:
|
||||
task = self.analysis_tasks.get(task_id)
|
||||
if task:
|
||||
task.cancel()
|
||||
state = self.task_results.get(task_id)
|
||||
if not state:
|
||||
return None
|
||||
state["status"] = "failed"
|
||||
state["error"] = error
|
||||
self.persist_task(task_id, state)
|
||||
return state
|
||||
|
|
@ -0,0 +1,29 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
def _env_flag(name: str, default: bool = False) -> bool:
|
||||
raw = os.environ.get(name)
|
||||
if raw is None:
|
||||
return default
|
||||
return raw.strip().lower() in {"1", "true", "yes", "on"}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MigrationFlags:
|
||||
"""Feature flags for backend application-service migration."""
|
||||
|
||||
use_application_services: bool = False
|
||||
use_result_store: bool = False
|
||||
use_request_context: bool = True
|
||||
|
||||
|
||||
def load_migration_flags() -> MigrationFlags:
|
||||
"""Load service migration flags from the environment."""
|
||||
return MigrationFlags(
|
||||
use_application_services=_env_flag("TRADINGAGENTS_USE_APPLICATION_SERVICES", default=False),
|
||||
use_result_store=_env_flag("TRADINGAGENTS_USE_RESULT_STORE", default=False),
|
||||
use_request_context=_env_flag("TRADINGAGENTS_USE_REQUEST_CONTEXT", default=True),
|
||||
)
|
||||
|
|
@ -0,0 +1,37 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RequestContext:
|
||||
"""Minimal request-scoped metadata passed into application services."""
|
||||
|
||||
request_id: str
|
||||
api_key: Optional[str] = None
|
||||
client_host: Optional[str] = None
|
||||
is_local: bool = False
|
||||
metadata: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
|
||||
def build_request_context(
|
||||
request: Optional[Request] = None,
|
||||
*,
|
||||
api_key: Optional[str] = None,
|
||||
request_id: Optional[str] = None,
|
||||
metadata: Optional[dict[str, str]] = None,
|
||||
) -> RequestContext:
|
||||
"""Create a stable request context without leaking FastAPI internals into services."""
|
||||
client_host = request.client.host if request and request.client else None
|
||||
is_local = client_host in {"127.0.0.1", "::1", "localhost", "testclient"}
|
||||
return RequestContext(
|
||||
request_id=request_id or uuid4().hex,
|
||||
api_key=api_key,
|
||||
client_host=client_host,
|
||||
is_local=is_local,
|
||||
metadata=dict(metadata or {}),
|
||||
)
|
||||
|
|
@ -0,0 +1,51 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class ResultStore:
|
||||
"""Storage boundary for persisted task state and portfolio results."""
|
||||
|
||||
def __init__(self, task_status_dir: Path, portfolio_gateway):
|
||||
self.task_status_dir = task_status_dir
|
||||
self.portfolio_gateway = portfolio_gateway
|
||||
|
||||
def restore_task_results(self) -> dict[str, dict]:
|
||||
restored: dict[str, dict] = {}
|
||||
self.task_status_dir.mkdir(parents=True, exist_ok=True)
|
||||
for file_path in self.task_status_dir.glob("*.json"):
|
||||
try:
|
||||
data = json.loads(file_path.read_text())
|
||||
except Exception:
|
||||
continue
|
||||
task_id = data.get("task_id")
|
||||
if task_id:
|
||||
restored[task_id] = data
|
||||
return restored
|
||||
|
||||
def save_task_status(self, task_id: str, data: dict) -> None:
|
||||
self.task_status_dir.mkdir(parents=True, exist_ok=True)
|
||||
(self.task_status_dir / f"{task_id}.json").write_text(json.dumps(data, ensure_ascii=False))
|
||||
|
||||
def delete_task_status(self, task_id: str) -> None:
|
||||
(self.task_status_dir / f"{task_id}.json").unlink(missing_ok=True)
|
||||
|
||||
def get_watchlist(self) -> list:
|
||||
return self.portfolio_gateway.get_watchlist()
|
||||
|
||||
def get_accounts(self) -> dict:
|
||||
return self.portfolio_gateway.get_accounts()
|
||||
|
||||
async def get_positions(self, account: Optional[str] = None) -> list:
|
||||
return await self.portfolio_gateway.get_positions(account)
|
||||
|
||||
def get_recommendations(self, date: Optional[str] = None, limit: int = 50, offset: int = 0) -> dict:
|
||||
return self.portfolio_gateway.get_recommendations(date, limit, offset)
|
||||
|
||||
def get_recommendation(self, date: str, ticker: str) -> Optional[dict]:
|
||||
return self.portfolio_gateway.get_recommendation(date, ticker)
|
||||
|
||||
def save_recommendation(self, date: str, ticker: str, data: dict) -> None:
|
||||
self.portfolio_gateway.save_recommendation(date, ticker, data)
|
||||
|
|
@ -0,0 +1,55 @@
|
|||
import importlib
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
def _load_main_module(monkeypatch):
|
||||
backend_dir = Path(__file__).resolve().parents[1]
|
||||
monkeypatch.syspath_prepend(str(backend_dir))
|
||||
sys.modules.pop("main", None)
|
||||
return importlib.import_module("main")
|
||||
|
||||
|
||||
def test_config_check_smoke(monkeypatch):
|
||||
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
|
||||
monkeypatch.delenv("MINIMAX_API_KEY", raising=False)
|
||||
|
||||
main = _load_main_module(monkeypatch)
|
||||
|
||||
with TestClient(main.app) as client:
|
||||
response = client.get("/api/config/check")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"configured": False}
|
||||
|
||||
|
||||
def test_analysis_task_routes_smoke(monkeypatch):
|
||||
monkeypatch.delenv("DASHBOARD_API_KEY", raising=False)
|
||||
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
|
||||
|
||||
main = _load_main_module(monkeypatch)
|
||||
|
||||
seeded_task = {
|
||||
"task_id": "task-smoke",
|
||||
"ticker": "AAPL",
|
||||
"date": "2026-04-11",
|
||||
"status": "running",
|
||||
"created_at": "2026-04-11T10:00:00",
|
||||
}
|
||||
|
||||
with TestClient(main.app) as client:
|
||||
main.app.state.task_results["task-smoke"] = seeded_task
|
||||
|
||||
health_response = client.get("/health")
|
||||
tasks_response = client.get("/api/analysis/tasks")
|
||||
status_response = client.get("/api/analysis/status/task-smoke")
|
||||
|
||||
assert health_response.status_code == 200
|
||||
assert health_response.json() == {"status": "ok"}
|
||||
assert tasks_response.status_code == 200
|
||||
assert tasks_response.json()["total"] >= 1
|
||||
assert any(task["task_id"] == "task-smoke" for task in tasks_response.json()["tasks"])
|
||||
assert status_response.status_code == 200
|
||||
assert status_response.json()["task_id"] == "task-smoke"
|
||||
|
|
@ -0,0 +1,105 @@
|
|||
import json
|
||||
import asyncio
|
||||
|
||||
from services.analysis_service import AnalysisService
|
||||
from services.job_service import JobService
|
||||
from services.migration_flags import load_migration_flags
|
||||
from services.request_context import build_request_context
|
||||
from services.result_store import ResultStore
|
||||
|
||||
|
||||
class DummyPortfolioGateway:
|
||||
def __init__(self):
|
||||
self.saved = []
|
||||
|
||||
def get_watchlist(self):
|
||||
return [{"ticker": "AAPL", "name": "Apple"}]
|
||||
|
||||
async def get_positions(self, account=None):
|
||||
return [{"ticker": "AAPL", "account": account or "默认账户"}]
|
||||
|
||||
def get_accounts(self):
|
||||
return {"accounts": {"默认账户": {}}}
|
||||
|
||||
def get_recommendations(self, date=None, limit=50, offset=0):
|
||||
return {"recommendations": [], "total": 0, "limit": limit, "offset": offset}
|
||||
|
||||
def get_recommendation(self, date, ticker):
|
||||
return None
|
||||
|
||||
def save_recommendation(self, date, ticker, data):
|
||||
self.saved.append((date, ticker, data))
|
||||
|
||||
|
||||
def test_load_migration_flags_from_env(monkeypatch):
|
||||
monkeypatch.setenv("TRADINGAGENTS_USE_APPLICATION_SERVICES", "1")
|
||||
monkeypatch.setenv("TRADINGAGENTS_USE_RESULT_STORE", "true")
|
||||
monkeypatch.setenv("TRADINGAGENTS_USE_REQUEST_CONTEXT", "0")
|
||||
|
||||
flags = load_migration_flags()
|
||||
|
||||
assert flags.use_application_services is True
|
||||
assert flags.use_result_store is True
|
||||
assert flags.use_request_context is False
|
||||
|
||||
|
||||
def test_build_request_context_defaults():
|
||||
context = build_request_context(api_key="secret", metadata={"source": "test"})
|
||||
|
||||
assert context.api_key == "secret"
|
||||
assert context.request_id
|
||||
assert context.metadata == {"source": "test"}
|
||||
|
||||
|
||||
def test_result_store_round_trip(tmp_path):
|
||||
gateway = DummyPortfolioGateway()
|
||||
store = ResultStore(tmp_path / "task_status", gateway)
|
||||
|
||||
store.save_task_status("task-1", {"task_id": "task-1", "status": "running"})
|
||||
|
||||
restored = store.restore_task_results()
|
||||
positions = asyncio.run(store.get_positions("模拟账户"))
|
||||
|
||||
assert restored["task-1"]["status"] == "running"
|
||||
assert positions == [{"ticker": "AAPL", "account": "模拟账户"}]
|
||||
|
||||
|
||||
def test_job_service_create_and_fail_job():
|
||||
task_results = {}
|
||||
analysis_tasks = {}
|
||||
processes = {}
|
||||
persisted = {}
|
||||
|
||||
service = JobService(
|
||||
task_results=task_results,
|
||||
analysis_tasks=analysis_tasks,
|
||||
processes=processes,
|
||||
persist_task=lambda task_id, data: persisted.setdefault(task_id, json.loads(json.dumps(data))),
|
||||
delete_task=lambda task_id: persisted.pop(task_id, None),
|
||||
)
|
||||
|
||||
state = service.create_portfolio_job(task_id="port_1", total=2)
|
||||
assert state["total"] == 2
|
||||
assert processes["port_1"] is None
|
||||
|
||||
failed = service.fail_job("port_1", "boom")
|
||||
assert failed["status"] == "failed"
|
||||
assert persisted["port_1"]["error"] == "boom"
|
||||
|
||||
|
||||
def test_analysis_service_build_recommendation_record():
|
||||
rec = AnalysisService._build_recommendation_record(
|
||||
stdout='\n'.join([
|
||||
'SIGNAL_DETAIL:{"quant_signal":"BUY","llm_signal":"HOLD","confidence":0.75}',
|
||||
"ANALYSIS_COMPLETE:OVERWEIGHT",
|
||||
]),
|
||||
ticker="AAPL",
|
||||
stock={"name": "Apple"},
|
||||
date="2026-04-13",
|
||||
)
|
||||
|
||||
assert rec["ticker"] == "AAPL"
|
||||
assert rec["decision"] == "OVERWEIGHT"
|
||||
assert rec["quant_signal"] == "BUY"
|
||||
assert rec["llm_signal"] == "HOLD"
|
||||
assert rec["confidence"] == 0.75
|
||||
Loading…
Reference in New Issue