diff --git a/cli/analysis.py b/cli/analysis.py index 75ddde4a..56829a3d 100644 --- a/cli/analysis.py +++ b/cli/analysis.py @@ -33,11 +33,11 @@ from cli.utils import ( ) -def get_ticker(): +def get_ticker() -> str: return typer.prompt("", default="SPY") -def get_analysis_date(): +def get_analysis_date() -> str: while True: date_str = typer.prompt( "", default=datetime.datetime.now().strftime("%Y-%m-%d") @@ -54,7 +54,7 @@ def get_analysis_date(): ) -def get_user_selections(): +def get_user_selections() -> dict: with open("./cli/static/welcome.txt", "r") as f: welcome_ascii = f.read() @@ -135,7 +135,7 @@ def get_user_selections(): } -def process_chunk_for_display(chunk, selected_analysts: List[AnalystType]): +def process_chunk_for_display(chunk: dict, selected_analysts: List[AnalystType]) -> None: if "market_report" in chunk and chunk["market_report"]: message_buffer.update_report_section("market_report", chunk["market_report"]) message_buffer.update_agent_status("Market Analyst", "completed") @@ -253,7 +253,7 @@ def process_chunk_for_display(chunk, selected_analysts: List[AnalystType]): message_buffer.update_agent_status("Portfolio Manager", "completed") -def setup_logging_decorators(report_dir, log_file): +def setup_logging_decorators(report_dir, log_file) -> tuple: def save_message_decorator(obj, func_name): func = getattr(obj, func_name) @wraps(func) @@ -292,7 +292,7 @@ def setup_logging_decorators(report_dir, log_file): return save_message_decorator, save_tool_call_decorator, save_report_section_decorator -def run_analysis_for_ticker(ticker: str, config: dict): +def run_analysis_for_ticker(ticker: str, config: dict) -> None: analysis_date = datetime.datetime.now().strftime("%Y-%m-%d") console.print( @@ -330,7 +330,7 @@ def run_analysis_for_ticker(ticker: str, config: dict): _run_analysis_with_config(ticker, analysis_date, selected_analysts, config) -def run_analysis(): +def run_analysis() -> None: selections = get_user_selections() config = get_config() @@ -349,7 +349,7 @@ def run_analysis(): ) -def _run_analysis_with_config(ticker: str, analysis_date: str, selected_analysts: List[AnalystType], config: dict): +def _run_analysis_with_config(ticker: str, analysis_date: str, selected_analysts: List[AnalystType], config: dict) -> None: with loading("Initializing trading agents...", show_elapsed=True): graph = TradingAgentsGraph( [analyst.value for analyst in selected_analysts], config=config, debug=True diff --git a/cli/backtest_cmd.py b/cli/backtest_cmd.py index 0e9558a4..34124498 100644 --- a/cli/backtest_cmd.py +++ b/cli/backtest_cmd.py @@ -18,7 +18,7 @@ from cli.utils import loading console = Console() -def sma_buy(ticker, trading_date, ctx): +def sma_buy(ticker: str, trading_date: date_type, ctx: dict) -> bool: loader = ctx["data_loader"] ohlcv = loader.load_ohlcv(ticker, date_type(2020, 1, 1), trading_date) if len(ohlcv.bars) < 20: @@ -29,7 +29,7 @@ def sma_buy(ticker, trading_date, ctx): return current > sma * 1.02 -def sma_sell(ticker, trading_date, ctx): +def sma_sell(ticker: str, trading_date: date_type, ctx: dict) -> bool: loader = ctx["data_loader"] ohlcv = loader.load_ohlcv(ticker, date_type(2020, 1, 1), trading_date) if len(ohlcv.bars) < 20: @@ -40,7 +40,7 @@ def sma_sell(ticker, trading_date, ctx): return current < sma * 0.98 -def rsi_buy(ticker, trading_date, ctx): +def rsi_buy(ticker: str, trading_date: date_type, ctx: dict) -> bool: loader = ctx["data_loader"] ohlcv = loader.load_ohlcv(ticker, date_type(2020, 1, 1), trading_date) if len(ohlcv.bars) < 15: @@ -57,7 +57,7 @@ def rsi_buy(ticker, trading_date, ctx): return rsi < 30 -def rsi_sell(ticker, trading_date, ctx): +def rsi_sell(ticker: str, trading_date: date_type, ctx: dict) -> bool: loader = ctx["data_loader"] ohlcv = loader.load_ohlcv(ticker, date_type(2020, 1, 1), trading_date) if len(ohlcv.bars) < 15: @@ -74,11 +74,11 @@ def rsi_sell(ticker, trading_date, ctx): return rsi > 70 -def hold_buy(ticker, trading_date, ctx): +def hold_buy(ticker: str, trading_date: date_type, ctx: dict) -> bool: return ctx.get("day_index", 0) == 5 -def hold_sell(ticker, trading_date, ctx): +def hold_sell(ticker: str, trading_date: date_type, ctx: dict) -> bool: return False @@ -95,7 +95,7 @@ def run_backtest( end_date: str = None, initial_cash: float = 100000.0, strategy: str = "sma", -): +) -> None: if not ticker: console.print(create_question_box("Ticker Symbol", "Enter the ticker symbol to backtest", "AAPL")) ticker = typer.prompt("", default="AAPL") diff --git a/cli/discovery.py b/cli/discovery.py index ce26b253..6bc0356a 100644 --- a/cli/discovery.py +++ b/cli/discovery.py @@ -248,7 +248,7 @@ def select_stock_for_detail(trending_stocks: List[TrendingStock]) -> Optional[Tr return selected -def discover_trending_flow(run_analysis_callback=None): +def discover_trending_flow(run_analysis_callback=None) -> None: console.print(Rule("[bold green]Discover Trending Stocks[/bold green]")) console.print() diff --git a/cli/display.py b/cli/display.py index bde7266e..449ee41b 100644 --- a/cli/display.py +++ b/cli/display.py @@ -1,3 +1,5 @@ +from typing import Optional, Dict, Any + from rich.console import Console from rich.panel import Panel from rich.spinner import Spinner @@ -13,7 +15,7 @@ from cli.state import message_buffer console = Console() -def create_layout(): +def create_layout() -> Layout: layout = Layout() layout.split_column( Layout(name="header", size=3), @@ -29,7 +31,7 @@ def create_layout(): return layout -def update_display(layout, spinner_text=None): +def update_display(layout: Layout, spinner_text: Optional[str] = None) -> None: layout["header"].update( Panel( "[bold green]Welcome to TradingAgents CLI[/bold green]\n" @@ -207,7 +209,7 @@ def update_display(layout, spinner_text=None): layout["footer"].update(Panel(stats_table, border_style="grey50")) -def display_complete_report(final_state): +def display_complete_report(final_state: Dict[str, Any]) -> None: console.print("\n[bold green]Complete Analysis Report[/bold green]\n") analyst_reports = [] @@ -381,13 +383,13 @@ def display_complete_report(final_state): ) -def update_research_team_status(status): +def update_research_team_status(status: str) -> None: research_team = ["Bull Researcher", "Bear Researcher", "Research Manager", "Trader"] for agent in research_team: message_buffer.update_agent_status(agent, status) -def extract_content_string(content): +def extract_content_string(content: Any) -> str: if isinstance(content, str): return content elif isinstance(content, list): @@ -405,7 +407,7 @@ def extract_content_string(content): return str(content) -def create_question_box(title: str, prompt: str, default: str = None) -> Panel: +def create_question_box(title: str, prompt: str, default: Optional[str] = None) -> Panel: box_content = f"[bold]{title}[/bold]\n" box_content += f"[dim]{prompt}[/dim]" if default: diff --git a/cli/state.py b/cli/state.py index 6988123a..b9d64651 100644 --- a/cli/state.py +++ b/cli/state.py @@ -1,12 +1,12 @@ import datetime from collections import deque -from typing import Optional, Dict, Any +from typing import Optional, Dict, Any, Deque class MessageBuffer: - def __init__(self, max_length=100): - self.messages = deque(maxlen=max_length) - self.tool_calls = deque(maxlen=max_length) + def __init__(self, max_length: int = 100) -> None: + self.messages: Deque = deque(maxlen=max_length) + self.tool_calls: Deque = deque(maxlen=max_length) self.current_report = None self.final_report = None self.agent_status = { @@ -34,25 +34,25 @@ class MessageBuffer: "final_trade_decision": None, } - def add_message(self, message_type: str, content: str): + def add_message(self, message_type: str, content: str) -> None: timestamp = datetime.datetime.now().strftime("%H:%M:%S") self.messages.append((timestamp, message_type, content)) - def add_tool_call(self, tool_name: str, args: Dict[str, Any]): + def add_tool_call(self, tool_name: str, args: Dict[str, Any]) -> None: timestamp = datetime.datetime.now().strftime("%H:%M:%S") self.tool_calls.append((timestamp, tool_name, args)) - def update_agent_status(self, agent: str, status: str): + def update_agent_status(self, agent: str, status: str) -> None: if agent in self.agent_status: self.agent_status[agent] = status self.current_agent = agent - def update_report_section(self, section_name: str, content: str): + def update_report_section(self, section_name: str, content: str) -> None: if section_name in self.report_sections: self.report_sections[section_name] = content self._update_current_report() - def _update_current_report(self): + def _update_current_report(self) -> None: latest_section = None latest_content = None @@ -77,7 +77,7 @@ class MessageBuffer: self._update_final_report() - def _update_final_report(self): + def _update_final_report(self) -> None: report_parts = [] if any( @@ -121,7 +121,7 @@ class MessageBuffer: self.final_report = "\n\n".join(report_parts) if report_parts else None - def reset(self): + def reset(self) -> None: for agent in self.agent_status: self.agent_status[agent] = "pending" for section in self.report_sections: diff --git a/tradingagents/backtesting/engine.py b/tradingagents/backtesting/engine.py index ae9e75d8..589290c6 100644 --- a/tradingagents/backtesting/engine.py +++ b/tradingagents/backtesting/engine.py @@ -88,7 +88,7 @@ class BacktestEngine: error_message=str(e), ) - def _initialize(self): + def _initialize(self) -> None: self.portfolio = PortfolioSnapshot( cash=self.config.portfolio_config.initial_cash, ) @@ -98,7 +98,7 @@ class BacktestEngine: self.decisions = [] self.open_trades = {} - def _preload_data(self): + def _preload_data(self) -> None: logger.info("Preloading data for %s tickers", len(self.config.tickers)) for ticker in self.config.tickers: self.data_loader.load_ohlcv( @@ -115,7 +115,7 @@ class BacktestEngine: self.config.end_date, ) - def _process_day(self, trading_date: date, day_index: int): + def _process_day(self, trading_date: date, day_index: int) -> None: prices = self.data_loader.get_prices_dict(self.config.tickers, trading_date) if not prices: @@ -161,7 +161,7 @@ class BacktestEngine: decision: TradingDecision, price: Decimal, trading_date: date, - ): + ) -> None: ticker = decision.ticker config = self.config.portfolio_config position = self.portfolio.get_position(ticker) @@ -270,7 +270,7 @@ class BacktestEngine: ticker, quantity, execution_price, trading_date ) - def _record_equity(self, trading_date: date, prices: dict[str, Decimal]): + def _record_equity(self, trading_date: date, prices: dict[str, Decimal]) -> None: equity = self.portfolio.total_equity(prices) positions_value = self.portfolio.positions_value(prices) @@ -288,7 +288,7 @@ class BacktestEngine: daily_return = (equity - prev_equity) / prev_equity self.daily_returns.append(daily_return) - def _close_all_positions(self, final_date: date): + def _close_all_positions(self, final_date: date) -> None: prices = self.data_loader.get_prices_dict(self.config.tickers, final_date) for ticker, trade in list(self.open_trades.items()): @@ -305,7 +305,7 @@ class BacktestEngine: ) self._execute_decision(decision, prices[ticker], final_date) - def _empty_metrics(self): + def _empty_metrics(self) -> "BacktestMetrics": from tradingagents.models.backtest import BacktestMetrics return BacktestMetrics( start_equity=self.config.portfolio_config.initial_cash, diff --git a/tradingagents/dataflows/local.py b/tradingagents/dataflows/local.py index 94022bd7..2d77e195 100644 --- a/tradingagents/dataflows/local.py +++ b/tradingagents/dataflows/local.py @@ -78,7 +78,7 @@ def get_finnhub_news( query: Annotated[str, "Search query or ticker symbol"], start_date: Annotated[str, "Start date in yyyy-mm-dd format"], end_date: Annotated[str, "End date in yyyy-mm-dd format"], -): +) -> str: result = get_data_in_range(query, start_date, end_date, "news_data", DATA_DIR) @@ -101,7 +101,7 @@ def get_finnhub_news( def get_finnhub_company_insider_sentiment( ticker: Annotated[str, "ticker symbol for the company"], curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"], -): +) -> str: date_obj = datetime.strptime(curr_date, "%Y-%m-%d") before = date_obj - relativedelta(days=15) @@ -130,7 +130,7 @@ def get_finnhub_company_insider_sentiment( def get_finnhub_company_insider_transactions( ticker: Annotated[str, "ticker symbol"], curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"], -): +) -> str: date_obj = datetime.strptime(curr_date, "%Y-%m-%d") before = date_obj - relativedelta(days=15) @@ -156,7 +156,14 @@ def get_finnhub_company_insider_transactions( + "The change field reflects the variation in share count—here a negative number indicates a reduction in holdings—while share specifies the total number of shares involved. The transactionPrice denotes the per-share price at which the trade was executed, and transactionDate marks when the transaction occurred. The name field identifies the insider making the trade, and transactionCode (e.g., S for sale) clarifies the nature of the transaction. FilingDate records when the transaction was officially reported, and the unique id links to the specific SEC filing, as indicated by the source. Additionally, the symbol ties the transaction to a particular company, isDerivative flags whether the trade involves derivative securities, and currency notes the currency context of the transaction." ) -def get_data_in_range(ticker, start_date, end_date, data_type, data_dir, period=None): +def get_data_in_range( + ticker: str, + start_date: str, + end_date: str, + data_type: str, + data_dir: str, + period: str = None, +) -> dict: if period: data_path = os.path.join( @@ -186,7 +193,7 @@ def get_simfin_balance_sheet( "reporting frequency of the company's financial history: annual / quarterly", ], curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"], -): +) -> str: data_path = os.path.join( DATA_DIR, "fundamental_data", @@ -227,7 +234,7 @@ def get_simfin_cashflow( "reporting frequency of the company's financial history: annual / quarterly", ], curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"], -): +) -> str: data_path = os.path.join( DATA_DIR, "fundamental_data", @@ -268,7 +275,7 @@ def get_simfin_income_statements( "reporting frequency of the company's financial history: annual / quarterly", ], curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"], -): +) -> str: data_path = os.path.join( DATA_DIR, "fundamental_data", diff --git a/tradingagents/dataflows/y_finance.py b/tradingagents/dataflows/y_finance.py index 6b4bec7e..bb670fb5 100644 --- a/tradingagents/dataflows/y_finance.py +++ b/tradingagents/dataflows/y_finance.py @@ -13,7 +13,7 @@ def get_YFin_data_online( symbol: Annotated[str, "ticker symbol of the company"], start_date: Annotated[str, "Start date in yyyy-mm-dd format"], end_date: Annotated[str, "End date in yyyy-mm-dd format"], -): +) -> str: symbol = validate_ticker(symbol) start, end = validate_date_range(start_date, end_date) start_date = start.strftime("%Y-%m-%d") @@ -248,7 +248,6 @@ def get_stockstats_indicator( str, "The current trading date you are trading on, YYYY-mm-dd" ], ) -> str: - curr_date_dt = datetime.strptime(curr_date, "%Y-%m-%d") curr_date = curr_date_dt.strftime("%Y-%m-%d") @@ -272,7 +271,7 @@ def get_balance_sheet( ticker: Annotated[str, "ticker symbol of the company"], freq: Annotated[str, "frequency of data: 'annual' or 'quarterly'"] = "quarterly", curr_date: Annotated[str, "current date (not used for yfinance)"] = None -): +) -> str: try: ticker_obj = yf.Ticker(ticker.upper()) @@ -299,7 +298,7 @@ def get_cashflow( ticker: Annotated[str, "ticker symbol of the company"], freq: Annotated[str, "frequency of data: 'annual' or 'quarterly'"] = "quarterly", curr_date: Annotated[str, "current date (not used for yfinance)"] = None -): +) -> str: try: ticker_obj = yf.Ticker(ticker.upper()) @@ -326,7 +325,7 @@ def get_income_statement( ticker: Annotated[str, "ticker symbol of the company"], freq: Annotated[str, "frequency of data: 'annual' or 'quarterly'"] = "quarterly", curr_date: Annotated[str, "current date (not used for yfinance)"] = None -): +) -> str: try: ticker_obj = yf.Ticker(ticker.upper()) @@ -351,7 +350,7 @@ def get_income_statement( def get_insider_transactions( ticker: Annotated[str, "ticker symbol of the company"] -): +) -> str: try: ticker_obj = yf.Ticker(ticker.upper()) data = ticker_obj.insider_transactions diff --git a/tradingagents/graph/reflection.py b/tradingagents/graph/reflection.py index 33303231..f99b274c 100644 --- a/tradingagents/graph/reflection.py +++ b/tradingagents/graph/reflection.py @@ -56,7 +56,7 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur return f"{curr_market_report}\n\n{curr_sentiment_report}\n\n{curr_news_report}\n\n{curr_fundamentals_report}" def _reflect_on_component( - self, component_type: str, report: str, situation: str, returns_losses + self, component_type: str, report: str, situation: str, returns_losses: Any ) -> str: """Generate reflection for a component.""" messages = [ @@ -70,7 +70,7 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur result = self.quick_thinking_llm.invoke(messages).content return result - def reflect_bull_researcher(self, current_state, returns_losses, bull_memory): + def reflect_bull_researcher(self, current_state: Dict[str, Any], returns_losses: Any, bull_memory) -> None: """Reflect on bull researcher's analysis and update memory.""" situation = self._extract_current_situation(current_state) bull_debate_history = current_state["investment_debate_state"]["bull_history"] @@ -80,7 +80,7 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur ) bull_memory.add_situations([(situation, result)]) - def reflect_bear_researcher(self, current_state, returns_losses, bear_memory): + def reflect_bear_researcher(self, current_state: Dict[str, Any], returns_losses: Any, bear_memory) -> None: """Reflect on bear researcher's analysis and update memory.""" situation = self._extract_current_situation(current_state) bear_debate_history = current_state["investment_debate_state"]["bear_history"] @@ -90,7 +90,7 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur ) bear_memory.add_situations([(situation, result)]) - def reflect_trader(self, current_state, returns_losses, trader_memory): + def reflect_trader(self, current_state: Dict[str, Any], returns_losses: Any, trader_memory) -> None: """Reflect on trader's decision and update memory.""" situation = self._extract_current_situation(current_state) trader_decision = current_state["trader_investment_plan"] @@ -100,7 +100,7 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur ) trader_memory.add_situations([(situation, result)]) - def reflect_invest_judge(self, current_state, returns_losses, invest_judge_memory): + def reflect_invest_judge(self, current_state: Dict[str, Any], returns_losses: Any, invest_judge_memory) -> None: """Reflect on investment judge's decision and update memory.""" situation = self._extract_current_situation(current_state) judge_decision = current_state["investment_debate_state"]["judge_decision"] @@ -110,7 +110,7 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur ) invest_judge_memory.add_situations([(situation, result)]) - def reflect_risk_manager(self, current_state, returns_losses, risk_manager_memory): + def reflect_risk_manager(self, current_state: Dict[str, Any], returns_losses: Any, risk_manager_memory) -> None: """Reflect on risk manager's decision and update memory.""" situation = self._extract_current_situation(current_state) judge_decision = current_state["risk_debate_state"]["judge_decision"] diff --git a/tradingagents/graph/trading_graph.py b/tradingagents/graph/trading_graph.py index 3caee5af..4637f716 100644 --- a/tradingagents/graph/trading_graph.py +++ b/tradingagents/graph/trading_graph.py @@ -63,7 +63,7 @@ class DiscoveryTimeoutException(Exception): pass -def _timeout_handler(signum, frame): +def _timeout_handler(signum, frame) -> None: raise DiscoveryTimeoutException("Discovery operation timed out") @@ -155,7 +155,7 @@ class TradingAgentsGraph: ), } - def propagate(self, company_name, trade_date): + def propagate(self, company_name: str, trade_date) -> Tuple[Dict[str, Any], str]: company_name = validate_ticker(company_name) validated_date = validate_date(trade_date, allow_future=False) if isinstance(trade_date, str): @@ -185,7 +185,7 @@ class TradingAgentsGraph: return final_state, self.process_signal(final_state["final_trade_decision"]) - def _log_state(self, trade_date, final_state): + def _log_state(self, trade_date, final_state: Dict[str, Any]) -> None: self.log_states_dict[str(trade_date)] = { "company_of_interest": final_state["company_of_interest"], "trade_date": final_state["trade_date"], @@ -225,7 +225,7 @@ class TradingAgentsGraph: ) as f: json.dump(self.log_states_dict, f, indent=4) - def reflect_and_remember(self, returns_losses): + def reflect_and_remember(self, returns_losses) -> None: self.reflector.reflect_bull_researcher( self.curr_state, returns_losses, self.bull_memory ) @@ -242,7 +242,7 @@ class TradingAgentsGraph: self.curr_state, returns_losses, self.risk_manager_memory ) - def process_signal(self, full_signal): + def process_signal(self, full_signal: str) -> str: return self.signal_processor.process_signal(full_signal) def discover_trending(