TradingAgents/tradingagents/agents/discovery/entity_extractor.py

164 lines
6.1 KiB
Python

from dataclasses import dataclass, field
from typing import List, Optional
from pydantic import BaseModel, Field as PydanticField
from langchain_openai import ChatOpenAI
from langchain_anthropic import ChatAnthropic
from langchain_google_genai import ChatGoogleGenerativeAI
from tradingagents.default_config import DEFAULT_CONFIG
from tradingagents.agents.discovery.models import NewsArticle, EventCategory
BATCH_SIZE = 10
@dataclass
class EntityMention:
company_name: str
confidence: float
context_snippet: str
article_id: str
event_type: EventCategory
sentiment: float = field(default=0.0)
class ExtractedEntity(BaseModel):
company_name: str = PydanticField(description="The name of the publicly traded company mentioned")
confidence: float = PydanticField(description="Confidence score from 0.0 to 1.0 based on mention clarity")
context_snippet: str = PydanticField(description="Surrounding context of 50-100 characters around the company mention")
event_type: str = PydanticField(description="Event category: earnings, merger_acquisition, regulatory, product_launch, executive_change, or other")
sentiment: float = PydanticField(default=0.0, description="Sentiment score from -1.0 (negative) to 1.0 (positive)")
article_id: str = PydanticField(description="The article ID where this company was mentioned (e.g., article_0, article_1)")
class ExtractionResponse(BaseModel):
entities: List[ExtractedEntity] = PydanticField(default_factory=list, description="List of extracted company entities")
def _get_llm(config: Optional[dict] = None):
cfg = config or DEFAULT_CONFIG
provider = cfg.get("llm_provider", "openai").lower()
model = cfg.get("quick_think_llm", "gpt-4o-mini")
backend_url = cfg.get("backend_url", "https://api.openai.com/v1")
if provider in ("openai", "ollama", "openrouter"):
return ChatOpenAI(model=model, base_url=backend_url)
elif provider == "anthropic":
return ChatAnthropic(model=model, base_url=backend_url)
elif provider == "google":
return ChatGoogleGenerativeAI(model=model)
else:
raise ValueError(f"Unsupported LLM provider: {provider}")
EXTRACTION_PROMPT = """You are an expert at identifying publicly traded companies mentioned in news articles.
For each article provided, extract all mentions of publicly traded companies. For each company mention:
1. Extract the company name as it appears (e.g., "Apple Inc.", "Apple", "AAPL", "the iPhone maker")
2. Assign a confidence score from 0.0 to 1.0 based on how clearly the company is mentioned:
- 0.9-1.0: Direct company name or ticker symbol
- 0.7-0.9: Clear reference with context (e.g., "the Cupertino tech giant")
- 0.5-0.7: Indirect reference requiring inference
- Below 0.5: Uncertain or ambiguous reference
3. Extract 50-100 characters of surrounding context
4. Classify the event type:
- earnings: Quarterly/annual earnings reports, revenue announcements
- merger_acquisition: Mergers, acquisitions, buyouts, takeovers
- regulatory: SEC filings, government investigations, compliance issues
- product_launch: New products, services, or features
- executive_change: CEO/CFO changes, board appointments, departures
- other: Any other business news
5. Assign a sentiment score from -1.0 to 1.0:
- -1.0: Very negative news (lawsuits, crashes, major failures)
- -0.5: Moderately negative news
- 0.0: Neutral news
- 0.5: Moderately positive news
- 1.0: Very positive news (breakthroughs, record earnings)
6. Include the article_id (e.g., article_0, article_1) where the company was mentioned
Only extract companies that are publicly traded on major stock exchanges.
Handle name variations by providing the most complete company name found.
IMPORTANT: Each entity must include the article_id from which it was extracted.
Articles to analyze:
{articles_text}
Extract all company mentions from the articles above."""
def _format_articles_for_prompt(articles: List[NewsArticle], start_idx: int) -> str:
formatted = []
for i, article in enumerate(articles):
article_id = f"article_{start_idx + i}"
formatted.append(
f"[{article_id}]\n"
f"Title: {article.title}\n"
f"Source: {article.source}\n"
f"Content: {article.content_snippet}\n"
)
return "\n---\n".join(formatted)
def _extract_batch(
articles: List[NewsArticle],
start_idx: int,
llm,
) -> List[EntityMention]:
if not articles:
return []
articles_text = _format_articles_for_prompt(articles, start_idx)
prompt = EXTRACTION_PROMPT.format(articles_text=articles_text)
structured_llm = llm.with_structured_output(ExtractionResponse)
response = structured_llm.invoke(prompt)
mentions = []
for entity in response.entities:
event_type_str = entity.event_type.lower().strip()
valid_event_types = {e.value for e in EventCategory}
if event_type_str not in valid_event_types:
event_type_str = "other"
confidence = max(0.0, min(1.0, entity.confidence))
sentiment = max(-1.0, min(1.0, entity.sentiment))
context = entity.context_snippet
if len(context) > 150:
context = context[:147] + "..."
article_id = entity.article_id if entity.article_id else f"article_{start_idx}"
mention = EntityMention(
company_name=entity.company_name,
confidence=confidence,
context_snippet=context,
article_id=article_id,
event_type=EventCategory(event_type_str),
sentiment=sentiment,
)
mentions.append(mention)
return mentions
def extract_entities(
articles: List[NewsArticle],
config: Optional[dict] = None,
) -> List[EntityMention]:
if not articles:
return []
llm = _get_llm(config)
all_mentions: List[EntityMention] = []
for batch_start in range(0, len(articles), BATCH_SIZE):
batch_end = min(batch_start + BATCH_SIZE, len(articles))
batch = articles[batch_start:batch_end]
batch_mentions = _extract_batch(batch, batch_start, llm)
all_mentions.extend(batch_mentions)
return all_mentions