feat: add type hints to function signatures across codebase

Added return type hints and parameter type hints to functions in:
- tradingagents/graph/trading_graph.py
- tradingagents/graph/reflection.py
- tradingagents/dataflows/y_finance.py
- tradingagents/dataflows/local.py
- tradingagents/backtesting/engine.py
- cli/analysis.py, cli/discovery.py, cli/display.py, cli/state.py, cli/backtest_cmd.py

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Joseph O'Brien 2025-12-03 03:50:49 -05:00
parent 293df9c552
commit b8be981bc8
10 changed files with 72 additions and 64 deletions

View File

@ -33,11 +33,11 @@ from cli.utils import (
)
def get_ticker():
def get_ticker() -> str:
return typer.prompt("", default="SPY")
def get_analysis_date():
def get_analysis_date() -> str:
while True:
date_str = typer.prompt(
"", default=datetime.datetime.now().strftime("%Y-%m-%d")
@ -54,7 +54,7 @@ def get_analysis_date():
)
def get_user_selections():
def get_user_selections() -> dict:
with open("./cli/static/welcome.txt", "r") as f:
welcome_ascii = f.read()
@ -135,7 +135,7 @@ def get_user_selections():
}
def process_chunk_for_display(chunk, selected_analysts: List[AnalystType]):
def process_chunk_for_display(chunk: dict, selected_analysts: List[AnalystType]) -> None:
if "market_report" in chunk and chunk["market_report"]:
message_buffer.update_report_section("market_report", chunk["market_report"])
message_buffer.update_agent_status("Market Analyst", "completed")
@ -253,7 +253,7 @@ def process_chunk_for_display(chunk, selected_analysts: List[AnalystType]):
message_buffer.update_agent_status("Portfolio Manager", "completed")
def setup_logging_decorators(report_dir, log_file):
def setup_logging_decorators(report_dir, log_file) -> tuple:
def save_message_decorator(obj, func_name):
func = getattr(obj, func_name)
@wraps(func)
@ -292,7 +292,7 @@ def setup_logging_decorators(report_dir, log_file):
return save_message_decorator, save_tool_call_decorator, save_report_section_decorator
def run_analysis_for_ticker(ticker: str, config: dict):
def run_analysis_for_ticker(ticker: str, config: dict) -> None:
analysis_date = datetime.datetime.now().strftime("%Y-%m-%d")
console.print(
@ -330,7 +330,7 @@ def run_analysis_for_ticker(ticker: str, config: dict):
_run_analysis_with_config(ticker, analysis_date, selected_analysts, config)
def run_analysis():
def run_analysis() -> None:
selections = get_user_selections()
config = get_config()
@ -349,7 +349,7 @@ def run_analysis():
)
def _run_analysis_with_config(ticker: str, analysis_date: str, selected_analysts: List[AnalystType], config: dict):
def _run_analysis_with_config(ticker: str, analysis_date: str, selected_analysts: List[AnalystType], config: dict) -> None:
with loading("Initializing trading agents...", show_elapsed=True):
graph = TradingAgentsGraph(
[analyst.value for analyst in selected_analysts], config=config, debug=True

View File

@ -18,7 +18,7 @@ from cli.utils import loading
console = Console()
def sma_buy(ticker, trading_date, ctx):
def sma_buy(ticker: str, trading_date: date_type, ctx: dict) -> bool:
loader = ctx["data_loader"]
ohlcv = loader.load_ohlcv(ticker, date_type(2020, 1, 1), trading_date)
if len(ohlcv.bars) < 20:
@ -29,7 +29,7 @@ def sma_buy(ticker, trading_date, ctx):
return current > sma * 1.02
def sma_sell(ticker, trading_date, ctx):
def sma_sell(ticker: str, trading_date: date_type, ctx: dict) -> bool:
loader = ctx["data_loader"]
ohlcv = loader.load_ohlcv(ticker, date_type(2020, 1, 1), trading_date)
if len(ohlcv.bars) < 20:
@ -40,7 +40,7 @@ def sma_sell(ticker, trading_date, ctx):
return current < sma * 0.98
def rsi_buy(ticker, trading_date, ctx):
def rsi_buy(ticker: str, trading_date: date_type, ctx: dict) -> bool:
loader = ctx["data_loader"]
ohlcv = loader.load_ohlcv(ticker, date_type(2020, 1, 1), trading_date)
if len(ohlcv.bars) < 15:
@ -57,7 +57,7 @@ def rsi_buy(ticker, trading_date, ctx):
return rsi < 30
def rsi_sell(ticker, trading_date, ctx):
def rsi_sell(ticker: str, trading_date: date_type, ctx: dict) -> bool:
loader = ctx["data_loader"]
ohlcv = loader.load_ohlcv(ticker, date_type(2020, 1, 1), trading_date)
if len(ohlcv.bars) < 15:
@ -74,11 +74,11 @@ def rsi_sell(ticker, trading_date, ctx):
return rsi > 70
def hold_buy(ticker, trading_date, ctx):
def hold_buy(ticker: str, trading_date: date_type, ctx: dict) -> bool:
return ctx.get("day_index", 0) == 5
def hold_sell(ticker, trading_date, ctx):
def hold_sell(ticker: str, trading_date: date_type, ctx: dict) -> bool:
return False
@ -95,7 +95,7 @@ def run_backtest(
end_date: str = None,
initial_cash: float = 100000.0,
strategy: str = "sma",
):
) -> None:
if not ticker:
console.print(create_question_box("Ticker Symbol", "Enter the ticker symbol to backtest", "AAPL"))
ticker = typer.prompt("", default="AAPL")

