Merge 90482e8b39 into 589b351f2a
This commit is contained in:
commit
0d4db43be8
|
|
@ -3,4 +3,5 @@ OPENAI_API_KEY=
|
||||||
GOOGLE_API_KEY=
|
GOOGLE_API_KEY=
|
||||||
ANTHROPIC_API_KEY=
|
ANTHROPIC_API_KEY=
|
||||||
XAI_API_KEY=
|
XAI_API_KEY=
|
||||||
|
PERPLEXITY_API_KEY=
|
||||||
OPENROUTER_API_KEY=
|
OPENROUTER_API_KEY=
|
||||||
|
|
|
||||||
|
|
@ -217,3 +217,7 @@ __marimo__/
|
||||||
|
|
||||||
# Cache
|
# Cache
|
||||||
**/data_cache/
|
**/data_cache/
|
||||||
|
|
||||||
|
reports/
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
10
cli/utils.py
10
cli/utils.py
|
|
@ -162,6 +162,10 @@ def select_shallow_thinking_agent(provider) -> str:
|
||||||
("Grok 4 Fast (Non-Reasoning) - Speed optimized", "grok-4-fast-non-reasoning"),
|
("Grok 4 Fast (Non-Reasoning) - Speed optimized", "grok-4-fast-non-reasoning"),
|
||||||
("Grok 4.1 Fast (Reasoning) - High-performance, 2M ctx", "grok-4-1-fast-reasoning"),
|
("Grok 4.1 Fast (Reasoning) - High-performance, 2M ctx", "grok-4-1-fast-reasoning"),
|
||||||
],
|
],
|
||||||
|
"perplexity": [
|
||||||
|
("Sonar - Fast online search", "sonar"),
|
||||||
|
("Sonar Pro - Advanced online search", "sonar-pro"),
|
||||||
|
],
|
||||||
"openrouter": [
|
"openrouter": [
|
||||||
("NVIDIA Nemotron 3 Nano 30B (free)", "nvidia/nemotron-3-nano-30b-a3b:free"),
|
("NVIDIA Nemotron 3 Nano 30B (free)", "nvidia/nemotron-3-nano-30b-a3b:free"),
|
||||||
("Z.AI GLM 4.5 Air (free)", "z-ai/glm-4.5-air:free"),
|
("Z.AI GLM 4.5 Air (free)", "z-ai/glm-4.5-air:free"),
|
||||||
|
|
@ -229,6 +233,11 @@ def select_deep_thinking_agent(provider) -> str:
|
||||||
("Grok 4 Fast (Reasoning) - High-performance", "grok-4-fast-reasoning"),
|
("Grok 4 Fast (Reasoning) - High-performance", "grok-4-fast-reasoning"),
|
||||||
("Grok 4.1 Fast (Non-Reasoning) - Speed optimized, 2M ctx", "grok-4-1-fast-non-reasoning"),
|
("Grok 4.1 Fast (Non-Reasoning) - Speed optimized, 2M ctx", "grok-4-1-fast-non-reasoning"),
|
||||||
],
|
],
|
||||||
|
"perplexity": [
|
||||||
|
("Sonar Pro - Advanced online search", "sonar-pro"),
|
||||||
|
("Sonar Reasoning Pro - Advanced reasoning with search", "sonar-reasoning-pro"),
|
||||||
|
("Sonar Deep Research - In-depth research", "sonar-deep-research"),
|
||||||
|
],
|
||||||
"openrouter": [
|
"openrouter": [
|
||||||
("Z.AI GLM 4.5 Air (free)", "z-ai/glm-4.5-air:free"),
|
("Z.AI GLM 4.5 Air (free)", "z-ai/glm-4.5-air:free"),
|
||||||
("NVIDIA Nemotron 3 Nano 30B (free)", "nvidia/nemotron-3-nano-30b-a3b:free"),
|
("NVIDIA Nemotron 3 Nano 30B (free)", "nvidia/nemotron-3-nano-30b-a3b:free"),
|
||||||
|
|
@ -270,6 +279,7 @@ def select_llm_provider() -> tuple[str, str]:
|
||||||
("Google", "https://generativelanguage.googleapis.com/v1"),
|
("Google", "https://generativelanguage.googleapis.com/v1"),
|
||||||
("Anthropic", "https://api.anthropic.com/"),
|
("Anthropic", "https://api.anthropic.com/"),
|
||||||
("xAI", "https://api.x.ai/v1"),
|
("xAI", "https://api.x.ai/v1"),
|
||||||
|
("Perplexity", "https://api.perplexity.ai"),
|
||||||
("Openrouter", "https://openrouter.ai/api/v1"),
|
("Openrouter", "https://openrouter.ai/api/v1"),
|
||||||
("Ollama", "http://localhost:11434/v1"),
|
("Ollama", "http://localhost:11434/v1"),
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||||
import time
|
|
||||||
import json
|
|
||||||
from tradingagents.agents.utils.agent_utils import (
|
from tradingagents.agents.utils.agent_utils import (
|
||||||
build_instrument_context,
|
build_instrument_context,
|
||||||
get_balance_sheet,
|
get_balance_sheet,
|
||||||
|
|
@ -8,6 +8,8 @@ from tradingagents.agents.utils.agent_utils import (
|
||||||
get_fundamentals,
|
get_fundamentals,
|
||||||
get_income_statement,
|
get_income_statement,
|
||||||
get_insider_transactions,
|
get_insider_transactions,
|
||||||
|
prefetch_tool_data,
|
||||||
|
supports_tool_calling,
|
||||||
)
|
)
|
||||||
from tradingagents.dataflows.config import get_config
|
from tradingagents.dataflows.config import get_config
|
||||||
|
|
||||||
|
|
@ -52,13 +54,24 @@ def create_fundamentals_analyst(llm):
|
||||||
prompt = prompt.partial(current_date=current_date)
|
prompt = prompt.partial(current_date=current_date)
|
||||||
prompt = prompt.partial(instrument_context=instrument_context)
|
prompt = prompt.partial(instrument_context=instrument_context)
|
||||||
|
|
||||||
chain = prompt | llm.bind_tools(tools)
|
if supports_tool_calling():
|
||||||
|
chain = prompt | llm.bind_tools(tools)
|
||||||
result = chain.invoke(state["messages"])
|
result = chain.invoke(state["messages"])
|
||||||
|
else:
|
||||||
|
ticker = state["company_of_interest"]
|
||||||
|
tool_data = prefetch_tool_data(tools, [
|
||||||
|
{"ticker": ticker, "curr_date": current_date},
|
||||||
|
{"ticker": ticker, "freq": "quarterly", "curr_date": current_date},
|
||||||
|
{"ticker": ticker, "freq": "quarterly", "curr_date": current_date},
|
||||||
|
{"ticker": ticker, "freq": "quarterly", "curr_date": current_date},
|
||||||
|
])
|
||||||
|
result = (prompt | llm).invoke([
|
||||||
|
HumanMessage(content=f"Analyze {ticker}.\n\nHere is the pre-fetched fundamental data:\n\n{tool_data}\n\nWrite your comprehensive report.")
|
||||||
|
])
|
||||||
|
|
||||||
report = ""
|
report = ""
|
||||||
|
|
||||||
if len(result.tool_calls) == 0:
|
if not getattr(result, "tool_calls", None):
|
||||||
report = result.content
|
report = result.content
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,14 @@
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||||
import time
|
|
||||||
import json
|
|
||||||
from tradingagents.agents.utils.agent_utils import (
|
from tradingagents.agents.utils.agent_utils import (
|
||||||
build_instrument_context,
|
build_instrument_context,
|
||||||
get_indicators,
|
get_indicators,
|
||||||
get_stock_data,
|
get_stock_data,
|
||||||
|
prefetch_tool_data,
|
||||||
|
supports_tool_calling,
|
||||||
)
|
)
|
||||||
from tradingagents.dataflows.config import get_config
|
from tradingagents.dataflows.config import get_config
|
||||||
|
|
||||||
|
|
@ -71,13 +75,23 @@ Volume-Based Indicators:
|
||||||
prompt = prompt.partial(current_date=current_date)
|
prompt = prompt.partial(current_date=current_date)
|
||||||
prompt = prompt.partial(instrument_context=instrument_context)
|
prompt = prompt.partial(instrument_context=instrument_context)
|
||||||
|
|
||||||
chain = prompt | llm.bind_tools(tools)
|
if supports_tool_calling():
|
||||||
|
chain = prompt | llm.bind_tools(tools)
|
||||||
result = chain.invoke(state["messages"])
|
result = chain.invoke(state["messages"])
|
||||||
|
else:
|
||||||
|
ticker = state["company_of_interest"]
|
||||||
|
start_date = (datetime.strptime(current_date, "%Y-%m-%d") - timedelta(days=30)).strftime("%Y-%m-%d")
|
||||||
|
tool_data = prefetch_tool_data(tools, [
|
||||||
|
{"symbol": ticker, "start_date": start_date, "end_date": current_date},
|
||||||
|
{"symbol": ticker, "indicator": "rsi,macd,boll,boll_ub,boll_lb,atr,vwma,close_50_sma", "curr_date": current_date, "look_back_days": 30},
|
||||||
|
])
|
||||||
|
result = (prompt | llm).invoke([
|
||||||
|
HumanMessage(content=f"Analyze {ticker}.\n\nHere is the pre-fetched market data:\n\n{tool_data}\n\nWrite your comprehensive report.")
|
||||||
|
])
|
||||||
|
|
||||||
report = ""
|
report = ""
|
||||||
|
|
||||||
if len(result.tool_calls) == 0:
|
if not getattr(result, "tool_calls", None):
|
||||||
report = result.content
|
report = result.content
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,14 @@
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||||
import time
|
|
||||||
import json
|
|
||||||
from tradingagents.agents.utils.agent_utils import (
|
from tradingagents.agents.utils.agent_utils import (
|
||||||
build_instrument_context,
|
build_instrument_context,
|
||||||
get_global_news,
|
get_global_news,
|
||||||
get_news,
|
get_news,
|
||||||
|
prefetch_tool_data,
|
||||||
|
supports_tool_calling,
|
||||||
)
|
)
|
||||||
from tradingagents.dataflows.config import get_config
|
from tradingagents.dataflows.config import get_config
|
||||||
|
|
||||||
|
|
@ -46,12 +50,23 @@ def create_news_analyst(llm):
|
||||||
prompt = prompt.partial(current_date=current_date)
|
prompt = prompt.partial(current_date=current_date)
|
||||||
prompt = prompt.partial(instrument_context=instrument_context)
|
prompt = prompt.partial(instrument_context=instrument_context)
|
||||||
|
|
||||||
chain = prompt | llm.bind_tools(tools)
|
if supports_tool_calling():
|
||||||
result = chain.invoke(state["messages"])
|
chain = prompt | llm.bind_tools(tools)
|
||||||
|
result = chain.invoke(state["messages"])
|
||||||
|
else:
|
||||||
|
ticker = state["company_of_interest"]
|
||||||
|
start_date = (datetime.strptime(current_date, "%Y-%m-%d") - timedelta(days=7)).strftime("%Y-%m-%d")
|
||||||
|
tool_data = prefetch_tool_data(tools, [
|
||||||
|
{"ticker": ticker, "start_date": start_date, "end_date": current_date},
|
||||||
|
{"curr_date": current_date, "look_back_days": 7, "limit": 5},
|
||||||
|
])
|
||||||
|
result = (prompt | llm).invoke([
|
||||||
|
HumanMessage(content=f"Analyze {ticker}.\n\nHere is the pre-fetched news data:\n\n{tool_data}\n\nWrite your comprehensive report.")
|
||||||
|
])
|
||||||
|
|
||||||
report = ""
|
report = ""
|
||||||
|
|
||||||
if len(result.tool_calls) == 0:
|
if not getattr(result, "tool_calls", None):
|
||||||
report = result.content
|
report = result.content
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,14 @@
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||||
import time
|
|
||||||
import json
|
from tradingagents.agents.utils.agent_utils import (
|
||||||
from tradingagents.agents.utils.agent_utils import build_instrument_context, get_news
|
build_instrument_context,
|
||||||
|
get_news,
|
||||||
|
prefetch_tool_data,
|
||||||
|
supports_tool_calling,
|
||||||
|
)
|
||||||
from tradingagents.dataflows.config import get_config
|
from tradingagents.dataflows.config import get_config
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -41,13 +48,22 @@ def create_social_media_analyst(llm):
|
||||||
prompt = prompt.partial(current_date=current_date)
|
prompt = prompt.partial(current_date=current_date)
|
||||||
prompt = prompt.partial(instrument_context=instrument_context)
|
prompt = prompt.partial(instrument_context=instrument_context)
|
||||||
|
|
||||||
chain = prompt | llm.bind_tools(tools)
|
if supports_tool_calling():
|
||||||
|
chain = prompt | llm.bind_tools(tools)
|
||||||
result = chain.invoke(state["messages"])
|
result = chain.invoke(state["messages"])
|
||||||
|
else:
|
||||||
|
ticker = state["company_of_interest"]
|
||||||
|
start_date = (datetime.strptime(current_date, "%Y-%m-%d") - timedelta(days=7)).strftime("%Y-%m-%d")
|
||||||
|
tool_data = prefetch_tool_data(tools, [
|
||||||
|
{"ticker": ticker, "start_date": start_date, "end_date": current_date},
|
||||||
|
])
|
||||||
|
result = (prompt | llm).invoke([
|
||||||
|
HumanMessage(content=f"Analyze {ticker}.\n\nHere is the pre-fetched social media and news data:\n\n{tool_data}\n\nWrite your comprehensive report.")
|
||||||
|
])
|
||||||
|
|
||||||
report = ""
|
report = ""
|
||||||
|
|
||||||
if len(result.tool_calls) == 0:
|
if not getattr(result, "tool_calls", None):
|
||||||
report = result.content
|
report = result.content
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,32 @@
|
||||||
from langchain_core.messages import HumanMessage, RemoveMessage
|
from langchain_core.messages import HumanMessage, RemoveMessage
|
||||||
|
|
||||||
|
from tradingagents.dataflows.config import get_config
|
||||||
|
|
||||||
|
# Providers whose APIs do not support tool/function calling.
|
||||||
|
_NO_TOOL_CALLING_PROVIDERS = frozenset({"perplexity"})
|
||||||
|
|
||||||
|
|
||||||
|
def supports_tool_calling() -> bool:
|
||||||
|
"""Check if the current LLM provider supports tool calling."""
|
||||||
|
config = get_config()
|
||||||
|
return config.get("llm_provider", "openai").lower() not in _NO_TOOL_CALLING_PROVIDERS
|
||||||
|
|
||||||
|
|
||||||
|
def prefetch_tool_data(tools, tool_args_list) -> str:
|
||||||
|
"""Pre-call tools and return formatted results for prompt injection.
|
||||||
|
|
||||||
|
Used as a fallback for providers that don't support tool calling.
|
||||||
|
"""
|
||||||
|
results = []
|
||||||
|
for tool, args in zip(tools, tool_args_list):
|
||||||
|
try:
|
||||||
|
data = tool.invoke(args)
|
||||||
|
results.append(f"=== {tool.name} ===\n{data}")
|
||||||
|
except Exception as e:
|
||||||
|
results.append(f"=== {tool.name} ===\n[Error fetching data: {e}]")
|
||||||
|
return "\n\n".join(results)
|
||||||
|
|
||||||
|
|
||||||
# Import tools from separate utility files
|
# Import tools from separate utility files
|
||||||
from tradingagents.agents.utils.core_stock_tools import (
|
from tradingagents.agents.utils.core_stock_tools import (
|
||||||
get_stock_data
|
get_stock_data
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ def create_llm_client(
|
||||||
"""Create an LLM client for the specified provider.
|
"""Create an LLM client for the specified provider.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
provider: LLM provider (openai, anthropic, google, xai, ollama, openrouter)
|
provider: LLM provider (openai, anthropic, google, xai, perplexity, ollama, openrouter)
|
||||||
model: Model name/identifier
|
model: Model name/identifier
|
||||||
base_url: Optional base URL for API endpoint
|
base_url: Optional base URL for API endpoint
|
||||||
**kwargs: Additional provider-specific arguments
|
**kwargs: Additional provider-specific arguments
|
||||||
|
|
@ -37,8 +37,8 @@ def create_llm_client(
|
||||||
if provider_lower in ("openai", "ollama", "openrouter"):
|
if provider_lower in ("openai", "ollama", "openrouter"):
|
||||||
return OpenAIClient(model, base_url, provider=provider_lower, **kwargs)
|
return OpenAIClient(model, base_url, provider=provider_lower, **kwargs)
|
||||||
|
|
||||||
if provider_lower == "xai":
|
if provider_lower in ("xai", "perplexity"):
|
||||||
return OpenAIClient(model, base_url, provider="xai", **kwargs)
|
return OpenAIClient(model, base_url, provider=provider_lower, **kwargs)
|
||||||
|
|
||||||
if provider_lower == "anthropic":
|
if provider_lower == "anthropic":
|
||||||
return AnthropicClient(model, base_url, **kwargs)
|
return AnthropicClient(model, base_url, **kwargs)
|
||||||
|
|
|
||||||
|
|
@ -27,13 +27,14 @@ _PASSTHROUGH_KWARGS = (
|
||||||
# Provider base URLs and API key env vars
|
# Provider base URLs and API key env vars
|
||||||
_PROVIDER_CONFIG = {
|
_PROVIDER_CONFIG = {
|
||||||
"xai": ("https://api.x.ai/v1", "XAI_API_KEY"),
|
"xai": ("https://api.x.ai/v1", "XAI_API_KEY"),
|
||||||
|
"perplexity": ("https://api.perplexity.ai", "PERPLEXITY_API_KEY"),
|
||||||
"openrouter": ("https://openrouter.ai/api/v1", "OPENROUTER_API_KEY"),
|
"openrouter": ("https://openrouter.ai/api/v1", "OPENROUTER_API_KEY"),
|
||||||
"ollama": ("http://localhost:11434/v1", None),
|
"ollama": ("http://localhost:11434/v1", None),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class OpenAIClient(BaseLLMClient):
|
class OpenAIClient(BaseLLMClient):
|
||||||
"""Client for OpenAI, Ollama, OpenRouter, and xAI providers.
|
"""Client for OpenAI, Ollama, OpenRouter, xAI, and Perplexity providers.
|
||||||
|
|
||||||
For native OpenAI models, uses the Responses API (/v1/responses) which
|
For native OpenAI models, uses the Responses API (/v1/responses) which
|
||||||
supports reasoning_effort with function tools across all model families
|
supports reasoning_effort with function tools across all model families
|
||||||
|
|
|
||||||
|
|
@ -48,6 +48,15 @@ VALID_MODELS = {
|
||||||
"grok-4-fast-reasoning",
|
"grok-4-fast-reasoning",
|
||||||
"grok-4-fast-non-reasoning",
|
"grok-4-fast-non-reasoning",
|
||||||
],
|
],
|
||||||
|
"perplexity": [
|
||||||
|
# Sonar Pro series
|
||||||
|
"sonar-pro",
|
||||||
|
"sonar-reasoning-pro",
|
||||||
|
# Sonar series
|
||||||
|
"sonar",
|
||||||
|
# Deep Research
|
||||||
|
"sonar-deep-research",
|
||||||
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue