164 lines
6.1 KiB
Python
164 lines
6.1 KiB
Python
from dataclasses import dataclass, field
|
|
from typing import List, Optional
|
|
from pydantic import BaseModel, Field as PydanticField
|
|
|
|
from langchain_openai import ChatOpenAI
|
|
from langchain_anthropic import ChatAnthropic
|
|
from langchain_google_genai import ChatGoogleGenerativeAI
|
|
|
|
from tradingagents.default_config import DEFAULT_CONFIG
|
|
from tradingagents.agents.discovery.models import NewsArticle, EventCategory
|
|
|
|
|
|
BATCH_SIZE = 10
|
|
|
|
|
|
@dataclass
|
|
class EntityMention:
|
|
company_name: str
|
|
confidence: float
|
|
context_snippet: str
|
|
article_id: str
|
|
event_type: EventCategory
|
|
sentiment: float = field(default=0.0)
|
|
|
|
|
|
class ExtractedEntity(BaseModel):
|
|
company_name: str = PydanticField(description="The name of the publicly traded company mentioned")
|
|
confidence: float = PydanticField(description="Confidence score from 0.0 to 1.0 based on mention clarity")
|
|
context_snippet: str = PydanticField(description="Surrounding context of 50-100 characters around the company mention")
|
|
event_type: str = PydanticField(description="Event category: earnings, merger_acquisition, regulatory, product_launch, executive_change, or other")
|
|
sentiment: float = PydanticField(default=0.0, description="Sentiment score from -1.0 (negative) to 1.0 (positive)")
|
|
article_id: str = PydanticField(description="The article ID where this company was mentioned (e.g., article_0, article_1)")
|
|
|
|
|
|
class ExtractionResponse(BaseModel):
|
|
entities: List[ExtractedEntity] = PydanticField(default_factory=list, description="List of extracted company entities")
|
|
|
|
|
|
def _get_llm(config: Optional[dict] = None):
|
|
cfg = config or DEFAULT_CONFIG
|
|
provider = cfg.get("llm_provider", "openai").lower()
|
|
model = cfg.get("quick_think_llm", "gpt-4o-mini")
|
|
backend_url = cfg.get("backend_url", "https://api.openai.com/v1")
|
|
|
|
if provider in ("openai", "ollama", "openrouter"):
|
|
return ChatOpenAI(model=model, base_url=backend_url)
|
|
elif provider == "anthropic":
|
|
return ChatAnthropic(model=model, base_url=backend_url)
|
|
elif provider == "google":
|
|
return ChatGoogleGenerativeAI(model=model)
|
|
else:
|
|
raise ValueError(f"Unsupported LLM provider: {provider}")
|
|
|
|
|
|
EXTRACTION_PROMPT = """You are an expert at identifying publicly traded companies mentioned in news articles.
|
|
|
|
For each article provided, extract all mentions of publicly traded companies. For each company mention:
|
|
|
|
1. Extract the company name as it appears (e.g., "Apple Inc.", "Apple", "AAPL", "the iPhone maker")
|
|
2. Assign a confidence score from 0.0 to 1.0 based on how clearly the company is mentioned:
|
|
- 0.9-1.0: Direct company name or ticker symbol
|
|
- 0.7-0.9: Clear reference with context (e.g., "the Cupertino tech giant")
|
|
- 0.5-0.7: Indirect reference requiring inference
|
|
- Below 0.5: Uncertain or ambiguous reference
|
|
3. Extract 50-100 characters of surrounding context
|
|
4. Classify the event type:
|
|
- earnings: Quarterly/annual earnings reports, revenue announcements
|
|
- merger_acquisition: Mergers, acquisitions, buyouts, takeovers
|
|
- regulatory: SEC filings, government investigations, compliance issues
|
|
- product_launch: New products, services, or features
|
|
- executive_change: CEO/CFO changes, board appointments, departures
|
|
- other: Any other business news
|
|
5. Assign a sentiment score from -1.0 to 1.0:
|
|
- -1.0: Very negative news (lawsuits, crashes, major failures)
|
|
- -0.5: Moderately negative news
|
|
- 0.0: Neutral news
|
|
- 0.5: Moderately positive news
|
|
- 1.0: Very positive news (breakthroughs, record earnings)
|
|
6. Include the article_id (e.g., article_0, article_1) where the company was mentioned
|
|
|
|
Only extract companies that are publicly traded on major stock exchanges.
|
|
Handle name variations by providing the most complete company name found.
|
|
IMPORTANT: Each entity must include the article_id from which it was extracted.
|
|
|
|
Articles to analyze:
|
|
{articles_text}
|
|
|
|
Extract all company mentions from the articles above."""
|
|
|
|
|
|
def _format_articles_for_prompt(articles: List[NewsArticle], start_idx: int) -> str:
|
|
formatted = []
|
|
for i, article in enumerate(articles):
|
|
article_id = f"article_{start_idx + i}"
|
|
formatted.append(
|
|
f"[{article_id}]\n"
|
|
f"Title: {article.title}\n"
|
|
f"Source: {article.source}\n"
|
|
f"Content: {article.content_snippet}\n"
|
|
)
|
|
return "\n---\n".join(formatted)
|
|
|
|
|
|
def _extract_batch(
|
|
articles: List[NewsArticle],
|
|
start_idx: int,
|
|
llm,
|
|
) -> List[EntityMention]:
|
|
if not articles:
|
|
return []
|
|
|
|
articles_text = _format_articles_for_prompt(articles, start_idx)
|
|
prompt = EXTRACTION_PROMPT.format(articles_text=articles_text)
|
|
|
|
structured_llm = llm.with_structured_output(ExtractionResponse)
|
|
response = structured_llm.invoke(prompt)
|
|
|
|
mentions = []
|
|
for entity in response.entities:
|
|
event_type_str = entity.event_type.lower().strip()
|
|
valid_event_types = {e.value for e in EventCategory}
|
|
if event_type_str not in valid_event_types:
|
|
event_type_str = "other"
|
|
|
|
confidence = max(0.0, min(1.0, entity.confidence))
|
|
sentiment = max(-1.0, min(1.0, entity.sentiment))
|
|
|
|
context = entity.context_snippet
|
|
if len(context) > 150:
|
|
context = context[:147] + "..."
|
|
|
|
article_id = entity.article_id if entity.article_id else f"article_{start_idx}"
|
|
mention = EntityMention(
|
|
company_name=entity.company_name,
|
|
confidence=confidence,
|
|
context_snippet=context,
|
|
article_id=article_id,
|
|
event_type=EventCategory(event_type_str),
|
|
sentiment=sentiment,
|
|
)
|
|
mentions.append(mention)
|
|
|
|
return mentions
|
|
|
|
|
|
def extract_entities(
|
|
articles: List[NewsArticle],
|
|
config: Optional[dict] = None,
|
|
) -> List[EntityMention]:
|
|
if not articles:
|
|
return []
|
|
|
|
llm = _get_llm(config)
|
|
all_mentions: List[EntityMention] = []
|
|
|
|
for batch_start in range(0, len(articles), BATCH_SIZE):
|
|
batch_end = min(batch_start + BATCH_SIZE, len(articles))
|
|
batch = articles[batch_start:batch_end]
|
|
|
|
batch_mentions = _extract_batch(batch, batch_start, llm)
|
|
all_mentions.extend(batch_mentions)
|
|
|
|
return all_mentions
|