style: correções automáticas pre-commit

This commit is contained in:
Bruno Natalicio 2026-03-11 12:51:35 -03:00
parent 020f51a146
commit cefe0a12b2
8 changed files with 434 additions and 199 deletions

37
.github/workflows/ci.yml vendored Normal file
View File

@ -0,0 +1,37 @@
name: Python CI
on:
push:
branches: [ "main" ]
pull_request:
branches: [ "main" ]
jobs:
build-and-test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Install uv
uses: astral-sh/setup-uv@v2
with:
enable-cache: true
cache-dependency-glob: "uv.lock"
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version-file: ".python-version"
- name: Install the project
run: uv sync --all-extras --dev
- name: Format with Black
run: uv run black --check .
- name: Lint with Ruff
run: uv run ruff check .
# - name: Run tests (Uncomment when tests exist)
# run: uv run pytest

20
.pre-commit-config.yaml Normal file
View File

@ -0,0 +1,20 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: check-added-large-files
- repo: https://github.com/psf/black
rev: 24.2.0
hooks:
- id: black
language_version: python3.13
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.3.0
hooks:
- id: ruff
args: [ --fix ]

View File

@ -1,34 +1,39 @@
from typing import Optional
import datetime
import typer
import time
from pathlib import Path
from functools import wraps
from rich.console import Console
from collections import deque
import typer
from dotenv import load_dotenv
from rich import box
from rich.align import Align
from rich.console import Console
from rich.layout import Layout
from rich.live import Live
from rich.markdown import Markdown
from rich.panel import Panel
from rich.rule import Rule
from rich.spinner import Spinner
from rich.table import Table
from rich.text import Text
from tradingagents.default_config import DEFAULT_CONFIG
from tradingagents.graph.trading_graph import TradingAgentsGraph
from cli.announcements import display_announcements, fetch_announcements
from cli.stats_handler import StatsCallbackHandler
from cli.utils import (
ask_gemini_thinking_config,
ask_openai_reasoning_effort,
select_analysts,
select_deep_thinking_agent,
select_llm_provider,
select_research_depth,
select_shallow_thinking_agent,
)
# Load environment variables from .env file
load_dotenv()
from rich.panel import Panel
from rich.spinner import Spinner
from rich.live import Live
from rich.columns import Columns
from rich.markdown import Markdown
from rich.layout import Layout
from rich.text import Text
from rich.table import Table
from collections import deque
import time
from rich.tree import Tree
from rich import box
from rich.align import Align
from rich.rule import Rule
from tradingagents.graph.trading_graph import TradingAgentsGraph
from tradingagents.default_config import DEFAULT_CONFIG
from cli.models import AnalystType
from cli.utils import *
from cli.announcements import fetch_announcements, display_announcements
from cli.stats_handler import StatsCallbackHandler
console = Console()
@ -45,7 +50,11 @@ class MessageBuffer:
FIXED_AGENTS = {
"Research Team": ["Bull Researcher", "Bear Researcher", "Research Manager"],
"Trading Team": ["Trader"],
"Risk Management": ["Aggressive Analyst", "Neutral Analyst", "Conservative Analyst"],
"Risk Management": [
"Aggressive Analyst",
"Neutral Analyst",
"Conservative Analyst",
],
"Portfolio Management": ["Portfolio Manager"],
}
@ -165,7 +174,7 @@ class MessageBuffer:
if content is not None:
latest_section = section
latest_content = content
if latest_section and latest_content:
# Format the current section for display
section_titles = {
@ -188,7 +197,12 @@ class MessageBuffer:
report_parts = []
# Analyst Team Reports - use .get() to handle missing sections
analyst_sections = ["market_report", "sentiment_report", "news_report", "fundamentals_report"]
analyst_sections = [
"market_report",
"sentiment_report",
"news_report",
"fundamentals_report",
]
if any(self.report_sections.get(section) for section in analyst_sections):
report_parts.append("## Analyst Team Reports")
if self.report_sections.get("market_report"):
@ -289,7 +303,11 @@ def update_display(layout, spinner_text=None, stats_handler=None, start_time=Non
],
"Research Team": ["Bull Researcher", "Bear Researcher", "Research Manager"],
"Trading Team": ["Trader"],
"Risk Management": ["Aggressive Analyst", "Neutral Analyst", "Conservative Analyst"],
"Risk Management": [
"Aggressive Analyst",
"Neutral Analyst",
"Conservative Analyst",
],
"Portfolio Management": ["Portfolio Manager"],
}
@ -538,12 +556,10 @@ def get_user_selections():
# Step 5: OpenAI backend
console.print(
create_question_box(
"Step 5: OpenAI backend", "Select which service to talk to"
)
create_question_box("Step 5: OpenAI backend", "Select which service to talk to")
)
selected_llm_provider, backend_url = select_llm_provider()
# Step 6: Thinking agents
console.print(
create_question_box(
@ -561,16 +577,14 @@ def get_user_selections():
if provider_lower == "google":
console.print(
create_question_box(
"Step 7: Thinking Mode",
"Configure Gemini thinking mode"
"Step 7: Thinking Mode", "Configure Gemini thinking mode"
)
)
thinking_level = ask_gemini_thinking_config()
elif provider_lower == "openai":
console.print(
create_question_box(
"Step 7: Reasoning Effort",
"Configure OpenAI reasoning effort level"
"Step 7: Reasoning Effort", "Configure OpenAI reasoning effort level"
)
)
reasoning_effort = ask_openai_reasoning_effort()
@ -635,8 +649,12 @@ def save_report_to_disk(final_state, ticker: str, save_path: Path):
analyst_parts.append(("News Analyst", final_state["news_report"]))
if final_state.get("fundamentals_report"):
analysts_dir.mkdir(exist_ok=True)
(analysts_dir / "fundamentals.md").write_text(final_state["fundamentals_report"])
analyst_parts.append(("Fundamentals Analyst", final_state["fundamentals_report"]))
(analysts_dir / "fundamentals.md").write_text(
final_state["fundamentals_report"]
)
analyst_parts.append(
("Fundamentals Analyst", final_state["fundamentals_report"])
)
if analyst_parts:
content = "\n\n".join(f"### {name}\n{text}" for name, text in analyst_parts)
sections.append(f"## I. Analyst Team Reports\n\n{content}")
@ -659,7 +677,9 @@ def save_report_to_disk(final_state, ticker: str, save_path: Path):
(research_dir / "manager.md").write_text(debate["judge_decision"])
research_parts.append(("Research Manager", debate["judge_decision"]))
if research_parts:
content = "\n\n".join(f"### {name}\n{text}" for name, text in research_parts)
content = "\n\n".join(
f"### {name}\n{text}" for name, text in research_parts
)
sections.append(f"## II. Research Team Decision\n\n{content}")
# 3. Trading
@ -667,7 +687,9 @@ def save_report_to_disk(final_state, ticker: str, save_path: Path):
trading_dir = save_path / "3_trading"
trading_dir.mkdir(exist_ok=True)
(trading_dir / "trader.md").write_text(final_state["trader_investment_plan"])
sections.append(f"## III. Trading Team Plan\n\n### Trader\n{final_state['trader_investment_plan']}")
sections.append(
f"## III. Trading Team Plan\n\n### Trader\n{final_state['trader_investment_plan']}"
)
# 4. Risk Management
if final_state.get("risk_debate_state"):
@ -695,7 +717,9 @@ def save_report_to_disk(final_state, ticker: str, save_path: Path):
portfolio_dir = save_path / "5_portfolio"
portfolio_dir.mkdir(exist_ok=True)
(portfolio_dir / "decision.md").write_text(risk["judge_decision"])
sections.append(f"## V. Portfolio Manager Decision\n\n### Portfolio Manager\n{risk['judge_decision']}")
sections.append(
f"## V. Portfolio Manager Decision\n\n### Portfolio Manager\n{risk['judge_decision']}"
)
# Write consolidated report
header = f"# Trading Analysis Report: {ticker}\n\nGenerated: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
@ -719,9 +743,15 @@ def display_complete_report(final_state):
if final_state.get("fundamentals_report"):
analysts.append(("Fundamentals Analyst", final_state["fundamentals_report"]))
if analysts:
console.print(Panel("[bold]I. Analyst Team Reports[/bold]", border_style="cyan"))
console.print(
Panel("[bold]I. Analyst Team Reports[/bold]", border_style="cyan")
)
for title, content in analysts:
console.print(Panel(Markdown(content), title=title, border_style="blue", padding=(1, 2)))
console.print(
Panel(
Markdown(content), title=title, border_style="blue", padding=(1, 2)
)
)
# II. Research Team Reports
if final_state.get("investment_debate_state"):
@ -734,14 +764,32 @@ def display_complete_report(final_state):
if debate.get("judge_decision"):
research.append(("Research Manager", debate["judge_decision"]))
if research:
console.print(Panel("[bold]II. Research Team Decision[/bold]", border_style="magenta"))
console.print(
Panel("[bold]II. Research Team Decision[/bold]", border_style="magenta")
)
for title, content in research:
console.print(Panel(Markdown(content), title=title, border_style="blue", padding=(1, 2)))
console.print(
Panel(
Markdown(content),
title=title,
border_style="blue",
padding=(1, 2),
)
)
# III. Trading Team
if final_state.get("trader_investment_plan"):
console.print(Panel("[bold]III. Trading Team Plan[/bold]", border_style="yellow"))
console.print(Panel(Markdown(final_state["trader_investment_plan"]), title="Trader", border_style="blue", padding=(1, 2)))
console.print(
Panel("[bold]III. Trading Team Plan[/bold]", border_style="yellow")
)
console.print(
Panel(
Markdown(final_state["trader_investment_plan"]),
title="Trader",
border_style="blue",
padding=(1, 2),
)
)
# IV. Risk Management Team
if final_state.get("risk_debate_state"):
@ -754,14 +802,36 @@ def display_complete_report(final_state):
if risk.get("neutral_history"):
risk_reports.append(("Neutral Analyst", risk["neutral_history"]))
if risk_reports:
console.print(Panel("[bold]IV. Risk Management Team Decision[/bold]", border_style="red"))
console.print(
Panel(
"[bold]IV. Risk Management Team Decision[/bold]", border_style="red"
)
)
for title, content in risk_reports:
console.print(Panel(Markdown(content), title=title, border_style="blue", padding=(1, 2)))
console.print(
Panel(
Markdown(content),
title=title,
border_style="blue",
padding=(1, 2),
)
)
# V. Portfolio Manager Decision
if risk.get("judge_decision"):
console.print(Panel("[bold]V. Portfolio Manager Decision[/bold]", border_style="green"))
console.print(Panel(Markdown(risk["judge_decision"]), title="Portfolio Manager", border_style="blue", padding=(1, 2)))
console.print(
Panel(
"[bold]V. Portfolio Manager Decision[/bold]", border_style="green"
)
)
console.print(
Panel(
Markdown(risk["judge_decision"]),
title="Portfolio Manager",
border_style="blue",
padding=(1, 2),
)
)
def update_research_team_status(status):
@ -821,6 +891,7 @@ def update_analyst_statuses(message_buffer, chunk):
if message_buffer.agent_status.get("Bull Researcher") == "pending":
message_buffer.update_agent_status("Bull Researcher", "in_progress")
def extract_content_string(content):
"""Extract string content from various message formats.
Returns None if no meaningful text content is found.
@ -829,7 +900,7 @@ def extract_content_string(content):
def is_empty(val):
"""Check if value is empty using Python's truthiness."""
if val is None or val == '':
if val is None or val == "":
return True
if isinstance(val, str):
s = val.strip()
@ -848,16 +919,19 @@ def extract_content_string(content):
return content.strip()
if isinstance(content, dict):
text = content.get('text', '')
text = content.get("text", "")
return text.strip() if not is_empty(text) else None
if isinstance(content, list):
text_parts = [
item.get('text', '').strip() if isinstance(item, dict) and item.get('type') == 'text'
else (item.strip() if isinstance(item, str) else '')
(
item.get("text", "").strip()
if isinstance(item, dict) and item.get("type") == "text"
else (item.strip() if isinstance(item, str) else "")
)
for item in content
]
result = ' '.join(t for t in text_parts if t and not is_empty(t))
result = " ".join(t for t in text_parts if t and not is_empty(t))
return result if result else None
return str(content).strip() if not is_empty(content) else None
@ -872,7 +946,7 @@ def classify_message_type(message) -> tuple[str, str | None]:
"""
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
content = extract_content_string(getattr(message, 'content', None))
content = extract_content_string(getattr(message, "content", None))
if isinstance(message, HumanMessage):
if content and content.strip() == "Continue":
@ -893,9 +967,10 @@ def format_tool_args(args, max_length=80) -> str:
"""Format tool arguments for terminal display."""
result = str(args)
if len(result) > max_length:
return result[:max_length - 3] + "..."
return result[: max_length - 3] + "..."
return result
def run_analysis():
# First get all user selections
selections = get_user_selections()
@ -934,7 +1009,9 @@ def run_analysis():
start_time = time.time()
# Create result directory
results_dir = Path(config["results_dir"]) / selections["ticker"] / selections["analysis_date"]
results_dir = (
Path(config["results_dir"]) / selections["ticker"] / selections["analysis_date"]
)
results_dir.mkdir(parents=True, exist_ok=True)
report_dir = results_dir / "reports"
report_dir.mkdir(parents=True, exist_ok=True)
@ -943,6 +1020,7 @@ def run_analysis():
def save_message_decorator(obj, func_name):
func = getattr(obj, func_name)
@wraps(func)
def wrapper(*args, **kwargs):
func(*args, **kwargs)
@ -950,10 +1028,12 @@ def run_analysis():
content = content.replace("\n", " ") # Replace newlines with spaces
with open(log_file, "a") as f:
f.write(f"{timestamp} [{message_type}] {content}\n")
return wrapper
def save_tool_call_decorator(obj, func_name):
func = getattr(obj, func_name)
@wraps(func)
def wrapper(*args, **kwargs):
func(*args, **kwargs)
@ -961,29 +1041,39 @@ def run_analysis():
args_str = ", ".join(f"{k}={v}" for k, v in args.items())
with open(log_file, "a") as f:
f.write(f"{timestamp} [Tool Call] {tool_name}({args_str})\n")
return wrapper
def save_report_section_decorator(obj, func_name):
func = getattr(obj, func_name)
@wraps(func)
def wrapper(section_name, content):
func(section_name, content)
if section_name in obj.report_sections and obj.report_sections[section_name] is not None:
if (
section_name in obj.report_sections
and obj.report_sections[section_name] is not None
):
content = obj.report_sections[section_name]
if content:
file_name = f"{section_name}.md"
with open(report_dir / file_name, "w") as f:
f.write(content)
return wrapper
message_buffer.add_message = save_message_decorator(message_buffer, "add_message")
message_buffer.add_tool_call = save_tool_call_decorator(message_buffer, "add_tool_call")
message_buffer.update_report_section = save_report_section_decorator(message_buffer, "update_report_section")
message_buffer.add_tool_call = save_tool_call_decorator(
message_buffer, "add_tool_call"
)
message_buffer.update_report_section = save_report_section_decorator(
message_buffer, "update_report_section"
)
# Now start the display layout
layout = create_layout()
with Live(layout, refresh_per_second=4) as live:
with Live(layout, refresh_per_second=4):
# Initial display
update_display(layout, stats_handler=stats_handler, start_time=start_time)
@ -1007,7 +1097,9 @@ def run_analysis():
spinner_text = (
f"Analyzing {selections['ticker']} on {selections['analysis_date']}..."
)
update_display(layout, spinner_text, stats_handler=stats_handler, start_time=start_time)
update_display(
layout, spinner_text, stats_handler=stats_handler, start_time=start_time
)
# Initialize state and get graph args with callbacks
init_agent_state = graph.propagator.create_initial_state(
@ -1041,7 +1133,9 @@ def run_analysis():
tool_call["name"], tool_call["args"]
)
else:
message_buffer.add_tool_call(tool_call.name, tool_call.args)
message_buffer.add_tool_call(
tool_call.name, tool_call.args
)
# Update analyst statuses based on report state (runs on every chunk)
update_analyst_statuses(message_buffer, chunk)
@ -1078,7 +1172,9 @@ def run_analysis():
)
if message_buffer.agent_status.get("Trader") != "completed":
message_buffer.update_agent_status("Trader", "completed")
message_buffer.update_agent_status("Aggressive Analyst", "in_progress")
message_buffer.update_agent_status(
"Aggressive Analyst", "in_progress"
)
# Risk Management Team - Handle Risk Debate State
if chunk.get("risk_debate_state"):
@ -1089,33 +1185,65 @@ def run_analysis():
judge = risk_state.get("judge_decision", "").strip()
if agg_hist:
if message_buffer.agent_status.get("Aggressive Analyst") != "completed":
message_buffer.update_agent_status("Aggressive Analyst", "in_progress")
if (
message_buffer.agent_status.get("Aggressive Analyst")
!= "completed"
):
message_buffer.update_agent_status(
"Aggressive Analyst", "in_progress"
)
message_buffer.update_report_section(
"final_trade_decision", f"### Aggressive Analyst Analysis\n{agg_hist}"
"final_trade_decision",
f"### Aggressive Analyst Analysis\n{agg_hist}",
)
if con_hist:
if message_buffer.agent_status.get("Conservative Analyst") != "completed":
message_buffer.update_agent_status("Conservative Analyst", "in_progress")
if (
message_buffer.agent_status.get("Conservative Analyst")
!= "completed"
):
message_buffer.update_agent_status(
"Conservative Analyst", "in_progress"
)
message_buffer.update_report_section(
"final_trade_decision", f"### Conservative Analyst Analysis\n{con_hist}"
"final_trade_decision",
f"### Conservative Analyst Analysis\n{con_hist}",
)
if neu_hist:
if message_buffer.agent_status.get("Neutral Analyst") != "completed":
message_buffer.update_agent_status("Neutral Analyst", "in_progress")
if (
message_buffer.agent_status.get("Neutral Analyst")
!= "completed"
):
message_buffer.update_agent_status(
"Neutral Analyst", "in_progress"
)
message_buffer.update_report_section(
"final_trade_decision", f"### Neutral Analyst Analysis\n{neu_hist}"
"final_trade_decision",
f"### Neutral Analyst Analysis\n{neu_hist}",
)
if judge:
if message_buffer.agent_status.get("Portfolio Manager") != "completed":
message_buffer.update_agent_status("Portfolio Manager", "in_progress")
message_buffer.update_report_section(
"final_trade_decision", f"### Portfolio Manager Decision\n{judge}"
if (
message_buffer.agent_status.get("Portfolio Manager")
!= "completed"
):
message_buffer.update_agent_status(
"Portfolio Manager", "in_progress"
)
message_buffer.update_report_section(
"final_trade_decision",
f"### Portfolio Manager Decision\n{judge}",
)
message_buffer.update_agent_status(
"Aggressive Analyst", "completed"
)
message_buffer.update_agent_status(
"Conservative Analyst", "completed"
)
message_buffer.update_agent_status(
"Neutral Analyst", "completed"
)
message_buffer.update_agent_status(
"Portfolio Manager", "completed"
)
message_buffer.update_agent_status("Aggressive Analyst", "completed")
message_buffer.update_agent_status("Conservative Analyst", "completed")
message_buffer.update_agent_status("Neutral Analyst", "completed")
message_buffer.update_agent_status("Portfolio Manager", "completed")
# Update the display
update_display(layout, stats_handler=stats_handler, start_time=start_time)
@ -1124,7 +1252,7 @@ def run_analysis():
# Get final state and decision
final_state = trace[-1]
decision = graph.process_signal(final_state["final_trade_decision"])
graph.process_signal(final_state["final_trade_decision"])
# Update all agent statuses to completed
for agent in message_buffer.agent_status:
@ -1150,19 +1278,22 @@ def run_analysis():
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
default_path = Path.cwd() / "reports" / f"{selections['ticker']}_{timestamp}"
save_path_str = typer.prompt(
"Save path (press Enter for default)",
default=str(default_path)
"Save path (press Enter for default)", default=str(default_path)
).strip()
save_path = Path(save_path_str)
try:
report_file = save_report_to_disk(final_state, selections["ticker"], save_path)
report_file = save_report_to_disk(
final_state, selections["ticker"], save_path
)
console.print(f"\n[green]✓ Report saved to:[/green] {save_path.resolve()}")
console.print(f" [dim]Complete report:[/dim] {report_file.name}")
except Exception as e:
console.print(f"[red]Error saving report: {e}[/red]")
# Prompt to display full report
display_choice = typer.prompt("\nDisplay full report on screen?", default="Y").strip().upper()
display_choice = (
typer.prompt("\nDisplay full report on screen?", default="Y").strip().upper()
)
if display_choice in ("Y", "YES", ""):
display_complete_report(final_state)

View File

@ -1,8 +1,10 @@
import questionary
from typing import List, Optional, Tuple, Dict
from typing import List
from rich.console import Console
from cli.models import AnalystType
console = Console()
ANALYST_ORDER = [
("Market Analyst", AnalystType.MARKET),
("Social Media Analyst", AnalystType.SOCIAL),
@ -146,13 +148,25 @@ def select_shallow_thinking_agent(provider) -> str:
("Gemini 2.5 Flash Lite - Fast, low-cost", "gemini-2.5-flash-lite"),
],
"xai": [
("Grok 4.1 Fast (Non-Reasoning) - Speed optimized, 2M ctx", "grok-4-1-fast-non-reasoning"),
("Grok 4 Fast (Non-Reasoning) - Speed optimized", "grok-4-fast-non-reasoning"),
("Grok 4.1 Fast (Reasoning) - High-performance, 2M ctx", "grok-4-1-fast-reasoning"),
(
"Grok 4.1 Fast (Non-Reasoning) - Speed optimized, 2M ctx",
"grok-4-1-fast-non-reasoning",
),
(
"Grok 4 Fast (Non-Reasoning) - Speed optimized",
"grok-4-fast-non-reasoning",
),
(
"Grok 4.1 Fast (Reasoning) - High-performance, 2M ctx",
"grok-4-1-fast-reasoning",
),
("Grok 4 Fast (Reasoning) - High-performance", "grok-4-fast-reasoning"),
],
"openrouter": [
("NVIDIA Nemotron 3 Nano 30B (free)", "nvidia/nemotron-3-nano-30b-a3b:free"),
(
"NVIDIA Nemotron 3 Nano 30B (free)",
"nvidia/nemotron-3-nano-30b-a3b:free",
),
("Z.AI GLM 4.5 Air (free)", "z-ai/glm-4.5-air:free"),
],
"ollama": [
@ -213,15 +227,27 @@ def select_deep_thinking_agent(provider) -> str:
("Gemini 2.5 Flash - Balanced, recommended", "gemini-2.5-flash"),
],
"xai": [
("Grok 4.1 Fast (Reasoning) - High-performance, 2M ctx", "grok-4-1-fast-reasoning"),
(
"Grok 4.1 Fast (Reasoning) - High-performance, 2M ctx",
"grok-4-1-fast-reasoning",
),
("Grok 4 Fast (Reasoning) - High-performance", "grok-4-fast-reasoning"),
("Grok 4 - Flagship model", "grok-4-0709"),
("Grok 4.1 Fast (Non-Reasoning) - Speed optimized, 2M ctx", "grok-4-1-fast-non-reasoning"),
("Grok 4 Fast (Non-Reasoning) - Speed optimized", "grok-4-fast-non-reasoning"),
(
"Grok 4.1 Fast (Non-Reasoning) - Speed optimized, 2M ctx",
"grok-4-1-fast-non-reasoning",
),
(
"Grok 4 Fast (Non-Reasoning) - Speed optimized",
"grok-4-fast-non-reasoning",
),
],
"openrouter": [
("Z.AI GLM 4.5 Air (free)", "z-ai/glm-4.5-air:free"),
("NVIDIA Nemotron 3 Nano 30B (free)", "nvidia/nemotron-3-nano-30b-a3b:free"),
(
"NVIDIA Nemotron 3 Nano 30B (free)",
"nvidia/nemotron-3-nano-30b-a3b:free",
),
],
"ollama": [
("GLM-4.7-Flash:latest (30B, local)", "glm-4.7-flash:latest"),
@ -252,6 +278,7 @@ def select_deep_thinking_agent(provider) -> str:
return choice
def select_llm_provider() -> tuple[str, str]:
"""Select the OpenAI api url using interactive selection."""
# Define OpenAI api options with their corresponding endpoints
@ -263,7 +290,7 @@ def select_llm_provider() -> tuple[str, str]:
("Openrouter", "https://openrouter.ai/api/v1"),
("Ollama", "http://localhost:11434/v1"),
]
choice = questionary.select(
"Select your LLM Provider:",
choices=[
@ -279,11 +306,11 @@ def select_llm_provider() -> tuple[str, str]:
]
),
).ask()
if choice is None:
console.print("\n[red]no OpenAI backend selected. Exiting...[/red]")
exit(1)
display_name, url = choice
print(f"You selected: {display_name}\tURL: {url}")
@ -300,11 +327,13 @@ def ask_openai_reasoning_effort() -> str:
return questionary.select(
"Select Reasoning Effort:",
choices=choices,
style=questionary.Style([
("selected", "fg:cyan noinherit"),
("highlighted", "fg:cyan noinherit"),
("pointer", "fg:cyan noinherit"),
]),
style=questionary.Style(
[
("selected", "fg:cyan noinherit"),
("highlighted", "fg:cyan noinherit"),
("pointer", "fg:cyan noinherit"),
]
),
).ask()
@ -320,9 +349,11 @@ def ask_gemini_thinking_config() -> str | None:
questionary.Choice("Enable Thinking (recommended)", "high"),
questionary.Choice("Minimal/Disable Thinking", "minimal"),
],
style=questionary.Style([
("selected", "fg:green noinherit"),
("highlighted", "fg:green noinherit"),
("pointer", "fg:green noinherit"),
]),
style=questionary.Style(
[
("selected", "fg:green noinherit"),
("highlighted", "fg:green noinherit"),
("pointer", "fg:green noinherit"),
]
),
).ask()

View File

@ -1,10 +1,6 @@
from typing import Annotated, Sequence
from datetime import date, timedelta, datetime
from typing_extensions import TypedDict, Optional
from langchain_openai import ChatOpenAI
from tradingagents.agents import *
from langgraph.prebuilt import ToolNode
from langgraph.graph import END, StateGraph, START, MessagesState
from typing import Annotated
from typing_extensions import TypedDict
from langgraph.graph import MessagesState
# Researcher team state

View File

@ -2,9 +2,9 @@ from typing import Annotated
from datetime import datetime
from dateutil.relativedelta import relativedelta
import yfinance as yf
import os
from .stockstats_utils import StockstatsUtils
def get_YFin_data_online(
symbol: Annotated[str, "ticker symbol of the company"],
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
@ -46,6 +46,7 @@ def get_YFin_data_online(
return header + csv_string
def get_stock_stats_indicators_window(
symbol: Annotated[str, "ticker symbol of the company"],
indicator: Annotated[str, "technical indicator to get the analysis and report of"],
@ -140,28 +141,28 @@ def get_stock_stats_indicators_window(
# Optimized: Get stock data once and calculate indicators for all dates
try:
indicator_data = _get_stock_stats_bulk(symbol, indicator, curr_date)
# Generate the date range we need
current_dt = curr_date_dt
date_values = []
while current_dt >= before:
date_str = current_dt.strftime('%Y-%m-%d')
date_str = current_dt.strftime("%Y-%m-%d")
# Look up the indicator value for this date
if date_str in indicator_data:
indicator_value = indicator_data[date_str]
else:
indicator_value = "N/A: Not a trading day (weekend or holiday)"
date_values.append((date_str, indicator_value))
current_dt = current_dt - relativedelta(days=1)
# Build the result string
ind_string = ""
for date_str, value in date_values:
ind_string += f"{date_str}: {value}\n"
except Exception as e:
print(f"Error getting bulk stockstats data: {e}")
# Fallback to original implementation if bulk method fails
@ -187,7 +188,7 @@ def get_stock_stats_indicators_window(
def _get_stock_stats_bulk(
symbol: Annotated[str, "ticker symbol of the company"],
indicator: Annotated[str, "technical indicator to calculate"],
curr_date: Annotated[str, "current date for reference"]
curr_date: Annotated[str, "current date for reference"],
) -> dict:
"""
Optimized bulk calculation of stock stats indicators.
@ -195,13 +196,13 @@ def _get_stock_stats_bulk(
Returns dict mapping date strings to indicator values.
"""
from .config import get_config
import os
import pandas as pd
from stockstats import wrap
import os
config = get_config()
online = config["data_vendors"]["technical_indicators"] != "local"
if not online:
# Local data path
try:
@ -217,20 +218,20 @@ def _get_stock_stats_bulk(
else:
# Online data fetching with caching
today_date = pd.Timestamp.today()
curr_date_dt = pd.to_datetime(curr_date)
pd.to_datetime(curr_date)
end_date = today_date
start_date = today_date - pd.DateOffset(years=15)
start_date_str = start_date.strftime("%Y-%m-%d")
end_date_str = end_date.strftime("%Y-%m-%d")
os.makedirs(config["data_cache_dir"], exist_ok=True)
data_file = os.path.join(
config["data_cache_dir"],
f"{symbol}-YFin-data-{start_date_str}-{end_date_str}.csv",
)
if os.path.exists(data_file):
data = pd.read_csv(data_file)
data["Date"] = pd.to_datetime(data["Date"])
@ -245,25 +246,25 @@ def _get_stock_stats_bulk(
)
data = data.reset_index()
data.to_csv(data_file, index=False)
df = wrap(data)
df["Date"] = df["Date"].dt.strftime("%Y-%m-%d")
# Calculate the indicator for all rows at once
df[indicator] # This triggers stockstats to calculate the indicator
# Create a dictionary mapping date strings to indicator values
result_dict = {}
for _, row in df.iterrows():
date_str = row["Date"]
indicator_value = row[indicator]
# Handle NaN/None values
if pd.isna(indicator_value):
result_dict[date_str] = "N/A"
else:
result_dict[date_str] = str(indicator_value)
return result_dict
@ -295,7 +296,7 @@ def get_stockstats_indicator(
def get_fundamentals(
ticker: Annotated[str, "ticker symbol of the company"],
curr_date: Annotated[str, "current date (not used for yfinance)"] = None
curr_date: Annotated[str, "current date (not used for yfinance)"] = None,
):
"""Get company fundamentals overview from yfinance."""
try:
@ -342,7 +343,9 @@ def get_fundamentals(
lines.append(f"{label}: {value}")
header = f"# Company Fundamentals for {ticker.upper()}\n"
header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
header += (
f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
)
return header + "\n".join(lines)
@ -353,29 +356,31 @@ def get_fundamentals(
def get_balance_sheet(
ticker: Annotated[str, "ticker symbol of the company"],
freq: Annotated[str, "frequency of data: 'annual' or 'quarterly'"] = "quarterly",
curr_date: Annotated[str, "current date (not used for yfinance)"] = None
curr_date: Annotated[str, "current date (not used for yfinance)"] = None,
):
"""Get balance sheet data from yfinance."""
try:
ticker_obj = yf.Ticker(ticker.upper())
if freq.lower() == "quarterly":
data = ticker_obj.quarterly_balance_sheet
else:
data = ticker_obj.balance_sheet
if data.empty:
return f"No balance sheet data found for symbol '{ticker}'"
# Convert to CSV string for consistency with other functions
csv_string = data.to_csv()
# Add header information
header = f"# Balance Sheet data for {ticker.upper()} ({freq})\n"
header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
header += (
f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
)
return header + csv_string
except Exception as e:
return f"Error retrieving balance sheet for {ticker}: {str(e)}"
@ -383,29 +388,31 @@ def get_balance_sheet(
def get_cashflow(
ticker: Annotated[str, "ticker symbol of the company"],
freq: Annotated[str, "frequency of data: 'annual' or 'quarterly'"] = "quarterly",
curr_date: Annotated[str, "current date (not used for yfinance)"] = None
curr_date: Annotated[str, "current date (not used for yfinance)"] = None,
):
"""Get cash flow data from yfinance."""
try:
ticker_obj = yf.Ticker(ticker.upper())
if freq.lower() == "quarterly":
data = ticker_obj.quarterly_cashflow
else:
data = ticker_obj.cashflow
if data.empty:
return f"No cash flow data found for symbol '{ticker}'"
# Convert to CSV string for consistency with other functions
csv_string = data.to_csv()
# Add header information
header = f"# Cash Flow data for {ticker.upper()} ({freq})\n"
header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
header += (
f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
)
return header + csv_string
except Exception as e:
return f"Error retrieving cash flow for {ticker}: {str(e)}"
@ -413,52 +420,54 @@ def get_cashflow(
def get_income_statement(
ticker: Annotated[str, "ticker symbol of the company"],
freq: Annotated[str, "frequency of data: 'annual' or 'quarterly'"] = "quarterly",
curr_date: Annotated[str, "current date (not used for yfinance)"] = None
curr_date: Annotated[str, "current date (not used for yfinance)"] = None,
):
"""Get income statement data from yfinance."""
try:
ticker_obj = yf.Ticker(ticker.upper())
if freq.lower() == "quarterly":
data = ticker_obj.quarterly_income_stmt
else:
data = ticker_obj.income_stmt
if data.empty:
return f"No income statement data found for symbol '{ticker}'"
# Convert to CSV string for consistency with other functions
csv_string = data.to_csv()
# Add header information
header = f"# Income Statement data for {ticker.upper()} ({freq})\n"
header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
header += (
f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
)
return header + csv_string
except Exception as e:
return f"Error retrieving income statement for {ticker}: {str(e)}"
def get_insider_transactions(
ticker: Annotated[str, "ticker symbol of the company"]
):
def get_insider_transactions(ticker: Annotated[str, "ticker symbol of the company"]):
"""Get insider transactions data from yfinance."""
try:
ticker_obj = yf.Ticker(ticker.upper())
data = ticker_obj.insider_transactions
if data is None or data.empty:
return f"No insider transactions data found for symbol '{ticker}'"
# Convert to CSV string for consistency with other functions
csv_string = data.to_csv()
# Add header information
header = f"# Insider Transactions data for {ticker.upper()}\n"
header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
header += (
f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
)
return header + csv_string
except Exception as e:
return f"Error retrieving insider transactions for {ticker}: {str(e)}"
return f"Error retrieving insider transactions for {ticker}: {str(e)}"

View File

@ -1,11 +1,25 @@
# TradingAgents/graph/setup.py
from typing import Dict, Any
from typing import Dict
from langchain_openai import ChatOpenAI
from langgraph.graph import END, StateGraph, START
from langgraph.prebuilt import ToolNode
from tradingagents.agents import *
from tradingagents.agents import (
create_msg_delete,
create_market_analyst,
create_social_media_analyst,
create_fundamentals_analyst,
create_news_analyst,
create_bull_researcher,
create_bear_researcher,
create_research_manager,
create_trader,
create_aggressive_debator,
create_conservative_debator,
create_neutral_debator,
create_risk_manager,
)
from tradingagents.agents.utils.agent_states import AgentState
from .conditional_logic import ConditionalLogic
@ -58,9 +72,7 @@ class GraphSetup:
tool_nodes = {}
if "market" in selected_analysts:
analyst_nodes["market"] = create_market_analyst(
self.quick_thinking_llm
)
analyst_nodes["market"] = create_market_analyst(self.quick_thinking_llm)
delete_nodes["market"] = create_msg_delete()
tool_nodes["market"] = self.tool_nodes["market"]
@ -72,9 +84,7 @@ class GraphSetup:
tool_nodes["social"] = self.tool_nodes["social"]
if "news" in selected_analysts:
analyst_nodes["news"] = create_news_analyst(
self.quick_thinking_llm
)
analyst_nodes["news"] = create_news_analyst(self.quick_thinking_llm)
delete_nodes["news"] = create_msg_delete()
tool_nodes["news"] = self.tool_nodes["news"]

View File

@ -3,21 +3,14 @@
import os
from pathlib import Path
import json
from datetime import date
from typing import Dict, Any, Tuple, List, Optional
from typing import Dict, Any, List, Optional
from langgraph.prebuilt import ToolNode
from tradingagents.llm_clients import create_llm_client
from tradingagents.agents import *
from tradingagents.default_config import DEFAULT_CONFIG
from tradingagents.agents.utils.memory import FinancialSituationMemory
from tradingagents.agents.utils.agent_states import (
AgentState,
InvestDebateState,
RiskDebateState,
)
from tradingagents.dataflows.config import set_config
# Import the new abstract tool methods from agent_utils
@ -30,7 +23,7 @@ from tradingagents.agents.utils.agent_utils import (
get_income_statement,
get_news,
get_insider_transactions,
get_global_news
get_global_news,
)
from .conditional_logic import ConditionalLogic
@ -93,13 +86,17 @@ class TradingAgentsGraph:
self.deep_thinking_llm = deep_client.get_llm()
self.quick_thinking_llm = quick_client.get_llm()
# Initialize memories
self.bull_memory = FinancialSituationMemory("bull_memory", self.config)
self.bear_memory = FinancialSituationMemory("bear_memory", self.config)
self.trader_memory = FinancialSituationMemory("trader_memory", self.config)
self.invest_judge_memory = FinancialSituationMemory("invest_judge_memory", self.config)
self.risk_manager_memory = FinancialSituationMemory("risk_manager_memory", self.config)
self.invest_judge_memory = FinancialSituationMemory(
"invest_judge_memory", self.config
)
self.risk_manager_memory = FinancialSituationMemory(
"risk_manager_memory", self.config
)
# Create tool nodes
self.tool_nodes = self._create_tool_nodes()
@ -240,8 +237,12 @@ class TradingAgentsGraph:
},
"trader_investment_decision": final_state["trader_investment_plan"],
"risk_debate_state": {
"aggressive_history": final_state["risk_debate_state"]["aggressive_history"],
"conservative_history": final_state["risk_debate_state"]["conservative_history"],
"aggressive_history": final_state["risk_debate_state"][
"aggressive_history"
],
"conservative_history": final_state["risk_debate_state"][
"conservative_history"
],
"neutral_history": final_state["risk_debate_state"]["neutral_history"],
"history": final_state["risk_debate_state"]["history"],
"judge_decision": final_state["risk_debate_state"]["judge_decision"],