This commit is contained in:
Santiago de Diego 2026-03-26 22:11:58 +01:00 committed by GitHub
commit 0d4db43be8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 138 additions and 28 deletions

View File

@ -3,4 +3,5 @@ OPENAI_API_KEY=
GOOGLE_API_KEY=
ANTHROPIC_API_KEY=
XAI_API_KEY=
PERPLEXITY_API_KEY=
OPENROUTER_API_KEY=

4
.gitignore vendored
View File

@ -217,3 +217,7 @@ __marimo__/
# Cache
**/data_cache/
reports/

View File

@ -162,6 +162,10 @@ def select_shallow_thinking_agent(provider) -> str:
("Grok 4 Fast (Non-Reasoning) - Speed optimized", "grok-4-fast-non-reasoning"),
("Grok 4.1 Fast (Reasoning) - High-performance, 2M ctx", "grok-4-1-fast-reasoning"),
],
"perplexity": [
("Sonar - Fast online search", "sonar"),
("Sonar Pro - Advanced online search", "sonar-pro"),
],
"openrouter": [
("NVIDIA Nemotron 3 Nano 30B (free)", "nvidia/nemotron-3-nano-30b-a3b:free"),
("Z.AI GLM 4.5 Air (free)", "z-ai/glm-4.5-air:free"),
@ -229,6 +233,11 @@ def select_deep_thinking_agent(provider) -> str:
("Grok 4 Fast (Reasoning) - High-performance", "grok-4-fast-reasoning"),
("Grok 4.1 Fast (Non-Reasoning) - Speed optimized, 2M ctx", "grok-4-1-fast-non-reasoning"),
],
"perplexity": [
("Sonar Pro - Advanced online search", "sonar-pro"),
("Sonar Reasoning Pro - Advanced reasoning with search", "sonar-reasoning-pro"),
("Sonar Deep Research - In-depth research", "sonar-deep-research"),
],
"openrouter": [
("Z.AI GLM 4.5 Air (free)", "z-ai/glm-4.5-air:free"),
("NVIDIA Nemotron 3 Nano 30B (free)", "nvidia/nemotron-3-nano-30b-a3b:free"),
@ -270,6 +279,7 @@ def select_llm_provider() -> tuple[str, str]:
("Google", "https://generativelanguage.googleapis.com/v1"),
("Anthropic", "https://api.anthropic.com/"),
("xAI", "https://api.x.ai/v1"),
("Perplexity", "https://api.perplexity.ai"),
("Openrouter", "https://openrouter.ai/api/v1"),
("Ollama", "http://localhost:11434/v1"),
]

View File