View File

@ -248,7 +248,7 @@ def select_stock_for_detail(trending_stocks: List[TrendingStock]) -> Optional[Tr
return selected
def discover_trending_flow(run_analysis_callback=None):
def discover_trending_flow(run_analysis_callback=None) -> None:
console.print(Rule("[bold green]Discover Trending Stocks[/bold green]"))
console.print()

View File

@ -1,3 +1,5 @@
from typing import Optional, Dict, Any
from rich.console import Console
from rich.panel import Panel
from rich.spinner import Spinner
@ -13,7 +15,7 @@ from cli.state import message_buffer
console = Console()
def create_layout():
def create_layout() -> Layout:
layout = Layout()
layout.split_column(
Layout(name="header", size=3),
@ -29,7 +31,7 @@ def create_layout():
return layout
def update_display(layout, spinner_text=None):
def update_display(layout: Layout, spinner_text: Optional[str] = None) -> None:
layout["header"].update(
Panel(
"[bold green]Welcome to TradingAgents CLI[/bold green]\n"
@ -207,7 +209,7 @@ def update_display(layout, spinner_text=None):
layout["footer"].update(Panel(stats_table, border_style="grey50"))
def display_complete_report(final_state):
def display_complete_report(final_state: Dict[str, Any]) -> None:
console.print("\n[bold green]Complete Analysis Report[/bold green]\n")
analyst_reports = []
@ -381,13 +383,13 @@ def display_complete_report(final_state):
)
def update_research_team_status(status):
def update_research_team_status(status: str) -> None:
research_team = ["Bull Researcher", "Bear Researcher", "Research Manager", "Trader"]
for agent in research_team:
message_buffer.update_agent_status(agent, status)
def extract_content_string(content):
def extract_content_string(content: Any) -> str:
if isinstance(content, str):
return content
elif isinstance(content, list):
@ -405,7 +407,7 @@ def extract_content_string(content):
return str(content)
def create_question_box(title: str, prompt: str, default: str = None) -> Panel:
def create_question_box(title: str, prompt: str, default: Optional[str] = None) -> Panel:
box_content = f"[bold]{title}[/bold]\n"
box_content += f"[dim]{prompt}[/dim]"
if default:

View File

@ -1,12 +1,12 @@
import datetime
from collections import deque
from typing import Optional, Dict, Any
from typing import Optional, Dict, Any, Deque
class MessageBuffer:
def __init__(self, max_length=100):
self.messages = deque(maxlen=max_length)
self.tool_calls = deque(maxlen=max_length)
def __init__(self, max_length: int = 100) -> None:
self.messages: Deque = deque(maxlen=max_length)
self.tool_calls: Deque = deque(maxlen=max_length)
self.current_report = None
self.final_report = None
self.agent_status = {
@ -34,25 +34,25 @@ class MessageBuffer:
"final_trade_decision": None,
}
def add_message(self, message_type: str, content: str):
def add_message(self, message_type: str, content: str) -> None:
timestamp = datetime.datetime.now().strftime("%H:%M:%S")
self.messages.append((timestamp, message_type, content))
def add_tool_call(self, tool_name: str, args: Dict[str, Any]):
def add_tool_call(self, tool_name: str, args: Dict[str, Any]) -> None:
timestamp = datetime.datetime.now().strftime("%H:%M:%S")
self.tool_calls.append((timestamp, tool_name, args))
def update_agent_status(self, agent: str, status: str):
def update_agent_status(self, agent: str, status: str) -> None:
if agent in self.agent_status:
self.agent_status[agent] = status
self.current_agent = agent
def update_report_section(self, section_name: str, content: str):
def update_report_section(self, section_name: str, content: str) -> None:
if section_name in self.report_sections:
self.report_sections[section_name] = content
self._update_current_report()
def _update_current_report(self):
def _update_current_report(self) -> None:
latest_section = None
latest_content = None
@ -77,7 +77,7 @@ class MessageBuffer:
self._update_final_report()
def _update_final_report(self):
def _update_final_report(self) -> None:
report_parts = []
if any(
@ -121,7 +121,7 @@ class MessageBuffer:
self.final_report = "\n\n".join(report_parts) if report_parts else None
def reset(self):
def reset(self) -> None:
for agent in self.agent_status:
self.agent_status[agent] = "pending"
for section in self.report_sections:

View File

@ -88,7 +88,7 @@ class BacktestEngine:
error_message=str(e),
)
def _initialize(self):
def _initialize(self) -> None:
self.portfolio = PortfolioSnapshot(
cash=self.config.portfolio_config.initial_cash,
)
@ -98,7 +98,7 @@ class BacktestEngine:
self.decisions = []
self.open_trades = {}
def _preload_data(self):
def _preload_data(self) -> None:
logger.info("Preloading data for %s tickers", len(self.config.tickers))
for ticker in self.config.tickers:
self.data_loader.load_ohlcv(
@ -115,7 +115,7 @@ class BacktestEngine:
self.config.end_date,
)
def _process_day(self, trading_date: date, day_index: int):
def _process_day(self, trading_date: date, day_index: int) -> None:
prices = self.data_loader.get_prices_dict(self.config.tickers, trading_date)
if not prices:
@ -161,7 +161,7 @@ class BacktestEngine:
decision: TradingDecision,
price: Decimal,
trading_date: date,
):
) -> None:
ticker = decision.ticker
config = self.config.portfolio_config
position = self.portfolio.get_position(ticker)
@ -270,7 +270,7 @@ class BacktestEngine:
ticker, quantity, execution_price, trading_date
)
def _record_equity(self, trading_date: date, prices: dict[str, Decimal]):
def _record_equity(self, trading_date: date, prices: dict[str, Decimal]) -> None:
equity = self.portfolio.total_equity(prices)
positions_value = self.portfolio.positions_value(prices)
@ -288,7 +288,7 @@ class BacktestEngine:
daily_return = (equity - prev_equity) / prev_equity
self.daily_returns.append(daily_return)
def _close_all_positions(self, final_date: date):
def _close_all_positions(self, final_date: date) -> None:
prices = self.data_loader.get_prices_dict(self.config.tickers, final_date)
for ticker, trade in list(self.open_trades.items()):
@ -305,7 +305,7 @@ class BacktestEngine:
)
self._execute_decision(decision, prices[ticker], final_date)
def _empty_metrics(self):
def _empty_metrics(self) -> "BacktestMetrics":
from tradingagents.models.backtest import BacktestMetrics
return BacktestMetrics(
start_equity=self.config.portfolio_config.initial_cash,

View File

@ -78,7 +78,7 @@ def get_finnhub_news(
query: Annotated[str, "Search query or ticker symbol"],
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
):
) -> str:
result = get_data_in_range(query, start_date, end_date, "news_data", DATA_DIR)
@ -101,7 +101,7 @@ def get_finnhub_news(
def get_finnhub_company_insider_sentiment(
ticker: Annotated[str, "ticker symbol for the company"],
curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"],
):
) -> str:
date_obj = datetime.strptime(curr_date, "%Y-%m-%d")
before = date_obj - relativedelta(days=15)
@ -130,7 +130,7 @@ def get_finnhub_company_insider_sentiment(
def get_finnhub_company_insider_transactions(
ticker: Annotated[str, "ticker symbol"],
curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"],
):
) -> str:
date_obj = datetime.strptime(curr_date, "%Y-%m-%d")
before = date_obj - relativedelta(days=15)
@ -156,7 +156,14 @@ def get_finnhub_company_insider_transactions(
+ "The change field reflects the variation in share count—here a negative number indicates a reduction in holdings—while share specifies the total number of shares involved. The transactionPrice denotes the per-share price at which the trade was executed, and transactionDate marks when the transaction occurred. The name field identifies the insider making the trade, and transactionCode (e.g., S for sale) clarifies the nature of the transaction. FilingDate records when the transaction was officially reported, and the unique id links to the specific SEC filing, as indicated by the source. Additionally, the symbol ties the transaction to a particular company, isDerivative flags whether the trade involves derivative securities, and currency notes the currency context of the transaction."
)
def get_data_in_range(ticker, start_date, end_date, data_type, data_dir, period=None):
def get_data_in_range(
ticker: str,
start_date: str,
end_date: str,
data_type: str,
data_dir: str,
period: str = None,
) -> dict:
if period:
data_path = os.path.join(
@ -186,7 +193,7 @@ def get_simfin_balance_sheet(
"reporting frequency of the company's financial history: annual / quarterly",
],
curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"],
):
) -> str:
data_path = os.path.join(
DATA_DIR,
"fundamental_data",
@ -227,7 +234,7 @@ def get_simfin_cashflow(
"reporting frequency of the company's financial history: annual / quarterly",
],
curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"],
):
) -> str:
data_path = os.path.join(
DATA_DIR,
"fundamental_data",
@ -268,7 +275,7 @@ def get_simfin_income_statements(
"reporting frequency of the company's financial history: annual / quarterly",
],
curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"],
):
) -> str:
data_path = os.path.join(
DATA_DIR,
"fundamental_data",

View File

@ -13,7 +13,7 @@ def get_YFin_data_online(
symbol: Annotated[str, "ticker symbol of the company"],
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
):
) -> str:
symbol = validate_ticker(symbol)
start, end = validate_date_range(start_date, end_date)
start_date = start.strftime("%Y-%m-%d")
@ -248,7 +248,6 @@ def get_stockstats_indicator(
str, "The current trading date you are trading on, YYYY-mm-dd"
],
) -> str:
curr_date_dt = datetime.strptime(curr_date, "%Y-%m-%d")
curr_date = curr_date_dt.strftime("%Y-%m-%d")
@ -272,7 +271,7 @@ def get_balance_sheet(
ticker: Annotated[str, "ticker symbol of the company"],
freq: Annotated[str, "frequency of data: 'annual' or 'quarterly'"] = "quarterly",
curr_date: Annotated[str, "current date (not used for yfinance)"] = None
):
) -> str:
try:
ticker_obj = yf.Ticker(ticker.upper())
@ -299,7 +298,7 @@ def get_cashflow(
ticker: Annotated[str, "ticker symbol of the company"],
freq: Annotated[str, "frequency of data: 'annual' or 'quarterly'"] = "quarterly",
curr_date: Annotated[str, "current date (not used for yfinance)"] = None
):
) -> str:
try:
ticker_obj = yf.Ticker(ticker.upper())
@ -326,7 +325,7 @@ def get_income_statement(
ticker: Annotated[str, "ticker symbol of the company"],
freq: Annotated[str, "frequency of data: 'annual' or 'quarterly'"] = "quarterly",
curr_date: Annotated[str, "current date (not used for yfinance)"] = None
):
) -> str:
try:
ticker_obj = yf.Ticker(ticker.upper())
@ -351,7 +350,7 @@ def get_income_statement(
def get_insider_transactions(
ticker: Annotated[str, "ticker symbol of the company"]
):
) -> str:
try:
ticker_obj = yf.Ticker(ticker.upper())
data = ticker_obj.insider_transactions

