diff --git a/cli/main.py b/cli/main.py index 00a8f43a..ac22b40c 100644 --- a/cli/main.py +++ b/cli/main.py @@ -35,7 +35,20 @@ from tradingagents.agents.discovery.models import ( ) from tradingagents.agents.discovery.persistence import save_discovery_result from cli.models import AnalystType -from cli.utils import * +from cli.utils import ( + ANALYST_ORDER, + get_ticker, + get_analysis_date, + select_analysts, + select_research_depth, + select_shallow_thinking_agent, + select_deep_thinking_agent, + select_llm_provider, + loading, + with_loading, + MultiStageLoader, + LoadingIndicator, +) console = Console() @@ -466,33 +479,32 @@ def discover_trending_flow(): ) discovery_stages = [ - "Fetching news...", - "Extracting entities...", - "Resolving tickers...", - "Calculating scores...", + "Initializing analysis engine", + "Fetching news sources", + "Extracting stock entities", + "Resolving ticker symbols", + "Calculating trending scores", ] result = None - with Live(console=console, refresh_per_second=4) as live: - for i, stage in enumerate(discovery_stages): - progress_panel = Panel( - f"[bold cyan]{stage}[/bold cyan]\n\n" - f"[dim]Stage {i+1} of {len(discovery_stages)}[/dim]", - title="Discovery Progress", - border_style="cyan", - padding=(2, 4), - ) - live.update(Align.center(progress_panel)) - if i == 0: - try: - graph = TradingAgentsGraph(config=config, debug=False) - result = graph.discover_trending(request) - except Exception as e: - console.print(f"\n[red]Error during discovery: {e}[/red]") - return + with MultiStageLoader(discovery_stages, title="Discovery Progress") as loader: + try: + loader.next_stage() + graph = TradingAgentsGraph(config=config, debug=False) - time.sleep(0.5) + loader.next_stage() + result = graph.discover_trending(request) + + loader.next_stage() + time.sleep(0.3) + + loader.next_stage() + time.sleep(0.3) + + except Exception as e: + console.print(f"\n[red]Error during discovery: {e}[/red]") + return if result is None: console.print("\n[red]Discovery failed. Please try again.[/red]") @@ -504,7 +516,8 @@ def discover_trending_flow(): if result.status == DiscoveryStatus.COMPLETED: try: - save_path = save_discovery_result(result) + with loading("Saving discovery results..."): + save_path = save_discovery_result(result) console.print(f"\n[dim]Results saved to: {save_path}[/dim]") except Exception as e: console.print(f"\n[yellow]Warning: Could not save results: {e}[/yellow]") @@ -546,7 +559,9 @@ def discover_trending_flow(): ).ask() if analyze_choice: - console.print(f"\n[green]Starting analysis for {selected_stock.ticker}...[/green]\n") + console.print() + with loading(f"Preparing analysis for {selected_stock.ticker}...", spinner_style="loading"): + time.sleep(0.5) run_analysis_for_ticker(selected_stock.ticker, config) break @@ -586,9 +601,10 @@ def run_analysis_for_ticker(ticker: str, config: dict): config["max_risk_discuss_rounds"] = selected_research_depth config["deep_think_llm"] = selected_deep_thinker - graph = TradingAgentsGraph( - [analyst.value for analyst in selected_analysts], config=config, debug=True - ) + with loading("Initializing trading agents...", show_elapsed=True): + graph = TradingAgentsGraph( + [analyst.value for analyst in selected_analysts], config=config, debug=True + ) results_dir = Path(config["results_dir"]) / ticker / analysis_date results_dir.mkdir(parents=True, exist_ok=True) @@ -1332,9 +1348,10 @@ def run_analysis(): config["backend_url"] = selections["backend_url"] config["llm_provider"] = selections["llm_provider"].lower() - graph = TradingAgentsGraph( - [analyst.value for analyst in selections["analysts"]], config=config, debug=True - ) + with loading("Initializing trading agents...", show_elapsed=True): + graph = TradingAgentsGraph( + [analyst.value for analyst in selections["analysts"]], config=config, debug=True + ) results_dir = Path(config["results_dir"]) / selections["ticker"] / selections["analysis_date"] results_dir.mkdir(parents=True, exist_ok=True) diff --git a/cli/utils.py b/cli/utils.py index 7b9682a6..b837ca8d 100644 --- a/cli/utils.py +++ b/cli/utils.py @@ -1,8 +1,199 @@ import questionary -from typing import List, Optional, Tuple, Dict +from typing import List, Optional, Tuple, Dict, Callable, Any +from contextlib import contextmanager +from functools import wraps +import threading +import time + +from rich.console import Console +from rich.spinner import Spinner +from rich.live import Live +from rich.panel import Panel +from rich.text import Text +from rich.align import Align from cli.models import AnalystType +console = Console() + +SPINNER_STYLES = { + "default": "dots", + "fast": "dots2", + "bounce": "bouncingBall", + "pulse": "point", + "arrow": "arrow3", + "loading": "dots12", +} + + +class LoadingIndicator: + def __init__( + self, + message: str = "Working...", + spinner_style: str = "default", + show_elapsed: bool = False, + border_style: str = "cyan", + ): + self.message = message + self.spinner_name = SPINNER_STYLES.get(spinner_style, spinner_style) + self.show_elapsed = show_elapsed + self.border_style = border_style + self._live = None + self._start_time = None + self._stop_event = threading.Event() + self._update_thread = None + + def _create_display(self) -> Panel: + elapsed_text = "" + if self.show_elapsed and self._start_time: + elapsed = time.time() - self._start_time + elapsed_text = f" [{elapsed:.1f}s]" + + spinner = Spinner(self.spinner_name, text=f" {self.message}{elapsed_text}") + return Panel( + Align.center(spinner), + border_style=self.border_style, + padding=(0, 2), + ) + + def _update_loop(self): + while not self._stop_event.is_set(): + if self._live and self.show_elapsed: + self._live.update(self._create_display()) + time.sleep(0.1) + + def start(self): + self._start_time = time.time() + self._stop_event.clear() + self._live = Live( + self._create_display(), + console=console, + refresh_per_second=10, + transient=True, + ) + self._live.start() + if self.show_elapsed: + self._update_thread = threading.Thread(target=self._update_loop, daemon=True) + self._update_thread.start() + + def stop(self): + self._stop_event.set() + if self._update_thread: + self._update_thread.join(timeout=0.5) + if self._live: + self._live.stop() + + def update_message(self, message: str): + self.message = message + if self._live: + self._live.update(self._create_display()) + + +@contextmanager +def loading( + message: str = "Working...", + spinner_style: str = "default", + show_elapsed: bool = False, + success_message: Optional[str] = None, + error_message: Optional[str] = None, +): + indicator = LoadingIndicator( + message=message, + spinner_style=spinner_style, + show_elapsed=show_elapsed, + ) + try: + indicator.start() + yield indicator + if success_message: + console.print(f"[green]{success_message}[/green]") + except Exception as e: + if error_message: + console.print(f"[red]{error_message}: {e}[/red]") + raise + finally: + indicator.stop() + + +def with_loading( + message: str = "Working...", + spinner_style: str = "default", + show_elapsed: bool = False, + success_message: Optional[str] = None, +): + def decorator(func: Callable) -> Callable: + @wraps(func) + def wrapper(*args, **kwargs) -> Any: + with loading( + message=message, + spinner_style=spinner_style, + show_elapsed=show_elapsed, + success_message=success_message, + ): + return func(*args, **kwargs) + return wrapper + return decorator + + +class MultiStageLoader: + def __init__(self, stages: List[str], title: str = "Progress"): + self.stages = stages + self.title = title + self.current_stage = 0 + self._live = None + self._start_time = None + + def _create_display(self) -> Panel: + lines = [] + for i, stage in enumerate(self.stages): + if i < self.current_stage: + lines.append(Text(f" [done] {stage}", style="green")) + elif i == self.current_stage: + spinner = Spinner("dots", text=f" {stage}") + lines.append(spinner) + else: + lines.append(Text(f" [ -- ] {stage}", style="dim")) + + from rich.console import Group + content = Group(*lines) + + elapsed = "" + if self._start_time: + elapsed = f" [{time.time() - self._start_time:.1f}s]" + + return Panel( + content, + title=f"{self.title}{elapsed}", + border_style="cyan", + padding=(1, 2), + ) + + def start(self): + self._start_time = time.time() + self._live = Live( + self._create_display(), + console=console, + refresh_per_second=10, + ) + self._live.start() + + def next_stage(self): + self.current_stage += 1 + if self._live: + self._live.update(self._create_display()) + + def stop(self): + if self._live: + self._live.stop() + + def __enter__(self): + self.start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.stop() + return False + ANALYST_ORDER = [ ("Market Analyst", AnalystType.MARKET), ("Social Media Analyst", AnalystType.SOCIAL), diff --git a/tradingagents/agents/discovery/entity_extractor.py b/tradingagents/agents/discovery/entity_extractor.py index 5bad3242..d4efd7c3 100644 --- a/tradingagents/agents/discovery/entity_extractor.py +++ b/tradingagents/agents/discovery/entity_extractor.py @@ -29,6 +29,7 @@ class ExtractedEntity(BaseModel): context_snippet: str = PydanticField(description="Surrounding context of 50-100 characters around the company mention") event_type: str = PydanticField(description="Event category: earnings, merger_acquisition, regulatory, product_launch, executive_change, or other") sentiment: float = PydanticField(default=0.0, description="Sentiment score from -1.0 (negative) to 1.0 (positive)") + article_id: str = PydanticField(description="The article ID where this company was mentioned (e.g., article_0, article_1)") class ExtractionResponse(BaseModel): @@ -75,9 +76,11 @@ For each article provided, extract all mentions of publicly traded companies. Fo - 0.0: Neutral news - 0.5: Moderately positive news - 1.0: Very positive news (breakthroughs, record earnings) +6. Include the article_id (e.g., article_0, article_1) where the company was mentioned Only extract companies that are publicly traded on major stock exchanges. Handle name variations by providing the most complete company name found. +IMPORTANT: Each entity must include the article_id from which it was extracted. Articles to analyze: {articles_text} @@ -126,11 +129,12 @@ def _extract_batch( if len(context) > 150: context = context[:147] + "..." + article_id = entity.article_id if entity.article_id else f"article_{start_idx}" mention = EntityMention( company_name=entity.company_name, confidence=confidence, context_snippet=context, - article_id=f"article_{start_idx}", + article_id=article_id, event_type=EventCategory(event_type_str), sentiment=sentiment, ) diff --git a/tradingagents/agents/utils/memory.py b/tradingagents/agents/utils/memory.py index 9146313e..9a410183 100644 --- a/tradingagents/agents/utils/memory.py +++ b/tradingagents/agents/utils/memory.py @@ -11,7 +11,7 @@ class FinancialSituationMemory: self.embedding = "text-embedding-3-small" self.client = OpenAI(base_url=config["backend_url"]) self.chroma_client = chromadb.Client(Settings(allow_reset=True)) - self.situation_collection = self.chroma_client.create_collection(name=name) + self.situation_collection = self.chroma_client.get_or_create_collection(name=name) def get_embedding(self, text): """Get OpenAI embedding for a text""" diff --git a/tradingagents/dataflows/alpha_vantage_news.py b/tradingagents/dataflows/alpha_vantage_news.py index 74744b76..8968f06d 100644 --- a/tradingagents/dataflows/alpha_vantage_news.py +++ b/tradingagents/dataflows/alpha_vantage_news.py @@ -40,12 +40,19 @@ def get_bulk_news_alpha_vantage(lookback_hours: int) -> List[Dict[str, Any]]: try: response = json.loads(response) except json.JSONDecodeError: + print(f"DEBUG: Alpha Vantage JSON decode failed") return [] if not isinstance(response, dict): + print(f"DEBUG: Alpha Vantage response not a dict: {type(response)}") return [] + if "Information" in response: + print(f"DEBUG: Alpha Vantage info message: {response.get('Information')}") + feed = response.get("feed", []) + if not feed: + print(f"DEBUG: Alpha Vantage feed empty. Keys in response: {list(response.keys())}") articles = [] for item in feed: diff --git a/tradingagents/graph/trading_graph.py b/tradingagents/graph/trading_graph.py index ba4c092c..435ad168 100644 --- a/tradingagents/graph/trading_graph.py +++ b/tradingagents/graph/trading_graph.py @@ -265,6 +265,8 @@ class TradingAgentsGraph: mentions = extract_entities(articles, self.config) min_mentions = self.config.get("discovery_min_mentions", 2) + if len(articles) < 10: + min_mentions = 1 max_results = request.max_results or self.config.get("discovery_max_results", 20) trending_stocks = calculate_trending_scores(