@ -1,6 +1,6 @@
from langchain_core.messages import HumanMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
import time
import json
from tradingagents.agents.utils.agent_utils import (
build_instrument_context,
get_balance_sheet,
@ -8,6 +8,8 @@ from tradingagents.agents.utils.agent_utils import (
get_fundamentals,
get_income_statement,
get_insider_transactions,
prefetch_tool_data,
supports_tool_calling,
)
from tradingagents.dataflows.config import get_config
@ -52,13 +54,24 @@ def create_fundamentals_analyst(llm):
prompt = prompt.partial(current_date=current_date)
prompt = prompt.partial(instrument_context=instrument_context)
chain = prompt | llm.bind_tools(tools)
result = chain.invoke(state["messages"])
if supports_tool_calling():
chain = prompt | llm.bind_tools(tools)
result = chain.invoke(state["messages"])
else:
ticker = state["company_of_interest"]
tool_data = prefetch_tool_data(tools, [
{"ticker": ticker, "curr_date": current_date},
{"ticker": ticker, "freq": "quarterly", "curr_date": current_date},
{"ticker": ticker, "freq": "quarterly", "curr_date": current_date},
{"ticker": ticker, "freq": "quarterly", "curr_date": current_date},
])
result = (prompt | llm).invoke([
HumanMessage(content=f"Analyze {ticker}.\n\nHere is the pre-fetched fundamental data:\n\n{tool_data}\n\nWrite your comprehensive report.")
])
report = ""
if len(result.tool_calls) == 0:
if not getattr(result, "tool_calls", None):
report = result.content
return {

View File

@ -1,10 +1,14 @@
from datetime import datetime, timedelta
from langchain_core.messages import HumanMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
import time
import json
from tradingagents.agents.utils.agent_utils import (
build_instrument_context,
get_indicators,
get_stock_data,
prefetch_tool_data,
supports_tool_calling,
)
from tradingagents.dataflows.config import get_config
@ -71,13 +75,23 @@ Volume-Based Indicators:
prompt = prompt.partial(current_date=current_date)
prompt = prompt.partial(instrument_context=instrument_context)
chain = prompt | llm.bind_tools(tools)
result = chain.invoke(state["messages"])
if supports_tool_calling():
chain = prompt | llm.bind_tools(tools)
result = chain.invoke(state["messages"])
else:
ticker = state["company_of_interest"]
start_date = (datetime.strptime(current_date, "%Y-%m-%d") - timedelta(days=30)).strftime("%Y-%m-%d")
tool_data = prefetch_tool_data(tools, [
{"symbol": ticker, "start_date": start_date, "end_date": current_date},
{"symbol": ticker, "indicator": "rsi,macd,boll,boll_ub,boll_lb,atr,vwma,close_50_sma", "curr_date": current_date, "look_back_days": 30},
])
result = (prompt | llm).invoke([
HumanMessage(content=f"Analyze {ticker}.\n\nHere is the pre-fetched market data:\n\n{tool_data}\n\nWrite your comprehensive report.")
])
report = ""
if len(result.tool_calls) == 0:
if not getattr(result, "tool_calls", None):
report = result.content
return {

View File

@ -1,10 +1,14 @@
from datetime import datetime, timedelta
from langchain_core.messages import HumanMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
import time
import json
from tradingagents.agents.utils.agent_utils import (
build_instrument_context,
get_global_news,
get_news,
prefetch_tool_data,
supports_tool_calling,
)
from tradingagents.dataflows.config import get_config
@ -46,12 +50,23 @@ def create_news_analyst(llm):
prompt = prompt.partial(current_date=current_date)
prompt = prompt.partial(instrument_context=instrument_context)
chain = prompt | llm.bind_tools(tools)
result = chain.invoke(state["messages"])
if supports_tool_calling():
chain = prompt | llm.bind_tools(tools)
result = chain.invoke(state["messages"])
else:
ticker = state["company_of_interest"]
start_date = (datetime.strptime(current_date, "%Y-%m-%d") - timedelta(days=7)).strftime("%Y-%m-%d")
tool_data = prefetch_tool_data(tools, [
{"ticker": ticker, "start_date": start_date, "end_date": current_date},
{"curr_date": current_date, "look_back_days": 7, "limit": 5},
])
result = (prompt | llm).invoke([
HumanMessage(content=f"Analyze {ticker}.\n\nHere is the pre-fetched news data:\n\n{tool_data}\n\nWrite your comprehensive report.")
])
report = ""
if len(result.tool_calls) == 0:
if not getattr(result, "tool_calls", None):
report = result.content
return {

View File

@ -1,7 +1,14 @@
from datetime import datetime, timedelta
from langchain_core.messages import HumanMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
import time
import json
from tradingagents.agents.utils.agent_utils import build_instrument_context, get_news
from tradingagents.agents.utils.agent_utils import (
build_instrument_context,
get_news,
prefetch_tool_data,
supports_tool_calling,
)
from tradingagents.dataflows.config import get_config
@ -41,13 +48,22 @@ def create_social_media_analyst(llm):
prompt = prompt.partial(current_date=current_date)
prompt = prompt.partial(instrument_context=instrument_context)
chain = prompt | llm.bind_tools(tools)
result = chain.invoke(state["messages"])
if supports_tool_calling():
chain = prompt | llm.bind_tools(tools)
result = chain.invoke(state["messages"])
else:
ticker = state["company_of_interest"]
start_date = (datetime.strptime(current_date, "%Y-%m-%d") - timedelta(days=7)).strftime("%Y-%m-%d")
tool_data = prefetch_tool_data(tools, [
{"ticker": ticker, "start_date": start_date, "end_date": current_date},
])
result = (prompt | llm).invoke([
HumanMessage(content=f"Analyze {ticker}.\n\nHere is the pre-fetched social media and news data:\n\n{tool_data}\n\nWrite your comprehensive report.")
])
report = ""
if len(result.tool_calls) == 0:
if not getattr(result, "tool_calls", None):
report = result.content
return {

View File

@ -1,5 +1,32 @@
from langchain_core.messages import HumanMessage, RemoveMessage
from tradingagents.dataflows.config import get_config
# Providers whose APIs do not support tool/function calling.
_NO_TOOL_CALLING_PROVIDERS = frozenset({"perplexity"})
def supports_tool_calling() -> bool:
"""Check if the current LLM provider supports tool calling."""
config = get_config()
return config.get("llm_provider", "openai").lower() not in _NO_TOOL_CALLING_PROVIDERS
def prefetch_tool_data(tools, tool_args_list) -> str:
"""Pre-call tools and return formatted results for prompt injection.
Used as a fallback for providers that don't support tool calling.
"""
results = []
for tool, args in zip(tools, tool_args_list):
try:
data = tool.invoke(args)
results.append(f"=== {tool.name} ===\n{data}")
except Exception as e:
results.append(f"=== {tool.name} ===\n[Error fetching data: {e}]")
return "\n\n".join(results)
# Import tools from separate utility files
from tradingagents.agents.utils.core_stock_tools import (
get_stock_data

View File

@ -15,7 +15,7 @@ def create_llm_client(
"""Create an LLM client for the specified provider.
Args:
provider: LLM provider (openai, anthropic, google, xai, ollama, openrouter)
provider: LLM provider (openai, anthropic, google, xai, perplexity, ollama, openrouter)
model: Model name/identifier
base_url: Optional base URL for API endpoint
**kwargs: Additional provider-specific arguments
@ -37,8 +37,8 @@ def create_llm_client(
if provider_lower in ("openai", "ollama", "openrouter"):
return OpenAIClient(model, base_url, provider=provider_lower, **kwargs)
if provider_lower == "xai":
return OpenAIClient(model, base_url, provider="xai", **kwargs)
if provider_lower in ("xai", "perplexity"):
return OpenAIClient(model, base_url, provider=provider_lower, **kwargs)
if provider_lower == "anthropic":
return AnthropicClient(model, base_url, **kwargs)

View File

@ -27,13 +27,14 @@ _PASSTHROUGH_KWARGS = (
# Provider base URLs and API key env vars
_PROVIDER_CONFIG = {
"xai": ("https://api.x.ai/v1", "XAI_API_KEY"),
"perplexity": ("https://api.perplexity.ai", "PERPLEXITY_API_KEY"),
"openrouter": ("https://openrouter.ai/api/v1", "OPENROUTER_API_KEY"),
"ollama": ("http://localhost:11434/v1", None),
}
class OpenAIClient(BaseLLMClient):
"""Client for OpenAI, Ollama, OpenRouter, and xAI providers.
"""Client for OpenAI, Ollama, OpenRouter, xAI, and Perplexity providers.
For native OpenAI models, uses the Responses API (/v1/responses) which
supports reasoning_effort with function tools across all model families

View File

@ -48,6 +48,15 @@ VALID_MODELS = {
"grok-4-fast-reasoning",
"grok-4-fast-non-reasoning",
],
"perplexity": [
# Sonar Pro series
"sonar-pro",
"sonar-reasoning-pro",
# Sonar series
"sonar",
# Deep Research
"sonar-deep-research",
],
}