View File

@ -56,7 +56,7 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur
return f"{curr_market_report}\n\n{curr_sentiment_report}\n\n{curr_news_report}\n\n{curr_fundamentals_report}"
def _reflect_on_component(
self, component_type: str, report: str, situation: str, returns_losses
self, component_type: str, report: str, situation: str, returns_losses: Any
) -> str:
"""Generate reflection for a component."""
messages = [
@ -70,7 +70,7 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur
result = self.quick_thinking_llm.invoke(messages).content
return result
def reflect_bull_researcher(self, current_state, returns_losses, bull_memory):
def reflect_bull_researcher(self, current_state: Dict[str, Any], returns_losses: Any, bull_memory) -> None:
"""Reflect on bull researcher's analysis and update memory."""
situation = self._extract_current_situation(current_state)
bull_debate_history = current_state["investment_debate_state"]["bull_history"]
@ -80,7 +80,7 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur
)
bull_memory.add_situations([(situation, result)])
def reflect_bear_researcher(self, current_state, returns_losses, bear_memory):
def reflect_bear_researcher(self, current_state: Dict[str, Any], returns_losses: Any, bear_memory) -> None:
"""Reflect on bear researcher's analysis and update memory."""
situation = self._extract_current_situation(current_state)
bear_debate_history = current_state["investment_debate_state"]["bear_history"]
@ -90,7 +90,7 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur
)
bear_memory.add_situations([(situation, result)])
def reflect_trader(self, current_state, returns_losses, trader_memory):
def reflect_trader(self, current_state: Dict[str, Any], returns_losses: Any, trader_memory) -> None:
"""Reflect on trader's decision and update memory."""
situation = self._extract_current_situation(current_state)
trader_decision = current_state["trader_investment_plan"]
@ -100,7 +100,7 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur
)
trader_memory.add_situations([(situation, result)])
def reflect_invest_judge(self, current_state, returns_losses, invest_judge_memory):
def reflect_invest_judge(self, current_state: Dict[str, Any], returns_losses: Any, invest_judge_memory) -> None:
"""Reflect on investment judge's decision and update memory."""
situation = self._extract_current_situation(current_state)
judge_decision = current_state["investment_debate_state"]["judge_decision"]
@ -110,7 +110,7 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur
)
invest_judge_memory.add_situations([(situation, result)])
def reflect_risk_manager(self, current_state, returns_losses, risk_manager_memory):
def reflect_risk_manager(self, current_state: Dict[str, Any], returns_losses: Any, risk_manager_memory) -> None:
"""Reflect on risk manager's decision and update memory."""
situation = self._extract_current_situation(current_state)
judge_decision = current_state["risk_debate_state"]["judge_decision"]

