fix: address Gemini Code Assist review comments

- Fix RSI calculation to handle NaN values when loss is 0
- Parallelize symbol scanning with ThreadPoolExecutor for better performance
- Improve module import structure for Streamlit app compatibility
- Move yfinance import to top of app.py per Python conventions
This commit is contained in:
阳虎 2026-03-17 22:12:11 +08:00
parent d78eddb682
commit 774b1ed3d3
2 changed files with 16 additions and 9 deletions

View File

@ -13,6 +13,7 @@ import pandas as pd
import numpy as np import numpy as np
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import List, Dict, Optional, Tuple from typing import List, Dict, Optional, Tuple
from concurrent.futures import ThreadPoolExecutor, as_completed
# Magnificent Seven stocks # Magnificent Seven stocks
MAGNIFICENT_SEVEN = ["AAPL", "MSFT", "GOOGL", "AMZN", "NVDA", "META", "TSLA"] MAGNIFICENT_SEVEN = ["AAPL", "MSFT", "GOOGL", "AMZN", "NVDA", "META", "TSLA"]
@ -57,8 +58,12 @@ class MomentumIndicator:
delta = data.diff() delta = data.diff()
gain = (delta.where(delta > 0, 0)).rolling(window=period).mean() gain = (delta.where(delta > 0, 0)).rolling(window=period).mean()
loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean() loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean()
# Handle division by zero and NaN values
loss = loss.replace(0, np.nan)
rs = gain / loss rs = gain / loss
return 100 - (100 / (1 + rs)) # When loss is 0 (no downtrends), RSI = 100
rsi = 100 - (100 / (1 + rs))
return rsi.fillna(100 if loss.isna().any() else 50)
class MomentumScanner: class MomentumScanner:
@ -156,11 +161,13 @@ class MomentumScanner:
return "HOLD" return "HOLD"
def scan_all(self) -> List[Dict]: def scan_all(self) -> List[Dict]:
"""Scan all symbols and return results""" """Scan all symbols and return results (parallelized)"""
results = [] results = []
for symbol in self.symbols: with ThreadPoolExecutor(max_workers=min(len(self.symbols), 8)) as executor:
result = self.analyze_symbol(symbol) future_to_symbol = {executor.submit(self.analyze_symbol, symbol): symbol
results.append(result) for symbol in self.symbols}
for future in as_completed(future_to_symbol):
results.append(future.result())
return results return results

View File

@ -11,11 +11,12 @@ from plotly.subplots import make_subplots
from datetime import datetime, timedelta from datetime import datetime, timedelta
import sys import sys
import os import os
import yfinance as yf
# Add parent directory to path # Add parent directory to path for proper module imports
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
from __init__ import MomentumScanner, MomentumIndicator, MAGNIFICENT_SEVEN, get_top_momentum_stocks from tradingagents.dashboards.momentum import MomentumScanner, MomentumIndicator, MAGNIFICENT_SEVEN, get_top_momentum_stocks
# Page config # Page config
@ -151,7 +152,6 @@ def main():
if selected_symbol: if selected_symbol:
with st.spinner(f"Loading {selected_symbol} chart..."): with st.spinner(f"Loading {selected_symbol} chart..."):
# Fetch data for chart # Fetch data for chart
import yfinance as yf
ticker = yf.Ticker(selected_symbol) ticker = yf.Ticker(selected_symbol)
hist = ticker.history(period="3mo", interval=timeframe) hist = ticker.history(period="3mo", interval=timeframe)