TradingAgents v0.2.0: Multi-Provider LLM Support & Optimizations (#331)

Release v0.2.0: Multi-Provider LLM Support
This commit is contained in:
Yijia Xiao 2026-02-03 23:13:43 -08:00 committed by GitHub
commit e9470b69c4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
49 changed files with 1999 additions and 3357 deletions

View File

@ -1,2 +1,6 @@
ALPHA_VANTAGE_API_KEY=alpha_vantage_api_key_placeholder
OPENAI_API_KEY=openai_api_key_placeholder
# LLM Providers (set the one you use)
OPENAI_API_KEY=
GOOGLE_API_KEY=
ANTHROPIC_API_KEY=
XAI_API_KEY=
OPENROUTER_API_KEY=

224
.gitignore vendored
View File

@ -1,11 +1,219 @@
.venv
results
env/
# Byte-compiled / optimized / DLL files
__pycache__/
.DS_Store
*.csv
src/
eval_results/
eval_data/
*.py[codz]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py.cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
# Pipfile.lock
# UV
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# uv.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
# poetry.lock
# poetry.toml
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
# https://pdm-project.org/en/latest/usage/project/#working-with-version-control
# pdm.lock
# pdm.toml
.pdm-python
.pdm-build/
# pixi
# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
# pixi.lock
# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
# in the .venv directory. It is recommended not to include this directory in version control.
.pixi
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# Redis
*.rdb
*.aof
*.pid
# RabbitMQ
mnesia/
rabbitmq/
rabbitmq-data/
# ActiveMQ
activemq-data/
# SageMath parsed files
*.sage.py
# Environments
.env
.envrc
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
# .idea/
# Abstra
# Abstra is an AI-powered process automation framework.
# Ignore directories containing user credentials, local state, and settings.
# Learn more at https://abstra.io/docs
.abstra/
# Visual Studio Code
# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
# and can be added to the global gitignore or merged into this file. However, if you prefer,
# you could uncomment the following to ignore the entire vscode folder
# .vscode/
# Ruff stuff:
.ruff_cache/
# PyPI configuration file
.pypirc
# Marimo
marimo/_static/
marimo/_lsp/
__marimo__/
# Streamlit
.streamlit/secrets.toml
# Cache
**/data_cache/

View File

@ -25,11 +25,11 @@
---
# TradingAgents: Multi-Agents LLM Financial Trading Framework
# TradingAgents: Multi-Agents LLM Financial Trading Framework
> 🎉 **TradingAgents** officially released! We have received numerous inquiries about the work, and we would like to express our thanks for the enthusiasm in our community.
>
> So we decided to fully open-source the framework. Looking forward to building impactful projects with you!
## News
- [2026-02] **TradingAgents v0.2.0** released with multi-provider LLM support (GPT-5.x, Gemini 3.x, Claude 4.x, Grok 4.x) and improved system architecture.
- [2026-01] **Trading-R1** [Technical Report](https://arxiv.org/abs/2509.11420) released, with [Terminal](https://github.com/TauricResearch/Trading-R1) expected to land soon.
<div align="center">
<a href="https://www.star-history.com/#TauricResearch/TradingAgents&Date">
@ -41,6 +41,10 @@
</a>
</div>
> 🎉 **TradingAgents** officially released! We have received numerous inquiries about the work, and we would like to express our thanks for the enthusiasm in our community.
>
> So we decided to fully open-source the framework. Looking forward to building impactful projects with you!
<div align="center">
🚀 [TradingAgents](#tradingagents-framework) | ⚡ [Installation & CLI](#installation-and-cli) | 🎬 [Demo](https://www.youtube.com/watch?v=90gr5lwjIho) | 📦 [Package Usage](#tradingagents-package) | 🤝 [Contributing](#contributing) | 📄 [Citation](#citation)
@ -114,21 +118,24 @@ pip install -r requirements.txt
### Required APIs
You will need the OpenAI API for all the agents, and [Alpha Vantage API](https://www.alphavantage.co/support/#api-key) for fundamental and news data (default configuration).
TradingAgents supports multiple LLM providers. Set the API key for your chosen provider:
```bash
export OPENAI_API_KEY=$YOUR_OPENAI_API_KEY
export ALPHA_VANTAGE_API_KEY=$YOUR_ALPHA_VANTAGE_API_KEY
export OPENAI_API_KEY=... # OpenAI (GPT)
export GOOGLE_API_KEY=... # Google (Gemini)
export ANTHROPIC_API_KEY=... # Anthropic (Claude)
export XAI_API_KEY=... # xAI (Grok)
export OPENROUTER_API_KEY=... # OpenRouter
export ALPHA_VANTAGE_API_KEY=... # Alpha Vantage
```
Alternatively, you can create a `.env` file in the project root with your API keys (see `.env.example` for reference):
For local models, configure Ollama with `llm_provider: "ollama"` in your config.
Alternatively, copy `.env.example` to `.env` and fill in your keys:
```bash
cp .env.example .env
# Edit .env with your actual API keys
```
**Note:** We are happy to partner with Alpha Vantage to provide robust API support for TradingAgents. You can get a free AlphaVantage API [here](https://www.alphavantage.co/support/#api-key), TradingAgents-sourced requests also have increased rate limits to 60 requests per minute with no daily limits. Typically the quota is sufficient for performing complex tasks with TradingAgents thanks to Alpha Vantages open-source support program. If you prefer to use OpenAI for these data sources instead, you can modify the data vendor settings in `tradingagents/default_config.py`.
### CLI Usage
You can also try out the CLI directly by running:
@ -155,7 +162,7 @@ An interface will appear showing results as they load, letting you track the age
### Implementation Details
We built TradingAgents with LangGraph to ensure flexibility and modularity. We utilize `o1-preview` and `gpt-4o` as our deep thinking and fast thinking LLMs for our experiments. However, for testing purposes, we recommend you use `o4-mini` and `gpt-4.1-mini` to save on costs as our framework makes **lots of** API calls.
We built TradingAgents with LangGraph to ensure flexibility and modularity. The framework supports multiple LLM providers: OpenAI, Google, Anthropic, xAI, OpenRouter, and Ollama.
### Python Usage
@ -168,7 +175,7 @@ from tradingagents.default_config import DEFAULT_CONFIG
ta = TradingAgentsGraph(debug=True, config=DEFAULT_CONFIG.copy())
# forward propagate
_, decision = ta.propagate("NVDA", "2024-05-10")
_, decision = ta.propagate("NVDA", "2026-01-15")
print(decision)
```
@ -178,31 +185,18 @@ You can also adjust the default configuration to set your own choice of LLMs, de
from tradingagents.graph.trading_graph import TradingAgentsGraph
from tradingagents.default_config import DEFAULT_CONFIG
# Create a custom config
config = DEFAULT_CONFIG.copy()
config["deep_think_llm"] = "gpt-4.1-nano" # Use a different model
config["quick_think_llm"] = "gpt-4.1-nano" # Use a different model
config["max_debate_rounds"] = 1 # Increase debate rounds
config["llm_provider"] = "openai" # openai, google, anthropic, xai, openrouter, ollama
config["deep_think_llm"] = "gpt-5.2" # Model for complex reasoning
config["quick_think_llm"] = "gpt-5-mini" # Model for quick tasks
config["max_debate_rounds"] = 2
# Configure data vendors (default uses yfinance and Alpha Vantage)
config["data_vendors"] = {
"core_stock_apis": "yfinance", # Options: yfinance, alpha_vantage, local
"technical_indicators": "yfinance", # Options: yfinance, alpha_vantage, local
"fundamental_data": "alpha_vantage", # Options: openai, alpha_vantage, local
"news_data": "alpha_vantage", # Options: openai, alpha_vantage, google, local
}
# Initialize with custom config
ta = TradingAgentsGraph(debug=True, config=config)
# forward propagate
_, decision = ta.propagate("NVDA", "2024-05-10")
_, decision = ta.propagate("NVDA", "2026-01-15")
print(decision)
```
> The default configuration uses yfinance for stock price and technical data, and Alpha Vantage for fundamental and news data. For production use or if you encounter rate limits, consider upgrading to [Alpha Vantage Premium](https://www.alphavantage.co/premium/) for more stable and reliable data access. For offline experimentation, there's a local data vendor option that uses our **Tauric TradingDB**, a curated dataset for backtesting, though this is still in development. We're currently refining this dataset and plan to release it soon alongside our upcoming projects. Stay tuned!
You can view the full list of configurations in `tradingagents/default_config.py`.
See `tradingagents/default_config.py` for all configuration options.
## Contributing

51
cli/announcements.py Normal file
View File

@ -0,0 +1,51 @@
import getpass
import requests
from rich.console import Console
from rich.panel import Panel
from cli.config import CLI_CONFIG
def fetch_announcements(url: str = None, timeout: float = None) -> dict:
"""Fetch announcements from endpoint. Returns dict with announcements and settings."""
endpoint = url or CLI_CONFIG["announcements_url"]
timeout = timeout or CLI_CONFIG["announcements_timeout"]
fallback = CLI_CONFIG["announcements_fallback"]
try:
response = requests.get(endpoint, timeout=timeout)
response.raise_for_status()
data = response.json()
return {
"announcements": data.get("announcements", [fallback]),
"require_attention": data.get("require_attention", False),
}
except Exception:
return {
"announcements": [fallback],
"require_attention": False,
}
def display_announcements(console: Console, data: dict) -> None:
"""Display announcements panel. Prompts for Enter if require_attention is True."""
announcements = data.get("announcements", [])
require_attention = data.get("require_attention", False)
if not announcements:
return
content = "\n".join(announcements)
panel = Panel(
content,
border_style="cyan",
padding=(1, 2),
title="Announcements",
)
console.print(panel)
if require_attention:
getpass.getpass("Press Enter to continue...")
else:
console.print()

6
cli/config.py Normal file
View File

@ -0,0 +1,6 @@
CLI_CONFIG = {
# Announcements
"announcements_url": "https://api.tauric.ai/v1/announcements",
"announcements_timeout": 1.0,
"announcements_fallback": "[cyan]For more information, please visit[/cyan] [link=https://github.com/TauricResearch]https://github.com/TauricResearch[/link]",
}

File diff suppressed because it is too large Load Diff

76
cli/stats_handler.py Normal file
View File

@ -0,0 +1,76 @@
import threading
from typing import Any, Dict, List, Union
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.outputs import LLMResult
from langchain_core.messages import AIMessage
class StatsCallbackHandler(BaseCallbackHandler):
"""Callback handler that tracks LLM calls, tool calls, and token usage."""
def __init__(self) -> None:
super().__init__()
self._lock = threading.Lock()
self.llm_calls = 0
self.tool_calls = 0
self.tokens_in = 0
self.tokens_out = 0
def on_llm_start(
self,
serialized: Dict[str, Any],
prompts: List[str],
**kwargs: Any,
) -> None:
"""Increment LLM call counter when an LLM starts."""
with self._lock:
self.llm_calls += 1
def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[Any]],
**kwargs: Any,
) -> None:
"""Increment LLM call counter when a chat model starts."""
with self._lock:
self.llm_calls += 1
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Extract token usage from LLM response."""
try:
generation = response.generations[0][0]
except (IndexError, TypeError):
return
usage_metadata = None
if hasattr(generation, "message"):
message = generation.message
if isinstance(message, AIMessage) and hasattr(message, "usage_metadata"):
usage_metadata = message.usage_metadata
if usage_metadata:
with self._lock:
self.tokens_in += usage_metadata.get("input_tokens", 0)
self.tokens_out += usage_metadata.get("output_tokens", 0)
def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
**kwargs: Any,
) -> None:
"""Increment tool call counter when a tool starts."""
with self._lock:
self.tool_calls += 1
def get_stats(self) -> Dict[str, Any]:
"""Return current statistics."""
with self._lock:
return {
"llm_calls": self.llm_calls,
"tool_calls": self.tool_calls,
"tokens_in": self.tokens_in,
"tokens_out": self.tokens_out,
}

View File

@ -128,31 +128,38 @@ def select_shallow_thinking_agent(provider) -> str:
# Define shallow thinking llm engine options with their corresponding model names
SHALLOW_AGENT_OPTIONS = {
"openai": [
("GPT-4o-mini - Fast and efficient for quick tasks", "gpt-4o-mini"),
("GPT-4.1-nano - Ultra-lightweight model for basic operations", "gpt-4.1-nano"),
("GPT-4.1-mini - Compact model with good performance", "gpt-4.1-mini"),
("GPT-4o - Standard model with solid capabilities", "gpt-4o"),
("GPT-5 Mini - Cost-optimized reasoning", "gpt-5-mini"),
("GPT-5 Nano - Ultra-fast, high-throughput", "gpt-5-nano"),
("GPT-5.2 - Latest flagship", "gpt-5.2"),
("GPT-5.1 - Flexible reasoning", "gpt-5.1"),
("GPT-4.1 - Smartest non-reasoning, 1M context", "gpt-4.1"),
],
"anthropic": [
("Claude Haiku 3.5 - Fast inference and standard capabilities", "claude-3-5-haiku-latest"),
("Claude Sonnet 3.5 - Highly capable standard model", "claude-3-5-sonnet-latest"),
("Claude Sonnet 3.7 - Exceptional hybrid reasoning and agentic capabilities", "claude-3-7-sonnet-latest"),
("Claude Sonnet 4 - High performance and excellent reasoning", "claude-sonnet-4-0"),
("Claude Haiku 4.5 - Fast + extended thinking", "claude-haiku-4-5"),
("Claude Sonnet 4.5 - Best for agents/coding", "claude-sonnet-4-5"),
("Claude Sonnet 4 - High-performance", "claude-sonnet-4-20250514"),
],
"google": [
("Gemini 2.0 Flash-Lite - Cost efficiency and low latency", "gemini-2.0-flash-lite"),
("Gemini 2.0 Flash - Next generation features, speed, and thinking", "gemini-2.0-flash"),
("Gemini 2.5 Flash - Adaptive thinking, cost efficiency", "gemini-2.5-flash-preview-05-20"),
("Gemini 3 Flash - Next-gen fast", "gemini-3-flash-preview"),
("Gemini 2.5 Flash - Balanced, recommended", "gemini-2.5-flash"),
("Gemini 3 Pro - Reasoning-first", "gemini-3-pro-preview"),
("Gemini 2.5 Flash Lite - Fast, low-cost", "gemini-2.5-flash-lite"),
],
"xai": [
("Grok 4.1 Fast (Non-Reasoning) - Speed optimized, 2M ctx", "grok-4-1-fast-non-reasoning"),
("Grok 4 Fast (Non-Reasoning) - Speed optimized", "grok-4-fast-non-reasoning"),
("Grok 4.1 Fast (Reasoning) - High-performance, 2M ctx", "grok-4-1-fast-reasoning"),
("Grok 4 Fast (Reasoning) - High-performance", "grok-4-fast-reasoning"),
],
"openrouter": [
("Meta: Llama 4 Scout", "meta-llama/llama-4-scout:free"),
("Meta: Llama 3.3 8B Instruct - A lightweight and ultra-fast variant of Llama 3.3 70B", "meta-llama/llama-3.3-8b-instruct:free"),
("google/gemini-2.0-flash-exp:free - Gemini Flash 2.0 offers a significantly faster time to first token", "google/gemini-2.0-flash-exp:free"),
("NVIDIA Nemotron 3 Nano 30B (free)", "nvidia/nemotron-3-nano-30b-a3b:free"),
("Z.AI GLM 4.5 Air (free)", "z-ai/glm-4.5-air:free"),
],
"ollama": [
("llama3.1 local", "llama3.1"),
("llama3.2 local", "llama3.2"),
]
("Qwen3:latest (8B, local)", "qwen3:latest"),
("GPT-OSS:latest (20B, local)", "gpt-oss:latest"),
("GLM-4.7-Flash:latest (30B, local)", "glm-4.7-flash:latest"),
],
}
choice = questionary.select(
@ -186,37 +193,43 @@ def select_deep_thinking_agent(provider) -> str:
# Define deep thinking llm engine options with their corresponding model names
DEEP_AGENT_OPTIONS = {
"openai": [
("GPT-4.1-nano - Ultra-lightweight model for basic operations", "gpt-4.1-nano"),
("GPT-4.1-mini - Compact model with good performance", "gpt-4.1-mini"),
("GPT-4o - Standard model with solid capabilities", "gpt-4o"),
("o4-mini - Specialized reasoning model (compact)", "o4-mini"),
("o3-mini - Advanced reasoning model (lightweight)", "o3-mini"),
("o3 - Full advanced reasoning model", "o3"),
("o1 - Premier reasoning and problem-solving model", "o1"),
("GPT-5.2 - Latest flagship", "gpt-5.2"),
("GPT-5.1 - Flexible reasoning", "gpt-5.1"),
("GPT-5 - Advanced reasoning", "gpt-5"),
("GPT-4.1 - Smartest non-reasoning, 1M context", "gpt-4.1"),
("GPT-5 Mini - Cost-optimized reasoning", "gpt-5-mini"),
("GPT-5 Nano - Ultra-fast, high-throughput", "gpt-5-nano"),
],
"anthropic": [
("Claude Haiku 3.5 - Fast inference and standard capabilities", "claude-3-5-haiku-latest"),
("Claude Sonnet 3.5 - Highly capable standard model", "claude-3-5-sonnet-latest"),
("Claude Sonnet 3.7 - Exceptional hybrid reasoning and agentic capabilities", "claude-3-7-sonnet-latest"),
("Claude Sonnet 4 - High performance and excellent reasoning", "claude-sonnet-4-0"),
("Claude Opus 4 - Most powerful Anthropic model", " claude-opus-4-0"),
("Claude Sonnet 4.5 - Best for agents/coding", "claude-sonnet-4-5"),
("Claude Opus 4.5 - Premium, max intelligence", "claude-opus-4-5"),
("Claude Opus 4.1 - Most capable model", "claude-opus-4-1-20250805"),
("Claude Haiku 4.5 - Fast + extended thinking", "claude-haiku-4-5"),
("Claude Sonnet 4 - High-performance", "claude-sonnet-4-20250514"),
],
"google": [
("Gemini 2.0 Flash-Lite - Cost efficiency and low latency", "gemini-2.0-flash-lite"),
("Gemini 2.0 Flash - Next generation features, speed, and thinking", "gemini-2.0-flash"),
("Gemini 2.5 Flash - Adaptive thinking, cost efficiency", "gemini-2.5-flash-preview-05-20"),
("Gemini 2.5 Pro", "gemini-2.5-pro-preview-06-05"),
("Gemini 3 Pro - Reasoning-first", "gemini-3-pro-preview"),
("Gemini 3 Flash - Next-gen fast", "gemini-3-flash-preview"),
("Gemini 2.5 Flash - Balanced, recommended", "gemini-2.5-flash"),
],
"xai": [
("Grok 4.1 Fast (Reasoning) - High-performance, 2M ctx", "grok-4-1-fast-reasoning"),
("Grok 4 Fast (Reasoning) - High-performance", "grok-4-fast-reasoning"),
("Grok 4 - Flagship model", "grok-4-0709"),
("Grok 4.1 Fast (Non-Reasoning) - Speed optimized, 2M ctx", "grok-4-1-fast-non-reasoning"),
("Grok 4 Fast (Non-Reasoning) - Speed optimized", "grok-4-fast-non-reasoning"),
],
"openrouter": [
("DeepSeek V3 - a 685B-parameter, mixture-of-experts model", "deepseek/deepseek-chat-v3-0324:free"),
("Deepseek - latest iteration of the flagship chat model family from the DeepSeek team.", "deepseek/deepseek-chat-v3-0324:free"),
("Z.AI GLM 4.5 Air (free)", "z-ai/glm-4.5-air:free"),
("NVIDIA Nemotron 3 Nano 30B (free)", "nvidia/nemotron-3-nano-30b-a3b:free"),
],
"ollama": [
("llama3.1 local", "llama3.1"),
("qwen3", "qwen3"),
]
("GLM-4.7-Flash:latest (30B, local)", "glm-4.7-flash:latest"),
("GPT-OSS:latest (20B, local)", "gpt-oss:latest"),
("Qwen3:latest (8B, local)", "qwen3:latest"),
],
}
choice = questionary.select(
"Select Your [Deep-Thinking LLM Engine]:",
choices=[
@ -244,10 +257,11 @@ def select_llm_provider() -> tuple[str, str]:
# Define OpenAI api options with their corresponding endpoints
BASE_URLS = [
("OpenAI", "https://api.openai.com/v1"),
("Anthropic", "https://api.anthropic.com/"),
("Google", "https://generativelanguage.googleapis.com/v1"),
("Anthropic", "https://api.anthropic.com/"),
("xAI", "https://api.x.ai/v1"),
("Openrouter", "https://openrouter.ai/api/v1"),
("Ollama", "http://localhost:11434/v1"),
("Ollama", "http://localhost:11434/v1"),
]
choice = questionary.select(
@ -272,5 +286,43 @@ def select_llm_provider() -> tuple[str, str]:
display_name, url = choice
print(f"You selected: {display_name}\tURL: {url}")
return display_name, url
def ask_openai_reasoning_effort() -> str:
"""Ask for OpenAI reasoning effort level."""
choices = [
questionary.Choice("Medium (Default)", "medium"),
questionary.Choice("High (More thorough)", "high"),
questionary.Choice("Low (Faster)", "low"),
]
return questionary.select(
"Select Reasoning Effort:",
choices=choices,
style=questionary.Style([
("selected", "fg:cyan noinherit"),
("highlighted", "fg:cyan noinherit"),
("pointer", "fg:cyan noinherit"),
]),
).ask()
def ask_gemini_thinking_config() -> str | None:
"""Ask for Gemini thinking configuration.
Returns thinking_level: "high" or "minimal".
Client maps to appropriate API param based on model series.
"""
return questionary.select(
"Select Thinking Mode:",
choices=[
questionary.Choice("Enable Thinking (recommended)", "high"),
questionary.Choice("Minimal/Disable Thinking", "minimal"),
],
style=questionary.Style([
("selected", "fg:green noinherit"),
("highlighted", "fg:green noinherit"),
("pointer", "fg:green noinherit"),
]),
).ask()

14
main.py
View File

@ -8,16 +8,16 @@ load_dotenv()
# Create a custom config
config = DEFAULT_CONFIG.copy()
config["deep_think_llm"] = "gpt-4o-mini" # Use a different model
config["quick_think_llm"] = "gpt-4o-mini" # Use a different model
config["deep_think_llm"] = "gpt-5-mini" # Use a different model
config["quick_think_llm"] = "gpt-5-mini" # Use a different model
config["max_debate_rounds"] = 1 # Increase debate rounds
# Configure data vendors (default uses yfinance and alpha_vantage)
# Configure data vendors (default uses yfinance, no extra API keys needed)
config["data_vendors"] = {
"core_stock_apis": "yfinance", # Options: yfinance, alpha_vantage, local
"technical_indicators": "yfinance", # Options: yfinance, alpha_vantage, local
"fundamental_data": "alpha_vantage", # Options: openai, alpha_vantage, local
"news_data": "alpha_vantage", # Options: openai, alpha_vantage, google, local
"core_stock_apis": "yfinance", # Options: alpha_vantage, yfinance
"technical_indicators": "yfinance", # Options: alpha_vantage, yfinance
"fundamental_data": "yfinance", # Options: alpha_vantage, yfinance
"news_data": "yfinance", # Options: alpha_vantage, yfinance
}
# Initialize with custom config

View File

@ -5,14 +5,8 @@ description = "Add your description here"
readme = "README.md"
requires-python = ">=3.10"
dependencies = [
"akshare>=1.16.98",
"backtrader>=1.9.78.123",
"chainlit>=2.5.5",
"chromadb>=1.0.12",
"eodhd>=1.0.32",
"feedparser>=6.0.11",
"finnhub-python>=2.4.23",
"grip>=4.6.2",
"langchain-anthropic>=0.3.15",
"langchain-experimental>=0.3.4",
"langchain-google-genai>=2.1.5",
@ -20,16 +14,16 @@ dependencies = [
"langgraph>=0.4.8",
"pandas>=2.3.0",
"parsel>=1.10.0",
"praw>=7.8.1",
"pytz>=2025.2",
"questionary>=2.1.0",
"rank-bm25>=0.2.2",
"redis>=6.2.0",
"requests>=2.32.4",
"rich>=14.0.0",
"typer>=0.21.0",
"setuptools>=80.9.0",
"stockstats>=0.6.5",
"tqdm>=4.67.1",
"tushare>=1.4.21",
"typing-extensions>=4.14.0",
"yfinance>=0.2.63",
]

View File

@ -3,17 +3,11 @@ langchain-openai
langchain-experimental
pandas
yfinance
praw
feedparser
stockstats
eodhd
langgraph
chromadb
rank-bm25
setuptools
backtrader
akshare
tushare
finnhub-python
parsel
requests
tqdm
@ -21,6 +15,7 @@ pytz
redis
chainlit
rich
typer
questionary
langchain_anthropic
langchain-google-genai

View File

@ -10,8 +10,8 @@ from .analysts.social_media_analyst import create_social_media_analyst
from .researchers.bear_researcher import create_bear_researcher
from .researchers.bull_researcher import create_bull_researcher
from .risk_mgmt.aggresive_debator import create_risky_debator
from .risk_mgmt.conservative_debator import create_safe_debator
from .risk_mgmt.aggressive_debator import create_aggressive_debator
from .risk_mgmt.conservative_debator import create_conservative_debator
from .risk_mgmt.neutral_debator import create_neutral_debator
from .managers.research_manager import create_research_manager
@ -32,9 +32,9 @@ __all__ = [
"create_market_analyst",
"create_neutral_debator",
"create_news_analyst",
"create_risky_debator",
"create_aggressive_debator",
"create_risk_manager",
"create_safe_debator",
"create_conservative_debator",
"create_social_media_analyst",
"create_trader",
]

View File

@ -1,7 +1,7 @@
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
import time
import json
from tradingagents.agents.utils.agent_utils import get_fundamentals, get_balance_sheet, get_cashflow, get_income_statement, get_insider_sentiment, get_insider_transactions
from tradingagents.agents.utils.agent_utils import get_fundamentals, get_balance_sheet, get_cashflow, get_income_statement, get_insider_transactions
from tradingagents.dataflows.config import get_config

View File

@ -76,7 +76,7 @@ Volume-Based Indicators:
if len(result.tool_calls) == 0:
report = result.content
return {
"messages": [result],
"market_report": report,

View File

@ -22,7 +22,7 @@ def create_risk_manager(llm, memory):
for i, rec in enumerate(past_memories, 1):
past_memory_str += rec["recommendation"] + "\n\n"
prompt = f"""As the Risk Management Judge and Debate Facilitator, your goal is to evaluate the debate between three risk analysts—Risky, Neutral, and Safe/Conservative—and determine the best course of action for the trader. Your decision must result in a clear recommendation: Buy, Sell, or Hold. Choose Hold only if strongly justified by specific arguments, not as a fallback when all sides seem valid. Strive for clarity and decisiveness.
prompt = f"""As the Risk Management Judge and Debate Facilitator, your goal is to evaluate the debate between three risk analysts—Aggressive, Neutral, and Conservative—and determine the best course of action for the trader. Your decision must result in a clear recommendation: Buy, Sell, or Hold. Choose Hold only if strongly justified by specific arguments, not as a fallback when all sides seem valid. Strive for clarity and decisiveness.
Guidelines for Decision-Making:
1. **Summarize Key Arguments**: Extract the strongest points from each analyst, focusing on relevance to the context.
@ -48,12 +48,12 @@ Focus on actionable insights and continuous improvement. Build on past lessons,
new_risk_debate_state = {
"judge_decision": response.content,
"history": risk_debate_state["history"],
"risky_history": risk_debate_state["risky_history"],
"safe_history": risk_debate_state["safe_history"],
"aggressive_history": risk_debate_state["aggressive_history"],
"conservative_history": risk_debate_state["conservative_history"],
"neutral_history": risk_debate_state["neutral_history"],
"latest_speaker": "Judge",
"current_risky_response": risk_debate_state["current_risky_response"],
"current_safe_response": risk_debate_state["current_safe_response"],
"current_aggressive_response": risk_debate_state["current_aggressive_response"],
"current_conservative_response": risk_debate_state["current_conservative_response"],
"current_neutral_response": risk_debate_state["current_neutral_response"],
"count": risk_debate_state["count"],
}

View File

@ -2,13 +2,13 @@ import time
import json
def create_risky_debator(llm):
def risky_node(state) -> dict:
def create_aggressive_debator(llm):
def aggressive_node(state) -> dict:
risk_debate_state = state["risk_debate_state"]
history = risk_debate_state.get("history", "")
risky_history = risk_debate_state.get("risky_history", "")
aggressive_history = risk_debate_state.get("aggressive_history", "")
current_safe_response = risk_debate_state.get("current_safe_response", "")
current_conservative_response = risk_debate_state.get("current_conservative_response", "")
current_neutral_response = risk_debate_state.get("current_neutral_response", "")
market_research_report = state["market_report"]
@ -18,7 +18,7 @@ def create_risky_debator(llm):
trader_decision = state["trader_investment_plan"]
prompt = f"""As the Risky Risk Analyst, your role is to actively champion high-reward, high-risk opportunities, emphasizing bold strategies and competitive advantages. When evaluating the trader's decision or plan, focus intently on the potential upside, growth potential, and innovative benefits—even when these come with elevated risk. Use the provided market data and sentiment analysis to strengthen your arguments and challenge the opposing views. Specifically, respond directly to each point made by the conservative and neutral analysts, countering with data-driven rebuttals and persuasive reasoning. Highlight where their caution might miss critical opportunities or where their assumptions may be overly conservative. Here is the trader's decision:
prompt = f"""As the Aggressive Risk Analyst, your role is to actively champion high-reward, high-risk opportunities, emphasizing bold strategies and competitive advantages. When evaluating the trader's decision or plan, focus intently on the potential upside, growth potential, and innovative benefits—even when these come with elevated risk. Use the provided market data and sentiment analysis to strengthen your arguments and challenge the opposing views. Specifically, respond directly to each point made by the conservative and neutral analysts, countering with data-driven rebuttals and persuasive reasoning. Highlight where their caution might miss critical opportunities or where their assumptions may be overly conservative. Here is the trader's decision:
{trader_decision}
@ -28,22 +28,22 @@ Market Research Report: {market_research_report}
Social Media Sentiment Report: {sentiment_report}
Latest World Affairs Report: {news_report}
Company Fundamentals Report: {fundamentals_report}
Here is the current conversation history: {history} Here are the last arguments from the conservative analyst: {current_safe_response} Here are the last arguments from the neutral analyst: {current_neutral_response}. If there are no responses from the other viewpoints, do not halluncinate and just present your point.
Here is the current conversation history: {history} Here are the last arguments from the conservative analyst: {current_conservative_response} Here are the last arguments from the neutral analyst: {current_neutral_response}. If there are no responses from the other viewpoints, do not hallucinate and just present your point.
Engage actively by addressing any specific concerns raised, refuting the weaknesses in their logic, and asserting the benefits of risk-taking to outpace market norms. Maintain a focus on debating and persuading, not just presenting data. Challenge each counterpoint to underscore why a high-risk approach is optimal. Output conversationally as if you are speaking without any special formatting."""
response = llm.invoke(prompt)
argument = f"Risky Analyst: {response.content}"
argument = f"Aggressive Analyst: {response.content}"
new_risk_debate_state = {
"history": history + "\n" + argument,
"risky_history": risky_history + "\n" + argument,
"safe_history": risk_debate_state.get("safe_history", ""),
"aggressive_history": aggressive_history + "\n" + argument,
"conservative_history": risk_debate_state.get("conservative_history", ""),
"neutral_history": risk_debate_state.get("neutral_history", ""),
"latest_speaker": "Risky",
"current_risky_response": argument,
"current_safe_response": risk_debate_state.get("current_safe_response", ""),
"latest_speaker": "Aggressive",
"current_aggressive_response": argument,
"current_conservative_response": risk_debate_state.get("current_conservative_response", ""),
"current_neutral_response": risk_debate_state.get(
"current_neutral_response", ""
),
@ -52,4 +52,4 @@ Engage actively by addressing any specific concerns raised, refuting the weaknes
return {"risk_debate_state": new_risk_debate_state}
return risky_node
return aggressive_node

View File

@ -3,13 +3,13 @@ import time
import json
def create_safe_debator(llm):
def safe_node(state) -> dict:
def create_conservative_debator(llm):
def conservative_node(state) -> dict:
risk_debate_state = state["risk_debate_state"]
history = risk_debate_state.get("history", "")
safe_history = risk_debate_state.get("safe_history", "")
conservative_history = risk_debate_state.get("conservative_history", "")
current_risky_response = risk_debate_state.get("current_risky_response", "")
current_aggressive_response = risk_debate_state.get("current_aggressive_response", "")
current_neutral_response = risk_debate_state.get("current_neutral_response", "")
market_research_report = state["market_report"]
@ -19,34 +19,34 @@ def create_safe_debator(llm):
trader_decision = state["trader_investment_plan"]
prompt = f"""As the Safe/Conservative Risk Analyst, your primary objective is to protect assets, minimize volatility, and ensure steady, reliable growth. You prioritize stability, security, and risk mitigation, carefully assessing potential losses, economic downturns, and market volatility. When evaluating the trader's decision or plan, critically examine high-risk elements, pointing out where the decision may expose the firm to undue risk and where more cautious alternatives could secure long-term gains. Here is the trader's decision:
prompt = f"""As the Conservative Risk Analyst, your primary objective is to protect assets, minimize volatility, and ensure steady, reliable growth. You prioritize stability, security, and risk mitigation, carefully assessing potential losses, economic downturns, and market volatility. When evaluating the trader's decision or plan, critically examine high-risk elements, pointing out where the decision may expose the firm to undue risk and where more cautious alternatives could secure long-term gains. Here is the trader's decision:
{trader_decision}
Your task is to actively counter the arguments of the Risky and Neutral Analysts, highlighting where their views may overlook potential threats or fail to prioritize sustainability. Respond directly to their points, drawing from the following data sources to build a convincing case for a low-risk approach adjustment to the trader's decision:
Your task is to actively counter the arguments of the Aggressive and Neutral Analysts, highlighting where their views may overlook potential threats or fail to prioritize sustainability. Respond directly to their points, drawing from the following data sources to build a convincing case for a low-risk approach adjustment to the trader's decision:
Market Research Report: {market_research_report}
Social Media Sentiment Report: {sentiment_report}
Latest World Affairs Report: {news_report}
Company Fundamentals Report: {fundamentals_report}
Here is the current conversation history: {history} Here is the last response from the risky analyst: {current_risky_response} Here is the last response from the neutral analyst: {current_neutral_response}. If there are no responses from the other viewpoints, do not halluncinate and just present your point.
Here is the current conversation history: {history} Here is the last response from the aggressive analyst: {current_aggressive_response} Here is the last response from the neutral analyst: {current_neutral_response}. If there are no responses from the other viewpoints, do not hallucinate and just present your point.
Engage by questioning their optimism and emphasizing the potential downsides they may have overlooked. Address each of their counterpoints to showcase why a conservative stance is ultimately the safest path for the firm's assets. Focus on debating and critiquing their arguments to demonstrate the strength of a low-risk strategy over their approaches. Output conversationally as if you are speaking without any special formatting."""
response = llm.invoke(prompt)
argument = f"Safe Analyst: {response.content}"
argument = f"Conservative Analyst: {response.content}"
new_risk_debate_state = {
"history": history + "\n" + argument,
"risky_history": risk_debate_state.get("risky_history", ""),
"safe_history": safe_history + "\n" + argument,
"aggressive_history": risk_debate_state.get("aggressive_history", ""),
"conservative_history": conservative_history + "\n" + argument,
"neutral_history": risk_debate_state.get("neutral_history", ""),
"latest_speaker": "Safe",
"current_risky_response": risk_debate_state.get(
"current_risky_response", ""
"latest_speaker": "Conservative",
"current_aggressive_response": risk_debate_state.get(
"current_aggressive_response", ""
),
"current_safe_response": argument,
"current_conservative_response": argument,
"current_neutral_response": risk_debate_state.get(
"current_neutral_response", ""
),
@ -55,4 +55,4 @@ Engage by questioning their optimism and emphasizing the potential downsides the
return {"risk_debate_state": new_risk_debate_state}
return safe_node
return conservative_node

View File

@ -8,8 +8,8 @@ def create_neutral_debator(llm):
history = risk_debate_state.get("history", "")
neutral_history = risk_debate_state.get("neutral_history", "")
current_risky_response = risk_debate_state.get("current_risky_response", "")
current_safe_response = risk_debate_state.get("current_safe_response", "")
current_aggressive_response = risk_debate_state.get("current_aggressive_response", "")
current_conservative_response = risk_debate_state.get("current_conservative_response", "")
market_research_report = state["market_report"]
sentiment_report = state["sentiment_report"]
@ -22,15 +22,15 @@ def create_neutral_debator(llm):
{trader_decision}
Your task is to challenge both the Risky and Safe Analysts, pointing out where each perspective may be overly optimistic or overly cautious. Use insights from the following data sources to support a moderate, sustainable strategy to adjust the trader's decision:
Your task is to challenge both the Aggressive and Conservative Analysts, pointing out where each perspective may be overly optimistic or overly cautious. Use insights from the following data sources to support a moderate, sustainable strategy to adjust the trader's decision:
Market Research Report: {market_research_report}
Social Media Sentiment Report: {sentiment_report}
Latest World Affairs Report: {news_report}
Company Fundamentals Report: {fundamentals_report}
Here is the current conversation history: {history} Here is the last response from the risky analyst: {current_risky_response} Here is the last response from the safe analyst: {current_safe_response}. If there are no responses from the other viewpoints, do not halluncinate and just present your point.
Here is the current conversation history: {history} Here is the last response from the aggressive analyst: {current_aggressive_response} Here is the last response from the conservative analyst: {current_conservative_response}. If there are no responses from the other viewpoints, do not hallucinate and just present your point.
Engage actively by analyzing both sides critically, addressing weaknesses in the risky and conservative arguments to advocate for a more balanced approach. Challenge each of their points to illustrate why a moderate risk strategy might offer the best of both worlds, providing growth potential while safeguarding against extreme volatility. Focus on debating rather than simply presenting data, aiming to show that a balanced view can lead to the most reliable outcomes. Output conversationally as if you are speaking without any special formatting."""
Engage actively by analyzing both sides critically, addressing weaknesses in the aggressive and conservative arguments to advocate for a more balanced approach. Challenge each of their points to illustrate why a moderate risk strategy might offer the best of both worlds, providing growth potential while safeguarding against extreme volatility. Focus on debating rather than simply presenting data, aiming to show that a balanced view can lead to the most reliable outcomes. Output conversationally as if you are speaking without any special formatting."""
response = llm.invoke(prompt)
@ -38,14 +38,14 @@ Engage actively by analyzing both sides critically, addressing weaknesses in the
new_risk_debate_state = {
"history": history + "\n" + argument,
"risky_history": risk_debate_state.get("risky_history", ""),
"safe_history": risk_debate_state.get("safe_history", ""),
"aggressive_history": risk_debate_state.get("aggressive_history", ""),
"conservative_history": risk_debate_state.get("conservative_history", ""),
"neutral_history": neutral_history + "\n" + argument,
"latest_speaker": "Neutral",
"current_risky_response": risk_debate_state.get(
"current_risky_response", ""
"current_aggressive_response": risk_debate_state.get(
"current_aggressive_response", ""
),
"current_safe_response": risk_debate_state.get("current_safe_response", ""),
"current_conservative_response": risk_debate_state.get("current_conservative_response", ""),
"current_neutral_response": argument,
"count": risk_debate_state["count"] + 1,
}

View File

@ -23,22 +23,22 @@ class InvestDebateState(TypedDict):
# Risk management team state
class RiskDebateState(TypedDict):
risky_history: Annotated[
str, "Risky Agent's Conversation history"
aggressive_history: Annotated[
str, "Aggressive Agent's Conversation history"
] # Conversation history
safe_history: Annotated[
str, "Safe Agent's Conversation history"
conservative_history: Annotated[
str, "Conservative Agent's Conversation history"
] # Conversation history
neutral_history: Annotated[
str, "Neutral Agent's Conversation history"
] # Conversation history
history: Annotated[str, "Conversation history"] # Conversation history
latest_speaker: Annotated[str, "Analyst that spoke last"]
current_risky_response: Annotated[
str, "Latest response by the risky analyst"
current_aggressive_response: Annotated[
str, "Latest response by the aggressive analyst"
] # Last response
current_safe_response: Annotated[
str, "Latest response by the safe analyst"
current_conservative_response: Annotated[
str, "Latest response by the conservative analyst"
] # Last response
current_neutral_response: Annotated[
str, "Latest response by the neutral analyst"

View File

@ -15,7 +15,6 @@ from tradingagents.agents.utils.fundamental_data_tools import (
)
from tradingagents.agents.utils.news_data_tools import (
get_news,
get_insider_sentiment,
get_insider_transactions,
get_global_news
)
@ -24,15 +23,15 @@ def create_msg_delete():
def delete_messages(state):
"""Clear messages and add placeholder for Anthropic compatibility"""
messages = state["messages"]
# Remove all messages
removal_operations = [RemoveMessage(id=m.id) for m in messages]
# Add a minimal placeholder message
placeholder = HumanMessage(content="Continue")
return {"messages": removal_operations + [placeholder]}
return delete_messages

View File

@ -1,75 +1,106 @@
import chromadb
from chromadb.config import Settings
from openai import OpenAI
"""Financial situation memory using BM25 for lexical similarity matching.
Uses BM25 (Best Matching 25) algorithm for retrieval - no API calls,
no token limits, works offline with any LLM provider.
"""
from rank_bm25 import BM25Okapi
from typing import List, Tuple
import re
class FinancialSituationMemory:
def __init__(self, name, config):
if config["backend_url"] == "http://localhost:11434/v1":
self.embedding = "nomic-embed-text"
"""Memory system for storing and retrieving financial situations using BM25."""
def __init__(self, name: str, config: dict = None):
"""Initialize the memory system.
Args:
name: Name identifier for this memory instance
config: Configuration dict (kept for API compatibility, not used for BM25)
"""
self.name = name
self.documents: List[str] = []
self.recommendations: List[str] = []
self.bm25 = None
def _tokenize(self, text: str) -> List[str]:
"""Tokenize text for BM25 indexing.
Simple whitespace + punctuation tokenization with lowercasing.
"""
# Lowercase and split on non-alphanumeric characters
tokens = re.findall(r'\b\w+\b', text.lower())
return tokens
def _rebuild_index(self):
"""Rebuild the BM25 index after adding documents."""
if self.documents:
tokenized_docs = [self._tokenize(doc) for doc in self.documents]
self.bm25 = BM25Okapi(tokenized_docs)
else:
self.embedding = "text-embedding-3-small"
self.client = OpenAI(base_url=config["backend_url"])
self.chroma_client = chromadb.Client(Settings(allow_reset=True))
self.situation_collection = self.chroma_client.create_collection(name=name)
self.bm25 = None
def get_embedding(self, text):
"""Get OpenAI embedding for a text"""
response = self.client.embeddings.create(
model=self.embedding, input=text
)
return response.data[0].embedding
def add_situations(self, situations_and_advice: List[Tuple[str, str]]):
"""Add financial situations and their corresponding advice.
def add_situations(self, situations_and_advice):
"""Add financial situations and their corresponding advice. Parameter is a list of tuples (situation, rec)"""
Args:
situations_and_advice: List of tuples (situation, recommendation)
"""
for situation, recommendation in situations_and_advice:
self.documents.append(situation)
self.recommendations.append(recommendation)
situations = []
advice = []
ids = []
embeddings = []
# Rebuild BM25 index with new documents
self._rebuild_index()
offset = self.situation_collection.count()
def get_memories(self, current_situation: str, n_matches: int = 1) -> List[dict]:
"""Find matching recommendations using BM25 similarity.
for i, (situation, recommendation) in enumerate(situations_and_advice):
situations.append(situation)
advice.append(recommendation)
ids.append(str(offset + i))
embeddings.append(self.get_embedding(situation))
Args:
current_situation: The current financial situation to match against
n_matches: Number of top matches to return
self.situation_collection.add(
documents=situations,
metadatas=[{"recommendation": rec} for rec in advice],
embeddings=embeddings,
ids=ids,
)
Returns:
List of dicts with matched_situation, recommendation, and similarity_score
"""
if not self.documents or self.bm25 is None:
return []
def get_memories(self, current_situation, n_matches=1):
"""Find matching recommendations using OpenAI embeddings"""
query_embedding = self.get_embedding(current_situation)
# Tokenize query
query_tokens = self._tokenize(current_situation)
results = self.situation_collection.query(
query_embeddings=[query_embedding],
n_results=n_matches,
include=["metadatas", "documents", "distances"],
)
# Get BM25 scores for all documents
scores = self.bm25.get_scores(query_tokens)
matched_results = []
for i in range(len(results["documents"][0])):
matched_results.append(
{
"matched_situation": results["documents"][0][i],
"recommendation": results["metadatas"][0][i]["recommendation"],
"similarity_score": 1 - results["distances"][0][i],
}
)
# Get top-n indices sorted by score (descending)
top_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:n_matches]
return matched_results
# Build results
results = []
max_score = max(scores) if max(scores) > 0 else 1 # Normalize scores
for idx in top_indices:
# Normalize score to 0-1 range for consistency
normalized_score = scores[idx] / max_score if max_score > 0 else 0
results.append({
"matched_situation": self.documents[idx],
"recommendation": self.recommendations[idx],
"similarity_score": normalized_score,
})
return results
def clear(self):
"""Clear all stored memories."""
self.documents = []
self.recommendations = []
self.bm25 = None
if __name__ == "__main__":
# Example usage
matcher = FinancialSituationMemory()
matcher = FinancialSituationMemory("test_memory")
# Example data
example_data = [
@ -96,7 +127,7 @@ if __name__ == "__main__":
# Example query
current_situation = """
Market showing increased volatility in tech sector, with institutional investors
Market showing increased volatility in tech sector, with institutional investors
reducing positions and rising interest rates affecting growth stock valuations
"""

View File

@ -38,34 +38,16 @@ def get_global_news(
"""
return route_to_vendor("get_global_news", curr_date, look_back_days, limit)
@tool
def get_insider_sentiment(
ticker: Annotated[str, "ticker symbol for the company"],
curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"],
) -> str:
"""
Retrieve insider sentiment information about a company.
Uses the configured news_data vendor.
Args:
ticker (str): Ticker symbol of the company
curr_date (str): Current date you are trading at, yyyy-mm-dd
Returns:
str: A report of insider sentiment data
"""
return route_to_vendor("get_insider_sentiment", ticker, curr_date)
@tool
def get_insider_transactions(
ticker: Annotated[str, "ticker symbol"],
curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"],
) -> str:
"""
Retrieve insider transaction information about a company.
Uses the configured news_data vendor.
Args:
ticker (str): Ticker symbol of the company
curr_date (str): Current date you are trading at, yyyy-mm-dd
Returns:
str: A report of insider transaction data
"""
return route_to_vendor("get_insider_transactions", ticker, curr_date)
return route_to_vendor("get_insider_transactions", ticker)

View File

@ -2,4 +2,4 @@
from .alpha_vantage_stock import get_stock
from .alpha_vantage_indicator import get_indicator
from .alpha_vantage_fundamentals import get_fundamentals, get_balance_sheet, get_cashflow, get_income_statement
from .alpha_vantage_news import get_news, get_insider_transactions
from .alpha_vantage_news import get_news, get_global_news, get_insider_transactions

View File

@ -18,12 +18,40 @@ def get_news(ticker, start_date, end_date) -> dict[str, str] | str:
"tickers": ticker,
"time_from": format_datetime_for_api(start_date),
"time_to": format_datetime_for_api(end_date),
"sort": "LATEST",
"limit": "50",
}
return _make_api_request("NEWS_SENTIMENT", params)
def get_global_news(curr_date, look_back_days: int = 7, limit: int = 50) -> dict[str, str] | str:
"""Returns global market news & sentiment data without ticker-specific filtering.
Covers broad market topics like financial markets, economy, and more.
Args:
curr_date: Current date in yyyy-mm-dd format.
look_back_days: Number of days to look back (default 7).
limit: Maximum number of articles (default 50).
Returns:
Dictionary containing global news sentiment data or JSON string.
"""
from datetime import datetime, timedelta
# Calculate start date
curr_dt = datetime.strptime(curr_date, "%Y-%m-%d")
start_dt = curr_dt - timedelta(days=look_back_days)
start_date = start_dt.strftime("%Y-%m-%d")
params = {
"topics": "financial_markets,economy_macro,economy_monetary",
"time_from": format_datetime_for_api(start_date),
"time_to": format_datetime_for_api(curr_date),
"limit": str(limit),
}
return _make_api_request("NEWS_SENTIMENT", params)
def get_insider_transactions(symbol: str) -> dict[str, str] | str:
"""Returns latest and historical insider transactions by key stakeholders.

View File

@ -3,24 +3,21 @@ from typing import Dict, Optional
# Use default config but allow it to be overridden
_config: Optional[Dict] = None
DATA_DIR: Optional[str] = None
def initialize_config():
"""Initialize the configuration with default values."""
global _config, DATA_DIR
global _config
if _config is None:
_config = default_config.DEFAULT_CONFIG.copy()
DATA_DIR = _config["data_dir"]
def set_config(config: Dict):
"""Update the configuration with custom values."""
global _config, DATA_DIR
global _config
if _config is None:
_config = default_config.DEFAULT_CONFIG.copy()
_config.update(config)
DATA_DIR = _config["data_dir"]
def get_config() -> Dict:

View File

@ -1,30 +0,0 @@
from typing import Annotated
from datetime import datetime
from dateutil.relativedelta import relativedelta
from .googlenews_utils import getNewsData
def get_google_news(
query: Annotated[str, "Query to search with"],
curr_date: Annotated[str, "Curr date in yyyy-mm-dd format"],
look_back_days: Annotated[int, "how many days to look back"],
) -> str:
query = query.replace(" ", "+")
start_date = datetime.strptime(curr_date, "%Y-%m-%d")
before = start_date - relativedelta(days=look_back_days)
before = before.strftime("%Y-%m-%d")
news_results = getNewsData(query, before, curr_date)
news_str = ""
for news in news_results:
news_str += (
f"### {news['title']} (source: {news['source']}) \n\n{news['snippet']}\n\n"
)
if len(news_results) == 0:
return ""
return f"## {query} Google News, from {before} to {curr_date}:\n\n{news_str}"

View File

@ -1,108 +0,0 @@
import json
import requests
from bs4 import BeautifulSoup
from datetime import datetime
import time
import random
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
retry_if_result,
)
def is_rate_limited(response):
"""Check if the response indicates rate limiting (status code 429)"""
return response.status_code == 429
@retry(
retry=(retry_if_result(is_rate_limited)),
wait=wait_exponential(multiplier=1, min=4, max=60),
stop=stop_after_attempt(5),
)
def make_request(url, headers):
"""Make a request with retry logic for rate limiting"""
# Random delay before each request to avoid detection
time.sleep(random.uniform(2, 6))
response = requests.get(url, headers=headers)
return response
def getNewsData(query, start_date, end_date):
"""
Scrape Google News search results for a given query and date range.
query: str - search query
start_date: str - start date in the format yyyy-mm-dd or mm/dd/yyyy
end_date: str - end date in the format yyyy-mm-dd or mm/dd/yyyy
"""
if "-" in start_date:
start_date = datetime.strptime(start_date, "%Y-%m-%d")
start_date = start_date.strftime("%m/%d/%Y")
if "-" in end_date:
end_date = datetime.strptime(end_date, "%Y-%m-%d")
end_date = end_date.strftime("%m/%d/%Y")
headers = {
"User-Agent": (
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
"AppleWebKit/537.36 (KHTML, like Gecko) "
"Chrome/101.0.4951.54 Safari/537.36"
)
}
news_results = []
page = 0
while True:
offset = page * 10
url = (
f"https://www.google.com/search?q={query}"
f"&tbs=cdr:1,cd_min:{start_date},cd_max:{end_date}"
f"&tbm=nws&start={offset}"
)
try:
response = make_request(url, headers)
soup = BeautifulSoup(response.content, "html.parser")
results_on_page = soup.select("div.SoaBEf")
if not results_on_page:
break # No more results found
for el in results_on_page:
try:
link = el.find("a")["href"]
title = el.select_one("div.MBeuO").get_text()
snippet = el.select_one(".GI74Re").get_text()
date = el.select_one(".LfVVr").get_text()
source = el.select_one(".NUnG9d span").get_text()
news_results.append(
{
"link": link,
"title": title,
"snippet": snippet,
"date": date,
"source": source,
}
)
except Exception as e:
print(f"Error processing result: {e}")
# If one of the fields is not found, skip this result
continue
# Update the progress bar with the current count of results scraped
# Check for the "Next" link (pagination)
next_link = soup.find("a", id="pnnext")
if not next_link:
break
page += 1
except Exception as e:
print(f"Failed after multiple retries: {e}")
break
return news_results

View File

@ -1,10 +1,16 @@
from typing import Annotated
# Import from vendor-specific modules
from .local import get_YFin_data, get_finnhub_news, get_finnhub_company_insider_sentiment, get_finnhub_company_insider_transactions, get_simfin_balance_sheet, get_simfin_cashflow, get_simfin_income_statements, get_reddit_global_news, get_reddit_company_news
from .y_finance import get_YFin_data_online, get_stock_stats_indicators_window, get_balance_sheet as get_yfinance_balance_sheet, get_cashflow as get_yfinance_cashflow, get_income_statement as get_yfinance_income_statement, get_insider_transactions as get_yfinance_insider_transactions
from .google import get_google_news
from .openai import get_stock_news_openai, get_global_news_openai, get_fundamentals_openai
from .y_finance import (
get_YFin_data_online,
get_stock_stats_indicators_window,
get_fundamentals as get_yfinance_fundamentals,
get_balance_sheet as get_yfinance_balance_sheet,
get_cashflow as get_yfinance_cashflow,
get_income_statement as get_yfinance_income_statement,
get_insider_transactions as get_yfinance_insider_transactions,
)
from .yfinance_news import get_news_yfinance, get_global_news_yfinance
from .alpha_vantage import (
get_stock as get_alpha_vantage_stock,
get_indicator as get_alpha_vantage_indicator,
@ -13,7 +19,8 @@ from .alpha_vantage import (
get_cashflow as get_alpha_vantage_cashflow,
get_income_statement as get_alpha_vantage_income_statement,
get_insider_transactions as get_alpha_vantage_insider_transactions,
get_news as get_alpha_vantage_news
get_news as get_alpha_vantage_news,
get_global_news as get_alpha_vantage_global_news,
)
from .alpha_vantage_common import AlphaVantageRateLimitError
@ -44,21 +51,18 @@ TOOLS_CATEGORIES = {
]
},
"news_data": {
"description": "News (public/insiders, original/processed)",
"description": "News and insider data",
"tools": [
"get_news",
"get_global_news",
"get_insider_sentiment",
"get_insider_transactions",
]
}
}
VENDOR_LIST = [
"local",
"yfinance",
"openai",
"google"
"alpha_vantage",
]
# Mapping of methods to their vendor-specific implementations
@ -67,52 +71,41 @@ VENDOR_METHODS = {
"get_stock_data": {
"alpha_vantage": get_alpha_vantage_stock,
"yfinance": get_YFin_data_online,
"local": get_YFin_data,
},
# technical_indicators
"get_indicators": {
"alpha_vantage": get_alpha_vantage_indicator,
"yfinance": get_stock_stats_indicators_window,
"local": get_stock_stats_indicators_window
},
# fundamental_data
"get_fundamentals": {
"alpha_vantage": get_alpha_vantage_fundamentals,
"openai": get_fundamentals_openai,
"yfinance": get_yfinance_fundamentals,
},
"get_balance_sheet": {
"alpha_vantage": get_alpha_vantage_balance_sheet,
"yfinance": get_yfinance_balance_sheet,
"local": get_simfin_balance_sheet,
},
"get_cashflow": {
"alpha_vantage": get_alpha_vantage_cashflow,
"yfinance": get_yfinance_cashflow,
"local": get_simfin_cashflow,
},
"get_income_statement": {
"alpha_vantage": get_alpha_vantage_income_statement,
"yfinance": get_yfinance_income_statement,
"local": get_simfin_income_statements,
},
# news_data
"get_news": {
"alpha_vantage": get_alpha_vantage_news,
"openai": get_stock_news_openai,
"google": get_google_news,
"local": [get_finnhub_news, get_reddit_company_news, get_google_news],
"yfinance": get_news_yfinance,
},
"get_global_news": {
"openai": get_global_news_openai,
"local": get_reddit_global_news
},
"get_insider_sentiment": {
"local": get_finnhub_company_insider_sentiment
"yfinance": get_global_news_yfinance,
"alpha_vantage": get_alpha_vantage_global_news,
},
"get_insider_transactions": {
"alpha_vantage": get_alpha_vantage_insider_transactions,
"yfinance": get_yfinance_insider_transactions,
"local": get_finnhub_company_insider_transactions,
},
}
@ -142,103 +135,28 @@ def route_to_vendor(method: str, *args, **kwargs):
"""Route method calls to appropriate vendor implementation with fallback support."""
category = get_category_for_method(method)
vendor_config = get_vendor(category, method)
# Handle comma-separated vendors
primary_vendors = [v.strip() for v in vendor_config.split(',')]
if method not in VENDOR_METHODS:
raise ValueError(f"Method '{method}' not supported")
# Get all available vendors for this method for fallback
# Build fallback chain: primary vendors first, then remaining available vendors
all_available_vendors = list(VENDOR_METHODS[method].keys())
# Create fallback vendor list: primary vendors first, then remaining vendors as fallbacks
fallback_vendors = primary_vendors.copy()
for vendor in all_available_vendors:
if vendor not in fallback_vendors:
fallback_vendors.append(vendor)
# Debug: Print fallback ordering
primary_str = "".join(primary_vendors)
fallback_str = "".join(fallback_vendors)
print(f"DEBUG: {method} - Primary: [{primary_str}] | Full fallback order: [{fallback_str}]")
# Track results and execution state
results = []
vendor_attempt_count = 0
any_primary_vendor_attempted = False
successful_vendor = None
for vendor in fallback_vendors:
if vendor not in VENDOR_METHODS[method]:
if vendor in primary_vendors:
print(f"INFO: Vendor '{vendor}' not supported for method '{method}', falling back to next vendor")
continue
vendor_impl = VENDOR_METHODS[method][vendor]
is_primary_vendor = vendor in primary_vendors
vendor_attempt_count += 1
impl_func = vendor_impl[0] if isinstance(vendor_impl, list) else vendor_impl
# Track if we attempted any primary vendor
if is_primary_vendor:
any_primary_vendor_attempted = True
try:
return impl_func(*args, **kwargs)
except AlphaVantageRateLimitError:
continue # Only rate limits trigger fallback
# Debug: Print current attempt
vendor_type = "PRIMARY" if is_primary_vendor else "FALLBACK"
print(f"DEBUG: Attempting {vendor_type} vendor '{vendor}' for {method} (attempt #{vendor_attempt_count})")
# Handle list of methods for a vendor
if isinstance(vendor_impl, list):
vendor_methods = [(impl, vendor) for impl in vendor_impl]
print(f"DEBUG: Vendor '{vendor}' has multiple implementations: {len(vendor_methods)} functions")
else:
vendor_methods = [(vendor_impl, vendor)]
# Run methods for this vendor
vendor_results = []
for impl_func, vendor_name in vendor_methods:
try:
print(f"DEBUG: Calling {impl_func.__name__} from vendor '{vendor_name}'...")
result = impl_func(*args, **kwargs)
vendor_results.append(result)
print(f"SUCCESS: {impl_func.__name__} from vendor '{vendor_name}' completed successfully")
except AlphaVantageRateLimitError as e:
if vendor == "alpha_vantage":
print(f"RATE_LIMIT: Alpha Vantage rate limit exceeded, falling back to next available vendor")
print(f"DEBUG: Rate limit details: {e}")
# Continue to next vendor for fallback
continue
except Exception as e:
# Log error but continue with other implementations
print(f"FAILED: {impl_func.__name__} from vendor '{vendor_name}' failed: {e}")
continue
# Add this vendor's results
if vendor_results:
results.extend(vendor_results)
successful_vendor = vendor
result_summary = f"Got {len(vendor_results)} result(s)"
print(f"SUCCESS: Vendor '{vendor}' succeeded - {result_summary}")
# Stopping logic: Stop after first successful vendor for single-vendor configs
# Multiple vendor configs (comma-separated) may want to collect from multiple sources
if len(primary_vendors) == 1:
print(f"DEBUG: Stopping after successful vendor '{vendor}' (single-vendor config)")
break
else:
print(f"FAILED: Vendor '{vendor}' produced no results")
# Final result summary
if not results:
print(f"FAILURE: All {vendor_attempt_count} vendor attempts failed for method '{method}'")
raise RuntimeError(f"All vendor implementations failed for method '{method}'")
else:
print(f"FINAL: Method '{method}' completed with {len(results)} result(s) from {vendor_attempt_count} vendor attempt(s)")
# Return single result if only one, otherwise concatenate as string
if len(results) == 1:
return results[0]
else:
# Convert all results to strings and concatenate
return '\n'.join(str(result) for result in results)
raise RuntimeError(f"No available vendor for '{method}'")

View File

@ -1,475 +0,0 @@
from typing import Annotated
import pandas as pd
import os
from .config import DATA_DIR
from datetime import datetime
from dateutil.relativedelta import relativedelta
import json
from .reddit_utils import fetch_top_from_category
from tqdm import tqdm
def get_YFin_data_window(
symbol: Annotated[str, "ticker symbol of the company"],
curr_date: Annotated[str, "Start date in yyyy-mm-dd format"],
look_back_days: Annotated[int, "how many days to look back"],
) -> str:
# calculate past days
date_obj = datetime.strptime(curr_date, "%Y-%m-%d")
before = date_obj - relativedelta(days=look_back_days)
start_date = before.strftime("%Y-%m-%d")
# read in data
data = pd.read_csv(
os.path.join(
DATA_DIR,
f"market_data/price_data/{symbol}-YFin-data-2015-01-01-2025-03-25.csv",
)
)
# Extract just the date part for comparison
data["DateOnly"] = data["Date"].str[:10]
# Filter data between the start and end dates (inclusive)
filtered_data = data[
(data["DateOnly"] >= start_date) & (data["DateOnly"] <= curr_date)
]
# Drop the temporary column we created
filtered_data = filtered_data.drop("DateOnly", axis=1)
# Set pandas display options to show the full DataFrame
with pd.option_context(
"display.max_rows", None, "display.max_columns", None, "display.width", None
):
df_string = filtered_data.to_string()
return (
f"## Raw Market Data for {symbol} from {start_date} to {curr_date}:\n\n"
+ df_string
)
def get_YFin_data(
symbol: Annotated[str, "ticker symbol of the company"],
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
) -> str:
# read in data
data = pd.read_csv(
os.path.join(
DATA_DIR,
f"market_data/price_data/{symbol}-YFin-data-2015-01-01-2025-03-25.csv",
)
)
if end_date > "2025-03-25":
raise Exception(
f"Get_YFin_Data: {end_date} is outside of the data range of 2015-01-01 to 2025-03-25"
)
# Extract just the date part for comparison
data["DateOnly"] = data["Date"].str[:10]
# Filter data between the start and end dates (inclusive)
filtered_data = data[
(data["DateOnly"] >= start_date) & (data["DateOnly"] <= end_date)
]
# Drop the temporary column we created
filtered_data = filtered_data.drop("DateOnly", axis=1)
# remove the index from the dataframe
filtered_data = filtered_data.reset_index(drop=True)
return filtered_data
def get_finnhub_news(
query: Annotated[str, "Search query or ticker symbol"],
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
):
"""
Retrieve news about a company within a time frame
Args
query (str): Search query or ticker symbol
start_date (str): Start date in yyyy-mm-dd format
end_date (str): End date in yyyy-mm-dd format
Returns
str: dataframe containing the news of the company in the time frame
"""
result = get_data_in_range(query, start_date, end_date, "news_data", DATA_DIR)
if len(result) == 0:
return ""
combined_result = ""
for day, data in result.items():
if len(data) == 0:
continue
for entry in data:
current_news = (
"### " + entry["headline"] + f" ({day})" + "\n" + entry["summary"]
)
combined_result += current_news + "\n\n"
return f"## {query} News, from {start_date} to {end_date}:\n" + str(combined_result)
def get_finnhub_company_insider_sentiment(
ticker: Annotated[str, "ticker symbol for the company"],
curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"],
):
"""
Retrieve insider sentiment about a company (retrieved from public SEC information) for the past 15 days
Args:
ticker (str): ticker symbol of the company
curr_date (str): current date you are trading on, yyyy-mm-dd
Returns:
str: a report of the sentiment in the past 15 days starting at curr_date
"""
date_obj = datetime.strptime(curr_date, "%Y-%m-%d")
before = date_obj - relativedelta(days=15) # Default 15 days lookback
before = before.strftime("%Y-%m-%d")
data = get_data_in_range(ticker, before, curr_date, "insider_senti", DATA_DIR)
if len(data) == 0:
return ""
result_str = ""
seen_dicts = []
for date, senti_list in data.items():
for entry in senti_list:
if entry not in seen_dicts:
result_str += f"### {entry['year']}-{entry['month']}:\nChange: {entry['change']}\nMonthly Share Purchase Ratio: {entry['mspr']}\n\n"
seen_dicts.append(entry)
return (
f"## {ticker} Insider Sentiment Data for {before} to {curr_date}:\n"
+ result_str
+ "The change field refers to the net buying/selling from all insiders' transactions. The mspr field refers to monthly share purchase ratio."
)
def get_finnhub_company_insider_transactions(
ticker: Annotated[str, "ticker symbol"],
curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"],
):
"""
Retrieve insider transcaction information about a company (retrieved from public SEC information) for the past 15 days
Args:
ticker (str): ticker symbol of the company
curr_date (str): current date you are trading at, yyyy-mm-dd
Returns:
str: a report of the company's insider transaction/trading informtaion in the past 15 days
"""
date_obj = datetime.strptime(curr_date, "%Y-%m-%d")
before = date_obj - relativedelta(days=15) # Default 15 days lookback
before = before.strftime("%Y-%m-%d")
data = get_data_in_range(ticker, before, curr_date, "insider_trans", DATA_DIR)
if len(data) == 0:
return ""
result_str = ""
seen_dicts = []
for date, senti_list in data.items():
for entry in senti_list:
if entry not in seen_dicts:
result_str += f"### Filing Date: {entry['filingDate']}, {entry['name']}:\nChange:{entry['change']}\nShares: {entry['share']}\nTransaction Price: {entry['transactionPrice']}\nTransaction Code: {entry['transactionCode']}\n\n"
seen_dicts.append(entry)
return (
f"## {ticker} insider transactions from {before} to {curr_date}:\n"
+ result_str
+ "The change field reflects the variation in share count—here a negative number indicates a reduction in holdings—while share specifies the total number of shares involved. The transactionPrice denotes the per-share price at which the trade was executed, and transactionDate marks when the transaction occurred. The name field identifies the insider making the trade, and transactionCode (e.g., S for sale) clarifies the nature of the transaction. FilingDate records when the transaction was officially reported, and the unique id links to the specific SEC filing, as indicated by the source. Additionally, the symbol ties the transaction to a particular company, isDerivative flags whether the trade involves derivative securities, and currency notes the currency context of the transaction."
)
def get_data_in_range(ticker, start_date, end_date, data_type, data_dir, period=None):
"""
Gets finnhub data saved and processed on disk.
Args:
start_date (str): Start date in YYYY-MM-DD format.
end_date (str): End date in YYYY-MM-DD format.
data_type (str): Type of data from finnhub to fetch. Can be insider_trans, SEC_filings, news_data, insider_senti, or fin_as_reported.
data_dir (str): Directory where the data is saved.
period (str): Default to none, if there is a period specified, should be annual or quarterly.
"""
if period:
data_path = os.path.join(
data_dir,
"finnhub_data",
data_type,
f"{ticker}_{period}_data_formatted.json",
)
else:
data_path = os.path.join(
data_dir, "finnhub_data", data_type, f"{ticker}_data_formatted.json"
)
data = open(data_path, "r")
data = json.load(data)
# filter keys (date, str in format YYYY-MM-DD) by the date range (str, str in format YYYY-MM-DD)
filtered_data = {}
for key, value in data.items():
if start_date <= key <= end_date and len(value) > 0:
filtered_data[key] = value
return filtered_data
def get_simfin_balance_sheet(
ticker: Annotated[str, "ticker symbol"],
freq: Annotated[
str,
"reporting frequency of the company's financial history: annual / quarterly",
],
curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"],
):
data_path = os.path.join(
DATA_DIR,
"fundamental_data",
"simfin_data_all",
"balance_sheet",
"companies",
"us",
f"us-balance-{freq}.csv",
)
df = pd.read_csv(data_path, sep=";")
# Convert date strings to datetime objects and remove any time components
df["Report Date"] = pd.to_datetime(df["Report Date"], utc=True).dt.normalize()
df["Publish Date"] = pd.to_datetime(df["Publish Date"], utc=True).dt.normalize()
# Convert the current date to datetime and normalize
curr_date_dt = pd.to_datetime(curr_date, utc=True).normalize()
# Filter the DataFrame for the given ticker and for reports that were published on or before the current date
filtered_df = df[(df["Ticker"] == ticker) & (df["Publish Date"] <= curr_date_dt)]
# Check if there are any available reports; if not, return a notification
if filtered_df.empty:
print("No balance sheet available before the given current date.")
return ""
# Get the most recent balance sheet by selecting the row with the latest Publish Date
latest_balance_sheet = filtered_df.loc[filtered_df["Publish Date"].idxmax()]
# drop the SimFinID column
latest_balance_sheet = latest_balance_sheet.drop("SimFinId")
return (
f"## {freq} balance sheet for {ticker} released on {str(latest_balance_sheet['Publish Date'])[0:10]}: \n"
+ str(latest_balance_sheet)
+ "\n\nThis includes metadata like reporting dates and currency, share details, and a breakdown of assets, liabilities, and equity. Assets are grouped as current (liquid items like cash and receivables) and noncurrent (long-term investments and property). Liabilities are split between short-term obligations and long-term debts, while equity reflects shareholder funds such as paid-in capital and retained earnings. Together, these components ensure that total assets equal the sum of liabilities and equity."
)
def get_simfin_cashflow(
ticker: Annotated[str, "ticker symbol"],
freq: Annotated[
str,
"reporting frequency of the company's financial history: annual / quarterly",
],
curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"],
):
data_path = os.path.join(
DATA_DIR,
"fundamental_data",
"simfin_data_all",
"cash_flow",
"companies",
"us",
f"us-cashflow-{freq}.csv",
)
df = pd.read_csv(data_path, sep=";")
# Convert date strings to datetime objects and remove any time components
df["Report Date"] = pd.to_datetime(df["Report Date"], utc=True).dt.normalize()
df["Publish Date"] = pd.to_datetime(df["Publish Date"], utc=True).dt.normalize()
# Convert the current date to datetime and normalize
curr_date_dt = pd.to_datetime(curr_date, utc=True).normalize()
# Filter the DataFrame for the given ticker and for reports that were published on or before the current date
filtered_df = df[(df["Ticker"] == ticker) & (df["Publish Date"] <= curr_date_dt)]
# Check if there are any available reports; if not, return a notification
if filtered_df.empty:
print("No cash flow statement available before the given current date.")
return ""
# Get the most recent cash flow statement by selecting the row with the latest Publish Date
latest_cash_flow = filtered_df.loc[filtered_df["Publish Date"].idxmax()]
# drop the SimFinID column
latest_cash_flow = latest_cash_flow.drop("SimFinId")
return (
f"## {freq} cash flow statement for {ticker} released on {str(latest_cash_flow['Publish Date'])[0:10]}: \n"
+ str(latest_cash_flow)
+ "\n\nThis includes metadata like reporting dates and currency, share details, and a breakdown of cash movements. Operating activities show cash generated from core business operations, including net income adjustments for non-cash items and working capital changes. Investing activities cover asset acquisitions/disposals and investments. Financing activities include debt transactions, equity issuances/repurchases, and dividend payments. The net change in cash represents the overall increase or decrease in the company's cash position during the reporting period."
)
def get_simfin_income_statements(
ticker: Annotated[str, "ticker symbol"],
freq: Annotated[
str,
"reporting frequency of the company's financial history: annual / quarterly",
],
curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"],
):
data_path = os.path.join(
DATA_DIR,
"fundamental_data",
"simfin_data_all",
"income_statements",
"companies",
"us",
f"us-income-{freq}.csv",
)
df = pd.read_csv(data_path, sep=";")
# Convert date strings to datetime objects and remove any time components
df["Report Date"] = pd.to_datetime(df["Report Date"], utc=True).dt.normalize()
df["Publish Date"] = pd.to_datetime(df["Publish Date"], utc=True).dt.normalize()
# Convert the current date to datetime and normalize
curr_date_dt = pd.to_datetime(curr_date, utc=True).normalize()
# Filter the DataFrame for the given ticker and for reports that were published on or before the current date
filtered_df = df[(df["Ticker"] == ticker) & (df["Publish Date"] <= curr_date_dt)]
# Check if there are any available reports; if not, return a notification
if filtered_df.empty:
print("No income statement available before the given current date.")
return ""
# Get the most recent income statement by selecting the row with the latest Publish Date
latest_income = filtered_df.loc[filtered_df["Publish Date"].idxmax()]
# drop the SimFinID column
latest_income = latest_income.drop("SimFinId")
return (
f"## {freq} income statement for {ticker} released on {str(latest_income['Publish Date'])[0:10]}: \n"
+ str(latest_income)
+ "\n\nThis includes metadata like reporting dates and currency, share details, and a comprehensive breakdown of the company's financial performance. Starting with Revenue, it shows Cost of Revenue and resulting Gross Profit. Operating Expenses are detailed, including SG&A, R&D, and Depreciation. The statement then shows Operating Income, followed by non-operating items and Interest Expense, leading to Pretax Income. After accounting for Income Tax and any Extraordinary items, it concludes with Net Income, representing the company's bottom-line profit or loss for the period."
)
def get_reddit_global_news(
curr_date: Annotated[str, "Current date in yyyy-mm-dd format"],
look_back_days: Annotated[int, "Number of days to look back"] = 7,
limit: Annotated[int, "Maximum number of articles to return"] = 5,
) -> str:
"""
Retrieve the latest top reddit news
Args:
curr_date: Current date in yyyy-mm-dd format
look_back_days: Number of days to look back (default 7)
limit: Maximum number of articles to return (default 5)
Returns:
str: A formatted string containing the latest news articles posts on reddit
"""
curr_date_dt = datetime.strptime(curr_date, "%Y-%m-%d")
before = curr_date_dt - relativedelta(days=look_back_days)
before = before.strftime("%Y-%m-%d")
posts = []
# iterate from before to curr_date
curr_iter_date = datetime.strptime(before, "%Y-%m-%d")
total_iterations = (curr_date_dt - curr_iter_date).days + 1
pbar = tqdm(desc=f"Getting Global News on {curr_date}", total=total_iterations)
while curr_iter_date <= curr_date_dt:
curr_date_str = curr_iter_date.strftime("%Y-%m-%d")
fetch_result = fetch_top_from_category(
"global_news",
curr_date_str,
limit,
data_path=os.path.join(DATA_DIR, "reddit_data"),
)
posts.extend(fetch_result)
curr_iter_date += relativedelta(days=1)
pbar.update(1)
pbar.close()
if len(posts) == 0:
return ""
news_str = ""
for post in posts:
if post["content"] == "":
news_str += f"### {post['title']}\n\n"
else:
news_str += f"### {post['title']}\n\n{post['content']}\n\n"
return f"## Global News Reddit, from {before} to {curr_date}:\n{news_str}"
def get_reddit_company_news(
query: Annotated[str, "Search query or ticker symbol"],
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
) -> str:
"""
Retrieve the latest top reddit news
Args:
query: Search query or ticker symbol
start_date: Start date in yyyy-mm-dd format
end_date: End date in yyyy-mm-dd format
Returns:
str: A formatted string containing news articles posts on reddit
"""
start_date_dt = datetime.strptime(start_date, "%Y-%m-%d")
end_date_dt = datetime.strptime(end_date, "%Y-%m-%d")
posts = []
# iterate from start_date to end_date
curr_date = start_date_dt
total_iterations = (end_date_dt - curr_date).days + 1
pbar = tqdm(
desc=f"Getting Company News for {query} from {start_date} to {end_date}",
total=total_iterations,
)
while curr_date <= end_date_dt:
curr_date_str = curr_date.strftime("%Y-%m-%d")
fetch_result = fetch_top_from_category(
"company_news",
curr_date_str,
10, # max limit per day
query,
data_path=os.path.join(DATA_DIR, "reddit_data"),
)
posts.extend(fetch_result)
curr_date += relativedelta(days=1)
pbar.update(1)
pbar.close()
if len(posts) == 0:
return ""
news_str = ""
for post in posts:
if post["content"] == "":
news_str += f"### {post['title']}\n\n"
else:
news_str += f"### {post['title']}\n\n{post['content']}\n\n"
return f"##{query} News Reddit, from {start_date} to {end_date}:\n\n{news_str}"

View File

@ -1,107 +0,0 @@
from openai import OpenAI
from .config import get_config
def get_stock_news_openai(query, start_date, end_date):
config = get_config()
client = OpenAI(base_url=config["backend_url"])
response = client.responses.create(
model=config["quick_think_llm"],
input=[
{
"role": "system",
"content": [
{
"type": "input_text",
"text": f"Can you search Social Media for {query} from {start_date} to {end_date}? Make sure you only get the data posted during that period.",
}
],
}
],
text={"format": {"type": "text"}},
reasoning={},
tools=[
{
"type": "web_search_preview",
"user_location": {"type": "approximate"},
"search_context_size": "low",
}
],
temperature=1,
max_output_tokens=4096,
top_p=1,
store=True,
)
return response.output[1].content[0].text
def get_global_news_openai(curr_date, look_back_days=7, limit=5):
config = get_config()
client = OpenAI(base_url=config["backend_url"])
response = client.responses.create(
model=config["quick_think_llm"],
input=[
{
"role": "system",
"content": [
{
"type": "input_text",
"text": f"Can you search global or macroeconomics news from {look_back_days} days before {curr_date} to {curr_date} that would be informative for trading purposes? Make sure you only get the data posted during that period. Limit the results to {limit} articles.",
}
],
}
],
text={"format": {"type": "text"}},
reasoning={},
tools=[
{
"type": "web_search_preview",
"user_location": {"type": "approximate"},
"search_context_size": "low",
}
],
temperature=1,
max_output_tokens=4096,
top_p=1,
store=True,
)
return response.output[1].content[0].text
def get_fundamentals_openai(ticker, curr_date):
config = get_config()
client = OpenAI(base_url=config["backend_url"])
response = client.responses.create(
model=config["quick_think_llm"],
input=[
{
"role": "system",
"content": [
{
"type": "input_text",
"text": f"Can you search Fundamental for discussions on {ticker} during of the month before {curr_date} to the month of {curr_date}. Make sure you only get the data posted during that period. List as a table, with PE/PS/Cash flow/ etc",
}
],
}
],
text={"format": {"type": "text"}},
reasoning={},
tools=[
{
"type": "web_search_preview",
"user_location": {"type": "approximate"},
"search_context_size": "low",
}
],
temperature=1,
max_output_tokens=4096,
top_p=1,
store=True,
)
return response.output[1].content[0].text

View File

@ -1,135 +0,0 @@
import requests
import time
import json
from datetime import datetime, timedelta
from contextlib import contextmanager
from typing import Annotated
import os
import re
ticker_to_company = {
"AAPL": "Apple",
"MSFT": "Microsoft",
"GOOGL": "Google",
"AMZN": "Amazon",
"TSLA": "Tesla",
"NVDA": "Nvidia",
"TSM": "Taiwan Semiconductor Manufacturing Company OR TSMC",
"JPM": "JPMorgan Chase OR JP Morgan",
"JNJ": "Johnson & Johnson OR JNJ",
"V": "Visa",
"WMT": "Walmart",
"META": "Meta OR Facebook",
"AMD": "AMD",
"INTC": "Intel",
"QCOM": "Qualcomm",
"BABA": "Alibaba",
"ADBE": "Adobe",
"NFLX": "Netflix",
"CRM": "Salesforce",
"PYPL": "PayPal",
"PLTR": "Palantir",
"MU": "Micron",
"SQ": "Block OR Square",
"ZM": "Zoom",
"CSCO": "Cisco",
"SHOP": "Shopify",
"ORCL": "Oracle",
"X": "Twitter OR X",
"SPOT": "Spotify",
"AVGO": "Broadcom",
"ASML": "ASML ",
"TWLO": "Twilio",
"SNAP": "Snap Inc.",
"TEAM": "Atlassian",
"SQSP": "Squarespace",
"UBER": "Uber",
"ROKU": "Roku",
"PINS": "Pinterest",
}
def fetch_top_from_category(
category: Annotated[
str, "Category to fetch top post from. Collection of subreddits."
],
date: Annotated[str, "Date to fetch top posts from."],
max_limit: Annotated[int, "Maximum number of posts to fetch."],
query: Annotated[str, "Optional query to search for in the subreddit."] = None,
data_path: Annotated[
str,
"Path to the data folder. Default is 'reddit_data'.",
] = "reddit_data",
):
base_path = data_path
all_content = []
if max_limit < len(os.listdir(os.path.join(base_path, category))):
raise ValueError(
"REDDIT FETCHING ERROR: max limit is less than the number of files in the category. Will not be able to fetch any posts"
)
limit_per_subreddit = max_limit // len(
os.listdir(os.path.join(base_path, category))
)
for data_file in os.listdir(os.path.join(base_path, category)):
# check if data_file is a .jsonl file
if not data_file.endswith(".jsonl"):
continue
all_content_curr_subreddit = []
with open(os.path.join(base_path, category, data_file), "rb") as f:
for i, line in enumerate(f):
# skip empty lines
if not line.strip():
continue
parsed_line = json.loads(line)
# select only lines that are from the date
post_date = datetime.utcfromtimestamp(
parsed_line["created_utc"]
).strftime("%Y-%m-%d")
if post_date != date:
continue
# if is company_news, check that the title or the content has the company's name (query) mentioned
if "company" in category and query:
search_terms = []
if "OR" in ticker_to_company[query]:
search_terms = ticker_to_company[query].split(" OR ")
else:
search_terms = [ticker_to_company[query]]
search_terms.append(query)
found = False
for term in search_terms:
if re.search(
term, parsed_line["title"], re.IGNORECASE
) or re.search(term, parsed_line["selftext"], re.IGNORECASE):
found = True
break
if not found:
continue
post = {
"title": parsed_line["title"],
"content": parsed_line["selftext"],
"url": parsed_line["url"],
"upvotes": parsed_line["ups"],
"posted_date": post_date,
}
all_content_curr_subreddit.append(post)
# sort all_content_curr_subreddit by upvote_ratio in descending order
all_content_curr_subreddit.sort(key=lambda x: x["upvotes"], reverse=True)
all_content.extend(all_content_curr_subreddit[:limit_per_subreddit])
return all_content

View File

@ -3,7 +3,7 @@ import yfinance as yf
from stockstats import wrap
from typing import Annotated
import os
from .config import get_config, DATA_DIR
from .config import get_config
class StockstatsUtils:
@ -17,63 +17,45 @@ class StockstatsUtils:
str, "curr date for retrieving stock price data, YYYY-mm-dd"
],
):
# Get config and set up data directory path
config = get_config()
online = config["data_vendors"]["technical_indicators"] != "local"
df = None
data = None
today_date = pd.Timestamp.today()
curr_date_dt = pd.to_datetime(curr_date)
if not online:
try:
data = pd.read_csv(
os.path.join(
DATA_DIR,
f"{symbol}-YFin-data-2015-01-01-2025-03-25.csv",
)
)
df = wrap(data)
except FileNotFoundError:
raise Exception("Stockstats fail: Yahoo Finance data not fetched yet!")
end_date = today_date
start_date = today_date - pd.DateOffset(years=15)
start_date_str = start_date.strftime("%Y-%m-%d")
end_date_str = end_date.strftime("%Y-%m-%d")
# Ensure cache directory exists
os.makedirs(config["data_cache_dir"], exist_ok=True)
data_file = os.path.join(
config["data_cache_dir"],
f"{symbol}-YFin-data-{start_date_str}-{end_date_str}.csv",
)
if os.path.exists(data_file):
data = pd.read_csv(data_file)
data["Date"] = pd.to_datetime(data["Date"])
else:
# Get today's date as YYYY-mm-dd to add to cache
today_date = pd.Timestamp.today()
curr_date = pd.to_datetime(curr_date)
end_date = today_date
start_date = today_date - pd.DateOffset(years=15)
start_date = start_date.strftime("%Y-%m-%d")
end_date = end_date.strftime("%Y-%m-%d")
# Get config and ensure cache directory exists
os.makedirs(config["data_cache_dir"], exist_ok=True)
data_file = os.path.join(
config["data_cache_dir"],
f"{symbol}-YFin-data-{start_date}-{end_date}.csv",
data = yf.download(
symbol,
start=start_date_str,
end=end_date_str,
multi_level_index=False,
progress=False,
auto_adjust=True,
)
data = data.reset_index()
data.to_csv(data_file, index=False)
if os.path.exists(data_file):
data = pd.read_csv(data_file)
data["Date"] = pd.to_datetime(data["Date"])
else:
data = yf.download(
symbol,
start=start_date,
end=end_date,
multi_level_index=False,
progress=False,
auto_adjust=True,
)
data = data.reset_index()
data.to_csv(data_file, index=False)
df = wrap(data)
df["Date"] = df["Date"].dt.strftime("%Y-%m-%d")
curr_date = curr_date.strftime("%Y-%m-%d")
df = wrap(data)
df["Date"] = df["Date"].dt.strftime("%Y-%m-%d")
curr_date_str = curr_date_dt.strftime("%Y-%m-%d")
df[indicator] # trigger stockstats to calculate the indicator
matching_rows = df[df["Date"].str.startswith(curr_date)]
matching_rows = df[df["Date"].str.startswith(curr_date_str)]
if not matching_rows.empty:
indicator_value = matching_rows[indicator].values[0]

View File

@ -293,6 +293,63 @@ def get_stockstats_indicator(
return str(indicator_value)
def get_fundamentals(
ticker: Annotated[str, "ticker symbol of the company"],
curr_date: Annotated[str, "current date (not used for yfinance)"] = None
):
"""Get company fundamentals overview from yfinance."""
try:
ticker_obj = yf.Ticker(ticker.upper())
info = ticker_obj.info
if not info:
return f"No fundamentals data found for symbol '{ticker}'"
fields = [
("Name", info.get("longName")),
("Sector", info.get("sector")),
("Industry", info.get("industry")),
("Market Cap", info.get("marketCap")),
("PE Ratio (TTM)", info.get("trailingPE")),
("Forward PE", info.get("forwardPE")),
("PEG Ratio", info.get("pegRatio")),
("Price to Book", info.get("priceToBook")),
("EPS (TTM)", info.get("trailingEps")),
("Forward EPS", info.get("forwardEps")),
("Dividend Yield", info.get("dividendYield")),
("Beta", info.get("beta")),
("52 Week High", info.get("fiftyTwoWeekHigh")),
("52 Week Low", info.get("fiftyTwoWeekLow")),
("50 Day Average", info.get("fiftyDayAverage")),
("200 Day Average", info.get("twoHundredDayAverage")),
("Revenue (TTM)", info.get("totalRevenue")),
("Gross Profit", info.get("grossProfits")),
("EBITDA", info.get("ebitda")),
("Net Income", info.get("netIncomeToCommon")),
("Profit Margin", info.get("profitMargins")),
("Operating Margin", info.get("operatingMargins")),
("Return on Equity", info.get("returnOnEquity")),
("Return on Assets", info.get("returnOnAssets")),
("Debt to Equity", info.get("debtToEquity")),
("Current Ratio", info.get("currentRatio")),
("Book Value", info.get("bookValue")),
("Free Cash Flow", info.get("freeCashflow")),
]
lines = []
for label, value in fields:
if value is not None:
lines.append(f"{label}: {value}")
header = f"# Company Fundamentals for {ticker.upper()}\n"
header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
return header + "\n".join(lines)
except Exception as e:
return f"Error retrieving fundamentals for {ticker}: {str(e)}"
def get_balance_sheet(
ticker: Annotated[str, "ticker symbol of the company"],
freq: Annotated[str, "frequency of data: 'annual' or 'quarterly'"] = "quarterly",

View File

@ -1,117 +0,0 @@
# gets data/stats
import yfinance as yf
from typing import Annotated, Callable, Any, Optional
from pandas import DataFrame
import pandas as pd
from functools import wraps
from .utils import save_output, SavePathType, decorate_all_methods
def init_ticker(func: Callable) -> Callable:
"""Decorator to initialize yf.Ticker and pass it to the function."""
@wraps(func)
def wrapper(symbol: Annotated[str, "ticker symbol"], *args, **kwargs) -> Any:
ticker = yf.Ticker(symbol)
return func(ticker, *args, **kwargs)
return wrapper
@decorate_all_methods(init_ticker)
class YFinanceUtils:
def get_stock_data(
symbol: Annotated[str, "ticker symbol"],
start_date: Annotated[
str, "start date for retrieving stock price data, YYYY-mm-dd"
],
end_date: Annotated[
str, "end date for retrieving stock price data, YYYY-mm-dd"
],
save_path: SavePathType = None,
) -> DataFrame:
"""retrieve stock price data for designated ticker symbol"""
ticker = symbol
# add one day to the end_date so that the data range is inclusive
end_date = pd.to_datetime(end_date) + pd.DateOffset(days=1)
end_date = end_date.strftime("%Y-%m-%d")
stock_data = ticker.history(start=start_date, end=end_date)
# save_output(stock_data, f"Stock data for {ticker.ticker}", save_path)
return stock_data
def get_stock_info(
symbol: Annotated[str, "ticker symbol"],
) -> dict:
"""Fetches and returns latest stock information."""
ticker = symbol
stock_info = ticker.info
return stock_info
def get_company_info(
symbol: Annotated[str, "ticker symbol"],
save_path: Optional[str] = None,
) -> DataFrame:
"""Fetches and returns company information as a DataFrame."""
ticker = symbol
info = ticker.info
company_info = {
"Company Name": info.get("shortName", "N/A"),
"Industry": info.get("industry", "N/A"),
"Sector": info.get("sector", "N/A"),
"Country": info.get("country", "N/A"),
"Website": info.get("website", "N/A"),
}
company_info_df = DataFrame([company_info])
if save_path:
company_info_df.to_csv(save_path)
print(f"Company info for {ticker.ticker} saved to {save_path}")
return company_info_df
def get_stock_dividends(
symbol: Annotated[str, "ticker symbol"],
save_path: Optional[str] = None,
) -> DataFrame:
"""Fetches and returns the latest dividends data as a DataFrame."""
ticker = symbol
dividends = ticker.dividends
if save_path:
dividends.to_csv(save_path)
print(f"Dividends for {ticker.ticker} saved to {save_path}")
return dividends
def get_income_stmt(symbol: Annotated[str, "ticker symbol"]) -> DataFrame:
"""Fetches and returns the latest income statement of the company as a DataFrame."""
ticker = symbol
income_stmt = ticker.financials
return income_stmt
def get_balance_sheet(symbol: Annotated[str, "ticker symbol"]) -> DataFrame:
"""Fetches and returns the latest balance sheet of the company as a DataFrame."""
ticker = symbol
balance_sheet = ticker.balance_sheet
return balance_sheet
def get_cash_flow(symbol: Annotated[str, "ticker symbol"]) -> DataFrame:
"""Fetches and returns the latest cash flow statement of the company as a DataFrame."""
ticker = symbol
cash_flow = ticker.cashflow
return cash_flow
def get_analyst_recommendations(symbol: Annotated[str, "ticker symbol"]) -> tuple:
"""Fetches the latest analyst recommendations and returns the most common recommendation and its count."""
ticker = symbol
recommendations = ticker.recommendations
if recommendations.empty:
return None, 0 # No recommendations available
# Assuming 'period' column exists and needs to be excluded
row_0 = recommendations.iloc[0, 1:] # Exclude 'period' column if necessary
# Find the maximum voting result
max_votes = row_0.max()
majority_voting_result = row_0[row_0 == max_votes].index.tolist()
return majority_voting_result[0], max_votes

View File

@ -0,0 +1,190 @@
"""yfinance-based news data fetching functions."""
import yfinance as yf
from datetime import datetime
from dateutil.relativedelta import relativedelta
def _extract_article_data(article: dict) -> dict:
"""Extract article data from yfinance news format (handles nested 'content' structure)."""
# Handle nested content structure
if "content" in article:
content = article["content"]
title = content.get("title", "No title")
summary = content.get("summary", "")
provider = content.get("provider", {})
publisher = provider.get("displayName", "Unknown")
# Get URL from canonicalUrl or clickThroughUrl
url_obj = content.get("canonicalUrl") or content.get("clickThroughUrl") or {}
link = url_obj.get("url", "")
# Get publish date
pub_date_str = content.get("pubDate", "")
pub_date = None
if pub_date_str:
try:
pub_date = datetime.fromisoformat(pub_date_str.replace("Z", "+00:00"))
except (ValueError, AttributeError):
pass
return {
"title": title,
"summary": summary,
"publisher": publisher,
"link": link,
"pub_date": pub_date,
}
else:
# Fallback for flat structure
return {
"title": article.get("title", "No title"),
"summary": article.get("summary", ""),
"publisher": article.get("publisher", "Unknown"),
"link": article.get("link", ""),
"pub_date": None,
}
def get_news_yfinance(
ticker: str,
start_date: str,
end_date: str,
) -> str:
"""
Retrieve news for a specific stock ticker using yfinance.
Args:
ticker: Stock ticker symbol (e.g., "AAPL")
start_date: Start date in yyyy-mm-dd format
end_date: End date in yyyy-mm-dd format
Returns:
Formatted string containing news articles
"""
try:
stock = yf.Ticker(ticker)
news = stock.get_news(count=20)
if not news:
return f"No news found for {ticker}"
# Parse date range for filtering
start_dt = datetime.strptime(start_date, "%Y-%m-%d")
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
news_str = ""
filtered_count = 0
for article in news:
data = _extract_article_data(article)
# Filter by date if publish time is available
if data["pub_date"]:
pub_date_naive = data["pub_date"].replace(tzinfo=None)
if not (start_dt <= pub_date_naive <= end_dt + relativedelta(days=1)):
continue
news_str += f"### {data['title']} (source: {data['publisher']})\n"
if data["summary"]:
news_str += f"{data['summary']}\n"
if data["link"]:
news_str += f"Link: {data['link']}\n"
news_str += "\n"
filtered_count += 1
if filtered_count == 0:
return f"No news found for {ticker} between {start_date} and {end_date}"
return f"## {ticker} News, from {start_date} to {end_date}:\n\n{news_str}"
except Exception as e:
return f"Error fetching news for {ticker}: {str(e)}"
def get_global_news_yfinance(
curr_date: str,
look_back_days: int = 7,
limit: int = 10,
) -> str:
"""
Retrieve global/macro economic news using yfinance Search.
Args:
curr_date: Current date in yyyy-mm-dd format
look_back_days: Number of days to look back
limit: Maximum number of articles to return
Returns:
Formatted string containing global news articles
"""
# Search queries for macro/global news
search_queries = [
"stock market economy",
"Federal Reserve interest rates",
"inflation economic outlook",
"global markets trading",
]
all_news = []
seen_titles = set()
try:
for query in search_queries:
search = yf.Search(
query=query,
news_count=limit,
enable_fuzzy_query=True,
)
if search.news:
for article in search.news:
# Handle both flat and nested structures
if "content" in article:
data = _extract_article_data(article)
title = data["title"]
else:
title = article.get("title", "")
# Deduplicate by title
if title and title not in seen_titles:
seen_titles.add(title)
all_news.append(article)
if len(all_news) >= limit:
break
if not all_news:
return f"No global news found for {curr_date}"
# Calculate date range
curr_dt = datetime.strptime(curr_date, "%Y-%m-%d")
start_dt = curr_dt - relativedelta(days=look_back_days)
start_date = start_dt.strftime("%Y-%m-%d")
news_str = ""
for article in all_news[:limit]:
# Handle both flat and nested structures
if "content" in article:
data = _extract_article_data(article)
title = data["title"]
publisher = data["publisher"]
link = data["link"]
summary = data["summary"]
else:
title = article.get("title", "No title")
publisher = article.get("publisher", "Unknown")
link = article.get("link", "")
summary = ""
news_str += f"### {title} (source: {publisher})\n"
if summary:
news_str += f"{summary}\n"
if link:
news_str += f"Link: {link}\n"
news_str += "\n"
return f"## Global Market News, from {start_date} to {curr_date}:\n\n{news_str}"
except Exception as e:
return f"Error fetching global news: {str(e)}"

View File

@ -3,16 +3,18 @@ import os
DEFAULT_CONFIG = {
"project_dir": os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
"results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", "./results"),
"data_dir": "/Users/yluo/Documents/Code/ScAI/FR1-data",
"data_cache_dir": os.path.join(
os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
"dataflows/data_cache",
),
# LLM settings
"llm_provider": "openai",
"deep_think_llm": "o4-mini",
"quick_think_llm": "gpt-4o-mini",
"deep_think_llm": "gpt-5.2",
"quick_think_llm": "gpt-5-mini",
"backend_url": "https://api.openai.com/v1",
# Provider-specific thinking configuration
"google_thinking_level": None, # "high", "minimal", etc.
"openai_reasoning_effort": None, # "medium", "high", "low"
# Debate and discussion settings
"max_debate_rounds": 1,
"max_risk_discuss_rounds": 1,
@ -20,14 +22,13 @@ DEFAULT_CONFIG = {
# Data vendor configuration
# Category-level configuration (default for all tools in category)
"data_vendors": {
"core_stock_apis": "yfinance", # Options: yfinance, alpha_vantage, local
"technical_indicators": "yfinance", # Options: yfinance, alpha_vantage, local
"fundamental_data": "alpha_vantage", # Options: openai, alpha_vantage, local
"news_data": "alpha_vantage", # Options: openai, alpha_vantage, google, local
"core_stock_apis": "yfinance", # Options: alpha_vantage, yfinance
"technical_indicators": "yfinance", # Options: alpha_vantage, yfinance
"fundamental_data": "yfinance", # Options: alpha_vantage, yfinance
"news_data": "yfinance", # Options: alpha_vantage, yfinance
},
# Tool-level configuration (takes precedence over category-level)
"tool_vendors": {
# Example: "get_stock_data": "alpha_vantage", # Override category default
# Example: "get_news": "openai", # Override category default
},
}

View File

@ -60,8 +60,8 @@ class ConditionalLogic:
state["risk_debate_state"]["count"] >= 3 * self.max_risk_discuss_rounds
): # 3 rounds of back-and-forth between 3 agents
return "Risk Judge"
if state["risk_debate_state"]["latest_speaker"].startswith("Risky"):
return "Safe Analyst"
if state["risk_debate_state"]["latest_speaker"].startswith("Safe"):
if state["risk_debate_state"]["latest_speaker"].startswith("Aggressive"):
return "Conservative Analyst"
if state["risk_debate_state"]["latest_speaker"].startswith("Conservative"):
return "Neutral Analyst"
return "Risky Analyst"
return "Aggressive Analyst"

View File

@ -1,6 +1,6 @@
# TradingAgents/graph/propagation.py
from typing import Dict, Any
from typing import Dict, Any, List, Optional
from tradingagents.agents.utils.agent_states import (
AgentState,
InvestDebateState,
@ -29,8 +29,8 @@ class Propagator:
"risk_debate_state": RiskDebateState(
{
"history": "",
"current_risky_response": "",
"current_safe_response": "",
"current_aggressive_response": "",
"current_conservative_response": "",
"current_neutral_response": "",
"count": 0,
}
@ -41,9 +41,17 @@ class Propagator:
"news_report": "",
}
def get_graph_args(self) -> Dict[str, Any]:
"""Get arguments for the graph invocation."""
def get_graph_args(self, callbacks: Optional[List] = None) -> Dict[str, Any]:
"""Get arguments for the graph invocation.
Args:
callbacks: Optional list of callback handlers for tool execution tracking.
Note: LLM callbacks are handled separately via LLM constructor.
"""
config = {"recursion_limit": self.max_recur_limit}
if callbacks:
config["callbacks"] = callbacks
return {
"stream_mode": "values",
"config": {"recursion_limit": self.max_recur_limit},
"config": config,
}

View File

@ -98,9 +98,9 @@ class GraphSetup:
trader_node = create_trader(self.quick_thinking_llm, self.trader_memory)
# Create risk analysis nodes
risky_analyst = create_risky_debator(self.quick_thinking_llm)
aggressive_analyst = create_aggressive_debator(self.quick_thinking_llm)
neutral_analyst = create_neutral_debator(self.quick_thinking_llm)
safe_analyst = create_safe_debator(self.quick_thinking_llm)
conservative_analyst = create_conservative_debator(self.quick_thinking_llm)
risk_manager_node = create_risk_manager(
self.deep_thinking_llm, self.risk_manager_memory
)
@ -121,9 +121,9 @@ class GraphSetup:
workflow.add_node("Bear Researcher", bear_researcher_node)
workflow.add_node("Research Manager", research_manager_node)
workflow.add_node("Trader", trader_node)
workflow.add_node("Risky Analyst", risky_analyst)
workflow.add_node("Aggressive Analyst", aggressive_analyst)
workflow.add_node("Neutral Analyst", neutral_analyst)
workflow.add_node("Safe Analyst", safe_analyst)
workflow.add_node("Conservative Analyst", conservative_analyst)
workflow.add_node("Risk Judge", risk_manager_node)
# Define edges
@ -170,17 +170,17 @@ class GraphSetup:
},
)
workflow.add_edge("Research Manager", "Trader")
workflow.add_edge("Trader", "Risky Analyst")
workflow.add_edge("Trader", "Aggressive Analyst")
workflow.add_conditional_edges(
"Risky Analyst",
"Aggressive Analyst",
self.conditional_logic.should_continue_risk_analysis,
{
"Safe Analyst": "Safe Analyst",
"Conservative Analyst": "Conservative Analyst",
"Risk Judge": "Risk Judge",
},
)
workflow.add_conditional_edges(
"Safe Analyst",
"Conservative Analyst",
self.conditional_logic.should_continue_risk_analysis,
{
"Neutral Analyst": "Neutral Analyst",
@ -191,7 +191,7 @@ class GraphSetup:
"Neutral Analyst",
self.conditional_logic.should_continue_risk_analysis,
{
"Risky Analyst": "Risky Analyst",
"Aggressive Analyst": "Aggressive Analyst",
"Risk Judge": "Risk Judge",
},
)

View File

@ -6,12 +6,10 @@ import json
from datetime import date
from typing import Dict, Any, Tuple, List, Optional
from langchain_openai import ChatOpenAI
from langchain_anthropic import ChatAnthropic
from langchain_google_genai import ChatGoogleGenerativeAI
from langgraph.prebuilt import ToolNode
from tradingagents.llm_clients import create_llm_client
from tradingagents.agents import *
from tradingagents.default_config import DEFAULT_CONFIG
from tradingagents.agents.utils.memory import FinancialSituationMemory
@ -31,7 +29,6 @@ from tradingagents.agents.utils.agent_utils import (
get_cashflow,
get_income_statement,
get_news,
get_insider_sentiment,
get_insider_transactions,
get_global_news
)
@ -51,6 +48,7 @@ class TradingAgentsGraph:
selected_analysts=["market", "social", "news", "fundamentals"],
debug=False,
config: Dict[str, Any] = None,
callbacks: Optional[List] = None,
):
"""Initialize the trading agents graph and components.
@ -58,9 +56,11 @@ class TradingAgentsGraph:
selected_analysts: List of analyst types to include
debug: Whether to run in debug mode
config: Configuration dictionary. If None, uses default config
callbacks: Optional list of callback handlers (e.g., for tracking LLM/tool stats)
"""
self.debug = debug
self.config = config or DEFAULT_CONFIG
self.callbacks = callbacks or []
# Update the interface's config
set_config(self.config)
@ -71,18 +71,28 @@ class TradingAgentsGraph:
exist_ok=True,
)
# Initialize LLMs
if self.config["llm_provider"].lower() == "openai" or self.config["llm_provider"] == "ollama" or self.config["llm_provider"] == "openrouter":
self.deep_thinking_llm = ChatOpenAI(model=self.config["deep_think_llm"], base_url=self.config["backend_url"])
self.quick_thinking_llm = ChatOpenAI(model=self.config["quick_think_llm"], base_url=self.config["backend_url"])
elif self.config["llm_provider"].lower() == "anthropic":
self.deep_thinking_llm = ChatAnthropic(model=self.config["deep_think_llm"], base_url=self.config["backend_url"])
self.quick_thinking_llm = ChatAnthropic(model=self.config["quick_think_llm"], base_url=self.config["backend_url"])
elif self.config["llm_provider"].lower() == "google":
self.deep_thinking_llm = ChatGoogleGenerativeAI(model=self.config["deep_think_llm"])
self.quick_thinking_llm = ChatGoogleGenerativeAI(model=self.config["quick_think_llm"])
else:
raise ValueError(f"Unsupported LLM provider: {self.config['llm_provider']}")
# Initialize LLMs with provider-specific thinking configuration
llm_kwargs = self._get_provider_kwargs()
# Add callbacks to kwargs if provided (passed to LLM constructor)
if self.callbacks:
llm_kwargs["callbacks"] = self.callbacks
deep_client = create_llm_client(
provider=self.config["llm_provider"],
model=self.config["deep_think_llm"],
base_url=self.config.get("backend_url"),
**llm_kwargs,
)
quick_client = create_llm_client(
provider=self.config["llm_provider"],
model=self.config["quick_think_llm"],
base_url=self.config.get("backend_url"),
**llm_kwargs,
)
self.deep_thinking_llm = deep_client.get_llm()
self.quick_thinking_llm = quick_client.get_llm()
# Initialize memories
self.bull_memory = FinancialSituationMemory("bull_memory", self.config)
@ -120,6 +130,23 @@ class TradingAgentsGraph:
# Set up the graph
self.graph = self.graph_setup.setup_graph(selected_analysts)
def _get_provider_kwargs(self) -> Dict[str, Any]:
"""Get provider-specific kwargs for LLM client creation."""
kwargs = {}
provider = self.config.get("llm_provider", "").lower()
if provider == "google":
thinking_level = self.config.get("google_thinking_level")
if thinking_level:
kwargs["thinking_level"] = thinking_level
elif provider == "openai":
reasoning_effort = self.config.get("openai_reasoning_effort")
if reasoning_effort:
kwargs["reasoning_effort"] = reasoning_effort
return kwargs
def _create_tool_nodes(self) -> Dict[str, ToolNode]:
"""Create tool nodes for different data sources using abstract methods."""
return {
@ -142,7 +169,6 @@ class TradingAgentsGraph:
# News and insider information
get_news,
get_global_news,
get_insider_sentiment,
get_insider_transactions,
]
),
@ -214,8 +240,8 @@ class TradingAgentsGraph:
},
"trader_investment_decision": final_state["trader_investment_plan"],
"risk_debate_state": {
"risky_history": final_state["risk_debate_state"]["risky_history"],
"safe_history": final_state["risk_debate_state"]["safe_history"],
"aggressive_history": final_state["risk_debate_state"]["aggressive_history"],
"conservative_history": final_state["risk_debate_state"]["conservative_history"],
"neutral_history": final_state["risk_debate_state"]["neutral_history"],
"history": final_state["risk_debate_state"]["history"],
"judge_decision": final_state["risk_debate_state"]["judge_decision"],

View File

@ -0,0 +1,24 @@
# LLM Clients - Consistency Improvements
## Issues to Fix
### 1. `validate_model()` is never called
- Add validation call in `get_llm()` with warning (not error) for unknown models
### 2. Inconsistent parameter handling
| Client | API Key Param | Special Params |
|--------|---------------|----------------|
| OpenAI | `api_key` | `reasoning_effort` |
| Anthropic | `api_key` | `thinking_config``thinking` |
| Google | `google_api_key` | `thinking_budget` |
**Fix:** Standardize with unified `api_key` that maps to provider-specific keys
### 3. `base_url` accepted but ignored
- `AnthropicClient`: accepts `base_url` but never uses it
- `GoogleClient`: accepts `base_url` but never uses it (correct - Google doesn't support it)
**Fix:** Remove unused `base_url` from clients that don't support it
### 4. Update validators.py with models from CLI
- Sync `VALID_MODELS` dict with CLI model options after Feature 2 is complete

View File

@ -0,0 +1,4 @@
from .base_client import BaseLLMClient
from .factory import create_llm_client
__all__ = ["BaseLLMClient", "create_llm_client"]

View File

@ -0,0 +1,27 @@
from typing import Any, Optional
from langchain_anthropic import ChatAnthropic
from .base_client import BaseLLMClient
from .validators import validate_model
class AnthropicClient(BaseLLMClient):
"""Client for Anthropic Claude models."""
def __init__(self, model: str, base_url: Optional[str] = None, **kwargs):
super().__init__(model, base_url, **kwargs)
def get_llm(self) -> Any:
"""Return configured ChatAnthropic instance."""
llm_kwargs = {"model": self.model}
for key in ("timeout", "max_retries", "api_key", "max_tokens", "callbacks"):
if key in self.kwargs:
llm_kwargs[key] = self.kwargs[key]
return ChatAnthropic(**llm_kwargs)
def validate_model(self) -> bool:
"""Validate model for Anthropic."""
return validate_model("anthropic", self.model)

View File

@ -0,0 +1,21 @@
from abc import ABC, abstractmethod
from typing import Any, Optional
class BaseLLMClient(ABC):
"""Abstract base class for LLM clients."""
def __init__(self, model: str, base_url: Optional[str] = None, **kwargs):
self.model = model
self.base_url = base_url
self.kwargs = kwargs
@abstractmethod
def get_llm(self) -> Any:
"""Return the configured LLM instance."""
pass
@abstractmethod
def validate_model(self) -> bool:
"""Validate that the model is supported by this client."""
pass

View File

@ -0,0 +1,43 @@
from typing import Optional
from .base_client import BaseLLMClient
from .openai_client import OpenAIClient
from .anthropic_client import AnthropicClient
from .google_client import GoogleClient
def create_llm_client(
provider: str,
model: str,
base_url: Optional[str] = None,
**kwargs,
) -> BaseLLMClient:
"""Create an LLM client for the specified provider.
Args:
provider: LLM provider (openai, anthropic, google, xai, ollama, openrouter)
model: Model name/identifier
base_url: Optional base URL for API endpoint
**kwargs: Additional provider-specific arguments
Returns:
Configured BaseLLMClient instance
Raises:
ValueError: If provider is not supported
"""
provider_lower = provider.lower()
if provider_lower in ("openai", "ollama", "openrouter"):
return OpenAIClient(model, base_url, provider=provider_lower, **kwargs)
if provider_lower == "xai":
return OpenAIClient(model, base_url, provider="xai", **kwargs)
if provider_lower == "anthropic":
return AnthropicClient(model, base_url, **kwargs)
if provider_lower == "google":
return GoogleClient(model, base_url, **kwargs)
raise ValueError(f"Unsupported LLM provider: {provider}")

View File

@ -0,0 +1,65 @@
from typing import Any, Optional
from langchain_google_genai import ChatGoogleGenerativeAI
from .base_client import BaseLLMClient
from .validators import validate_model
class NormalizedChatGoogleGenerativeAI(ChatGoogleGenerativeAI):
"""ChatGoogleGenerativeAI with normalized content output.
Gemini 3 models return content as list: [{'type': 'text', 'text': '...'}]
This normalizes to string for consistent downstream handling.
"""
def _normalize_content(self, response):
content = response.content
if isinstance(content, list):
texts = [
item.get("text", "") if isinstance(item, dict) and item.get("type") == "text"
else item if isinstance(item, str) else ""
for item in content
]
response.content = "\n".join(t for t in texts if t)
return response
def invoke(self, input, config=None, **kwargs):
return self._normalize_content(super().invoke(input, config, **kwargs))
class GoogleClient(BaseLLMClient):
"""Client for Google Gemini models."""
def __init__(self, model: str, base_url: Optional[str] = None, **kwargs):
super().__init__(model, base_url, **kwargs)
def get_llm(self) -> Any:
"""Return configured ChatGoogleGenerativeAI instance."""
llm_kwargs = {"model": self.model}
for key in ("timeout", "max_retries", "google_api_key", "callbacks"):
if key in self.kwargs:
llm_kwargs[key] = self.kwargs[key]
# Map thinking_level to appropriate API param based on model
# Gemini 3 Pro: low, high
# Gemini 3 Flash: minimal, low, medium, high
# Gemini 2.5: thinking_budget (0=disable, -1=dynamic)
thinking_level = self.kwargs.get("thinking_level")
if thinking_level:
model_lower = self.model.lower()
if "gemini-3" in model_lower:
# Gemini 3 Pro doesn't support "minimal", use "low" instead
if "pro" in model_lower and thinking_level == "minimal":
thinking_level = "low"
llm_kwargs["thinking_level"] = thinking_level
else:
# Gemini 2.5: map to thinking_budget
llm_kwargs["thinking_budget"] = -1 if thinking_level == "high" else 0
return NormalizedChatGoogleGenerativeAI(**llm_kwargs)
def validate_model(self) -> bool:
"""Validate model for Google."""
return validate_model("google", self.model)

View File

@ -0,0 +1,72 @@
import os
from typing import Any, Optional
from langchain_openai import ChatOpenAI
from .base_client import BaseLLMClient
from .validators import validate_model
class UnifiedChatOpenAI(ChatOpenAI):
"""ChatOpenAI subclass that strips incompatible params for certain models."""
def __init__(self, **kwargs):
model = kwargs.get("model", "")
if self._is_reasoning_model(model):
kwargs.pop("temperature", None)
kwargs.pop("top_p", None)
super().__init__(**kwargs)
@staticmethod
def _is_reasoning_model(model: str) -> bool:
"""Check if model is a reasoning model that doesn't support temperature."""
model_lower = model.lower()
return (
model_lower.startswith("o1")
or model_lower.startswith("o3")
or "gpt-5" in model_lower
)
class OpenAIClient(BaseLLMClient):
"""Client for OpenAI, Ollama, OpenRouter, and xAI providers."""
def __init__(
self,
model: str,
base_url: Optional[str] = None,
provider: str = "openai",
**kwargs,
):
super().__init__(model, base_url, **kwargs)
self.provider = provider.lower()
def get_llm(self) -> Any:
"""Return configured ChatOpenAI instance."""
llm_kwargs = {"model": self.model}
if self.provider == "xai":
llm_kwargs["base_url"] = "https://api.x.ai/v1"
api_key = os.environ.get("XAI_API_KEY")
if api_key:
llm_kwargs["api_key"] = api_key
elif self.provider == "openrouter":
llm_kwargs["base_url"] = "https://openrouter.ai/api/v1"
api_key = os.environ.get("OPENROUTER_API_KEY")
if api_key:
llm_kwargs["api_key"] = api_key
elif self.provider == "ollama":
llm_kwargs["base_url"] = "http://localhost:11434/v1"
llm_kwargs["api_key"] = "ollama" # Ollama doesn't require auth
elif self.base_url:
llm_kwargs["base_url"] = self.base_url
for key in ("timeout", "max_retries", "reasoning_effort", "api_key", "callbacks"):
if key in self.kwargs:
llm_kwargs[key] = self.kwargs[key]
return UnifiedChatOpenAI(**llm_kwargs)
def validate_model(self) -> bool:
"""Validate model for the provider."""
return validate_model(self.provider, self.model)

View File

@ -0,0 +1,82 @@
"""Model name validators for each provider.
Only validates model names - does NOT enforce limits.
Let LLM providers use their own defaults for unspecified params.
"""
VALID_MODELS = {
"openai": [
# GPT-5 series (2025)
"gpt-5.2",
"gpt-5.1",
"gpt-5",
"gpt-5-mini",
"gpt-5-nano",
# GPT-4.1 series (2025)
"gpt-4.1",
"gpt-4.1-mini",
"gpt-4.1-nano",
# o-series reasoning models
"o4-mini",
"o3",
"o3-mini",
"o1",
"o1-preview",
# GPT-4o series (legacy but still supported)
"gpt-4o",
"gpt-4o-mini",
],
"anthropic": [
# Claude 4.5 series (2025)
"claude-opus-4-5",
"claude-sonnet-4-5",
"claude-haiku-4-5",
# Claude 4.x series
"claude-opus-4-1-20250805",
"claude-sonnet-4-20250514",
# Claude 3.7 series
"claude-3-7-sonnet-20250219",
# Claude 3.5 series (legacy)
"claude-3-5-haiku-20241022",
"claude-3-5-sonnet-20241022",
],
"google": [
# Gemini 3 series (preview)
"gemini-3-pro-preview",
"gemini-3-flash-preview",
# Gemini 2.5 series
"gemini-2.5-pro",
"gemini-2.5-flash",
"gemini-2.5-flash-lite",
# Gemini 2.0 series
"gemini-2.0-flash",
"gemini-2.0-flash-lite",
],
"xai": [
# Grok 4.1 series
"grok-4-1-fast",
"grok-4-1-fast-reasoning",
"grok-4-1-fast-non-reasoning",
# Grok 4 series
"grok-4",
"grok-4-0709",
"grok-4-fast-reasoning",
"grok-4-fast-non-reasoning",
],
}
def validate_model(provider: str, model: str) -> bool:
"""Check if model name is valid for the given provider.
For ollama, openrouter - any model is accepted.
"""
provider_lower = provider.lower()
if provider_lower in ("ollama", "openrouter"):
return True
if provider_lower not in VALID_MODELS:
return True
return model in VALID_MODELS[provider_lower]

1430
uv.lock

File diff suppressed because it is too large Load Diff