Merge branch 'main' of origin into main

Resolved conflicts in:
- cli/main.py
- cli/utils.py
- tradingagents/agents/utils/agent_utils.py
- tradingagents/agents/utils/memory.py
- tradingagents/dataflows/interface.py
- tradingagents/default_config.py

Added run_trading_agents.sh script
This commit is contained in:
Ming Jia 2025-06-25 21:54:09 -07:00
commit 5152d475d9
7 changed files with 73 additions and 11 deletions

View File

@ -1,6 +1,8 @@
from typing import Optional from typing import Optional
import datetime import datetime
import typer import typer
from pathlib import Path
from functools import wraps
from rich.console import Console from rich.console import Console
from rich.panel import Panel from rich.panel import Panel
from rich.spinner import Spinner from rich.spinner import Spinner
@ -747,6 +749,53 @@ def run_analysis():
[analyst.value for analyst in selections["analysts"]], config=config, debug=True [analyst.value for analyst in selections["analysts"]], config=config, debug=True
) )
# Create result directory
results_dir = Path(config["results_dir"]) / selections["ticker"] / selections["analysis_date"]
results_dir.mkdir(parents=True, exist_ok=True)
report_dir = results_dir / "reports"
report_dir.mkdir(parents=True, exist_ok=True)
log_file = results_dir / "message_tool.log"
log_file.touch(exist_ok=True)
def save_message_decorator(obj, func_name):
func = getattr(obj, func_name)
@wraps(func)
def wrapper(*args, **kwargs):
func(*args, **kwargs)
timestamp, message_type, content = obj.messages[-1]
content = content.replace("\n", " ") # Replace newlines with spaces
with open(log_file, "a") as f:
f.write(f"{timestamp} [{message_type}] {content}\n")
return wrapper
def save_tool_call_decorator(obj, func_name):
func = getattr(obj, func_name)
@wraps(func)
def wrapper(*args, **kwargs):
func(*args, **kwargs)
timestamp, tool_name, args = obj.tool_calls[-1]
args_str = ", ".join(f"{k}={v}" for k, v in args.items())
with open(log_file, "a") as f:
f.write(f"{timestamp} [Tool Call] {tool_name}({args_str})\n")
return wrapper
def save_report_section_decorator(obj, func_name):
func = getattr(obj, func_name)
@wraps(func)
def wrapper(section_name, content):
func(section_name, content)
if section_name in obj.report_sections and obj.report_sections[section_name] is not None:
content = obj.report_sections[section_name]
if content:
file_name = f"{section_name}.md"
with open(report_dir / file_name, "w") as f:
f.write(content)
return wrapper
message_buffer.add_message = save_message_decorator(message_buffer, "add_message")
message_buffer.add_tool_call = save_tool_call_decorator(message_buffer, "add_tool_call")
message_buffer.update_report_section = save_report_section_decorator(message_buffer, "update_report_section")
# Now start the display layout # Now start the display layout
layout = create_layout() layout = create_layout()

View File

@ -150,6 +150,7 @@ def select_shallow_thinking_agent(provider) -> str:
("google/gemini-2.0-flash-exp:free - Gemini Flash 2.0 offers a significantly faster time to first token", "google/gemini-2.0-flash-exp:free"), ("google/gemini-2.0-flash-exp:free - Gemini Flash 2.0 offers a significantly faster time to first token", "google/gemini-2.0-flash-exp:free"),
], ],
"ollama": [ "ollama": [
("llama3.1 local", "llama3.1"),
("llama3.2 local", "llama3.2"), ("llama3.2 local", "llama3.2"),
] ]
} }
@ -211,6 +212,7 @@ def select_deep_thinking_agent(provider) -> str:
("Deepseek - latest iteration of the flagship chat model family from the DeepSeek team.", "deepseek/deepseek-chat-v3-0324:free"), ("Deepseek - latest iteration of the flagship chat model family from the DeepSeek team.", "deepseek/deepseek-chat-v3-0324:free"),
], ],
"ollama": [ "ollama": [
("llama3.1 local", "llama3.1"),
("qwen3", "qwen3"), ("qwen3", "qwen3"),
] ]
} }

10
run_trading_agents.sh Executable file
View File

@ -0,0 +1,10 @@
#!/bin/bash
# TradingAgents with Alternative AI Provider
# Usage: ./run_trading_agents.sh
export ANTHROPIC_API_KEY="your_key_here" # Update this
# export GOOGLE_API_KEY="your_key_here" # Or this for Google
cd "$(dirname "$0")"
source venv/bin/activate.fish
python -c "from cli.main import app; app()"

View File

