From 7c8d523680ee100c84529e397f289114a6c74ebd Mon Sep 17 00:00:00 2001 From: dtarkent2-sys Date: Mon, 13 Apr 2026 22:20:16 -0400 Subject: [PATCH] Round 1: import upstream additions (TauricResearch/TradingAgents) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Pure file additions from upstream — zero conflict, zero risk to local customizations. Cherry-picked individually rather than via merge. New files: - tradingagents/llm_clients/model_catalog.py — centralized model registry - tradingagents/agents/managers/portfolio_manager.py — new PM (paired with upstream's risk_manager.py removal, deferred to a later round) - tradingagents/__init__.py — package init - tests/test_google_api_key.py - tests/test_model_validation.py - tests/test_ticker_symbol_handling.py - .env.enterprise.example — reference template Skipped (deferred to later rounds): - Round 2: bug-fix cherry-picks (yfinance retry, look-ahead bias fix, PM-reads-trader fix, indicator normalization, etc.) - Round 3: feat:add DeepSeek/Qwen/GLM provider support (b0f6058) — conflicts with local llm_clients/ customizations, needs careful merge - Round 4: portfolio manager restructure (b8b2825/318adda) — replaces risk_manager.py which has local customizations, needs hand-port Skipped entirely: - Docker support (anti-Docker stack) - Multi-language output (English-only arena) - GPT-5.4 default flip (would undo a913416 glm-5.1 default) - azure_client.py (not in arena roster) - OpenRouter dynamic model selection (not in provider mix) Co-Authored-By: Claude Opus 4.6 (1M context) --- .env.enterprise.example | 5 + tests/test_google_api_key.py | 28 ++++ tests/test_model_validation.py | 52 +++++++ tests/test_ticker_symbol_handling.py | 18 +++ tradingagents/__init__.py | 2 + .../agents/managers/portfolio_manager.py | 77 ++++++++++ tradingagents/llm_clients/model_catalog.py | 134 ++++++++++++++++++ 7 files changed, 316 insertions(+) create mode 100644 .env.enterprise.example create mode 100644 tests/test_google_api_key.py create mode 100644 tests/test_model_validation.py create mode 100644 tests/test_ticker_symbol_handling.py create mode 100644 tradingagents/__init__.py create mode 100644 tradingagents/agents/managers/portfolio_manager.py create mode 100644 tradingagents/llm_clients/model_catalog.py diff --git a/.env.enterprise.example b/.env.enterprise.example new file mode 100644 index 00000000..4f7bda3d --- /dev/null +++ b/.env.enterprise.example @@ -0,0 +1,5 @@ +# Azure OpenAI +AZURE_OPENAI_API_KEY= +AZURE_OPENAI_ENDPOINT=https://your-resource-name.openai.azure.com/ +AZURE_OPENAI_DEPLOYMENT_NAME= +# OPENAI_API_VERSION=2024-10-21 # optional, required for non-v1 API diff --git a/tests/test_google_api_key.py b/tests/test_google_api_key.py new file mode 100644 index 00000000..e1607c49 --- /dev/null +++ b/tests/test_google_api_key.py @@ -0,0 +1,28 @@ +import unittest +from unittest.mock import patch + +from tradingagents.llm_clients.google_client import GoogleClient + + +class TestGoogleApiKeyStandardization(unittest.TestCase): + """Verify GoogleClient accepts unified api_key parameter.""" + + @patch("tradingagents.llm_clients.google_client.NormalizedChatGoogleGenerativeAI") + def test_api_key_handling(self, mock_chat): + test_cases = [ + ("unified api_key is mapped", {"api_key": "test-key-123"}, "test-key-123"), + ("legacy google_api_key still works", {"google_api_key": "legacy-key-456"}, "legacy-key-456"), + ("unified api_key takes precedence", {"api_key": "unified", "google_api_key": "legacy"}, "unified"), + ] + + for msg, kwargs, expected_key in test_cases: + with self.subTest(msg=msg): + mock_chat.reset_mock() + client = GoogleClient("gemini-2.5-flash", **kwargs) + client.get_llm() + call_kwargs = mock_chat.call_args[1] + self.assertEqual(call_kwargs.get("google_api_key"), expected_key) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_model_validation.py b/tests/test_model_validation.py new file mode 100644 index 00000000..50f26318 --- /dev/null +++ b/tests/test_model_validation.py @@ -0,0 +1,52 @@ +import unittest +import warnings + +from tradingagents.llm_clients.base_client import BaseLLMClient +from tradingagents.llm_clients.model_catalog import get_known_models +from tradingagents.llm_clients.validators import validate_model + + +class DummyLLMClient(BaseLLMClient): + def __init__(self, provider: str, model: str): + self.provider = provider + super().__init__(model) + + def get_llm(self): + self.warn_if_unknown_model() + return object() + + def validate_model(self) -> bool: + return validate_model(self.provider, self.model) + + +class ModelValidationTests(unittest.TestCase): + def test_cli_catalog_models_are_all_validator_approved(self): + for provider, models in get_known_models().items(): + if provider in ("ollama", "openrouter"): + continue + + for model in models: + with self.subTest(provider=provider, model=model): + self.assertTrue(validate_model(provider, model)) + + def test_unknown_model_emits_warning_for_strict_provider(self): + client = DummyLLMClient("openai", "not-a-real-openai-model") + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + client.get_llm() + + self.assertEqual(len(caught), 1) + self.assertIn("not-a-real-openai-model", str(caught[0].message)) + self.assertIn("openai", str(caught[0].message)) + + def test_openrouter_and_ollama_accept_custom_models_without_warning(self): + for provider in ("openrouter", "ollama"): + client = DummyLLMClient(provider, "custom-model-name") + + with self.subTest(provider=provider): + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + client.get_llm() + + self.assertEqual(caught, []) diff --git a/tests/test_ticker_symbol_handling.py b/tests/test_ticker_symbol_handling.py new file mode 100644 index 00000000..858d26cd --- /dev/null +++ b/tests/test_ticker_symbol_handling.py @@ -0,0 +1,18 @@ +import unittest + +from cli.utils import normalize_ticker_symbol +from tradingagents.agents.utils.agent_utils import build_instrument_context + + +class TickerSymbolHandlingTests(unittest.TestCase): + def test_normalize_ticker_symbol_preserves_exchange_suffix(self): + self.assertEqual(normalize_ticker_symbol(" cnc.to "), "CNC.TO") + + def test_build_instrument_context_mentions_exact_symbol(self): + context = build_instrument_context("7203.T") + self.assertIn("7203.T", context) + self.assertIn("exchange suffix", context) + + +if __name__ == "__main__": + unittest.main() diff --git a/tradingagents/__init__.py b/tradingagents/__init__.py new file mode 100644 index 00000000..43a2b439 --- /dev/null +++ b/tradingagents/__init__.py @@ -0,0 +1,2 @@ +import os +os.environ.setdefault("PYTHONUTF8", "1") diff --git a/tradingagents/agents/managers/portfolio_manager.py b/tradingagents/agents/managers/portfolio_manager.py new file mode 100644 index 00000000..6c69ae9f --- /dev/null +++ b/tradingagents/agents/managers/portfolio_manager.py @@ -0,0 +1,77 @@ +from tradingagents.agents.utils.agent_utils import build_instrument_context, get_language_instruction + + +def create_portfolio_manager(llm, memory): + def portfolio_manager_node(state) -> dict: + + instrument_context = build_instrument_context(state["company_of_interest"]) + + history = state["risk_debate_state"]["history"] + risk_debate_state = state["risk_debate_state"] + market_research_report = state["market_report"] + news_report = state["news_report"] + fundamentals_report = state["fundamentals_report"] + sentiment_report = state["sentiment_report"] + research_plan = state["investment_plan"] + trader_plan = state["trader_investment_plan"] + + curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}" + past_memories = memory.get_memories(curr_situation, n_matches=2) + + past_memory_str = "" + for i, rec in enumerate(past_memories, 1): + past_memory_str += rec["recommendation"] + "\n\n" + + prompt = f"""As the Portfolio Manager, synthesize the risk analysts' debate and deliver the final trading decision. + +{instrument_context} + +--- + +**Rating Scale** (use exactly one): +- **Buy**: Strong conviction to enter or add to position +- **Overweight**: Favorable outlook, gradually increase exposure +- **Hold**: Maintain current position, no action needed +- **Underweight**: Reduce exposure, take partial profits +- **Sell**: Exit position or avoid entry + +**Context:** +- Research Manager's investment plan: **{research_plan}** +- Trader's transaction proposal: **{trader_plan}** +- Lessons from past decisions: **{past_memory_str}** + +**Required Output Structure:** +1. **Rating**: State one of Buy / Overweight / Hold / Underweight / Sell. +2. **Executive Summary**: A concise action plan covering entry strategy, position sizing, key risk levels, and time horizon. +3. **Investment Thesis**: Detailed reasoning anchored in the analysts' debate and past reflections. + +--- + +**Risk Analysts Debate History:** +{history} + +--- + +Be decisive and ground every conclusion in specific evidence from the analysts.{get_language_instruction()}""" + + response = llm.invoke(prompt) + + new_risk_debate_state = { + "judge_decision": response.content, + "history": risk_debate_state["history"], + "aggressive_history": risk_debate_state["aggressive_history"], + "conservative_history": risk_debate_state["conservative_history"], + "neutral_history": risk_debate_state["neutral_history"], + "latest_speaker": "Judge", + "current_aggressive_response": risk_debate_state["current_aggressive_response"], + "current_conservative_response": risk_debate_state["current_conservative_response"], + "current_neutral_response": risk_debate_state["current_neutral_response"], + "count": risk_debate_state["count"], + } + + return { + "risk_debate_state": new_risk_debate_state, + "final_trade_decision": response.content, + } + + return portfolio_manager_node diff --git a/tradingagents/llm_clients/model_catalog.py b/tradingagents/llm_clients/model_catalog.py new file mode 100644 index 00000000..a2c57ed8 --- /dev/null +++ b/tradingagents/llm_clients/model_catalog.py @@ -0,0 +1,134 @@ +"""Shared model catalog for CLI selections and validation.""" + +from __future__ import annotations + +from typing import Dict, List, Tuple + +ModelOption = Tuple[str, str] +ProviderModeOptions = Dict[str, Dict[str, List[ModelOption]]] + + +MODEL_OPTIONS: ProviderModeOptions = { + "openai": { + "quick": [ + ("GPT-5.4 Mini - Fast, strong coding and tool use", "gpt-5.4-mini"), + ("GPT-5.4 Nano - Cheapest, high-volume tasks", "gpt-5.4-nano"), + ("GPT-5.4 - Latest frontier, 1M context", "gpt-5.4"), + ("GPT-4.1 - Smartest non-reasoning model", "gpt-4.1"), + ], + "deep": [ + ("GPT-5.4 - Latest frontier, 1M context", "gpt-5.4"), + ("GPT-5.2 - Strong reasoning, cost-effective", "gpt-5.2"), + ("GPT-5.4 Mini - Fast, strong coding and tool use", "gpt-5.4-mini"), + ("GPT-5.4 Pro - Most capable, expensive ($30/$180 per 1M tokens)", "gpt-5.4-pro"), + ], + }, + "anthropic": { + "quick": [ + ("Claude Sonnet 4.6 - Best speed and intelligence balance", "claude-sonnet-4-6"), + ("Claude Haiku 4.5 - Fast, near-instant responses", "claude-haiku-4-5"), + ("Claude Sonnet 4.5 - Agents and coding", "claude-sonnet-4-5"), + ], + "deep": [ + ("Claude Opus 4.6 - Most intelligent, agents and coding", "claude-opus-4-6"), + ("Claude Opus 4.5 - Premium, max intelligence", "claude-opus-4-5"), + ("Claude Sonnet 4.6 - Best speed and intelligence balance", "claude-sonnet-4-6"), + ("Claude Sonnet 4.5 - Agents and coding", "claude-sonnet-4-5"), + ], + }, + "google": { + "quick": [ + ("Gemini 3 Flash - Next-gen fast", "gemini-3-flash-preview"), + ("Gemini 2.5 Flash - Balanced, stable", "gemini-2.5-flash"), + ("Gemini 3.1 Flash Lite - Most cost-efficient", "gemini-3.1-flash-lite-preview"), + ("Gemini 2.5 Flash Lite - Fast, low-cost", "gemini-2.5-flash-lite"), + ], + "deep": [ + ("Gemini 3.1 Pro - Reasoning-first, complex workflows", "gemini-3.1-pro-preview"), + ("Gemini 3 Flash - Next-gen fast", "gemini-3-flash-preview"), + ("Gemini 2.5 Pro - Stable pro model", "gemini-2.5-pro"), + ("Gemini 2.5 Flash - Balanced, stable", "gemini-2.5-flash"), + ], + }, + "xai": { + "quick": [ + ("Grok 4.1 Fast (Non-Reasoning) - Speed optimized, 2M ctx", "grok-4-1-fast-non-reasoning"), + ("Grok 4 Fast (Non-Reasoning) - Speed optimized", "grok-4-fast-non-reasoning"), + ("Grok 4.1 Fast (Reasoning) - High-performance, 2M ctx", "grok-4-1-fast-reasoning"), + ], + "deep": [ + ("Grok 4 - Flagship model", "grok-4-0709"), + ("Grok 4.1 Fast (Reasoning) - High-performance, 2M ctx", "grok-4-1-fast-reasoning"), + ("Grok 4 Fast (Reasoning) - High-performance", "grok-4-fast-reasoning"), + ("Grok 4.1 Fast (Non-Reasoning) - Speed optimized, 2M ctx", "grok-4-1-fast-non-reasoning"), + ], + }, + "deepseek": { + "quick": [ + ("DeepSeek V3.2", "deepseek-chat"), + ("Custom model ID", "custom"), + ], + "deep": [ + ("DeepSeek V3.2 (thinking)", "deepseek-reasoner"), + ("DeepSeek V3.2", "deepseek-chat"), + ("Custom model ID", "custom"), + ], + }, + "qwen": { + "quick": [ + ("Qwen 3.5 Flash", "qwen3.5-flash"), + ("Qwen Plus", "qwen-plus"), + ("Custom model ID", "custom"), + ], + "deep": [ + ("Qwen 3.6 Plus", "qwen3.6-plus"), + ("Qwen 3.5 Plus", "qwen3.5-plus"), + ("Qwen 3 Max", "qwen3-max"), + ("Custom model ID", "custom"), + ], + }, + "glm": { + "quick": [ + ("GLM-4.7", "glm-4.7"), + ("GLM-5", "glm-5"), + ("Custom model ID", "custom"), + ], + "deep": [ + ("GLM-5.1", "glm-5.1"), + ("GLM-5", "glm-5"), + ("Custom model ID", "custom"), + ], + }, + # OpenRouter: fetched dynamically. Azure: any deployed model name. + "ollama": { + "quick": [ + ("Qwen3:latest (8B, local)", "qwen3:latest"), + ("GPT-OSS:latest (20B, local)", "gpt-oss:latest"), + ("GLM-4.7-Flash:latest (30B, local)", "glm-4.7-flash:latest"), + ], + "deep": [ + ("GLM-4.7-Flash:latest (30B, local)", "glm-4.7-flash:latest"), + ("GPT-OSS:latest (20B, local)", "gpt-oss:latest"), + ("Qwen3:latest (8B, local)", "qwen3:latest"), + ], + }, +} + + +def get_model_options(provider: str, mode: str) -> List[ModelOption]: + """Return shared model options for a provider and selection mode.""" + return MODEL_OPTIONS[provider.lower()][mode] + + +def get_known_models() -> Dict[str, List[str]]: + """Build known model names from the shared CLI catalog.""" + return { + provider: sorted( + { + value + for options in mode_options.values() + for _, value in options + } + ) + for provider, mode_options in MODEL_OPTIONS.items() + }