diff --git a/cli/main.py b/cli/main.py index 3f42f2e2..64616ee1 100644 --- a/cli/main.py +++ b/cli/main.py @@ -1,6 +1,8 @@ from typing import Optional import datetime import typer +from pathlib import Path +from functools import wraps from rich.console import Console from rich.panel import Panel from rich.spinner import Spinner @@ -747,6 +749,53 @@ def run_analysis(): [analyst.value for analyst in selections["analysts"]], config=config, debug=True ) + # Create result directory + results_dir = Path(config["results_dir"]) / selections["ticker"] / selections["analysis_date"] + results_dir.mkdir(parents=True, exist_ok=True) + report_dir = results_dir / "reports" + report_dir.mkdir(parents=True, exist_ok=True) + log_file = results_dir / "message_tool.log" + log_file.touch(exist_ok=True) + + def save_message_decorator(obj, func_name): + func = getattr(obj, func_name) + @wraps(func) + def wrapper(*args, **kwargs): + func(*args, **kwargs) + timestamp, message_type, content = obj.messages[-1] + content = content.replace("\n", " ") # Replace newlines with spaces + with open(log_file, "a") as f: + f.write(f"{timestamp} [{message_type}] {content}\n") + return wrapper + + def save_tool_call_decorator(obj, func_name): + func = getattr(obj, func_name) + @wraps(func) + def wrapper(*args, **kwargs): + func(*args, **kwargs) + timestamp, tool_name, args = obj.tool_calls[-1] + args_str = ", ".join(f"{k}={v}" for k, v in args.items()) + with open(log_file, "a") as f: + f.write(f"{timestamp} [Tool Call] {tool_name}({args_str})\n") + return wrapper + + def save_report_section_decorator(obj, func_name): + func = getattr(obj, func_name) + @wraps(func) + def wrapper(section_name, content): + func(section_name, content) + if section_name in obj.report_sections and obj.report_sections[section_name] is not None: + content = obj.report_sections[section_name] + if content: + file_name = f"{section_name}.md" + with open(report_dir / file_name, "w") as f: + f.write(content) + return wrapper + + message_buffer.add_message = save_message_decorator(message_buffer, "add_message") + message_buffer.add_tool_call = save_tool_call_decorator(message_buffer, "add_tool_call") + message_buffer.update_report_section = save_report_section_decorator(message_buffer, "update_report_section") + # Now start the display layout layout = create_layout() diff --git a/cli/utils.py b/cli/utils.py index d3873360..7b9682a6 100644 --- a/cli/utils.py +++ b/cli/utils.py @@ -150,6 +150,7 @@ def select_shallow_thinking_agent(provider) -> str: ("google/gemini-2.0-flash-exp:free - Gemini Flash 2.0 offers a significantly faster time to first token", "google/gemini-2.0-flash-exp:free"), ], "ollama": [ + ("llama3.1 local", "llama3.1"), ("llama3.2 local", "llama3.2"), ] } @@ -211,6 +212,7 @@ def select_deep_thinking_agent(provider) -> str: ("Deepseek - latest iteration of the flagship chat model family from the DeepSeek team.", "deepseek/deepseek-chat-v3-0324:free"), ], "ollama": [ + ("llama3.1 local", "llama3.1"), ("qwen3", "qwen3"), ] } diff --git a/run_trading_agents.sh b/run_trading_agents.sh new file mode 100755 index 00000000..abe223b8 --- /dev/null +++ b/run_trading_agents.sh @@ -0,0 +1,10 @@ +#!/bin/bash +# TradingAgents with Alternative AI Provider +# Usage: ./run_trading_agents.sh + +export ANTHROPIC_API_KEY="your_key_here" # Update this +# export GOOGLE_API_KEY="your_key_here" # Or this for Google + +cd "$(dirname "$0")" +source venv/bin/activate.fish +python -c "from cli.main import app; app()" diff --git a/tradingagents/agents/utils/agent_utils.py b/tradingagents/agents/utils/agent_utils.py index b7313b71..0b07f044 100644 --- a/tradingagents/agents/utils/agent_utils.py +++ b/tradingagents/agents/utils/agent_utils.py @@ -124,7 +124,7 @@ class Toolkit: def get_YFin_data( symbol: Annotated[str, "ticker symbol of the company"], start_date: Annotated[str, "Start date in yyyy-mm-dd format"], - end_date: Annotated[str, "Start date in yyyy-mm-dd format"], + end_date: Annotated[str, "End date in yyyy-mm-dd format"], ) -> str: """ Retrieve the stock price data for a given ticker symbol from Yahoo Finance. @@ -145,7 +145,7 @@ class Toolkit: def get_YFin_data_online( symbol: Annotated[str, "ticker symbol of the company"], start_date: Annotated[str, "Start date in yyyy-mm-dd format"], - end_date: Annotated[str, "Start date in yyyy-mm-dd format"], + end_date: Annotated[str, "End date in yyyy-mm-dd format"], ) -> str: """ Retrieve the stock price data for a given ticker symbol from Yahoo Finance. diff --git a/tradingagents/agents/utils/memory.py b/tradingagents/agents/utils/memory.py index d51ae1b1..cc9ab352 100644 --- a/tradingagents/agents/utils/memory.py +++ b/tradingagents/agents/utils/memory.py @@ -15,7 +15,7 @@ class FinancialSituationMemory: self.client = None else: self.embedding = "text-embedding-3-small" - self.client = OpenAI() + self.client = OpenAI(base_url=config["backend_url"]) self.chroma_client = chromadb.Client(Settings(allow_reset=True)) self.situation_collection = self.chroma_client.create_collection(name=name) diff --git a/tradingagents/dataflows/interface.py b/tradingagents/dataflows/interface.py index a0952945..7fffbb4f 100644 --- a/tradingagents/dataflows/interface.py +++ b/tradingagents/dataflows/interface.py @@ -628,7 +628,7 @@ def get_YFin_data_window( def get_YFin_data_online( symbol: Annotated[str, "ticker symbol of the company"], start_date: Annotated[str, "Start date in yyyy-mm-dd format"], - end_date: Annotated[str, "Start date in yyyy-mm-dd format"], + end_date: Annotated[str, "End date in yyyy-mm-dd format"], ): datetime.strptime(start_date, "%Y-%m-%d") @@ -670,7 +670,7 @@ def get_YFin_data_online( def get_YFin_data( symbol: Annotated[str, "ticker symbol of the company"], start_date: Annotated[str, "Start date in yyyy-mm-dd format"], - end_date: Annotated[str, "Start date in yyyy-mm-dd format"], + end_date: Annotated[str, "End date in yyyy-mm-dd format"], ) -> str: # read in data data = pd.read_csv( @@ -704,10 +704,10 @@ def get_YFin_data( def get_stock_news_openai(ticker, curr_date): config = get_config() - client = OpenAI() + client = OpenAI(base_url=config["backend_url"]) response = client.responses.create( - model="gpt-4.1-mini", + model=config["quick_think_llm"], input=[ { "role": "system", @@ -739,10 +739,10 @@ def get_stock_news_openai(ticker, curr_date): def get_global_news_openai(curr_date): config = get_config() - client = OpenAI() + client = OpenAI(base_url=config["backend_url"]) response = client.responses.create( - model="gpt-4.1-mini", + model=config["quick_think_llm"], input=[ { "role": "system", @@ -774,10 +774,10 @@ def get_global_news_openai(curr_date): def get_fundamentals_openai(ticker, curr_date): config = get_config() - client = OpenAI() + client = OpenAI(base_url=config["backend_url"]) response = client.responses.create( - model="gpt-4.1-mini", + model=config["quick_think_llm"], input=[ { "role": "system", diff --git a/tradingagents/default_config.py b/tradingagents/default_config.py index 2cf15b85..089e9c24 100644 --- a/tradingagents/default_config.py +++ b/tradingagents/default_config.py @@ -2,6 +2,7 @@ import os DEFAULT_CONFIG = { "project_dir": os.path.abspath(os.path.join(os.path.dirname(__file__), ".")), + "results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", "./results"), "data_dir": "/Users/yluo/Documents/Code/ScAI/FR1-data", "data_cache_dir": os.path.join( os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),