View File

@ -63,7 +63,7 @@ class DiscoveryTimeoutException(Exception):
pass
def _timeout_handler(signum, frame):
def _timeout_handler(signum, frame) -> None:
raise DiscoveryTimeoutException("Discovery operation timed out")
@ -155,7 +155,7 @@ class TradingAgentsGraph:
),
}
def propagate(self, company_name, trade_date):
def propagate(self, company_name: str, trade_date) -> Tuple[Dict[str, Any], str]:
company_name = validate_ticker(company_name)
validated_date = validate_date(trade_date, allow_future=False)
if isinstance(trade_date, str):
@ -185,7 +185,7 @@ class TradingAgentsGraph:
return final_state, self.process_signal(final_state["final_trade_decision"])
def _log_state(self, trade_date, final_state):
def _log_state(self, trade_date, final_state: Dict[str, Any]) -> None:
self.log_states_dict[str(trade_date)] = {
"company_of_interest": final_state["company_of_interest"],
"trade_date": final_state["trade_date"],
@ -225,7 +225,7 @@ class TradingAgentsGraph:
) as f:
json.dump(self.log_states_dict, f, indent=4)
def reflect_and_remember(self, returns_losses):
def reflect_and_remember(self, returns_losses) -> None:
self.reflector.reflect_bull_researcher(
self.curr_state, returns_losses, self.bull_memory
)
@ -242,7 +242,7 @@ class TradingAgentsGraph:
self.curr_state, returns_losses, self.risk_manager_memory
)
def process_signal(self, full_signal):
def process_signal(self, full_signal: str) -> str:
return self.signal_processor.process_signal(full_signal)
def discover_trending(