@ -124,7 +124,7 @@ class Toolkit:
def get_YFin_data( def get_YFin_data(
symbol: Annotated[str, "ticker symbol of the company"], symbol: Annotated[str, "ticker symbol of the company"],
start_date: Annotated[str, "Start date in yyyy-mm-dd format"], start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
end_date: Annotated[str, "Start date in yyyy-mm-dd format"], end_date: Annotated[str, "End date in yyyy-mm-dd format"],
) -> str: ) -> str:
""" """
Retrieve the stock price data for a given ticker symbol from Yahoo Finance. Retrieve the stock price data for a given ticker symbol from Yahoo Finance.
@ -145,7 +145,7 @@ class Toolkit:
def get_YFin_data_online( def get_YFin_data_online(
symbol: Annotated[str, "ticker symbol of the company"], symbol: Annotated[str, "ticker symbol of the company"],
start_date: Annotated[str, "Start date in yyyy-mm-dd format"], start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
end_date: Annotated[str, "Start date in yyyy-mm-dd format"], end_date: Annotated[str, "End date in yyyy-mm-dd format"],
) -> str: ) -> str:
""" """
Retrieve the stock price data for a given ticker symbol from Yahoo Finance. Retrieve the stock price data for a given ticker symbol from Yahoo Finance.

View File

@ -15,7 +15,7 @@ class FinancialSituationMemory:
self.client = None self.client = None
else: else:
self.embedding = "text-embedding-3-small" self.embedding = "text-embedding-3-small"
self.client = OpenAI() self.client = OpenAI(base_url=config["backend_url"])
self.chroma_client = chromadb.Client(Settings(allow_reset=True)) self.chroma_client = chromadb.Client(Settings(allow_reset=True))
self.situation_collection = self.chroma_client.create_collection(name=name) self.situation_collection = self.chroma_client.create_collection(name=name)

View File

@ -628,7 +628,7 @@ def get_YFin_data_window(
def get_YFin_data_online( def get_YFin_data_online(
symbol: Annotated[str, "ticker symbol of the company"], symbol: Annotated[str, "ticker symbol of the company"],
start_date: Annotated[str, "Start date in yyyy-mm-dd format"], start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
end_date: Annotated[str, "Start date in yyyy-mm-dd format"], end_date: Annotated[str, "End date in yyyy-mm-dd format"],
): ):
datetime.strptime(start_date, "%Y-%m-%d") datetime.strptime(start_date, "%Y-%m-%d")
@ -670,7 +670,7 @@ def get_YFin_data_online(
def get_YFin_data( def get_YFin_data(
symbol: Annotated[str, "ticker symbol of the company"], symbol: Annotated[str, "ticker symbol of the company"],
start_date: Annotated[str, "Start date in yyyy-mm-dd format"], start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
end_date: Annotated[str, "Start date in yyyy-mm-dd format"], end_date: Annotated[str, "End date in yyyy-mm-dd format"],
) -> str: ) -> str:
# read in data # read in data
data = pd.read_csv( data = pd.read_csv(
@ -704,10 +704,10 @@ def get_YFin_data(
def get_stock_news_openai(ticker, curr_date): def get_stock_news_openai(ticker, curr_date):
config = get_config() config = get_config()
client = OpenAI() client = OpenAI(base_url=config["backend_url"])
response = client.responses.create( response = client.responses.create(
model="gpt-4.1-mini", model=config["quick_think_llm"],
input=[ input=[
{ {
"role": "system", "role": "system",
@ -739,10 +739,10 @@ def get_stock_news_openai(ticker, curr_date):
def get_global_news_openai(curr_date): def get_global_news_openai(curr_date):
config = get_config() config = get_config()
client = OpenAI() client = OpenAI(base_url=config["backend_url"])
response = client.responses.create( response = client.responses.create(
model="gpt-4.1-mini", model=config["quick_think_llm"],
input=[ input=[
{ {
"role": "system", "role": "system",
@ -774,10 +774,10 @@ def get_global_news_openai(curr_date):
def get_fundamentals_openai(ticker, curr_date): def get_fundamentals_openai(ticker, curr_date):
config = get_config() config = get_config()
client = OpenAI() client = OpenAI(base_url=config["backend_url"])
response = client.responses.create( response = client.responses.create(
model="gpt-4.1-mini", model=config["quick_think_llm"],
input=[ input=[
{ {
"role": "system", "role": "system",

View File

@ -2,6 +2,7 @@ import os
DEFAULT_CONFIG = { DEFAULT_CONFIG = {
"project_dir": os.path.abspath(os.path.join(os.path.dirname(__file__), ".")), "project_dir": os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
"results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", "./results"),
"data_dir": "/Users/yluo/Documents/Code/ScAI/FR1-data", "data_dir": "/Users/yluo/Documents/Code/ScAI/FR1-data",
"data_cache_dir": os.path.join( "data_cache_dir": os.path.join(
os.path.abspath(os.path.join(os.path.dirname(__file__), ".")), os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),