Compare commits

...

7 Commits

16 changed files with 462 additions and 148 deletions

5
.env.enterprise.example Normal file
View File

@ -0,0 +1,5 @@
# Azure OpenAI
AZURE_OPENAI_API_KEY=
AZURE_OPENAI_ENDPOINT=https://your-resource-name.openai.azure.com/
AZURE_OPENAI_DEPLOYMENT_NAME=
# OPENAI_API_VERSION=2024-10-21 # optional, required for non-v1 API

View File

@ -3,4 +3,7 @@ OPENAI_API_KEY=
GOOGLE_API_KEY=
ANTHROPIC_API_KEY=
XAI_API_KEY=
DEEPSEEK_API_KEY=
DASHSCOPE_API_KEY=
ZHIPU_API_KEY=
OPENROUTER_API_KEY=

View File

@ -140,10 +140,15 @@ export OPENAI_API_KEY=... # OpenAI (GPT)
export GOOGLE_API_KEY=... # Google (Gemini)
export ANTHROPIC_API_KEY=... # Anthropic (Claude)
export XAI_API_KEY=... # xAI (Grok)
export DEEPSEEK_API_KEY=... # DeepSeek
export DASHSCOPE_API_KEY=... # Qwen (Alibaba DashScope)
export ZHIPU_API_KEY=... # GLM (Zhipu)
export OPENROUTER_API_KEY=... # OpenRouter
export ALPHA_VANTAGE_API_KEY=... # Alpha Vantage
```
For enterprise providers (e.g. Azure OpenAI, AWS Bedrock), copy `.env.enterprise.example` to `.env.enterprise` and fill in your credentials.
For local models, configure Ollama with `llm_provider: "ollama"` in your config.
Alternatively, copy `.env.example` to `.env` and fill in your keys:

View File

@ -6,8 +6,9 @@ from functools import wraps
from rich.console import Console
from dotenv import load_dotenv
# Load environment variables from .env file
# Load environment variables
load_dotenv()
load_dotenv(".env.enterprise", override=False)
from rich.panel import Panel
from rich.spinner import Spinner
from rich.live import Live
@ -24,6 +25,12 @@ from rich.align import Align
from rich.rule import Rule
from tradingagents.graph.trading_graph import TradingAgentsGraph
from tradingagents.graph.analyst_execution import (
AnalystWallTimeTracker,
build_analyst_execution_plan,
get_initial_analyst_node,
sync_analyst_tracker_from_chunk,
)
from tradingagents.default_config import DEFAULT_CONFIG
from cli.models import AnalystType
from cli.utils import *
@ -79,7 +86,7 @@ class MessageBuffer:
self.current_agent = None
self.report_sections = {}
self.selected_analysts = []
self._last_message_id = None
self._processed_message_ids = set()
def init_for_analysis(self, selected_analysts):
"""Initialize agent status and report sections based on selected analysts.
@ -114,7 +121,7 @@ class MessageBuffer:
self.current_agent = None
self.messages.clear()
self.tool_calls.clear()
self._last_message_id = None
self._processed_message_ids.clear()
def get_completed_reports_count(self):
"""Count reports that are finalized (their finalizing agent is completed).
@ -809,7 +816,7 @@ ANALYST_REPORT_MAP = {
}
def update_analyst_statuses(message_buffer, chunk):
def update_analyst_statuses(message_buffer, chunk, wall_time_tracker=None):
"""Update analyst statuses based on accumulated report state.
Logic:
@ -823,6 +830,9 @@ def update_analyst_statuses(message_buffer, chunk):
selected = message_buffer.selected_analysts
found_active = False
if wall_time_tracker is not None:
sync_analyst_tracker_from_chunk(wall_time_tracker, chunk)
for analyst_key in ANALYST_ORDER:
if analyst_key not in selected:
continue
@ -949,6 +959,11 @@ def run_analysis():
# Normalize analyst selection to predefined order (selection is a 'set', order is fixed)
selected_set = {analyst.value for analyst in selections["analysts"]}
selected_analyst_keys = [a for a in ANALYST_ORDER if a in selected_set]
analyst_execution_plan = build_analyst_execution_plan(
selected_analyst_keys,
concurrency_limit=config["analyst_concurrency_limit"],
)
analyst_wall_time_tracker = AnalystWallTimeTracker(analyst_execution_plan)
# Initialize the graph with callbacks bound to LLMs
graph = TradingAgentsGraph(
@ -1031,8 +1046,9 @@ def run_analysis():
update_display(layout, stats_handler=stats_handler, start_time=start_time)
# Update agent status to in_progress for the first analyst
first_analyst = f"{selections['analysts'][0].value.capitalize()} Analyst"
first_analyst = get_initial_analyst_node(analyst_execution_plan)
message_buffer.update_agent_status(first_analyst, "in_progress")
analyst_wall_time_tracker.mark_started(selected_analyst_keys[0])
update_display(layout, stats_handler=stats_handler, start_time=start_time)
# Create spinner text
@ -1052,31 +1068,31 @@ def run_analysis():
# Stream the analysis
trace = []
for chunk in graph.graph.stream(init_agent_state, **args):
# Process messages if present (skip duplicates via message ID)
if len(chunk["messages"]) > 0:
last_message = chunk["messages"][-1]
msg_id = getattr(last_message, "id", None)
# Process all messages in chunk, deduplicating by message ID
for message in chunk.get("messages", []):
msg_id = getattr(message, "id", None)
if msg_id is not None:
if msg_id in message_buffer._processed_message_ids:
continue
message_buffer._processed_message_ids.add(msg_id)
if msg_id != message_buffer._last_message_id:
message_buffer._last_message_id = msg_id
msg_type, content = classify_message_type(message)
if content and content.strip():
message_buffer.add_message(msg_type, content)
# Add message to buffer
msg_type, content = classify_message_type(last_message)
if content and content.strip():
message_buffer.add_message(msg_type, content)
# Handle tool calls
if hasattr(last_message, "tool_calls") and last_message.tool_calls:
for tool_call in last_message.tool_calls:
if isinstance(tool_call, dict):
message_buffer.add_tool_call(
tool_call["name"], tool_call["args"]
)
else:
message_buffer.add_tool_call(tool_call.name, tool_call.args)
if hasattr(message, "tool_calls") and message.tool_calls:
for tool_call in message.tool_calls:
if isinstance(tool_call, dict):
message_buffer.add_tool_call(tool_call["name"], tool_call["args"])
else:
message_buffer.add_tool_call(tool_call.name, tool_call.args)
# Update analyst statuses based on report state (runs on every chunk)
update_analyst_statuses(message_buffer, chunk)
update_analyst_statuses(
message_buffer,
chunk,
wall_time_tracker=analyst_wall_time_tracker,
)
# Research Team - Handle Investment Debate State
if chunk.get("investment_debate_state"):
@ -1165,6 +1181,7 @@ def run_analysis():
message_buffer.add_message(
"System", f"Completed analysis for {selections['analysis_date']}"
)
message_buffer.add_message("System", analyst_wall_time_tracker.format_summary())
# Update final report sections
for section in message_buffer.report_sections.keys():
@ -1175,6 +1192,7 @@ def run_analysis():
# Post-analysis prompts (outside Live context for clean interaction)
console.print("\n[bold cyan]Analysis Complete![/bold cyan]\n")
console.print(f"[dim]{analyst_wall_time_tracker.format_summary()}[/dim]")
# Prompt to save report
save_choice = typer.prompt("Save report?", default="Y").strip().upper()

View File

@ -174,17 +174,30 @@ def select_openrouter_model() -> str:
return choice
def select_shallow_thinking_agent(provider) -> str:
"""Select shallow thinking llm engine using an interactive selection."""
def _prompt_custom_model_id() -> str:
"""Prompt user to type a custom model ID."""
return questionary.text(
"Enter model ID:",
validate=lambda x: len(x.strip()) > 0 or "Please enter a model ID.",
).ask().strip()
def _select_model(provider: str, mode: str) -> str:
"""Select a model for the given provider and mode (quick/deep)."""
if provider.lower() == "openrouter":
return select_openrouter_model()
if provider.lower() == "azure":
return questionary.text(
f"Enter Azure deployment name ({mode}-thinking):",
validate=lambda x: len(x.strip()) > 0 or "Please enter a deployment name.",
).ask().strip()
choice = questionary.select(
"Select Your [Quick-Thinking LLM Engine]:",
f"Select Your [{mode.title()}-Thinking LLM Engine]:",
choices=[
questionary.Choice(display, value=value)
for display, value in get_model_options(provider, "quick")
for display, value in get_model_options(provider, mode)
],
instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
style=questionary.Style(
@ -197,58 +210,45 @@ def select_shallow_thinking_agent(provider) -> str:
).ask()
if choice is None:
console.print(
"\n[red]No shallow thinking llm engine selected. Exiting...[/red]"
)
console.print(f"\n[red]No {mode} thinking llm engine selected. Exiting...[/red]")
exit(1)
if choice == "custom":
return _prompt_custom_model_id()
return choice
def select_shallow_thinking_agent(provider) -> str:
"""Select shallow thinking llm engine using an interactive selection."""
return _select_model(provider, "quick")
def select_deep_thinking_agent(provider) -> str:
"""Select deep thinking llm engine using an interactive selection."""
if provider.lower() == "openrouter":
return select_openrouter_model()
choice = questionary.select(
"Select Your [Deep-Thinking LLM Engine]:",
choices=[
questionary.Choice(display, value=value)
for display, value in get_model_options(provider, "deep")
],
instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
style=questionary.Style(
[
("selected", "fg:magenta noinherit"),
("highlighted", "fg:magenta noinherit"),
("pointer", "fg:magenta noinherit"),
]
),
).ask()
if choice is None:
console.print("\n[red]No deep thinking llm engine selected. Exiting...[/red]")
exit(1)
return choice
return _select_model(provider, "deep")
def select_llm_provider() -> tuple[str, str | None]:
"""Select the LLM provider and its API endpoint."""
BASE_URLS = [
("OpenAI", "https://api.openai.com/v1"),
("Google", None), # google-genai SDK manages its own endpoint
("Anthropic", "https://api.anthropic.com/"),
("xAI", "https://api.x.ai/v1"),
("Openrouter", "https://openrouter.ai/api/v1"),
("Ollama", "http://localhost:11434/v1"),
# (display_name, provider_key, base_url)
PROVIDERS = [
("OpenAI", "openai", "https://api.openai.com/v1"),
("Google", "google", None),
("Anthropic", "anthropic", "https://api.anthropic.com/"),
("xAI", "xai", "https://api.x.ai/v1"),
("DeepSeek", "deepseek", "https://api.deepseek.com"),
("Qwen", "qwen", "https://dashscope.aliyuncs.com/compatible-mode/v1"),
("GLM", "glm", "https://open.bigmodel.cn/api/paas/v4/"),
("OpenRouter", "openrouter", "https://openrouter.ai/api/v1"),
("Azure OpenAI", "azure", None),
("Ollama", "ollama", "http://localhost:11434/v1"),
]
choice = questionary.select(
"Select your LLM Provider:",
choices=[
questionary.Choice(display, value=(display, value))
for display, value in BASE_URLS
questionary.Choice(display, value=(provider_key, url))
for display, provider_key, url in PROVIDERS
],
instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
style=questionary.Style(
@ -261,13 +261,11 @@ def select_llm_provider() -> tuple[str, str | None]:
).ask()
if choice is None:
console.print("\n[red]no OpenAI backend selected. Exiting...[/red]")
console.print("\n[red]No LLM provider selected. Exiting...[/red]")
exit(1)
display_name, url = choice
print(f"You selected: {display_name}\tURL: {url}")
return display_name, url
provider, url = choice
return provider, url
def ask_openai_reasoning_effort() -> str:

View File

@ -4,7 +4,7 @@ services:
env_file:
- .env
volumes:
- ./results:/home/appuser/app/results
- tradingagents_data:/home/appuser/.tradingagents
tty: true
stdin_open: true
@ -22,7 +22,7 @@ services:
environment:
- LLM_PROVIDER=ollama
volumes:
- ./results:/home/appuser/app/results
- tradingagents_data:/home/appuser/.tradingagents
depends_on:
- ollama
tty: true
@ -31,4 +31,5 @@ services:
- ollama
volumes:
tradingagents_data:
ollama_data:

View File

@ -0,0 +1,84 @@
import unittest
from tradingagents.graph.analyst_execution import (
AnalystWallTimeTracker,
build_analyst_execution_plan,
get_initial_analyst_node,
sync_analyst_tracker_from_chunk,
)
class AnalystExecutionPlanTests(unittest.TestCase):
def test_build_plan_preserves_selected_order(self):
plan = build_analyst_execution_plan(["news", "market"], concurrency_limit=2)
self.assertEqual([spec.key for spec in plan.specs], ["news", "market"])
self.assertEqual(plan.concurrency_limit, 2)
self.assertEqual(plan.specs[0].agent_node, "News Analyst")
self.assertEqual(plan.specs[0].tool_node, "tools_news")
self.assertEqual(plan.specs[0].clear_node, "Msg Clear News")
def test_rejects_unknown_analyst_keys(self):
with self.assertRaises(ValueError):
build_analyst_execution_plan(["market", "macro"])
def test_requires_positive_concurrency_limit(self):
with self.assertRaises(ValueError):
build_analyst_execution_plan(["market"], concurrency_limit=0)
def test_get_initial_analyst_node_uses_plan_metadata(self):
plan = build_analyst_execution_plan(["fundamentals", "news"])
self.assertEqual(
get_initial_analyst_node(plan),
"Fundamentals Analyst",
)
class AnalystWallTimeTrackerTests(unittest.TestCase):
def test_records_wall_time_when_analyst_completes(self):
plan = build_analyst_execution_plan(["market", "news"])
tracker = AnalystWallTimeTracker(plan)
tracker.mark_started("market", started_at=10.0)
tracker.mark_completed("market", completed_at=13.5)
self.assertEqual(tracker.get_wall_times(), {"market": 3.5})
def test_formats_summary_in_plan_order(self):
plan = build_analyst_execution_plan(["news", "market"])
tracker = AnalystWallTimeTracker(plan)
tracker.mark_started("market", started_at=20.0)
tracker.mark_completed("market", completed_at=22.25)
tracker.mark_started("news", started_at=10.0)
tracker.mark_completed("news", completed_at=14.0)
self.assertEqual(
tracker.format_summary(),
"Analyst wall time: News 4.00s | Market 2.25s",
)
def test_syncs_wall_time_from_sequential_chunks(self):
plan = build_analyst_execution_plan(["market", "news"])
tracker = AnalystWallTimeTracker(plan)
sync_analyst_tracker_from_chunk(tracker, {}, now=10.0)
self.assertEqual(tracker.get_wall_times(), {})
sync_analyst_tracker_from_chunk(
tracker,
{"market_report": "done"},
now=13.0,
)
self.assertEqual(tracker.get_wall_times(), {"market": 3.0})
sync_analyst_tracker_from_chunk(
tracker,
{"market_report": "done", "news_report": "done"},
now=18.0,
)
self.assertEqual(
tracker.get_wall_times(),
{"market": 3.0, "news": 5.0},
)

View File

@ -78,7 +78,7 @@ class FinancialSituationMemory:
# Build results
results = []
max_score = max(scores) if max(scores) > 0 else 1 # Normalize scores
max_score = float(scores.max()) if len(scores) > 0 and scores.max() > 0 else 1.0
for idx in top_indices:
# Normalize score to 0-1 range for consistency

View File

@ -1,12 +1,11 @@
import os
_TRADINGAGENTS_HOME = os.path.join(os.path.expanduser("~"), ".tradingagents")
DEFAULT_CONFIG = {
"project_dir": os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
"results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", "./results"),
"data_cache_dir": os.path.join(
os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
"dataflows/data_cache",
),
"results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", os.path.join(_TRADINGAGENTS_HOME, "logs")),
"data_cache_dir": os.getenv("TRADINGAGENTS_CACHE_DIR", os.path.join(_TRADINGAGENTS_HOME, "cache")),
# LLM settings
"llm_provider": "openai",
"deep_think_llm": "gpt-5.4",
@ -23,6 +22,7 @@ DEFAULT_CONFIG = {
"max_debate_rounds": 1,
"max_risk_discuss_rounds": 1,
"max_recur_limit": 100,
"analyst_concurrency_limit": 1,
# Data vendor configuration
# Category-level configuration (default for all tools in category)
"data_vendors": {

View File

@ -0,0 +1,136 @@
from dataclasses import dataclass
from time import monotonic
from typing import Dict, Iterable, List, Optional
@dataclass(frozen=True)
class AnalystNodeSpec:
key: str
agent_node: str
clear_node: str
tool_node: str
report_key: str
@dataclass(frozen=True)
class AnalystExecutionPlan:
specs: List[AnalystNodeSpec]
concurrency_limit: int
ANALYST_NODE_SPECS: Dict[str, AnalystNodeSpec] = {
"market": AnalystNodeSpec(
key="market",
agent_node="Market Analyst",
clear_node="Msg Clear Market",
tool_node="tools_market",
report_key="market_report",
),
"social": AnalystNodeSpec(
key="social",
agent_node="Social Analyst",
clear_node="Msg Clear Social",
tool_node="tools_social",
report_key="sentiment_report",
),
"news": AnalystNodeSpec(
key="news",
agent_node="News Analyst",
clear_node="Msg Clear News",
tool_node="tools_news",
report_key="news_report",
),
"fundamentals": AnalystNodeSpec(
key="fundamentals",
agent_node="Fundamentals Analyst",
clear_node="Msg Clear Fundamentals",
tool_node="tools_fundamentals",
report_key="fundamentals_report",
),
}
def build_analyst_execution_plan(
selected_analysts: Iterable[str],
concurrency_limit: int = 1,
) -> AnalystExecutionPlan:
if concurrency_limit < 1:
raise ValueError("analyst concurrency limit must be >= 1")
specs: List[AnalystNodeSpec] = []
for analyst_key in selected_analysts:
spec = ANALYST_NODE_SPECS.get(analyst_key)
if spec is None:
raise ValueError(f"unknown analyst key: {analyst_key}")
specs.append(spec)
if not specs:
raise ValueError("at least one analyst must be selected")
return AnalystExecutionPlan(specs=specs, concurrency_limit=concurrency_limit)
def get_initial_analyst_node(plan: AnalystExecutionPlan) -> str:
return plan.specs[0].agent_node
class AnalystWallTimeTracker:
def __init__(self, plan: AnalystExecutionPlan):
self.plan = plan
self._started_at: Dict[str, float] = {}
self._wall_times: Dict[str, float] = {}
def mark_started(self, analyst_key: str, started_at: Optional[float] = None) -> None:
if analyst_key not in ANALYST_NODE_SPECS:
raise ValueError(f"unknown analyst key: {analyst_key}")
self._started_at.setdefault(analyst_key, monotonic() if started_at is None else started_at)
def mark_completed(
self,
analyst_key: str,
completed_at: Optional[float] = None,
) -> None:
if analyst_key not in ANALYST_NODE_SPECS:
raise ValueError(f"unknown analyst key: {analyst_key}")
if analyst_key in self._wall_times:
return
started_at = self._started_at.get(analyst_key)
if started_at is None:
return
finished_at = monotonic() if completed_at is None else completed_at
self._wall_times[analyst_key] = max(0.0, finished_at - started_at)
def get_wall_times(self) -> Dict[str, float]:
return dict(self._wall_times)
def format_summary(self) -> str:
parts = []
for spec in self.plan.specs:
duration = self._wall_times.get(spec.key)
if duration is not None:
label = spec.agent_node.removesuffix(" Analyst")
parts.append(f"{label} {duration:.2f}s")
if not parts:
return "Analyst wall time: pending"
return "Analyst wall time: " + " | ".join(parts)
def sync_analyst_tracker_from_chunk(
tracker: AnalystWallTimeTracker,
chunk: Dict[str, str],
now: Optional[float] = None,
) -> None:
current_time = monotonic() if now is None else now
active_found = False
for spec in tracker.plan.specs:
has_report = bool(chunk.get(spec.report_key))
if has_report:
tracker.mark_started(spec.key, started_at=current_time)
tracker.mark_completed(spec.key, completed_at=current_time)
continue
if not active_found:
tracker.mark_started(spec.key, started_at=current_time)
active_found = True

View File

@ -7,6 +7,7 @@ from langgraph.prebuilt import ToolNode
from tradingagents.agents import *
from tradingagents.agents.utils.agent_states import AgentState
from .analyst_execution import build_analyst_execution_plan
from .conditional_logic import ConditionalLogic
@ -24,6 +25,7 @@ class GraphSetup:
invest_judge_memory,
portfolio_manager_memory,
conditional_logic: ConditionalLogic,
analyst_concurrency_limit: int = 1,
):
"""Initialize with required components."""
self.quick_thinking_llm = quick_thinking_llm
@ -35,6 +37,7 @@ class GraphSetup:
self.invest_judge_memory = invest_judge_memory
self.portfolio_manager_memory = portfolio_manager_memory
self.conditional_logic = conditional_logic
self.analyst_concurrency_limit = analyst_concurrency_limit
def setup_graph(
self, selected_analysts=["market", "social", "news", "fundamentals"]
@ -48,41 +51,17 @@ class GraphSetup:
- "news": News analyst
- "fundamentals": Fundamentals analyst
"""
if len(selected_analysts) == 0:
raise ValueError("Trading Agents Graph Setup Error: no analysts selected!")
plan = build_analyst_execution_plan(
selected_analysts,
concurrency_limit=self.analyst_concurrency_limit,
)
# Create analyst nodes
analyst_nodes = {}
delete_nodes = {}
tool_nodes = {}
if "market" in selected_analysts:
analyst_nodes["market"] = create_market_analyst(
self.quick_thinking_llm
)
delete_nodes["market"] = create_msg_delete()
tool_nodes["market"] = self.tool_nodes["market"]
if "social" in selected_analysts:
analyst_nodes["social"] = create_social_media_analyst(
self.quick_thinking_llm
)
delete_nodes["social"] = create_msg_delete()
tool_nodes["social"] = self.tool_nodes["social"]
if "news" in selected_analysts:
analyst_nodes["news"] = create_news_analyst(
self.quick_thinking_llm
)
delete_nodes["news"] = create_msg_delete()
tool_nodes["news"] = self.tool_nodes["news"]
if "fundamentals" in selected_analysts:
analyst_nodes["fundamentals"] = create_fundamentals_analyst(
self.quick_thinking_llm
)
delete_nodes["fundamentals"] = create_msg_delete()
tool_nodes["fundamentals"] = self.tool_nodes["fundamentals"]
analyst_factories = {
"market": lambda: create_market_analyst(self.quick_thinking_llm),
"social": lambda: create_social_media_analyst(self.quick_thinking_llm),
"news": lambda: create_news_analyst(self.quick_thinking_llm),
"fundamentals": lambda: create_fundamentals_analyst(self.quick_thinking_llm),
}
# Create researcher and manager nodes
bull_researcher_node = create_bull_researcher(
@ -108,12 +87,10 @@ class GraphSetup:
workflow = StateGraph(AgentState)
# Add analyst nodes to the graph
for analyst_type, node in analyst_nodes.items():
workflow.add_node(f"{analyst_type.capitalize()} Analyst", node)
workflow.add_node(
f"Msg Clear {analyst_type.capitalize()}", delete_nodes[analyst_type]
)
workflow.add_node(f"tools_{analyst_type}", tool_nodes[analyst_type])
for spec in plan.specs:
workflow.add_node(spec.agent_node, analyst_factories[spec.key]())
workflow.add_node(spec.clear_node, create_msg_delete())
workflow.add_node(spec.tool_node, self.tool_nodes[spec.key])
# Add other nodes
workflow.add_node("Bull Researcher", bull_researcher_node)
@ -127,27 +104,25 @@ class GraphSetup:
# Define edges
# Start with the first analyst
first_analyst = selected_analysts[0]
workflow.add_edge(START, f"{first_analyst.capitalize()} Analyst")
workflow.add_edge(START, plan.specs[0].agent_node)
# Connect analysts in sequence
for i, analyst_type in enumerate(selected_analysts):
current_analyst = f"{analyst_type.capitalize()} Analyst"
current_tools = f"tools_{analyst_type}"
current_clear = f"Msg Clear {analyst_type.capitalize()}"
for i, spec in enumerate(plan.specs):
current_analyst = spec.agent_node
current_tools = spec.tool_node
current_clear = spec.clear_node
# Add conditional edges for current analyst
workflow.add_conditional_edges(
current_analyst,
getattr(self.conditional_logic, f"should_continue_{analyst_type}"),
getattr(self.conditional_logic, f"should_continue_{spec.key}"),
[current_tools, current_clear],
)
workflow.add_edge(current_tools, current_analyst)
# Connect to next analyst or to Bull Researcher if this is the last analyst
if i < len(selected_analysts) - 1:
next_analyst = f"{selected_analysts[i+1].capitalize()} Analyst"
workflow.add_edge(current_clear, next_analyst)
if i < len(plan.specs) - 1:
workflow.add_edge(current_clear, plan.specs[i + 1].agent_node)
else:
workflow.add_edge(current_clear, "Bull Researcher")

View File

@ -66,10 +66,8 @@ class TradingAgentsGraph:
set_config(self.config)
# Create necessary directories
os.makedirs(
os.path.join(self.config["project_dir"], "dataflows/data_cache"),
exist_ok=True,
)
os.makedirs(self.config["data_cache_dir"], exist_ok=True)
os.makedirs(self.config["results_dir"], exist_ok=True)
# Initialize LLMs with provider-specific thinking configuration
llm_kwargs = self._get_provider_kwargs()
@ -119,6 +117,7 @@ class TradingAgentsGraph:
self.invest_judge_memory,
self.portfolio_manager_memory,
self.conditional_logic,
analyst_concurrency_limit=self.config.get("analyst_concurrency_limit", 1),
)
self.propagator = Propagator()

View File

@ -0,0 +1,52 @@
import os
from typing import Any, Optional
from langchain_openai import AzureChatOpenAI
from .base_client import BaseLLMClient, normalize_content
from .validators import validate_model
_PASSTHROUGH_KWARGS = (
"timeout", "max_retries", "api_key", "reasoning_effort",
"callbacks", "http_client", "http_async_client",
)
class NormalizedAzureChatOpenAI(AzureChatOpenAI):
"""AzureChatOpenAI with normalized content output."""
def invoke(self, input, config=None, **kwargs):
return normalize_content(super().invoke(input, config, **kwargs))
class AzureOpenAIClient(BaseLLMClient):
"""Client for Azure OpenAI deployments.
Requires environment variables:
AZURE_OPENAI_API_KEY: API key
AZURE_OPENAI_ENDPOINT: Endpoint URL (e.g. https://<resource>.openai.azure.com/)
AZURE_OPENAI_DEPLOYMENT_NAME: Deployment name
OPENAI_API_VERSION: API version (e.g. 2025-03-01-preview)
"""
def __init__(self, model: str, base_url: Optional[str] = None, **kwargs):
super().__init__(model, base_url, **kwargs)
def get_llm(self) -> Any:
"""Return configured AzureChatOpenAI instance."""
self.warn_if_unknown_model()
llm_kwargs = {
"model": self.model,
"azure_deployment": os.environ.get("AZURE_OPENAI_DEPLOYMENT_NAME", self.model),
}
for key in _PASSTHROUGH_KWARGS:
if key in self.kwargs:
llm_kwargs[key] = self.kwargs[key]
return NormalizedAzureChatOpenAI(**llm_kwargs)
def validate_model(self) -> bool:
"""Azure accepts any deployed model name."""
return True

View File

@ -4,6 +4,12 @@ from .base_client import BaseLLMClient
from .openai_client import OpenAIClient
from .anthropic_client import AnthropicClient
from .google_client import GoogleClient
from .azure_client import AzureOpenAIClient
# Providers that use the OpenAI-compatible chat completions API
_OPENAI_COMPATIBLE = (
"openai", "xai", "deepseek", "qwen", "glm", "ollama", "openrouter",
)
def create_llm_client(
@ -15,16 +21,10 @@ def create_llm_client(
"""Create an LLM client for the specified provider.
Args:
provider: LLM provider (openai, anthropic, google, xai, ollama, openrouter)
provider: LLM provider name
model: Model name/identifier
base_url: Optional base URL for API endpoint
**kwargs: Additional provider-specific arguments
- http_client: Custom httpx.Client for SSL proxy or certificate customization
- http_async_client: Custom httpx.AsyncClient for async operations
- timeout: Request timeout in seconds
- max_retries: Maximum retry attempts
- api_key: API key for the provider
- callbacks: LangChain callbacks
Returns:
Configured BaseLLMClient instance
@ -34,16 +34,16 @@ def create_llm_client(
"""
provider_lower = provider.lower()
if provider_lower in ("openai", "ollama", "openrouter"):
if provider_lower in _OPENAI_COMPATIBLE:
return OpenAIClient(model, base_url, provider=provider_lower, **kwargs)
if provider_lower == "xai":
return OpenAIClient(model, base_url, provider="xai", **kwargs)
if provider_lower == "anthropic":
return AnthropicClient(model, base_url, **kwargs)
if provider_lower == "google":
return GoogleClient(model, base_url, **kwargs)
if provider_lower == "azure":
return AzureOpenAIClient(model, base_url, **kwargs)
raise ValueError(f"Unsupported LLM provider: {provider}")

View File

@ -63,8 +63,43 @@ MODEL_OPTIONS: ProviderModeOptions = {
("Grok 4.1 Fast (Non-Reasoning) - Speed optimized, 2M ctx", "grok-4-1-fast-non-reasoning"),
],
},
# OpenRouter models are fetched dynamically at CLI runtime.
# No static entries needed; any model ID is accepted by the validator.
"deepseek": {
"quick": [
("DeepSeek V3.2", "deepseek-chat"),
("Custom model ID", "custom"),
],
"deep": [
("DeepSeek V3.2 (thinking)", "deepseek-reasoner"),
("DeepSeek V3.2", "deepseek-chat"),
("Custom model ID", "custom"),
],
},
"qwen": {
"quick": [
("Qwen 3.5 Flash", "qwen3.5-flash"),
("Qwen Plus", "qwen-plus"),
("Custom model ID", "custom"),
],
"deep": [
("Qwen 3.6 Plus", "qwen3.6-plus"),
("Qwen 3.5 Plus", "qwen3.5-plus"),
("Qwen 3 Max", "qwen3-max"),
("Custom model ID", "custom"),
],
},
"glm": {
"quick": [
("GLM-4.7", "glm-4.7"),
("GLM-5", "glm-5"),
("Custom model ID", "custom"),
],
"deep": [
("GLM-5.1", "glm-5.1"),
("GLM-5", "glm-5"),
("Custom model ID", "custom"),
],
},
# OpenRouter: fetched dynamically. Azure: any deployed model name.
"ollama": {
"quick": [
("Qwen3:latest (8B, local)", "qwen3:latest"),

View File

@ -27,6 +27,9 @@ _PASSTHROUGH_KWARGS = (
# Provider base URLs and API key env vars
_PROVIDER_CONFIG = {
"xai": ("https://api.x.ai/v1", "XAI_API_KEY"),
"deepseek": ("https://api.deepseek.com", "DEEPSEEK_API_KEY"),
"qwen": ("https://dashscope-intl.aliyuncs.com/compatible-mode/v1", "DASHSCOPE_API_KEY"),
"glm": ("https://api.z.ai/api/paas/v4/", "ZHIPU_API_KEY"),
"openrouter": ("https://openrouter.ai/api/v1", "OPENROUTER_API_KEY"),
"ollama": ("http://localhost:11434/v1", None),
}