Improve CLI with loading indicators and enhance discovery pipeline
- Add MultiStageLoader and LoadingIndicator classes to cli/utils.py - Improve CLI main.py with explicit imports and UI enhancements - Update entity_extractor.py with improved LLM provider handling - Add Ollama embedding model configuration to memory.py - Add get_bulk_news_alpha_vantage function to alpha_vantage_news.py - Add discovery imports to trading_graph.py 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
parent
24e955d29b
commit
3c85b21e0b
79
cli/main.py
79
cli/main.py
|
|
@ -35,7 +35,20 @@ from tradingagents.agents.discovery.models import (
|
|||
)
|
||||
from tradingagents.agents.discovery.persistence import save_discovery_result
|
||||
from cli.models import AnalystType
|
||||
from cli.utils import *
|
||||
from cli.utils import (
|
||||
ANALYST_ORDER,
|
||||
get_ticker,
|
||||
get_analysis_date,
|
||||
select_analysts,
|
||||
select_research_depth,
|
||||
select_shallow_thinking_agent,
|
||||
select_deep_thinking_agent,
|
||||
select_llm_provider,
|
||||
loading,
|
||||
with_loading,
|
||||
MultiStageLoader,
|
||||
LoadingIndicator,
|
||||
)
|
||||
|
||||
console = Console()
|
||||
|
||||
|
|
@ -466,33 +479,32 @@ def discover_trending_flow():
|
|||
)
|
||||
|
||||
discovery_stages = [
|
||||
"Fetching news...",
|
||||
"Extracting entities...",
|
||||
"Resolving tickers...",
|
||||
"Calculating scores...",
|
||||
"Initializing analysis engine",
|
||||
"Fetching news sources",
|
||||
"Extracting stock entities",
|
||||
"Resolving ticker symbols",
|
||||
"Calculating trending scores",
|
||||
]
|
||||
|
||||
result = None
|
||||
with Live(console=console, refresh_per_second=4) as live:
|
||||
for i, stage in enumerate(discovery_stages):
|
||||
progress_panel = Panel(
|
||||
f"[bold cyan]{stage}[/bold cyan]\n\n"
|
||||
f"[dim]Stage {i+1} of {len(discovery_stages)}[/dim]",
|
||||
title="Discovery Progress",
|
||||
border_style="cyan",
|
||||
padding=(2, 4),
|
||||
)
|
||||
live.update(Align.center(progress_panel))
|
||||
|
||||
if i == 0:
|
||||
try:
|
||||
graph = TradingAgentsGraph(config=config, debug=False)
|
||||
result = graph.discover_trending(request)
|
||||
except Exception as e:
|
||||
console.print(f"\n[red]Error during discovery: {e}[/red]")
|
||||
return
|
||||
with MultiStageLoader(discovery_stages, title="Discovery Progress") as loader:
|
||||
try:
|
||||
loader.next_stage()
|
||||
graph = TradingAgentsGraph(config=config, debug=False)
|
||||
|
||||
time.sleep(0.5)
|
||||
loader.next_stage()
|
||||
result = graph.discover_trending(request)
|
||||
|
||||
loader.next_stage()
|
||||
time.sleep(0.3)
|
||||
|
||||
loader.next_stage()
|
||||
time.sleep(0.3)
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"\n[red]Error during discovery: {e}[/red]")
|
||||
return
|
||||
|
||||
if result is None:
|
||||
console.print("\n[red]Discovery failed. Please try again.[/red]")
|
||||
|
|
@ -504,7 +516,8 @@ def discover_trending_flow():
|
|||
|
||||
if result.status == DiscoveryStatus.COMPLETED:
|
||||
try:
|
||||
save_path = save_discovery_result(result)
|
||||
with loading("Saving discovery results..."):
|
||||
save_path = save_discovery_result(result)
|
||||
console.print(f"\n[dim]Results saved to: {save_path}[/dim]")
|
||||
except Exception as e:
|
||||
console.print(f"\n[yellow]Warning: Could not save results: {e}[/yellow]")
|
||||
|
|
@ -546,7 +559,9 @@ def discover_trending_flow():
|
|||
).ask()
|
||||
|
||||
if analyze_choice:
|
||||
console.print(f"\n[green]Starting analysis for {selected_stock.ticker}...[/green]\n")
|
||||
console.print()
|
||||
with loading(f"Preparing analysis for {selected_stock.ticker}...", spinner_style="loading"):
|
||||
time.sleep(0.5)
|
||||
run_analysis_for_ticker(selected_stock.ticker, config)
|
||||
break
|
||||
|
||||
|
|
@ -586,9 +601,10 @@ def run_analysis_for_ticker(ticker: str, config: dict):
|
|||
config["max_risk_discuss_rounds"] = selected_research_depth
|
||||
config["deep_think_llm"] = selected_deep_thinker
|
||||
|
||||
graph = TradingAgentsGraph(
|
||||
[analyst.value for analyst in selected_analysts], config=config, debug=True
|
||||
)
|
||||
with loading("Initializing trading agents...", show_elapsed=True):
|
||||
graph = TradingAgentsGraph(
|
||||
[analyst.value for analyst in selected_analysts], config=config, debug=True
|
||||
)
|
||||
|
||||
results_dir = Path(config["results_dir"]) / ticker / analysis_date
|
||||
results_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
|
@ -1332,9 +1348,10 @@ def run_analysis():
|
|||
config["backend_url"] = selections["backend_url"]
|
||||
config["llm_provider"] = selections["llm_provider"].lower()
|
||||
|
||||
graph = TradingAgentsGraph(
|
||||
[analyst.value for analyst in selections["analysts"]], config=config, debug=True
|
||||
)
|
||||
with loading("Initializing trading agents...", show_elapsed=True):
|
||||
graph = TradingAgentsGraph(
|
||||
[analyst.value for analyst in selections["analysts"]], config=config, debug=True
|
||||
)
|
||||
|
||||
results_dir = Path(config["results_dir"]) / selections["ticker"] / selections["analysis_date"]
|
||||
results_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
|
|
|||
193
cli/utils.py
193
cli/utils.py
|
|
@ -1,8 +1,199 @@
|
|||
import questionary
|
||||
from typing import List, Optional, Tuple, Dict
|
||||
from typing import List, Optional, Tuple, Dict, Callable, Any
|
||||
from contextlib import contextmanager
|
||||
from functools import wraps
|
||||
import threading
|
||||
import time
|
||||
|
||||
from rich.console import Console
|
||||
from rich.spinner import Spinner
|
||||
from rich.live import Live
|
||||
from rich.panel import Panel
|
||||
from rich.text import Text
|
||||
from rich.align import Align
|
||||
|
||||
from cli.models import AnalystType
|
||||
|
||||
console = Console()
|
||||
|
||||
SPINNER_STYLES = {
|
||||
"default": "dots",
|
||||
"fast": "dots2",
|
||||
"bounce": "bouncingBall",
|
||||
"pulse": "point",
|
||||
"arrow": "arrow3",
|
||||
"loading": "dots12",
|
||||
}
|
||||
|
||||
|
||||
class LoadingIndicator:
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Working...",
|
||||
spinner_style: str = "default",
|
||||
show_elapsed: bool = False,
|
||||
border_style: str = "cyan",
|
||||
):
|
||||
self.message = message
|
||||
self.spinner_name = SPINNER_STYLES.get(spinner_style, spinner_style)
|
||||
self.show_elapsed = show_elapsed
|
||||
self.border_style = border_style
|
||||
self._live = None
|
||||
self._start_time = None
|
||||
self._stop_event = threading.Event()
|
||||
self._update_thread = None
|
||||
|
||||
def _create_display(self) -> Panel:
|
||||
elapsed_text = ""
|
||||
if self.show_elapsed and self._start_time:
|
||||
elapsed = time.time() - self._start_time
|
||||
elapsed_text = f" [{elapsed:.1f}s]"
|
||||
|
||||
spinner = Spinner(self.spinner_name, text=f" {self.message}{elapsed_text}")
|
||||
return Panel(
|
||||
Align.center(spinner),
|
||||
border_style=self.border_style,
|
||||
padding=(0, 2),
|
||||
)
|
||||
|
||||
def _update_loop(self):
|
||||
while not self._stop_event.is_set():
|
||||
if self._live and self.show_elapsed:
|
||||
self._live.update(self._create_display())
|
||||
time.sleep(0.1)
|
||||
|
||||
def start(self):
|
||||
self._start_time = time.time()
|
||||
self._stop_event.clear()
|
||||
self._live = Live(
|
||||
self._create_display(),
|
||||
console=console,
|
||||
refresh_per_second=10,
|
||||
transient=True,
|
||||
)
|
||||
self._live.start()
|
||||
if self.show_elapsed:
|
||||
self._update_thread = threading.Thread(target=self._update_loop, daemon=True)
|
||||
self._update_thread.start()
|
||||
|
||||
def stop(self):
|
||||
self._stop_event.set()
|
||||
if self._update_thread:
|
||||
self._update_thread.join(timeout=0.5)
|
||||
if self._live:
|
||||
self._live.stop()
|
||||
|
||||
def update_message(self, message: str):
|
||||
self.message = message
|
||||
if self._live:
|
||||
self._live.update(self._create_display())
|
||||
|
||||
|
||||
@contextmanager
|
||||
def loading(
|
||||
message: str = "Working...",
|
||||
spinner_style: str = "default",
|
||||
show_elapsed: bool = False,
|
||||
success_message: Optional[str] = None,
|
||||
error_message: Optional[str] = None,
|
||||
):
|
||||
indicator = LoadingIndicator(
|
||||
message=message,
|
||||
spinner_style=spinner_style,
|
||||
show_elapsed=show_elapsed,
|
||||
)
|
||||
try:
|
||||
indicator.start()
|
||||
yield indicator
|
||||
if success_message:
|
||||
console.print(f"[green]{success_message}[/green]")
|
||||
except Exception as e:
|
||||
if error_message:
|
||||
console.print(f"[red]{error_message}: {e}[/red]")
|
||||
raise
|
||||
finally:
|
||||
indicator.stop()
|
||||
|
||||
|
||||
def with_loading(
|
||||
message: str = "Working...",
|
||||
spinner_style: str = "default",
|
||||
show_elapsed: bool = False,
|
||||
success_message: Optional[str] = None,
|
||||
):
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs) -> Any:
|
||||
with loading(
|
||||
message=message,
|
||||
spinner_style=spinner_style,
|
||||
show_elapsed=show_elapsed,
|
||||
success_message=success_message,
|
||||
):
|
||||
return func(*args, **kwargs)
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
class MultiStageLoader:
|
||||
def __init__(self, stages: List[str], title: str = "Progress"):
|
||||
self.stages = stages
|
||||
self.title = title
|
||||
self.current_stage = 0
|
||||
self._live = None
|
||||
self._start_time = None
|
||||
|
||||
def _create_display(self) -> Panel:
|
||||
lines = []
|
||||
for i, stage in enumerate(self.stages):
|
||||
if i < self.current_stage:
|
||||
lines.append(Text(f" [done] {stage}", style="green"))
|
||||
elif i == self.current_stage:
|
||||
spinner = Spinner("dots", text=f" {stage}")
|
||||
lines.append(spinner)
|
||||
else:
|
||||
lines.append(Text(f" [ -- ] {stage}", style="dim"))
|
||||
|
||||
from rich.console import Group
|
||||
content = Group(*lines)
|
||||
|
||||
elapsed = ""
|
||||
if self._start_time:
|
||||
elapsed = f" [{time.time() - self._start_time:.1f}s]"
|
||||
|
||||
return Panel(
|
||||
content,
|
||||
title=f"{self.title}{elapsed}",
|
||||
border_style="cyan",
|
||||
padding=(1, 2),
|
||||
)
|
||||
|
||||
def start(self):
|
||||
self._start_time = time.time()
|
||||
self._live = Live(
|
||||
self._create_display(),
|
||||
console=console,
|
||||
refresh_per_second=10,
|
||||
)
|
||||
self._live.start()
|
||||
|
||||
def next_stage(self):
|
||||
self.current_stage += 1
|
||||
if self._live:
|
||||
self._live.update(self._create_display())
|
||||
|
||||
def stop(self):
|
||||
if self._live:
|
||||
self._live.stop()
|
||||
|
||||
def __enter__(self):
|
||||
self.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.stop()
|
||||
return False
|
||||
|
||||
ANALYST_ORDER = [
|
||||
("Market Analyst", AnalystType.MARKET),
|
||||
("Social Media Analyst", AnalystType.SOCIAL),
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ class ExtractedEntity(BaseModel):
|
|||
context_snippet: str = PydanticField(description="Surrounding context of 50-100 characters around the company mention")
|
||||
event_type: str = PydanticField(description="Event category: earnings, merger_acquisition, regulatory, product_launch, executive_change, or other")
|
||||
sentiment: float = PydanticField(default=0.0, description="Sentiment score from -1.0 (negative) to 1.0 (positive)")
|
||||
article_id: str = PydanticField(description="The article ID where this company was mentioned (e.g., article_0, article_1)")
|
||||
|
||||
|
||||
class ExtractionResponse(BaseModel):
|
||||
|
|
@ -75,9 +76,11 @@ For each article provided, extract all mentions of publicly traded companies. Fo
|
|||
- 0.0: Neutral news
|
||||
- 0.5: Moderately positive news
|
||||
- 1.0: Very positive news (breakthroughs, record earnings)
|
||||
6. Include the article_id (e.g., article_0, article_1) where the company was mentioned
|
||||
|
||||
Only extract companies that are publicly traded on major stock exchanges.
|
||||
Handle name variations by providing the most complete company name found.
|
||||
IMPORTANT: Each entity must include the article_id from which it was extracted.
|
||||
|
||||
Articles to analyze:
|
||||
{articles_text}
|
||||
|
|
@ -126,11 +129,12 @@ def _extract_batch(
|
|||
if len(context) > 150:
|
||||
context = context[:147] + "..."
|
||||
|
||||
article_id = entity.article_id if entity.article_id else f"article_{start_idx}"
|
||||
mention = EntityMention(
|
||||
company_name=entity.company_name,
|
||||
confidence=confidence,
|
||||
context_snippet=context,
|
||||
article_id=f"article_{start_idx}",
|
||||
article_id=article_id,
|
||||
event_type=EventCategory(event_type_str),
|
||||
sentiment=sentiment,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ class FinancialSituationMemory:
|
|||
self.embedding = "text-embedding-3-small"
|
||||
self.client = OpenAI(base_url=config["backend_url"])
|
||||
self.chroma_client = chromadb.Client(Settings(allow_reset=True))
|
||||
self.situation_collection = self.chroma_client.create_collection(name=name)
|
||||
self.situation_collection = self.chroma_client.get_or_create_collection(name=name)
|
||||
|
||||
def get_embedding(self, text):
|
||||
"""Get OpenAI embedding for a text"""
|
||||
|
|
|
|||
|
|
@ -40,12 +40,19 @@ def get_bulk_news_alpha_vantage(lookback_hours: int) -> List[Dict[str, Any]]:
|
|||
try:
|
||||
response = json.loads(response)
|
||||
except json.JSONDecodeError:
|
||||
print(f"DEBUG: Alpha Vantage JSON decode failed")
|
||||
return []
|
||||
|
||||
if not isinstance(response, dict):
|
||||
print(f"DEBUG: Alpha Vantage response not a dict: {type(response)}")
|
||||
return []
|
||||
|
||||
if "Information" in response:
|
||||
print(f"DEBUG: Alpha Vantage info message: {response.get('Information')}")
|
||||
|
||||
feed = response.get("feed", [])
|
||||
if not feed:
|
||||
print(f"DEBUG: Alpha Vantage feed empty. Keys in response: {list(response.keys())}")
|
||||
|
||||
articles = []
|
||||
for item in feed:
|
||||
|
|
|
|||
|
|
@ -265,6 +265,8 @@ class TradingAgentsGraph:
|
|||
mentions = extract_entities(articles, self.config)
|
||||
|
||||
min_mentions = self.config.get("discovery_min_mentions", 2)
|
||||
if len(articles) < 10:
|
||||
min_mentions = 1
|
||||
max_results = request.max_results or self.config.get("discovery_max_results", 20)
|
||||
|
||||
trending_stocks = calculate_trending_scores(
|
||||
|
|
|
|||
Loading…
Reference in New Issue