Improve CLI with loading indicators and enhance discovery pipeline
- Add MultiStageLoader and LoadingIndicator classes to cli/utils.py - Improve CLI main.py with explicit imports and UI enhancements - Update entity_extractor.py with improved LLM provider handling - Add Ollama embedding model configuration to memory.py - Add get_bulk_news_alpha_vantage function to alpha_vantage_news.py - Add discovery imports to trading_graph.py 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
parent
24e955d29b
commit
3c85b21e0b
79
cli/main.py
79
cli/main.py
|
|
@ -35,7 +35,20 @@ from tradingagents.agents.discovery.models import (
|
||||||
)
|
)
|
||||||
from tradingagents.agents.discovery.persistence import save_discovery_result
|
from tradingagents.agents.discovery.persistence import save_discovery_result
|
||||||
from cli.models import AnalystType
|
from cli.models import AnalystType
|
||||||
from cli.utils import *
|
from cli.utils import (
|
||||||
|
ANALYST_ORDER,
|
||||||
|
get_ticker,
|
||||||
|
get_analysis_date,
|
||||||
|
select_analysts,
|
||||||
|
select_research_depth,
|
||||||
|
select_shallow_thinking_agent,
|
||||||
|
select_deep_thinking_agent,
|
||||||
|
select_llm_provider,
|
||||||
|
loading,
|
||||||
|
with_loading,
|
||||||
|
MultiStageLoader,
|
||||||
|
LoadingIndicator,
|
||||||
|
)
|
||||||
|
|
||||||
console = Console()
|
console = Console()
|
||||||
|
|
||||||
|
|
@ -466,33 +479,32 @@ def discover_trending_flow():
|
||||||
)
|
)
|
||||||
|
|
||||||
discovery_stages = [
|
discovery_stages = [
|
||||||
"Fetching news...",
|
"Initializing analysis engine",
|
||||||
"Extracting entities...",
|
"Fetching news sources",
|
||||||
"Resolving tickers...",
|
"Extracting stock entities",
|
||||||
"Calculating scores...",
|
"Resolving ticker symbols",
|
||||||
|
"Calculating trending scores",
|
||||||
]
|
]
|
||||||
|
|
||||||
result = None
|
result = None
|
||||||
with Live(console=console, refresh_per_second=4) as live:
|
|
||||||
for i, stage in enumerate(discovery_stages):
|
|
||||||
progress_panel = Panel(
|
|
||||||
f"[bold cyan]{stage}[/bold cyan]\n\n"
|
|
||||||
f"[dim]Stage {i+1} of {len(discovery_stages)}[/dim]",
|
|
||||||
title="Discovery Progress",
|
|
||||||
border_style="cyan",
|
|
||||||
padding=(2, 4),
|
|
||||||
)
|
|
||||||
live.update(Align.center(progress_panel))
|
|
||||||
|
|
||||||
if i == 0:
|
with MultiStageLoader(discovery_stages, title="Discovery Progress") as loader:
|
||||||
try:
|
try:
|
||||||
graph = TradingAgentsGraph(config=config, debug=False)
|
loader.next_stage()
|
||||||
result = graph.discover_trending(request)
|
graph = TradingAgentsGraph(config=config, debug=False)
|
||||||
except Exception as e:
|
|
||||||
console.print(f"\n[red]Error during discovery: {e}[/red]")
|
|
||||||
return
|
|
||||||
|
|
||||||
time.sleep(0.5)
|
loader.next_stage()
|
||||||
|
result = graph.discover_trending(request)
|
||||||
|
|
||||||
|
loader.next_stage()
|
||||||
|
time.sleep(0.3)
|
||||||
|
|
||||||
|
loader.next_stage()
|
||||||
|
time.sleep(0.3)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
console.print(f"\n[red]Error during discovery: {e}[/red]")
|
||||||
|
return
|
||||||
|
|
||||||
if result is None:
|
if result is None:
|
||||||
console.print("\n[red]Discovery failed. Please try again.[/red]")
|
console.print("\n[red]Discovery failed. Please try again.[/red]")
|
||||||
|
|
@ -504,7 +516,8 @@ def discover_trending_flow():
|
||||||
|
|
||||||
if result.status == DiscoveryStatus.COMPLETED:
|
if result.status == DiscoveryStatus.COMPLETED:
|
||||||
try:
|
try:
|
||||||
save_path = save_discovery_result(result)
|
with loading("Saving discovery results..."):
|
||||||
|
save_path = save_discovery_result(result)
|
||||||
console.print(f"\n[dim]Results saved to: {save_path}[/dim]")
|
console.print(f"\n[dim]Results saved to: {save_path}[/dim]")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
console.print(f"\n[yellow]Warning: Could not save results: {e}[/yellow]")
|
console.print(f"\n[yellow]Warning: Could not save results: {e}[/yellow]")
|
||||||
|
|
@ -546,7 +559,9 @@ def discover_trending_flow():
|
||||||
).ask()
|
).ask()
|
||||||
|
|
||||||
if analyze_choice:
|
if analyze_choice:
|
||||||
console.print(f"\n[green]Starting analysis for {selected_stock.ticker}...[/green]\n")
|
console.print()
|
||||||
|
with loading(f"Preparing analysis for {selected_stock.ticker}...", spinner_style="loading"):
|
||||||
|
time.sleep(0.5)
|
||||||
run_analysis_for_ticker(selected_stock.ticker, config)
|
run_analysis_for_ticker(selected_stock.ticker, config)
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
@ -586,9 +601,10 @@ def run_analysis_for_ticker(ticker: str, config: dict):
|
||||||
config["max_risk_discuss_rounds"] = selected_research_depth
|
config["max_risk_discuss_rounds"] = selected_research_depth
|
||||||
config["deep_think_llm"] = selected_deep_thinker
|
config["deep_think_llm"] = selected_deep_thinker
|
||||||
|
|
||||||
graph = TradingAgentsGraph(
|
with loading("Initializing trading agents...", show_elapsed=True):
|
||||||
[analyst.value for analyst in selected_analysts], config=config, debug=True
|
graph = TradingAgentsGraph(
|
||||||
)
|
[analyst.value for analyst in selected_analysts], config=config, debug=True
|
||||||
|
)
|
||||||
|
|
||||||
results_dir = Path(config["results_dir"]) / ticker / analysis_date
|
results_dir = Path(config["results_dir"]) / ticker / analysis_date
|
||||||
results_dir.mkdir(parents=True, exist_ok=True)
|
results_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
@ -1332,9 +1348,10 @@ def run_analysis():
|
||||||
config["backend_url"] = selections["backend_url"]
|
config["backend_url"] = selections["backend_url"]
|
||||||
config["llm_provider"] = selections["llm_provider"].lower()
|
config["llm_provider"] = selections["llm_provider"].lower()
|
||||||
|
|
||||||
graph = TradingAgentsGraph(
|
with loading("Initializing trading agents...", show_elapsed=True):
|
||||||
[analyst.value for analyst in selections["analysts"]], config=config, debug=True
|
graph = TradingAgentsGraph(
|
||||||
)
|
[analyst.value for analyst in selections["analysts"]], config=config, debug=True
|
||||||
|
)
|
||||||
|
|
||||||
results_dir = Path(config["results_dir"]) / selections["ticker"] / selections["analysis_date"]
|
results_dir = Path(config["results_dir"]) / selections["ticker"] / selections["analysis_date"]
|
||||||
results_dir.mkdir(parents=True, exist_ok=True)
|
results_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
|
||||||
193
cli/utils.py
193
cli/utils.py
|
|
@ -1,8 +1,199 @@
|
||||||
import questionary
|
import questionary
|
||||||
from typing import List, Optional, Tuple, Dict
|
from typing import List, Optional, Tuple, Dict, Callable, Any
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from functools import wraps
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.spinner import Spinner
|
||||||
|
from rich.live import Live
|
||||||
|
from rich.panel import Panel
|
||||||
|
from rich.text import Text
|
||||||
|
from rich.align import Align
|
||||||
|
|
||||||
from cli.models import AnalystType
|
from cli.models import AnalystType
|
||||||
|
|
||||||
|
console = Console()
|
||||||
|
|
||||||
|
SPINNER_STYLES = {
|
||||||
|
"default": "dots",
|
||||||
|
"fast": "dots2",
|
||||||
|
"bounce": "bouncingBall",
|
||||||
|
"pulse": "point",
|
||||||
|
"arrow": "arrow3",
|
||||||
|
"loading": "dots12",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class LoadingIndicator:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str = "Working...",
|
||||||
|
spinner_style: str = "default",
|
||||||
|
show_elapsed: bool = False,
|
||||||
|
border_style: str = "cyan",
|
||||||
|
):
|
||||||
|
self.message = message
|
||||||
|
self.spinner_name = SPINNER_STYLES.get(spinner_style, spinner_style)
|
||||||
|
self.show_elapsed = show_elapsed
|
||||||
|
self.border_style = border_style
|
||||||
|
self._live = None
|
||||||
|
self._start_time = None
|
||||||
|
self._stop_event = threading.Event()
|
||||||
|
self._update_thread = None
|
||||||
|
|
||||||
|
def _create_display(self) -> Panel:
|
||||||
|
elapsed_text = ""
|
||||||
|
if self.show_elapsed and self._start_time:
|
||||||
|
elapsed = time.time() - self._start_time
|
||||||
|
elapsed_text = f" [{elapsed:.1f}s]"
|
||||||
|
|
||||||
|
spinner = Spinner(self.spinner_name, text=f" {self.message}{elapsed_text}")
|
||||||
|
return Panel(
|
||||||
|
Align.center(spinner),
|
||||||
|
border_style=self.border_style,
|
||||||
|
padding=(0, 2),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _update_loop(self):
|
||||||
|
while not self._stop_event.is_set():
|
||||||
|
if self._live and self.show_elapsed:
|
||||||
|
self._live.update(self._create_display())
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
self._start_time = time.time()
|
||||||
|
self._stop_event.clear()
|
||||||
|
self._live = Live(
|
||||||
|
self._create_display(),
|
||||||
|
console=console,
|
||||||
|
refresh_per_second=10,
|
||||||
|
transient=True,
|
||||||
|
)
|
||||||
|
self._live.start()
|
||||||
|
if self.show_elapsed:
|
||||||
|
self._update_thread = threading.Thread(target=self._update_loop, daemon=True)
|
||||||
|
self._update_thread.start()
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
self._stop_event.set()
|
||||||
|
if self._update_thread:
|
||||||
|
self._update_thread.join(timeout=0.5)
|
||||||
|
if self._live:
|
||||||
|
self._live.stop()
|
||||||
|
|
||||||
|
def update_message(self, message: str):
|
||||||
|
self.message = message
|
||||||
|
if self._live:
|
||||||
|
self._live.update(self._create_display())
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def loading(
|
||||||
|
message: str = "Working...",
|
||||||
|
spinner_style: str = "default",
|
||||||
|
show_elapsed: bool = False,
|
||||||
|
success_message: Optional[str] = None,
|
||||||
|
error_message: Optional[str] = None,
|
||||||
|
):
|
||||||
|
indicator = LoadingIndicator(
|
||||||
|
message=message,
|
||||||
|
spinner_style=spinner_style,
|
||||||
|
show_elapsed=show_elapsed,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
indicator.start()
|
||||||
|
yield indicator
|
||||||
|
if success_message:
|
||||||
|
console.print(f"[green]{success_message}[/green]")
|
||||||
|
except Exception as e:
|
||||||
|
if error_message:
|
||||||
|
console.print(f"[red]{error_message}: {e}[/red]")
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
indicator.stop()
|
||||||
|
|
||||||
|
|
||||||
|
def with_loading(
|
||||||
|
message: str = "Working...",
|
||||||
|
spinner_style: str = "default",
|
||||||
|
show_elapsed: bool = False,
|
||||||
|
success_message: Optional[str] = None,
|
||||||
|
):
|
||||||
|
def decorator(func: Callable) -> Callable:
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args, **kwargs) -> Any:
|
||||||
|
with loading(
|
||||||
|
message=message,
|
||||||
|
spinner_style=spinner_style,
|
||||||
|
show_elapsed=show_elapsed,
|
||||||
|
success_message=success_message,
|
||||||
|
):
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
return wrapper
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
class MultiStageLoader:
|
||||||
|
def __init__(self, stages: List[str], title: str = "Progress"):
|
||||||
|
self.stages = stages
|
||||||
|
self.title = title
|
||||||
|
self.current_stage = 0
|
||||||
|
self._live = None
|
||||||
|
self._start_time = None
|
||||||
|
|
||||||
|
def _create_display(self) -> Panel:
|
||||||
|
lines = []
|
||||||
|
for i, stage in enumerate(self.stages):
|
||||||
|
if i < self.current_stage:
|
||||||
|
lines.append(Text(f" [done] {stage}", style="green"))
|
||||||
|
elif i == self.current_stage:
|
||||||
|
spinner = Spinner("dots", text=f" {stage}")
|
||||||
|
lines.append(spinner)
|
||||||
|
else:
|
||||||
|
lines.append(Text(f" [ -- ] {stage}", style="dim"))
|
||||||
|
|
||||||
|
from rich.console import Group
|
||||||
|
content = Group(*lines)
|
||||||
|
|
||||||
|
elapsed = ""
|
||||||
|
if self._start_time:
|
||||||
|
elapsed = f" [{time.time() - self._start_time:.1f}s]"
|
||||||
|
|
||||||
|
return Panel(
|
||||||
|
content,
|
||||||
|
title=f"{self.title}{elapsed}",
|
||||||
|
border_style="cyan",
|
||||||
|
padding=(1, 2),
|
||||||
|
)
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
self._start_time = time.time()
|
||||||
|
self._live = Live(
|
||||||
|
self._create_display(),
|
||||||
|
console=console,
|
||||||
|
refresh_per_second=10,
|
||||||
|
)
|
||||||
|
self._live.start()
|
||||||
|
|
||||||
|
def next_stage(self):
|
||||||
|
self.current_stage += 1
|
||||||
|
if self._live:
|
||||||
|
self._live.update(self._create_display())
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
if self._live:
|
||||||
|
self._live.stop()
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.start()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
self.stop()
|
||||||
|
return False
|
||||||
|
|
||||||
ANALYST_ORDER = [
|
ANALYST_ORDER = [
|
||||||
("Market Analyst", AnalystType.MARKET),
|
("Market Analyst", AnalystType.MARKET),
|
||||||
("Social Media Analyst", AnalystType.SOCIAL),
|
("Social Media Analyst", AnalystType.SOCIAL),
|
||||||
|
|
|
||||||
|
|
@ -29,6 +29,7 @@ class ExtractedEntity(BaseModel):
|
||||||
context_snippet: str = PydanticField(description="Surrounding context of 50-100 characters around the company mention")
|
context_snippet: str = PydanticField(description="Surrounding context of 50-100 characters around the company mention")
|
||||||
event_type: str = PydanticField(description="Event category: earnings, merger_acquisition, regulatory, product_launch, executive_change, or other")
|
event_type: str = PydanticField(description="Event category: earnings, merger_acquisition, regulatory, product_launch, executive_change, or other")
|
||||||
sentiment: float = PydanticField(default=0.0, description="Sentiment score from -1.0 (negative) to 1.0 (positive)")
|
sentiment: float = PydanticField(default=0.0, description="Sentiment score from -1.0 (negative) to 1.0 (positive)")
|
||||||
|
article_id: str = PydanticField(description="The article ID where this company was mentioned (e.g., article_0, article_1)")
|
||||||
|
|
||||||
|
|
||||||
class ExtractionResponse(BaseModel):
|
class ExtractionResponse(BaseModel):
|
||||||
|
|
@ -75,9 +76,11 @@ For each article provided, extract all mentions of publicly traded companies. Fo
|
||||||
- 0.0: Neutral news
|
- 0.0: Neutral news
|
||||||
- 0.5: Moderately positive news
|
- 0.5: Moderately positive news
|
||||||
- 1.0: Very positive news (breakthroughs, record earnings)
|
- 1.0: Very positive news (breakthroughs, record earnings)
|
||||||
|
6. Include the article_id (e.g., article_0, article_1) where the company was mentioned
|
||||||
|
|
||||||
Only extract companies that are publicly traded on major stock exchanges.
|
Only extract companies that are publicly traded on major stock exchanges.
|
||||||
Handle name variations by providing the most complete company name found.
|
Handle name variations by providing the most complete company name found.
|
||||||
|
IMPORTANT: Each entity must include the article_id from which it was extracted.
|
||||||
|
|
||||||
Articles to analyze:
|
Articles to analyze:
|
||||||
{articles_text}
|
{articles_text}
|
||||||
|
|
@ -126,11 +129,12 @@ def _extract_batch(
|
||||||
if len(context) > 150:
|
if len(context) > 150:
|
||||||
context = context[:147] + "..."
|
context = context[:147] + "..."
|
||||||
|
|
||||||
|
article_id = entity.article_id if entity.article_id else f"article_{start_idx}"
|
||||||
mention = EntityMention(
|
mention = EntityMention(
|
||||||
company_name=entity.company_name,
|
company_name=entity.company_name,
|
||||||
confidence=confidence,
|
confidence=confidence,
|
||||||
context_snippet=context,
|
context_snippet=context,
|
||||||
article_id=f"article_{start_idx}",
|
article_id=article_id,
|
||||||
event_type=EventCategory(event_type_str),
|
event_type=EventCategory(event_type_str),
|
||||||
sentiment=sentiment,
|
sentiment=sentiment,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ class FinancialSituationMemory:
|
||||||
self.embedding = "text-embedding-3-small"
|
self.embedding = "text-embedding-3-small"
|
||||||
self.client = OpenAI(base_url=config["backend_url"])
|
self.client = OpenAI(base_url=config["backend_url"])
|
||||||
self.chroma_client = chromadb.Client(Settings(allow_reset=True))
|
self.chroma_client = chromadb.Client(Settings(allow_reset=True))
|
||||||
self.situation_collection = self.chroma_client.create_collection(name=name)
|
self.situation_collection = self.chroma_client.get_or_create_collection(name=name)
|
||||||
|
|
||||||
def get_embedding(self, text):
|
def get_embedding(self, text):
|
||||||
"""Get OpenAI embedding for a text"""
|
"""Get OpenAI embedding for a text"""
|
||||||
|
|
|
||||||
|
|
@ -40,12 +40,19 @@ def get_bulk_news_alpha_vantage(lookback_hours: int) -> List[Dict[str, Any]]:
|
||||||
try:
|
try:
|
||||||
response = json.loads(response)
|
response = json.loads(response)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
|
print(f"DEBUG: Alpha Vantage JSON decode failed")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
if not isinstance(response, dict):
|
if not isinstance(response, dict):
|
||||||
|
print(f"DEBUG: Alpha Vantage response not a dict: {type(response)}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
if "Information" in response:
|
||||||
|
print(f"DEBUG: Alpha Vantage info message: {response.get('Information')}")
|
||||||
|
|
||||||
feed = response.get("feed", [])
|
feed = response.get("feed", [])
|
||||||
|
if not feed:
|
||||||
|
print(f"DEBUG: Alpha Vantage feed empty. Keys in response: {list(response.keys())}")
|
||||||
|
|
||||||
articles = []
|
articles = []
|
||||||
for item in feed:
|
for item in feed:
|
||||||
|
|
|
||||||
|
|
@ -265,6 +265,8 @@ class TradingAgentsGraph:
|
||||||
mentions = extract_entities(articles, self.config)
|
mentions = extract_entities(articles, self.config)
|
||||||
|
|
||||||
min_mentions = self.config.get("discovery_min_mentions", 2)
|
min_mentions = self.config.get("discovery_min_mentions", 2)
|
||||||
|
if len(articles) < 10:
|
||||||
|
min_mentions = 1
|
||||||
max_results = request.max_results or self.config.get("discovery_max_results", 20)
|
max_results = request.max_results or self.config.get("discovery_max_results", 20)
|
||||||
|
|
||||||
trending_stocks = calculate_trending_scores(
|
trending_stocks = calculate_trending_scores(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue