From 23c4ad398807e0412699f5d305d52837f1450358 Mon Sep 17 00:00:00 2001 From: Santiago de Diego Date: Wed, 25 Mar 2026 19:16:10 +0100 Subject: [PATCH] Support for Perplexity API --- .env.example | 1 + .gitignore | 4 +++ cli/utils.py | 12 ++++++++ .../agents/analysts/fundamentals_analyst.py | 25 ++++++++++++---- .../agents/analysts/market_analyst.py | 26 ++++++++++++---- tradingagents/agents/analysts/news_analyst.py | 25 ++++++++++++---- .../agents/analysts/social_media_analyst.py | 30 ++++++++++++++----- tradingagents/agents/utils/agent_utils.py | 27 +++++++++++++++++ tradingagents/llm_clients/factory.py | 6 ++-- tradingagents/llm_clients/openai_client.py | 3 +- tradingagents/llm_clients/validators.py | 10 +++++++ 11 files changed, 141 insertions(+), 28 deletions(-) diff --git a/.env.example b/.env.example index 1328b838..044f2445 100644 --- a/.env.example +++ b/.env.example @@ -3,4 +3,5 @@ OPENAI_API_KEY= GOOGLE_API_KEY= ANTHROPIC_API_KEY= XAI_API_KEY= +PERPLEXITY_API_KEY= OPENROUTER_API_KEY= diff --git a/.gitignore b/.gitignore index 9a2904a9..28d87076 100644 --- a/.gitignore +++ b/.gitignore @@ -217,3 +217,7 @@ __marimo__/ # Cache **/data_cache/ + +reports/ + + diff --git a/cli/utils.py b/cli/utils.py index 18abc3a7..749423e1 100644 --- a/cli/utils.py +++ b/cli/utils.py @@ -162,6 +162,11 @@ def select_shallow_thinking_agent(provider) -> str: ("Grok 4 Fast (Non-Reasoning) - Speed optimized", "grok-4-fast-non-reasoning"), ("Grok 4.1 Fast (Reasoning) - High-performance, 2M ctx", "grok-4-1-fast-reasoning"), ], + "perplexity": [ + ("Sonar - Fast online search", "sonar"), + ("Sonar Reasoning - Reasoning with search", "sonar-reasoning"), + ("Sonar Pro - Advanced online search", "sonar-pro"), + ], "openrouter": [ ("NVIDIA Nemotron 3 Nano 30B (free)", "nvidia/nemotron-3-nano-30b-a3b:free"), ("Z.AI GLM 4.5 Air (free)", "z-ai/glm-4.5-air:free"), @@ -229,6 +234,12 @@ def select_deep_thinking_agent(provider) -> str: ("Grok 4 Fast (Reasoning) - High-performance", "grok-4-fast-reasoning"), ("Grok 4.1 Fast (Non-Reasoning) - Speed optimized, 2M ctx", "grok-4-1-fast-non-reasoning"), ], + "perplexity": [ + ("Sonar Pro - Advanced online search", "sonar-pro"), + ("Sonar Reasoning Pro - Advanced reasoning with search", "sonar-reasoning-pro"), + ("Sonar Reasoning - Reasoning with search", "sonar-reasoning"), + ("Sonar Deep Research - In-depth research", "sonar-deep-research"), + ], "openrouter": [ ("Z.AI GLM 4.5 Air (free)", "z-ai/glm-4.5-air:free"), ("NVIDIA Nemotron 3 Nano 30B (free)", "nvidia/nemotron-3-nano-30b-a3b:free"), @@ -270,6 +281,7 @@ def select_llm_provider() -> tuple[str, str]: ("Google", "https://generativelanguage.googleapis.com/v1"), ("Anthropic", "https://api.anthropic.com/"), ("xAI", "https://api.x.ai/v1"), + ("Perplexity", "https://api.perplexity.ai"), ("Openrouter", "https://openrouter.ai/api/v1"), ("Ollama", "http://localhost:11434/v1"), ] diff --git a/tradingagents/agents/analysts/fundamentals_analyst.py b/tradingagents/agents/analysts/fundamentals_analyst.py index 990398a6..dd460415 100644 --- a/tradingagents/agents/analysts/fundamentals_analyst.py +++ b/tradingagents/agents/analysts/fundamentals_analyst.py @@ -1,6 +1,6 @@ +from langchain_core.messages import HumanMessage from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder -import time -import json + from tradingagents.agents.utils.agent_utils import ( build_instrument_context, get_balance_sheet, @@ -8,6 +8,8 @@ from tradingagents.agents.utils.agent_utils import ( get_fundamentals, get_income_statement, get_insider_transactions, + prefetch_tool_data, + supports_tool_calling, ) from tradingagents.dataflows.config import get_config @@ -52,13 +54,24 @@ def create_fundamentals_analyst(llm): prompt = prompt.partial(current_date=current_date) prompt = prompt.partial(instrument_context=instrument_context) - chain = prompt | llm.bind_tools(tools) - - result = chain.invoke(state["messages"]) + if supports_tool_calling(): + chain = prompt | llm.bind_tools(tools) + result = chain.invoke(state["messages"]) + else: + ticker = state["company_of_interest"] + tool_data = prefetch_tool_data(tools, [ + {"ticker": ticker, "curr_date": current_date}, + {"ticker": ticker, "freq": "quarterly", "curr_date": current_date}, + {"ticker": ticker, "freq": "quarterly", "curr_date": current_date}, + {"ticker": ticker, "freq": "quarterly", "curr_date": current_date}, + ]) + result = (prompt | llm).invoke([ + HumanMessage(content=f"Analyze {ticker}.\n\nHere is the pre-fetched fundamental data:\n\n{tool_data}\n\nWrite your comprehensive report.") + ]) report = "" - if len(result.tool_calls) == 0: + if not getattr(result, "tool_calls", None): report = result.content return { diff --git a/tradingagents/agents/analysts/market_analyst.py b/tradingagents/agents/analysts/market_analyst.py index f5d17acd..dc4969c8 100644 --- a/tradingagents/agents/analysts/market_analyst.py +++ b/tradingagents/agents/analysts/market_analyst.py @@ -1,10 +1,14 @@ +from datetime import datetime, timedelta + +from langchain_core.messages import HumanMessage from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder -import time -import json + from tradingagents.agents.utils.agent_utils import ( build_instrument_context, get_indicators, get_stock_data, + prefetch_tool_data, + supports_tool_calling, ) from tradingagents.dataflows.config import get_config @@ -71,13 +75,23 @@ Volume-Based Indicators: prompt = prompt.partial(current_date=current_date) prompt = prompt.partial(instrument_context=instrument_context) - chain = prompt | llm.bind_tools(tools) - - result = chain.invoke(state["messages"]) + if supports_tool_calling(): + chain = prompt | llm.bind_tools(tools) + result = chain.invoke(state["messages"]) + else: + ticker = state["company_of_interest"] + start_date = (datetime.strptime(current_date, "%Y-%m-%d") - timedelta(days=30)).strftime("%Y-%m-%d") + tool_data = prefetch_tool_data(tools, [ + {"symbol": ticker, "start_date": start_date, "end_date": current_date}, + {"symbol": ticker, "indicator": "rsi,macd,boll,boll_ub,boll_lb,atr,vwma,close_50_sma", "curr_date": current_date, "look_back_days": 30}, + ]) + result = (prompt | llm).invoke([ + HumanMessage(content=f"Analyze {ticker}.\n\nHere is the pre-fetched market data:\n\n{tool_data}\n\nWrite your comprehensive report.") + ]) report = "" - if len(result.tool_calls) == 0: + if not getattr(result, "tool_calls", None): report = result.content return { diff --git a/tradingagents/agents/analysts/news_analyst.py b/tradingagents/agents/analysts/news_analyst.py index 3697c6f6..4d4c0cf1 100644 --- a/tradingagents/agents/analysts/news_analyst.py +++ b/tradingagents/agents/analysts/news_analyst.py @@ -1,10 +1,14 @@ +from datetime import datetime, timedelta + +from langchain_core.messages import HumanMessage from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder -import time -import json + from tradingagents.agents.utils.agent_utils import ( build_instrument_context, get_global_news, get_news, + prefetch_tool_data, + supports_tool_calling, ) from tradingagents.dataflows.config import get_config @@ -46,12 +50,23 @@ def create_news_analyst(llm): prompt = prompt.partial(current_date=current_date) prompt = prompt.partial(instrument_context=instrument_context) - chain = prompt | llm.bind_tools(tools) - result = chain.invoke(state["messages"]) + if supports_tool_calling(): + chain = prompt | llm.bind_tools(tools) + result = chain.invoke(state["messages"]) + else: + ticker = state["company_of_interest"] + start_date = (datetime.strptime(current_date, "%Y-%m-%d") - timedelta(days=7)).strftime("%Y-%m-%d") + tool_data = prefetch_tool_data(tools, [ + {"ticker": ticker, "start_date": start_date, "end_date": current_date}, + {"curr_date": current_date, "look_back_days": 7, "limit": 5}, + ]) + result = (prompt | llm).invoke([ + HumanMessage(content=f"Analyze {ticker}.\n\nHere is the pre-fetched news data:\n\n{tool_data}\n\nWrite your comprehensive report.") + ]) report = "" - if len(result.tool_calls) == 0: + if not getattr(result, "tool_calls", None): report = result.content return { diff --git a/tradingagents/agents/analysts/social_media_analyst.py b/tradingagents/agents/analysts/social_media_analyst.py index 43df2258..0ef9ed45 100644 --- a/tradingagents/agents/analysts/social_media_analyst.py +++ b/tradingagents/agents/analysts/social_media_analyst.py @@ -1,7 +1,14 @@ +from datetime import datetime, timedelta + +from langchain_core.messages import HumanMessage from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder -import time -import json -from tradingagents.agents.utils.agent_utils import build_instrument_context, get_news + +from tradingagents.agents.utils.agent_utils import ( + build_instrument_context, + get_news, + prefetch_tool_data, + supports_tool_calling, +) from tradingagents.dataflows.config import get_config @@ -41,13 +48,22 @@ def create_social_media_analyst(llm): prompt = prompt.partial(current_date=current_date) prompt = prompt.partial(instrument_context=instrument_context) - chain = prompt | llm.bind_tools(tools) - - result = chain.invoke(state["messages"]) + if supports_tool_calling(): + chain = prompt | llm.bind_tools(tools) + result = chain.invoke(state["messages"]) + else: + ticker = state["company_of_interest"] + start_date = (datetime.strptime(current_date, "%Y-%m-%d") - timedelta(days=7)).strftime("%Y-%m-%d") + tool_data = prefetch_tool_data(tools, [ + {"ticker": ticker, "start_date": start_date, "end_date": current_date}, + ]) + result = (prompt | llm).invoke([ + HumanMessage(content=f"Analyze {ticker}.\n\nHere is the pre-fetched social media and news data:\n\n{tool_data}\n\nWrite your comprehensive report.") + ]) report = "" - if len(result.tool_calls) == 0: + if not getattr(result, "tool_calls", None): report = result.content return { diff --git a/tradingagents/agents/utils/agent_utils.py b/tradingagents/agents/utils/agent_utils.py index e4abc4cd..8e4d7d91 100644 --- a/tradingagents/agents/utils/agent_utils.py +++ b/tradingagents/agents/utils/agent_utils.py @@ -1,5 +1,32 @@ from langchain_core.messages import HumanMessage, RemoveMessage +from tradingagents.dataflows.config import get_config + +# Providers whose APIs do not support tool/function calling. +_NO_TOOL_CALLING_PROVIDERS = frozenset({"perplexity"}) + + +def supports_tool_calling() -> bool: + """Check if the current LLM provider supports tool calling.""" + config = get_config() + return config.get("llm_provider", "openai").lower() not in _NO_TOOL_CALLING_PROVIDERS + + +def prefetch_tool_data(tools, tool_args_list) -> str: + """Pre-call tools and return formatted results for prompt injection. + + Used as a fallback for providers that don't support tool calling. + """ + results = [] + for tool, args in zip(tools, tool_args_list): + try: + data = tool.invoke(args) + results.append(f"=== {tool.name} ===\n{data}") + except Exception as e: + results.append(f"=== {tool.name} ===\n[Error fetching data: {e}]") + return "\n\n".join(results) + + # Import tools from separate utility files from tradingagents.agents.utils.core_stock_tools import ( get_stock_data diff --git a/tradingagents/llm_clients/factory.py b/tradingagents/llm_clients/factory.py index 93c2a7d3..35566ca1 100644 --- a/tradingagents/llm_clients/factory.py +++ b/tradingagents/llm_clients/factory.py @@ -15,7 +15,7 @@ def create_llm_client( """Create an LLM client for the specified provider. Args: - provider: LLM provider (openai, anthropic, google, xai, ollama, openrouter) + provider: LLM provider (openai, anthropic, google, xai, perplexity, ollama, openrouter) model: Model name/identifier base_url: Optional base URL for API endpoint **kwargs: Additional provider-specific arguments @@ -37,8 +37,8 @@ def create_llm_client( if provider_lower in ("openai", "ollama", "openrouter"): return OpenAIClient(model, base_url, provider=provider_lower, **kwargs) - if provider_lower == "xai": - return OpenAIClient(model, base_url, provider="xai", **kwargs) + if provider_lower in ("xai", "perplexity"): + return OpenAIClient(model, base_url, provider=provider_lower, **kwargs) if provider_lower == "anthropic": return AnthropicClient(model, base_url, **kwargs) diff --git a/tradingagents/llm_clients/openai_client.py b/tradingagents/llm_clients/openai_client.py index fd9b4e33..0bcba2ac 100644 --- a/tradingagents/llm_clients/openai_client.py +++ b/tradingagents/llm_clients/openai_client.py @@ -27,13 +27,14 @@ _PASSTHROUGH_KWARGS = ( # Provider base URLs and API key env vars _PROVIDER_CONFIG = { "xai": ("https://api.x.ai/v1", "XAI_API_KEY"), + "perplexity": ("https://api.perplexity.ai", "PERPLEXITY_API_KEY"), "openrouter": ("https://openrouter.ai/api/v1", "OPENROUTER_API_KEY"), "ollama": ("http://localhost:11434/v1", None), } class OpenAIClient(BaseLLMClient): - """Client for OpenAI, Ollama, OpenRouter, and xAI providers. + """Client for OpenAI, Ollama, OpenRouter, xAI, and Perplexity providers. For native OpenAI models, uses the Responses API (/v1/responses) which supports reasoning_effort with function tools across all model families diff --git a/tradingagents/llm_clients/validators.py b/tradingagents/llm_clients/validators.py index 1e2388b3..c4d6f675 100644 --- a/tradingagents/llm_clients/validators.py +++ b/tradingagents/llm_clients/validators.py @@ -48,6 +48,16 @@ VALID_MODELS = { "grok-4-fast-reasoning", "grok-4-fast-non-reasoning", ], + "perplexity": [ + # Sonar Pro series + "sonar-pro", + "sonar-reasoning-pro", + # Sonar series + "sonar", + "sonar-reasoning", + # Deep Research + "sonar-deep-research", + ], }