diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 00000000..5504caef --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,37 @@ +name: Python CI + +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "main" ] + +jobs: + build-and-test: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - name: Install uv + uses: astral-sh/setup-uv@v2 + with: + enable-cache: true + cache-dependency-glob: "uv.lock" + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version-file: ".python-version" + + - name: Install the project + run: uv sync --all-extras --dev + + - name: Format with Black + run: uv run black --check . + + - name: Lint with Ruff + run: uv run ruff check . + + # - name: Run tests (Uncomment when tests exist) + # run: uv run pytest diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..2b8d7c36 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,20 @@ +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.5.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-yaml + - id: check-added-large-files + + - repo: https://github.com/psf/black + rev: 24.2.0 + hooks: + - id: black + language_version: python3.13 + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.3.0 + hooks: + - id: ruff + args: [ --fix ] diff --git a/cli/main.py b/cli/main.py index fb97d189..ec5e9e73 100644 --- a/cli/main.py +++ b/cli/main.py @@ -1,34 +1,39 @@ -from typing import Optional import datetime -import typer +import time from pathlib import Path from functools import wraps -from rich.console import Console +from collections import deque + +import typer from dotenv import load_dotenv +from rich import box +from rich.align import Align +from rich.console import Console +from rich.layout import Layout +from rich.live import Live +from rich.markdown import Markdown +from rich.panel import Panel +from rich.rule import Rule +from rich.spinner import Spinner +from rich.table import Table +from rich.text import Text + +from tradingagents.default_config import DEFAULT_CONFIG +from tradingagents.graph.trading_graph import TradingAgentsGraph +from cli.announcements import display_announcements, fetch_announcements +from cli.stats_handler import StatsCallbackHandler +from cli.utils import ( + ask_gemini_thinking_config, + ask_openai_reasoning_effort, + select_analysts, + select_deep_thinking_agent, + select_llm_provider, + select_research_depth, + select_shallow_thinking_agent, +) # Load environment variables from .env file load_dotenv() -from rich.panel import Panel -from rich.spinner import Spinner -from rich.live import Live -from rich.columns import Columns -from rich.markdown import Markdown -from rich.layout import Layout -from rich.text import Text -from rich.table import Table -from collections import deque -import time -from rich.tree import Tree -from rich import box -from rich.align import Align -from rich.rule import Rule - -from tradingagents.graph.trading_graph import TradingAgentsGraph -from tradingagents.default_config import DEFAULT_CONFIG -from cli.models import AnalystType -from cli.utils import * -from cli.announcements import fetch_announcements, display_announcements -from cli.stats_handler import StatsCallbackHandler console = Console() @@ -45,7 +50,11 @@ class MessageBuffer: FIXED_AGENTS = { "Research Team": ["Bull Researcher", "Bear Researcher", "Research Manager"], "Trading Team": ["Trader"], - "Risk Management": ["Aggressive Analyst", "Neutral Analyst", "Conservative Analyst"], + "Risk Management": [ + "Aggressive Analyst", + "Neutral Analyst", + "Conservative Analyst", + ], "Portfolio Management": ["Portfolio Manager"], } @@ -165,7 +174,7 @@ class MessageBuffer: if content is not None: latest_section = section latest_content = content - + if latest_section and latest_content: # Format the current section for display section_titles = { @@ -188,7 +197,12 @@ class MessageBuffer: report_parts = [] # Analyst Team Reports - use .get() to handle missing sections - analyst_sections = ["market_report", "sentiment_report", "news_report", "fundamentals_report"] + analyst_sections = [ + "market_report", + "sentiment_report", + "news_report", + "fundamentals_report", + ] if any(self.report_sections.get(section) for section in analyst_sections): report_parts.append("## Analyst Team Reports") if self.report_sections.get("market_report"): @@ -289,7 +303,11 @@ def update_display(layout, spinner_text=None, stats_handler=None, start_time=Non ], "Research Team": ["Bull Researcher", "Bear Researcher", "Research Manager"], "Trading Team": ["Trader"], - "Risk Management": ["Aggressive Analyst", "Neutral Analyst", "Conservative Analyst"], + "Risk Management": [ + "Aggressive Analyst", + "Neutral Analyst", + "Conservative Analyst", + ], "Portfolio Management": ["Portfolio Manager"], } @@ -538,12 +556,10 @@ def get_user_selections(): # Step 5: OpenAI backend console.print( - create_question_box( - "Step 5: OpenAI backend", "Select which service to talk to" - ) + create_question_box("Step 5: OpenAI backend", "Select which service to talk to") ) selected_llm_provider, backend_url = select_llm_provider() - + # Step 6: Thinking agents console.print( create_question_box( @@ -561,16 +577,14 @@ def get_user_selections(): if provider_lower == "google": console.print( create_question_box( - "Step 7: Thinking Mode", - "Configure Gemini thinking mode" + "Step 7: Thinking Mode", "Configure Gemini thinking mode" ) ) thinking_level = ask_gemini_thinking_config() elif provider_lower == "openai": console.print( create_question_box( - "Step 7: Reasoning Effort", - "Configure OpenAI reasoning effort level" + "Step 7: Reasoning Effort", "Configure OpenAI reasoning effort level" ) ) reasoning_effort = ask_openai_reasoning_effort() @@ -635,8 +649,12 @@ def save_report_to_disk(final_state, ticker: str, save_path: Path): analyst_parts.append(("News Analyst", final_state["news_report"])) if final_state.get("fundamentals_report"): analysts_dir.mkdir(exist_ok=True) - (analysts_dir / "fundamentals.md").write_text(final_state["fundamentals_report"]) - analyst_parts.append(("Fundamentals Analyst", final_state["fundamentals_report"])) + (analysts_dir / "fundamentals.md").write_text( + final_state["fundamentals_report"] + ) + analyst_parts.append( + ("Fundamentals Analyst", final_state["fundamentals_report"]) + ) if analyst_parts: content = "\n\n".join(f"### {name}\n{text}" for name, text in analyst_parts) sections.append(f"## I. Analyst Team Reports\n\n{content}") @@ -659,7 +677,9 @@ def save_report_to_disk(final_state, ticker: str, save_path: Path): (research_dir / "manager.md").write_text(debate["judge_decision"]) research_parts.append(("Research Manager", debate["judge_decision"])) if research_parts: - content = "\n\n".join(f"### {name}\n{text}" for name, text in research_parts) + content = "\n\n".join( + f"### {name}\n{text}" for name, text in research_parts + ) sections.append(f"## II. Research Team Decision\n\n{content}") # 3. Trading @@ -667,7 +687,9 @@ def save_report_to_disk(final_state, ticker: str, save_path: Path): trading_dir = save_path / "3_trading" trading_dir.mkdir(exist_ok=True) (trading_dir / "trader.md").write_text(final_state["trader_investment_plan"]) - sections.append(f"## III. Trading Team Plan\n\n### Trader\n{final_state['trader_investment_plan']}") + sections.append( + f"## III. Trading Team Plan\n\n### Trader\n{final_state['trader_investment_plan']}" + ) # 4. Risk Management if final_state.get("risk_debate_state"): @@ -695,7 +717,9 @@ def save_report_to_disk(final_state, ticker: str, save_path: Path): portfolio_dir = save_path / "5_portfolio" portfolio_dir.mkdir(exist_ok=True) (portfolio_dir / "decision.md").write_text(risk["judge_decision"]) - sections.append(f"## V. Portfolio Manager Decision\n\n### Portfolio Manager\n{risk['judge_decision']}") + sections.append( + f"## V. Portfolio Manager Decision\n\n### Portfolio Manager\n{risk['judge_decision']}" + ) # Write consolidated report header = f"# Trading Analysis Report: {ticker}\n\nGenerated: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n" @@ -719,9 +743,15 @@ def display_complete_report(final_state): if final_state.get("fundamentals_report"): analysts.append(("Fundamentals Analyst", final_state["fundamentals_report"])) if analysts: - console.print(Panel("[bold]I. Analyst Team Reports[/bold]", border_style="cyan")) + console.print( + Panel("[bold]I. Analyst Team Reports[/bold]", border_style="cyan") + ) for title, content in analysts: - console.print(Panel(Markdown(content), title=title, border_style="blue", padding=(1, 2))) + console.print( + Panel( + Markdown(content), title=title, border_style="blue", padding=(1, 2) + ) + ) # II. Research Team Reports if final_state.get("investment_debate_state"): @@ -734,14 +764,32 @@ def display_complete_report(final_state): if debate.get("judge_decision"): research.append(("Research Manager", debate["judge_decision"])) if research: - console.print(Panel("[bold]II. Research Team Decision[/bold]", border_style="magenta")) + console.print( + Panel("[bold]II. Research Team Decision[/bold]", border_style="magenta") + ) for title, content in research: - console.print(Panel(Markdown(content), title=title, border_style="blue", padding=(1, 2))) + console.print( + Panel( + Markdown(content), + title=title, + border_style="blue", + padding=(1, 2), + ) + ) # III. Trading Team if final_state.get("trader_investment_plan"): - console.print(Panel("[bold]III. Trading Team Plan[/bold]", border_style="yellow")) - console.print(Panel(Markdown(final_state["trader_investment_plan"]), title="Trader", border_style="blue", padding=(1, 2))) + console.print( + Panel("[bold]III. Trading Team Plan[/bold]", border_style="yellow") + ) + console.print( + Panel( + Markdown(final_state["trader_investment_plan"]), + title="Trader", + border_style="blue", + padding=(1, 2), + ) + ) # IV. Risk Management Team if final_state.get("risk_debate_state"): @@ -754,14 +802,36 @@ def display_complete_report(final_state): if risk.get("neutral_history"): risk_reports.append(("Neutral Analyst", risk["neutral_history"])) if risk_reports: - console.print(Panel("[bold]IV. Risk Management Team Decision[/bold]", border_style="red")) + console.print( + Panel( + "[bold]IV. Risk Management Team Decision[/bold]", border_style="red" + ) + ) for title, content in risk_reports: - console.print(Panel(Markdown(content), title=title, border_style="blue", padding=(1, 2))) + console.print( + Panel( + Markdown(content), + title=title, + border_style="blue", + padding=(1, 2), + ) + ) # V. Portfolio Manager Decision if risk.get("judge_decision"): - console.print(Panel("[bold]V. Portfolio Manager Decision[/bold]", border_style="green")) - console.print(Panel(Markdown(risk["judge_decision"]), title="Portfolio Manager", border_style="blue", padding=(1, 2))) + console.print( + Panel( + "[bold]V. Portfolio Manager Decision[/bold]", border_style="green" + ) + ) + console.print( + Panel( + Markdown(risk["judge_decision"]), + title="Portfolio Manager", + border_style="blue", + padding=(1, 2), + ) + ) def update_research_team_status(status): @@ -821,6 +891,7 @@ def update_analyst_statuses(message_buffer, chunk): if message_buffer.agent_status.get("Bull Researcher") == "pending": message_buffer.update_agent_status("Bull Researcher", "in_progress") + def extract_content_string(content): """Extract string content from various message formats. Returns None if no meaningful text content is found. @@ -829,7 +900,7 @@ def extract_content_string(content): def is_empty(val): """Check if value is empty using Python's truthiness.""" - if val is None or val == '': + if val is None or val == "": return True if isinstance(val, str): s = val.strip() @@ -848,16 +919,19 @@ def extract_content_string(content): return content.strip() if isinstance(content, dict): - text = content.get('text', '') + text = content.get("text", "") return text.strip() if not is_empty(text) else None if isinstance(content, list): text_parts = [ - item.get('text', '').strip() if isinstance(item, dict) and item.get('type') == 'text' - else (item.strip() if isinstance(item, str) else '') + ( + item.get("text", "").strip() + if isinstance(item, dict) and item.get("type") == "text" + else (item.strip() if isinstance(item, str) else "") + ) for item in content ] - result = ' '.join(t for t in text_parts if t and not is_empty(t)) + result = " ".join(t for t in text_parts if t and not is_empty(t)) return result if result else None return str(content).strip() if not is_empty(content) else None @@ -872,7 +946,7 @@ def classify_message_type(message) -> tuple[str, str | None]: """ from langchain_core.messages import AIMessage, HumanMessage, ToolMessage - content = extract_content_string(getattr(message, 'content', None)) + content = extract_content_string(getattr(message, "content", None)) if isinstance(message, HumanMessage): if content and content.strip() == "Continue": @@ -893,9 +967,10 @@ def format_tool_args(args, max_length=80) -> str: """Format tool arguments for terminal display.""" result = str(args) if len(result) > max_length: - return result[:max_length - 3] + "..." + return result[: max_length - 3] + "..." return result + def run_analysis(): # First get all user selections selections = get_user_selections() @@ -934,7 +1009,9 @@ def run_analysis(): start_time = time.time() # Create result directory - results_dir = Path(config["results_dir"]) / selections["ticker"] / selections["analysis_date"] + results_dir = ( + Path(config["results_dir"]) / selections["ticker"] / selections["analysis_date"] + ) results_dir.mkdir(parents=True, exist_ok=True) report_dir = results_dir / "reports" report_dir.mkdir(parents=True, exist_ok=True) @@ -943,6 +1020,7 @@ def run_analysis(): def save_message_decorator(obj, func_name): func = getattr(obj, func_name) + @wraps(func) def wrapper(*args, **kwargs): func(*args, **kwargs) @@ -950,10 +1028,12 @@ def run_analysis(): content = content.replace("\n", " ") # Replace newlines with spaces with open(log_file, "a") as f: f.write(f"{timestamp} [{message_type}] {content}\n") + return wrapper - + def save_tool_call_decorator(obj, func_name): func = getattr(obj, func_name) + @wraps(func) def wrapper(*args, **kwargs): func(*args, **kwargs) @@ -961,29 +1041,39 @@ def run_analysis(): args_str = ", ".join(f"{k}={v}" for k, v in args.items()) with open(log_file, "a") as f: f.write(f"{timestamp} [Tool Call] {tool_name}({args_str})\n") + return wrapper def save_report_section_decorator(obj, func_name): func = getattr(obj, func_name) + @wraps(func) def wrapper(section_name, content): func(section_name, content) - if section_name in obj.report_sections and obj.report_sections[section_name] is not None: + if ( + section_name in obj.report_sections + and obj.report_sections[section_name] is not None + ): content = obj.report_sections[section_name] if content: file_name = f"{section_name}.md" with open(report_dir / file_name, "w") as f: f.write(content) + return wrapper message_buffer.add_message = save_message_decorator(message_buffer, "add_message") - message_buffer.add_tool_call = save_tool_call_decorator(message_buffer, "add_tool_call") - message_buffer.update_report_section = save_report_section_decorator(message_buffer, "update_report_section") + message_buffer.add_tool_call = save_tool_call_decorator( + message_buffer, "add_tool_call" + ) + message_buffer.update_report_section = save_report_section_decorator( + message_buffer, "update_report_section" + ) # Now start the display layout layout = create_layout() - with Live(layout, refresh_per_second=4) as live: + with Live(layout, refresh_per_second=4): # Initial display update_display(layout, stats_handler=stats_handler, start_time=start_time) @@ -1007,7 +1097,9 @@ def run_analysis(): spinner_text = ( f"Analyzing {selections['ticker']} on {selections['analysis_date']}..." ) - update_display(layout, spinner_text, stats_handler=stats_handler, start_time=start_time) + update_display( + layout, spinner_text, stats_handler=stats_handler, start_time=start_time + ) # Initialize state and get graph args with callbacks init_agent_state = graph.propagator.create_initial_state( @@ -1041,7 +1133,9 @@ def run_analysis(): tool_call["name"], tool_call["args"] ) else: - message_buffer.add_tool_call(tool_call.name, tool_call.args) + message_buffer.add_tool_call( + tool_call.name, tool_call.args + ) # Update analyst statuses based on report state (runs on every chunk) update_analyst_statuses(message_buffer, chunk) @@ -1078,7 +1172,9 @@ def run_analysis(): ) if message_buffer.agent_status.get("Trader") != "completed": message_buffer.update_agent_status("Trader", "completed") - message_buffer.update_agent_status("Aggressive Analyst", "in_progress") + message_buffer.update_agent_status( + "Aggressive Analyst", "in_progress" + ) # Risk Management Team - Handle Risk Debate State if chunk.get("risk_debate_state"): @@ -1089,33 +1185,65 @@ def run_analysis(): judge = risk_state.get("judge_decision", "").strip() if agg_hist: - if message_buffer.agent_status.get("Aggressive Analyst") != "completed": - message_buffer.update_agent_status("Aggressive Analyst", "in_progress") + if ( + message_buffer.agent_status.get("Aggressive Analyst") + != "completed" + ): + message_buffer.update_agent_status( + "Aggressive Analyst", "in_progress" + ) message_buffer.update_report_section( - "final_trade_decision", f"### Aggressive Analyst Analysis\n{agg_hist}" + "final_trade_decision", + f"### Aggressive Analyst Analysis\n{agg_hist}", ) if con_hist: - if message_buffer.agent_status.get("Conservative Analyst") != "completed": - message_buffer.update_agent_status("Conservative Analyst", "in_progress") + if ( + message_buffer.agent_status.get("Conservative Analyst") + != "completed" + ): + message_buffer.update_agent_status( + "Conservative Analyst", "in_progress" + ) message_buffer.update_report_section( - "final_trade_decision", f"### Conservative Analyst Analysis\n{con_hist}" + "final_trade_decision", + f"### Conservative Analyst Analysis\n{con_hist}", ) if neu_hist: - if message_buffer.agent_status.get("Neutral Analyst") != "completed": - message_buffer.update_agent_status("Neutral Analyst", "in_progress") + if ( + message_buffer.agent_status.get("Neutral Analyst") + != "completed" + ): + message_buffer.update_agent_status( + "Neutral Analyst", "in_progress" + ) message_buffer.update_report_section( - "final_trade_decision", f"### Neutral Analyst Analysis\n{neu_hist}" + "final_trade_decision", + f"### Neutral Analyst Analysis\n{neu_hist}", ) if judge: - if message_buffer.agent_status.get("Portfolio Manager") != "completed": - message_buffer.update_agent_status("Portfolio Manager", "in_progress") - message_buffer.update_report_section( - "final_trade_decision", f"### Portfolio Manager Decision\n{judge}" + if ( + message_buffer.agent_status.get("Portfolio Manager") + != "completed" + ): + message_buffer.update_agent_status( + "Portfolio Manager", "in_progress" + ) + message_buffer.update_report_section( + "final_trade_decision", + f"### Portfolio Manager Decision\n{judge}", + ) + message_buffer.update_agent_status( + "Aggressive Analyst", "completed" + ) + message_buffer.update_agent_status( + "Conservative Analyst", "completed" + ) + message_buffer.update_agent_status( + "Neutral Analyst", "completed" + ) + message_buffer.update_agent_status( + "Portfolio Manager", "completed" ) - message_buffer.update_agent_status("Aggressive Analyst", "completed") - message_buffer.update_agent_status("Conservative Analyst", "completed") - message_buffer.update_agent_status("Neutral Analyst", "completed") - message_buffer.update_agent_status("Portfolio Manager", "completed") # Update the display update_display(layout, stats_handler=stats_handler, start_time=start_time) @@ -1124,7 +1252,7 @@ def run_analysis(): # Get final state and decision final_state = trace[-1] - decision = graph.process_signal(final_state["final_trade_decision"]) + graph.process_signal(final_state["final_trade_decision"]) # Update all agent statuses to completed for agent in message_buffer.agent_status: @@ -1150,19 +1278,22 @@ def run_analysis(): timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") default_path = Path.cwd() / "reports" / f"{selections['ticker']}_{timestamp}" save_path_str = typer.prompt( - "Save path (press Enter for default)", - default=str(default_path) + "Save path (press Enter for default)", default=str(default_path) ).strip() save_path = Path(save_path_str) try: - report_file = save_report_to_disk(final_state, selections["ticker"], save_path) + report_file = save_report_to_disk( + final_state, selections["ticker"], save_path + ) console.print(f"\n[green]✓ Report saved to:[/green] {save_path.resolve()}") console.print(f" [dim]Complete report:[/dim] {report_file.name}") except Exception as e: console.print(f"[red]Error saving report: {e}[/red]") # Prompt to display full report - display_choice = typer.prompt("\nDisplay full report on screen?", default="Y").strip().upper() + display_choice = ( + typer.prompt("\nDisplay full report on screen?", default="Y").strip().upper() + ) if display_choice in ("Y", "YES", ""): display_complete_report(final_state) diff --git a/cli/utils.py b/cli/utils.py index aa097fb5..5758b27d 100644 --- a/cli/utils.py +++ b/cli/utils.py @@ -1,8 +1,10 @@ import questionary -from typing import List, Optional, Tuple, Dict - +from typing import List +from rich.console import Console from cli.models import AnalystType +console = Console() + ANALYST_ORDER = [ ("Market Analyst", AnalystType.MARKET), ("Social Media Analyst", AnalystType.SOCIAL), @@ -146,13 +148,25 @@ def select_shallow_thinking_agent(provider) -> str: ("Gemini 2.5 Flash Lite - Fast, low-cost", "gemini-2.5-flash-lite"), ], "xai": [ - ("Grok 4.1 Fast (Non-Reasoning) - Speed optimized, 2M ctx", "grok-4-1-fast-non-reasoning"), - ("Grok 4 Fast (Non-Reasoning) - Speed optimized", "grok-4-fast-non-reasoning"), - ("Grok 4.1 Fast (Reasoning) - High-performance, 2M ctx", "grok-4-1-fast-reasoning"), + ( + "Grok 4.1 Fast (Non-Reasoning) - Speed optimized, 2M ctx", + "grok-4-1-fast-non-reasoning", + ), + ( + "Grok 4 Fast (Non-Reasoning) - Speed optimized", + "grok-4-fast-non-reasoning", + ), + ( + "Grok 4.1 Fast (Reasoning) - High-performance, 2M ctx", + "grok-4-1-fast-reasoning", + ), ("Grok 4 Fast (Reasoning) - High-performance", "grok-4-fast-reasoning"), ], "openrouter": [ - ("NVIDIA Nemotron 3 Nano 30B (free)", "nvidia/nemotron-3-nano-30b-a3b:free"), + ( + "NVIDIA Nemotron 3 Nano 30B (free)", + "nvidia/nemotron-3-nano-30b-a3b:free", + ), ("Z.AI GLM 4.5 Air (free)", "z-ai/glm-4.5-air:free"), ], "ollama": [ @@ -213,15 +227,27 @@ def select_deep_thinking_agent(provider) -> str: ("Gemini 2.5 Flash - Balanced, recommended", "gemini-2.5-flash"), ], "xai": [ - ("Grok 4.1 Fast (Reasoning) - High-performance, 2M ctx", "grok-4-1-fast-reasoning"), + ( + "Grok 4.1 Fast (Reasoning) - High-performance, 2M ctx", + "grok-4-1-fast-reasoning", + ), ("Grok 4 Fast (Reasoning) - High-performance", "grok-4-fast-reasoning"), ("Grok 4 - Flagship model", "grok-4-0709"), - ("Grok 4.1 Fast (Non-Reasoning) - Speed optimized, 2M ctx", "grok-4-1-fast-non-reasoning"), - ("Grok 4 Fast (Non-Reasoning) - Speed optimized", "grok-4-fast-non-reasoning"), + ( + "Grok 4.1 Fast (Non-Reasoning) - Speed optimized, 2M ctx", + "grok-4-1-fast-non-reasoning", + ), + ( + "Grok 4 Fast (Non-Reasoning) - Speed optimized", + "grok-4-fast-non-reasoning", + ), ], "openrouter": [ ("Z.AI GLM 4.5 Air (free)", "z-ai/glm-4.5-air:free"), - ("NVIDIA Nemotron 3 Nano 30B (free)", "nvidia/nemotron-3-nano-30b-a3b:free"), + ( + "NVIDIA Nemotron 3 Nano 30B (free)", + "nvidia/nemotron-3-nano-30b-a3b:free", + ), ], "ollama": [ ("GLM-4.7-Flash:latest (30B, local)", "glm-4.7-flash:latest"), @@ -252,6 +278,7 @@ def select_deep_thinking_agent(provider) -> str: return choice + def select_llm_provider() -> tuple[str, str]: """Select the OpenAI api url using interactive selection.""" # Define OpenAI api options with their corresponding endpoints @@ -263,7 +290,7 @@ def select_llm_provider() -> tuple[str, str]: ("Openrouter", "https://openrouter.ai/api/v1"), ("Ollama", "http://localhost:11434/v1"), ] - + choice = questionary.select( "Select your LLM Provider:", choices=[ @@ -279,11 +306,11 @@ def select_llm_provider() -> tuple[str, str]: ] ), ).ask() - + if choice is None: console.print("\n[red]no OpenAI backend selected. Exiting...[/red]") exit(1) - + display_name, url = choice print(f"You selected: {display_name}\tURL: {url}") @@ -300,11 +327,13 @@ def ask_openai_reasoning_effort() -> str: return questionary.select( "Select Reasoning Effort:", choices=choices, - style=questionary.Style([ - ("selected", "fg:cyan noinherit"), - ("highlighted", "fg:cyan noinherit"), - ("pointer", "fg:cyan noinherit"), - ]), + style=questionary.Style( + [ + ("selected", "fg:cyan noinherit"), + ("highlighted", "fg:cyan noinherit"), + ("pointer", "fg:cyan noinherit"), + ] + ), ).ask() @@ -320,9 +349,11 @@ def ask_gemini_thinking_config() -> str | None: questionary.Choice("Enable Thinking (recommended)", "high"), questionary.Choice("Minimal/Disable Thinking", "minimal"), ], - style=questionary.Style([ - ("selected", "fg:green noinherit"), - ("highlighted", "fg:green noinherit"), - ("pointer", "fg:green noinherit"), - ]), + style=questionary.Style( + [ + ("selected", "fg:green noinherit"), + ("highlighted", "fg:green noinherit"), + ("pointer", "fg:green noinherit"), + ] + ), ).ask() diff --git a/tradingagents/agents/utils/agent_states.py b/tradingagents/agents/utils/agent_states.py index 813b00ee..6423b936 100644 --- a/tradingagents/agents/utils/agent_states.py +++ b/tradingagents/agents/utils/agent_states.py @@ -1,10 +1,6 @@ -from typing import Annotated, Sequence -from datetime import date, timedelta, datetime -from typing_extensions import TypedDict, Optional -from langchain_openai import ChatOpenAI -from tradingagents.agents import * -from langgraph.prebuilt import ToolNode -from langgraph.graph import END, StateGraph, START, MessagesState +from typing import Annotated +from typing_extensions import TypedDict +from langgraph.graph import MessagesState # Researcher team state diff --git a/tradingagents/dataflows/y_finance.py b/tradingagents/dataflows/y_finance.py index bc78d8b3..f606fbba 100644 --- a/tradingagents/dataflows/y_finance.py +++ b/tradingagents/dataflows/y_finance.py @@ -2,9 +2,9 @@ from typing import Annotated from datetime import datetime from dateutil.relativedelta import relativedelta import yfinance as yf -import os from .stockstats_utils import StockstatsUtils + def get_YFin_data_online( symbol: Annotated[str, "ticker symbol of the company"], start_date: Annotated[str, "Start date in yyyy-mm-dd format"], @@ -46,6 +46,7 @@ def get_YFin_data_online( return header + csv_string + def get_stock_stats_indicators_window( symbol: Annotated[str, "ticker symbol of the company"], indicator: Annotated[str, "technical indicator to get the analysis and report of"], @@ -140,28 +141,28 @@ def get_stock_stats_indicators_window( # Optimized: Get stock data once and calculate indicators for all dates try: indicator_data = _get_stock_stats_bulk(symbol, indicator, curr_date) - + # Generate the date range we need current_dt = curr_date_dt date_values = [] - + while current_dt >= before: - date_str = current_dt.strftime('%Y-%m-%d') - + date_str = current_dt.strftime("%Y-%m-%d") + # Look up the indicator value for this date if date_str in indicator_data: indicator_value = indicator_data[date_str] else: indicator_value = "N/A: Not a trading day (weekend or holiday)" - + date_values.append((date_str, indicator_value)) current_dt = current_dt - relativedelta(days=1) - + # Build the result string ind_string = "" for date_str, value in date_values: ind_string += f"{date_str}: {value}\n" - + except Exception as e: print(f"Error getting bulk stockstats data: {e}") # Fallback to original implementation if bulk method fails @@ -187,7 +188,7 @@ def get_stock_stats_indicators_window( def _get_stock_stats_bulk( symbol: Annotated[str, "ticker symbol of the company"], indicator: Annotated[str, "technical indicator to calculate"], - curr_date: Annotated[str, "current date for reference"] + curr_date: Annotated[str, "current date for reference"], ) -> dict: """ Optimized bulk calculation of stock stats indicators. @@ -195,13 +196,13 @@ def _get_stock_stats_bulk( Returns dict mapping date strings to indicator values. """ from .config import get_config + import os import pandas as pd from stockstats import wrap - import os - + config = get_config() online = config["data_vendors"]["technical_indicators"] != "local" - + if not online: # Local data path try: @@ -217,20 +218,20 @@ def _get_stock_stats_bulk( else: # Online data fetching with caching today_date = pd.Timestamp.today() - curr_date_dt = pd.to_datetime(curr_date) - + pd.to_datetime(curr_date) + end_date = today_date start_date = today_date - pd.DateOffset(years=15) start_date_str = start_date.strftime("%Y-%m-%d") end_date_str = end_date.strftime("%Y-%m-%d") - + os.makedirs(config["data_cache_dir"], exist_ok=True) - + data_file = os.path.join( config["data_cache_dir"], f"{symbol}-YFin-data-{start_date_str}-{end_date_str}.csv", ) - + if os.path.exists(data_file): data = pd.read_csv(data_file) data["Date"] = pd.to_datetime(data["Date"]) @@ -245,25 +246,25 @@ def _get_stock_stats_bulk( ) data = data.reset_index() data.to_csv(data_file, index=False) - + df = wrap(data) df["Date"] = df["Date"].dt.strftime("%Y-%m-%d") - + # Calculate the indicator for all rows at once df[indicator] # This triggers stockstats to calculate the indicator - + # Create a dictionary mapping date strings to indicator values result_dict = {} for _, row in df.iterrows(): date_str = row["Date"] indicator_value = row[indicator] - + # Handle NaN/None values if pd.isna(indicator_value): result_dict[date_str] = "N/A" else: result_dict[date_str] = str(indicator_value) - + return result_dict @@ -295,7 +296,7 @@ def get_stockstats_indicator( def get_fundamentals( ticker: Annotated[str, "ticker symbol of the company"], - curr_date: Annotated[str, "current date (not used for yfinance)"] = None + curr_date: Annotated[str, "current date (not used for yfinance)"] = None, ): """Get company fundamentals overview from yfinance.""" try: @@ -342,7 +343,9 @@ def get_fundamentals( lines.append(f"{label}: {value}") header = f"# Company Fundamentals for {ticker.upper()}\n" - header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n" + header += ( + f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n" + ) return header + "\n".join(lines) @@ -353,29 +356,31 @@ def get_fundamentals( def get_balance_sheet( ticker: Annotated[str, "ticker symbol of the company"], freq: Annotated[str, "frequency of data: 'annual' or 'quarterly'"] = "quarterly", - curr_date: Annotated[str, "current date (not used for yfinance)"] = None + curr_date: Annotated[str, "current date (not used for yfinance)"] = None, ): """Get balance sheet data from yfinance.""" try: ticker_obj = yf.Ticker(ticker.upper()) - + if freq.lower() == "quarterly": data = ticker_obj.quarterly_balance_sheet else: data = ticker_obj.balance_sheet - + if data.empty: return f"No balance sheet data found for symbol '{ticker}'" - + # Convert to CSV string for consistency with other functions csv_string = data.to_csv() - + # Add header information header = f"# Balance Sheet data for {ticker.upper()} ({freq})\n" - header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n" - + header += ( + f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n" + ) + return header + csv_string - + except Exception as e: return f"Error retrieving balance sheet for {ticker}: {str(e)}" @@ -383,29 +388,31 @@ def get_balance_sheet( def get_cashflow( ticker: Annotated[str, "ticker symbol of the company"], freq: Annotated[str, "frequency of data: 'annual' or 'quarterly'"] = "quarterly", - curr_date: Annotated[str, "current date (not used for yfinance)"] = None + curr_date: Annotated[str, "current date (not used for yfinance)"] = None, ): """Get cash flow data from yfinance.""" try: ticker_obj = yf.Ticker(ticker.upper()) - + if freq.lower() == "quarterly": data = ticker_obj.quarterly_cashflow else: data = ticker_obj.cashflow - + if data.empty: return f"No cash flow data found for symbol '{ticker}'" - + # Convert to CSV string for consistency with other functions csv_string = data.to_csv() - + # Add header information header = f"# Cash Flow data for {ticker.upper()} ({freq})\n" - header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n" - + header += ( + f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n" + ) + return header + csv_string - + except Exception as e: return f"Error retrieving cash flow for {ticker}: {str(e)}" @@ -413,52 +420,54 @@ def get_cashflow( def get_income_statement( ticker: Annotated[str, "ticker symbol of the company"], freq: Annotated[str, "frequency of data: 'annual' or 'quarterly'"] = "quarterly", - curr_date: Annotated[str, "current date (not used for yfinance)"] = None + curr_date: Annotated[str, "current date (not used for yfinance)"] = None, ): """Get income statement data from yfinance.""" try: ticker_obj = yf.Ticker(ticker.upper()) - + if freq.lower() == "quarterly": data = ticker_obj.quarterly_income_stmt else: data = ticker_obj.income_stmt - + if data.empty: return f"No income statement data found for symbol '{ticker}'" - + # Convert to CSV string for consistency with other functions csv_string = data.to_csv() - + # Add header information header = f"# Income Statement data for {ticker.upper()} ({freq})\n" - header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n" - + header += ( + f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n" + ) + return header + csv_string - + except Exception as e: return f"Error retrieving income statement for {ticker}: {str(e)}" -def get_insider_transactions( - ticker: Annotated[str, "ticker symbol of the company"] -): +def get_insider_transactions(ticker: Annotated[str, "ticker symbol of the company"]): """Get insider transactions data from yfinance.""" try: ticker_obj = yf.Ticker(ticker.upper()) data = ticker_obj.insider_transactions - + if data is None or data.empty: return f"No insider transactions data found for symbol '{ticker}'" - + # Convert to CSV string for consistency with other functions csv_string = data.to_csv() - + # Add header information header = f"# Insider Transactions data for {ticker.upper()}\n" - header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n" - + header += ( + f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n" + ) + return header + csv_string - + except Exception as e: - return f"Error retrieving insider transactions for {ticker}: {str(e)}" \ No newline at end of file + return f"Error retrieving insider transactions for {ticker}: {str(e)}" diff --git a/tradingagents/graph/setup.py b/tradingagents/graph/setup.py index 772efe7f..6c9c6300 100644 --- a/tradingagents/graph/setup.py +++ b/tradingagents/graph/setup.py @@ -1,11 +1,25 @@ # TradingAgents/graph/setup.py -from typing import Dict, Any +from typing import Dict from langchain_openai import ChatOpenAI from langgraph.graph import END, StateGraph, START from langgraph.prebuilt import ToolNode -from tradingagents.agents import * +from tradingagents.agents import ( + create_msg_delete, + create_market_analyst, + create_social_media_analyst, + create_fundamentals_analyst, + create_news_analyst, + create_bull_researcher, + create_bear_researcher, + create_research_manager, + create_trader, + create_aggressive_debator, + create_conservative_debator, + create_neutral_debator, + create_risk_manager, +) from tradingagents.agents.utils.agent_states import AgentState from .conditional_logic import ConditionalLogic @@ -58,9 +72,7 @@ class GraphSetup: tool_nodes = {} if "market" in selected_analysts: - analyst_nodes["market"] = create_market_analyst( - self.quick_thinking_llm - ) + analyst_nodes["market"] = create_market_analyst(self.quick_thinking_llm) delete_nodes["market"] = create_msg_delete() tool_nodes["market"] = self.tool_nodes["market"] @@ -72,9 +84,7 @@ class GraphSetup: tool_nodes["social"] = self.tool_nodes["social"] if "news" in selected_analysts: - analyst_nodes["news"] = create_news_analyst( - self.quick_thinking_llm - ) + analyst_nodes["news"] = create_news_analyst(self.quick_thinking_llm) delete_nodes["news"] = create_msg_delete() tool_nodes["news"] = self.tool_nodes["news"] diff --git a/tradingagents/graph/trading_graph.py b/tradingagents/graph/trading_graph.py index 44ecca0c..8f0c4b21 100644 --- a/tradingagents/graph/trading_graph.py +++ b/tradingagents/graph/trading_graph.py @@ -3,21 +3,14 @@ import os from pathlib import Path import json -from datetime import date -from typing import Dict, Any, Tuple, List, Optional +from typing import Dict, Any, List, Optional from langgraph.prebuilt import ToolNode from tradingagents.llm_clients import create_llm_client -from tradingagents.agents import * from tradingagents.default_config import DEFAULT_CONFIG from tradingagents.agents.utils.memory import FinancialSituationMemory -from tradingagents.agents.utils.agent_states import ( - AgentState, - InvestDebateState, - RiskDebateState, -) from tradingagents.dataflows.config import set_config # Import the new abstract tool methods from agent_utils @@ -30,7 +23,7 @@ from tradingagents.agents.utils.agent_utils import ( get_income_statement, get_news, get_insider_transactions, - get_global_news + get_global_news, ) from .conditional_logic import ConditionalLogic @@ -93,13 +86,17 @@ class TradingAgentsGraph: self.deep_thinking_llm = deep_client.get_llm() self.quick_thinking_llm = quick_client.get_llm() - + # Initialize memories self.bull_memory = FinancialSituationMemory("bull_memory", self.config) self.bear_memory = FinancialSituationMemory("bear_memory", self.config) self.trader_memory = FinancialSituationMemory("trader_memory", self.config) - self.invest_judge_memory = FinancialSituationMemory("invest_judge_memory", self.config) - self.risk_manager_memory = FinancialSituationMemory("risk_manager_memory", self.config) + self.invest_judge_memory = FinancialSituationMemory( + "invest_judge_memory", self.config + ) + self.risk_manager_memory = FinancialSituationMemory( + "risk_manager_memory", self.config + ) # Create tool nodes self.tool_nodes = self._create_tool_nodes() @@ -240,8 +237,12 @@ class TradingAgentsGraph: }, "trader_investment_decision": final_state["trader_investment_plan"], "risk_debate_state": { - "aggressive_history": final_state["risk_debate_state"]["aggressive_history"], - "conservative_history": final_state["risk_debate_state"]["conservative_history"], + "aggressive_history": final_state["risk_debate_state"][ + "aggressive_history" + ], + "conservative_history": final_state["risk_debate_state"][ + "conservative_history" + ], "neutral_history": final_state["risk_debate_state"]["neutral_history"], "history": final_state["risk_debate_state"]["history"], "judge_decision": final_state["risk_debate_state"]["judge_decision"],