Support for Perplexity API
This commit is contained in:
parent
589b351f2a
commit
23c4ad3988
|
|
@ -3,4 +3,5 @@ OPENAI_API_KEY=
|
|||
GOOGLE_API_KEY=
|
||||
ANTHROPIC_API_KEY=
|
||||
XAI_API_KEY=
|
||||
PERPLEXITY_API_KEY=
|
||||
OPENROUTER_API_KEY=
|
||||
|
|
|
|||
|
|
@ -217,3 +217,7 @@ __marimo__/
|
|||
|
||||
# Cache
|
||||
**/data_cache/
|
||||
|
||||
reports/
|
||||
|
||||
|
||||
|
|
|
|||
12
cli/utils.py
12
cli/utils.py
|
|
@ -162,6 +162,11 @@ def select_shallow_thinking_agent(provider) -> str:
|
|||
("Grok 4 Fast (Non-Reasoning) - Speed optimized", "grok-4-fast-non-reasoning"),
|
||||
("Grok 4.1 Fast (Reasoning) - High-performance, 2M ctx", "grok-4-1-fast-reasoning"),
|
||||
],
|
||||
"perplexity": [
|
||||
("Sonar - Fast online search", "sonar"),
|
||||
("Sonar Reasoning - Reasoning with search", "sonar-reasoning"),
|
||||
("Sonar Pro - Advanced online search", "sonar-pro"),
|
||||
],
|
||||
"openrouter": [
|
||||
("NVIDIA Nemotron 3 Nano 30B (free)", "nvidia/nemotron-3-nano-30b-a3b:free"),
|
||||
("Z.AI GLM 4.5 Air (free)", "z-ai/glm-4.5-air:free"),
|
||||
|
|
@ -229,6 +234,12 @@ def select_deep_thinking_agent(provider) -> str:
|
|||
("Grok 4 Fast (Reasoning) - High-performance", "grok-4-fast-reasoning"),
|
||||
("Grok 4.1 Fast (Non-Reasoning) - Speed optimized, 2M ctx", "grok-4-1-fast-non-reasoning"),
|
||||
],
|
||||
"perplexity": [
|
||||
("Sonar Pro - Advanced online search", "sonar-pro"),
|
||||
("Sonar Reasoning Pro - Advanced reasoning with search", "sonar-reasoning-pro"),
|
||||
("Sonar Reasoning - Reasoning with search", "sonar-reasoning"),
|
||||
("Sonar Deep Research - In-depth research", "sonar-deep-research"),
|
||||
],
|
||||
"openrouter": [
|
||||
("Z.AI GLM 4.5 Air (free)", "z-ai/glm-4.5-air:free"),
|
||||
("NVIDIA Nemotron 3 Nano 30B (free)", "nvidia/nemotron-3-nano-30b-a3b:free"),
|
||||
|
|
@ -270,6 +281,7 @@ def select_llm_provider() -> tuple[str, str]:
|
|||
("Google", "https://generativelanguage.googleapis.com/v1"),
|
||||
("Anthropic", "https://api.anthropic.com/"),
|
||||
("xAI", "https://api.x.ai/v1"),
|
||||
("Perplexity", "https://api.perplexity.ai"),
|
||||
("Openrouter", "https://openrouter.ai/api/v1"),
|
||||
("Ollama", "http://localhost:11434/v1"),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
import time
|
||||
import json
|
||||
|
||||
from tradingagents.agents.utils.agent_utils import (
|
||||
build_instrument_context,
|
||||
get_balance_sheet,
|
||||
|
|
@ -8,6 +8,8 @@ from tradingagents.agents.utils.agent_utils import (
|
|||
get_fundamentals,
|
||||
get_income_statement,
|
||||
get_insider_transactions,
|
||||
prefetch_tool_data,
|
||||
supports_tool_calling,
|
||||
)
|
||||
from tradingagents.dataflows.config import get_config
|
||||
|
||||
|
|
@ -52,13 +54,24 @@ def create_fundamentals_analyst(llm):
|
|||
prompt = prompt.partial(current_date=current_date)
|
||||
prompt = prompt.partial(instrument_context=instrument_context)
|
||||
|
||||
chain = prompt | llm.bind_tools(tools)
|
||||
|
||||
result = chain.invoke(state["messages"])
|
||||
if supports_tool_calling():
|
||||
chain = prompt | llm.bind_tools(tools)
|
||||
result = chain.invoke(state["messages"])
|
||||
else:
|
||||
ticker = state["company_of_interest"]
|
||||
tool_data = prefetch_tool_data(tools, [
|
||||
{"ticker": ticker, "curr_date": current_date},
|
||||
{"ticker": ticker, "freq": "quarterly", "curr_date": current_date},
|
||||
{"ticker": ticker, "freq": "quarterly", "curr_date": current_date},
|
||||
{"ticker": ticker, "freq": "quarterly", "curr_date": current_date},
|
||||
])
|
||||
result = (prompt | llm).invoke([
|
||||
HumanMessage(content=f"Analyze {ticker}.\n\nHere is the pre-fetched fundamental data:\n\n{tool_data}\n\nWrite your comprehensive report.")
|
||||
])
|
||||
|
||||
report = ""
|
||||
|
||||
if len(result.tool_calls) == 0:
|
||||
if not getattr(result, "tool_calls", None):
|
||||
report = result.content
|
||||
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -1,10 +1,14 @@
|
|||
from datetime import datetime, timedelta
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
import time
|
||||
import json
|
||||
|
||||
from tradingagents.agents.utils.agent_utils import (
|
||||
build_instrument_context,
|
||||
get_indicators,
|
||||
get_stock_data,
|
||||
prefetch_tool_data,
|
||||
supports_tool_calling,
|
||||
)
|
||||
from tradingagents.dataflows.config import get_config
|
||||
|
||||
|
|
@ -71,13 +75,23 @@ Volume-Based Indicators:
|
|||
prompt = prompt.partial(current_date=current_date)
|
||||
prompt = prompt.partial(instrument_context=instrument_context)
|
||||
|
||||
chain = prompt | llm.bind_tools(tools)
|
||||
|
||||
result = chain.invoke(state["messages"])
|
||||
if supports_tool_calling():
|
||||
chain = prompt | llm.bind_tools(tools)
|
||||
result = chain.invoke(state["messages"])
|
||||
else:
|
||||
ticker = state["company_of_interest"]
|
||||
start_date = (datetime.strptime(current_date, "%Y-%m-%d") - timedelta(days=30)).strftime("%Y-%m-%d")
|
||||
tool_data = prefetch_tool_data(tools, [
|
||||
{"symbol": ticker, "start_date": start_date, "end_date": current_date},
|
||||
{"symbol": ticker, "indicator": "rsi,macd,boll,boll_ub,boll_lb,atr,vwma,close_50_sma", "curr_date": current_date, "look_back_days": 30},
|
||||
])
|
||||
result = (prompt | llm).invoke([
|
||||
HumanMessage(content=f"Analyze {ticker}.\n\nHere is the pre-fetched market data:\n\n{tool_data}\n\nWrite your comprehensive report.")
|
||||
])
|
||||
|
||||
report = ""
|
||||
|
||||
if len(result.tool_calls) == 0:
|
||||
if not getattr(result, "tool_calls", None):
|
||||
report = result.content
|
||||
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -1,10 +1,14 @@
|
|||
from datetime import datetime, timedelta
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
import time
|
||||
import json
|
||||
|
||||
from tradingagents.agents.utils.agent_utils import (
|
||||
build_instrument_context,
|
||||
get_global_news,
|
||||
get_news,
|
||||
prefetch_tool_data,
|
||||
supports_tool_calling,
|
||||
)
|
||||
from tradingagents.dataflows.config import get_config
|
||||
|
||||
|
|
@ -46,12 +50,23 @@ def create_news_analyst(llm):
|
|||
prompt = prompt.partial(current_date=current_date)
|
||||
prompt = prompt.partial(instrument_context=instrument_context)
|
||||
|
||||
chain = prompt | llm.bind_tools(tools)
|
||||
result = chain.invoke(state["messages"])
|
||||
if supports_tool_calling():
|
||||
chain = prompt | llm.bind_tools(tools)
|
||||
result = chain.invoke(state["messages"])
|
||||
else:
|
||||
ticker = state["company_of_interest"]
|
||||
start_date = (datetime.strptime(current_date, "%Y-%m-%d") - timedelta(days=7)).strftime("%Y-%m-%d")
|
||||
tool_data = prefetch_tool_data(tools, [
|
||||
{"ticker": ticker, "start_date": start_date, "end_date": current_date},
|
||||
{"curr_date": current_date, "look_back_days": 7, "limit": 5},
|
||||
])
|
||||
result = (prompt | llm).invoke([
|
||||
HumanMessage(content=f"Analyze {ticker}.\n\nHere is the pre-fetched news data:\n\n{tool_data}\n\nWrite your comprehensive report.")
|
||||
])
|
||||
|
||||
report = ""
|
||||
|
||||
if len(result.tool_calls) == 0:
|
||||
if not getattr(result, "tool_calls", None):
|
||||
report = result.content
|
||||
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -1,7 +1,14 @@
|
|||
from datetime import datetime, timedelta
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
import time
|
||||
import json
|
||||
from tradingagents.agents.utils.agent_utils import build_instrument_context, get_news
|
||||
|
||||
from tradingagents.agents.utils.agent_utils import (
|
||||
build_instrument_context,
|
||||
get_news,
|
||||
prefetch_tool_data,
|
||||
supports_tool_calling,
|
||||
)
|
||||
from tradingagents.dataflows.config import get_config
|
||||
|
||||
|
||||
|
|
@ -41,13 +48,22 @@ def create_social_media_analyst(llm):
|
|||
prompt = prompt.partial(current_date=current_date)
|
||||
prompt = prompt.partial(instrument_context=instrument_context)
|
||||
|
||||
chain = prompt | llm.bind_tools(tools)
|
||||
|
||||
result = chain.invoke(state["messages"])
|
||||
if supports_tool_calling():
|
||||
chain = prompt | llm.bind_tools(tools)
|
||||
result = chain.invoke(state["messages"])
|
||||
else:
|
||||
ticker = state["company_of_interest"]
|
||||
start_date = (datetime.strptime(current_date, "%Y-%m-%d") - timedelta(days=7)).strftime("%Y-%m-%d")
|
||||
tool_data = prefetch_tool_data(tools, [
|
||||
{"ticker": ticker, "start_date": start_date, "end_date": current_date},
|
||||
])
|
||||
result = (prompt | llm).invoke([
|
||||
HumanMessage(content=f"Analyze {ticker}.\n\nHere is the pre-fetched social media and news data:\n\n{tool_data}\n\nWrite your comprehensive report.")
|
||||
])
|
||||
|
||||
report = ""
|
||||
|
||||
if len(result.tool_calls) == 0:
|
||||
if not getattr(result, "tool_calls", None):
|
||||
report = result.content
|
||||
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -1,5 +1,32 @@
|
|||
from langchain_core.messages import HumanMessage, RemoveMessage
|
||||
|
||||
from tradingagents.dataflows.config import get_config
|
||||
|
||||
# Providers whose APIs do not support tool/function calling.
|
||||
_NO_TOOL_CALLING_PROVIDERS = frozenset({"perplexity"})
|
||||
|
||||
|
||||
def supports_tool_calling() -> bool:
|
||||
"""Check if the current LLM provider supports tool calling."""
|
||||
config = get_config()
|
||||
return config.get("llm_provider", "openai").lower() not in _NO_TOOL_CALLING_PROVIDERS
|
||||
|
||||
|
||||
def prefetch_tool_data(tools, tool_args_list) -> str:
|
||||
"""Pre-call tools and return formatted results for prompt injection.
|
||||
|
||||
Used as a fallback for providers that don't support tool calling.
|
||||
"""
|
||||
results = []
|
||||
for tool, args in zip(tools, tool_args_list):
|
||||
try:
|
||||
data = tool.invoke(args)
|
||||
results.append(f"=== {tool.name} ===\n{data}")
|
||||
except Exception as e:
|
||||
results.append(f"=== {tool.name} ===\n[Error fetching data: {e}]")
|
||||
return "\n\n".join(results)
|
||||
|
||||
|
||||
# Import tools from separate utility files
|
||||
from tradingagents.agents.utils.core_stock_tools import (
|
||||
get_stock_data
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ def create_llm_client(
|
|||
"""Create an LLM client for the specified provider.
|
||||
|
||||
Args:
|
||||
provider: LLM provider (openai, anthropic, google, xai, ollama, openrouter)
|
||||
provider: LLM provider (openai, anthropic, google, xai, perplexity, ollama, openrouter)
|
||||
model: Model name/identifier
|
||||
base_url: Optional base URL for API endpoint
|
||||
**kwargs: Additional provider-specific arguments
|
||||
|
|
@ -37,8 +37,8 @@ def create_llm_client(
|
|||
if provider_lower in ("openai", "ollama", "openrouter"):
|
||||
return OpenAIClient(model, base_url, provider=provider_lower, **kwargs)
|
||||
|
||||
if provider_lower == "xai":
|
||||
return OpenAIClient(model, base_url, provider="xai", **kwargs)
|
||||
if provider_lower in ("xai", "perplexity"):
|
||||
return OpenAIClient(model, base_url, provider=provider_lower, **kwargs)
|
||||
|
||||
if provider_lower == "anthropic":
|
||||
return AnthropicClient(model, base_url, **kwargs)
|
||||
|
|
|
|||
|
|
@ -27,13 +27,14 @@ _PASSTHROUGH_KWARGS = (
|
|||
# Provider base URLs and API key env vars
|
||||
_PROVIDER_CONFIG = {
|
||||
"xai": ("https://api.x.ai/v1", "XAI_API_KEY"),
|
||||
"perplexity": ("https://api.perplexity.ai", "PERPLEXITY_API_KEY"),
|
||||
"openrouter": ("https://openrouter.ai/api/v1", "OPENROUTER_API_KEY"),
|
||||
"ollama": ("http://localhost:11434/v1", None),
|
||||
}
|
||||
|
||||
|
||||
class OpenAIClient(BaseLLMClient):
|
||||
"""Client for OpenAI, Ollama, OpenRouter, and xAI providers.
|
||||
"""Client for OpenAI, Ollama, OpenRouter, xAI, and Perplexity providers.
|
||||
|
||||
For native OpenAI models, uses the Responses API (/v1/responses) which
|
||||
supports reasoning_effort with function tools across all model families
|
||||
|
|
|
|||
|
|
@ -48,6 +48,16 @@ VALID_MODELS = {
|
|||
"grok-4-fast-reasoning",
|
||||
"grok-4-fast-non-reasoning",
|
||||
],
|
||||
"perplexity": [
|
||||
# Sonar Pro series
|
||||
"sonar-pro",
|
||||
"sonar-reasoning-pro",
|
||||
# Sonar series
|
||||
"sonar",
|
||||
"sonar-reasoning",
|
||||
# Deep Research
|
||||
"sonar-deep-research",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue