Improve decision schema and KRX data routing

This commit is contained in:
nornen0202 2026-04-08 03:46:25 +09:00
parent 69da5f0ed1
commit 682f865d13
46 changed files with 2397 additions and 607 deletions

33
.github/workflows/ci.yml vendored Normal file
View File

@ -0,0 +1,33 @@
name: CI
on:
pull_request:
push:
branches:
- main
- "codex/**"
permissions:
contents: read
jobs:
test:
runs-on: ubuntu-latest
timeout-minutes: 30
steps:
- name: Check out repository
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.13"
- name: Install TradingAgents
run: |
python -m pip install --upgrade pip
python -m pip install -e .
- name: Run unit tests
run: python -m unittest discover -s tests

View File

@ -43,6 +43,10 @@ jobs:
TRADINGAGENTS_ARCHIVE_DIR: ${{ vars.TRADINGAGENTS_ARCHIVE_DIR }}
CODEX_BINARY: ${{ vars.CODEX_BINARY }}
CODEX_HOME: ${{ vars.CODEX_HOME }}
ALPHA_VANTAGE_API_KEY: ${{ secrets.ALPHA_VANTAGE_API_KEY }}
NAVER_CLIENT_ID: ${{ secrets.NAVER_CLIENT_ID }}
NAVER_CLIENT_SECRET: ${{ secrets.NAVER_CLIENT_SECRET }}
OPENDART_API_KEY: ${{ secrets.OPENDART_API_KEY }}
steps:
- name: Check out repository
uses: actions/checkout@v4

View File

@ -0,0 +1,31 @@
import unittest
from copy import deepcopy
from unittest.mock import Mock, patch
from tradingagents.default_config import DEFAULT_CONFIG
from tradingagents.graph.trading_graph import TradingAgentsGraph
class _DummyClient:
def __init__(self):
self._llm = Mock()
def get_llm(self):
return self._llm
class GraphConfigurationTests(unittest.TestCase):
@patch("tradingagents.graph.trading_graph.GraphSetup.setup_graph", return_value=Mock())
@patch("tradingagents.graph.trading_graph.create_llm_client", return_value=_DummyClient())
def test_max_recur_limit_propagates_to_graph_args(self, *_mocks):
config = deepcopy(DEFAULT_CONFIG)
config["max_recur_limit"] = 321
graph = TradingAgentsGraph(config=config, selected_analysts=["market"])
self.assertEqual(graph.propagator.max_recur_limit, 321)
self.assertEqual(graph.propagator.get_graph_args()["config"]["recursion_limit"], 321)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,42 @@
import unittest
from tradingagents.agents.utils.instrument_resolver import resolve_instrument
from tradingagents.graph.propagation import Propagator
class InstrumentResolverTests(unittest.TestCase):
def test_resolves_us_symbol(self):
profile = resolve_instrument("AAPL")
self.assertEqual(profile.primary_symbol, "AAPL")
self.assertEqual(profile.country, "US")
def test_resolves_exchange_qualified_krx_symbol(self):
profile = resolve_instrument("005930.KS")
self.assertEqual(profile.primary_symbol, "005930.KS")
self.assertEqual(profile.country, "KR")
def test_resolves_numeric_krx_code(self):
profile = resolve_instrument("005930")
self.assertEqual(profile.primary_symbol, "005930.KS")
def test_resolves_korean_company_name(self):
profile = resolve_instrument("삼성전자")
self.assertEqual(profile.primary_symbol, "005930.KS")
def test_resolves_known_krx_english_name(self):
profile = resolve_instrument("NAVER")
self.assertEqual(profile.primary_symbol, "035420.KS")
def test_resolves_known_krx_numeric_code(self):
profile = resolve_instrument("035420")
self.assertEqual(profile.primary_symbol, "035420.KS")
def test_propagator_normalizes_instrument_into_state(self):
state = Propagator().create_initial_state("삼성전자", "2026-01-15")
self.assertEqual(state["company_of_interest"], "005930.KS")
self.assertEqual(state["input_instrument"], "삼성전자")
self.assertEqual(state["instrument_profile"]["country"], "KR")
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,47 @@
import unittest
from tradingagents.dataflows.alpha_vantage_news import normalize_alpha_vantage_article
from tradingagents.dataflows.yfinance_news import normalize_yfinance_article
class NewsNormalizationTests(unittest.TestCase):
def test_yfinance_article_normalizes_to_news_item(self):
article = {
"content": {
"title": "Samsung wins order",
"summary": "Large customer order announced.",
"provider": {"displayName": "Unit Test"},
"canonicalUrl": {"url": "https://example.com/article"},
"pubDate": "2026-04-08T09:00:00Z",
"relatedTickers": ["005930.KS"],
}
}
item = normalize_yfinance_article(article, fallback_symbol="005930.KS")
self.assertEqual(item.raw_vendor, "yfinance")
self.assertEqual(item.title, "Samsung wins order")
self.assertIn("005930.KS", item.symbols)
def test_alpha_vantage_article_normalizes_to_news_item(self):
article = {
"title": "Apple demand improves",
"summary": "Demand commentary improved after launch.",
"source": "Alpha Source",
"url": "https://example.com/alpha",
"time_published": "20260408T090000",
"ticker_sentiment": [{"ticker": "AAPL"}],
"topics": [{"topic": "earnings"}],
"overall_sentiment_score": "0.25",
}
item = normalize_alpha_vantage_article(article, fallback_symbol="AAPL")
self.assertEqual(item.raw_vendor, "alpha_vantage")
self.assertEqual(item.title, "Apple demand improves")
self.assertEqual(item.sentiment, 0.25)
self.assertIn("AAPL", item.symbols)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,25 @@
import unittest
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
class PromptToolConsistencyTests(unittest.TestCase):
def test_social_prompt_matches_real_tool_signatures(self):
source = (ROOT / "tradingagents" / "agents" / "analysts" / "social_media_analyst.py").read_text(encoding="utf-8")
self.assertIn("get_social_sentiment(symbol, start_date, end_date)", source)
self.assertIn("get_company_news(symbol, start_date, end_date)", source)
self.assertNotIn("get_news(query, start_date, end_date)", source)
self.assertNotIn("social media posts", source.lower())
def test_news_prompt_matches_real_tool_signatures(self):
source = (ROOT / "tradingagents" / "agents" / "analysts" / "news_analyst.py").read_text(encoding="utf-8")
self.assertIn("get_company_news(symbol, start_date, end_date)", source)
self.assertIn("get_macro_news(curr_date, look_back_days, limit, region, language)", source)
self.assertIn("get_disclosures(symbol, start_date, end_date)", source)
self.assertNotIn("get_news(query, start_date, end_date)", source)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,16 @@
import unittest
from unittest.mock import patch
from tradingagents.dataflows.yfinance_news import get_social_sentiment_yfinance
class SocialSentimentTests(unittest.TestCase):
@patch("tradingagents.dataflows.yfinance_news.fetch_company_news_yfinance", return_value=([], None, None))
def test_social_sentiment_reports_news_derived_fallback_when_empty(self, _mock_fetch):
result = get_social_sentiment_yfinance("AAPL", "2026-04-01", "2026-04-02")
self.assertIn("Dedicated social provider unavailable", result)
self.assertIn("news-derived sentiment", result)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,35 @@
import unittest
from tradingagents.graph.signal_processing import SignalProcessor
from tradingagents.schemas import StructuredDecisionValidationError, parse_structured_decision
class StructuredDecisionTests(unittest.TestCase):
def test_valid_schema_parses_deterministically(self):
payload = """
{
"rating": "BUY",
"confidence": 0.78,
"time_horizon": "medium",
"entry_logic": "Buy on pullbacks above support.",
"exit_logic": "Exit on a break below support.",
"position_sizing": "Half position.",
"risk_limits": "Risk 1% of capital.",
"catalysts": ["Earnings beat"],
"invalidators": ["Guidance cut"]
}
"""
decision = parse_structured_decision(payload)
processor = SignalProcessor(None)
self.assertEqual(decision.rating.value, "BUY")
self.assertEqual(processor.process_signal(payload), "BUY")
def test_invalid_schema_raises_validation_error(self):
payload = '{"confidence": 0.5, "time_horizon": "short"}'
with self.assertRaises(StructuredDecisionValidationError):
parse_structured_decision(payload)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,92 @@
import unittest
from unittest.mock import patch
from tradingagents.dataflows.alpha_vantage_common import AlphaVantageRateLimitError
from tradingagents.dataflows.interface import route_to_vendor
from tradingagents.dataflows.vendor_exceptions import VendorConfigurationError, VendorInputError, VendorMalformedResponseError
class VendorFallbackTests(unittest.TestCase):
def test_rate_limit_falls_back_to_next_vendor(self):
with patch("tradingagents.dataflows.interface.get_vendor", return_value="alpha_vantage,yfinance"), patch.dict(
"tradingagents.dataflows.interface.VENDOR_METHODS",
{
"get_company_news": {
"alpha_vantage": lambda *_args, **_kwargs: (_ for _ in ()).throw(AlphaVantageRateLimitError("rate")),
"yfinance": lambda *_args, **_kwargs: "yfinance result",
}
},
clear=False,
):
result = route_to_vendor("get_company_news", "AAPL", "2026-04-01", "2026-04-02")
self.assertEqual(result, "yfinance result")
def test_generic_exception_falls_back_to_next_vendor(self):
with patch("tradingagents.dataflows.interface.get_vendor", return_value="alpha_vantage,yfinance"), patch.dict(
"tradingagents.dataflows.interface.VENDOR_METHODS",
{
"get_company_news": {
"alpha_vantage": lambda *_args, **_kwargs: (_ for _ in ()).throw(RuntimeError("boom")),
"yfinance": lambda *_args, **_kwargs: "fallback result",
}
},
clear=False,
):
result = route_to_vendor("get_company_news", "AAPL", "2026-04-01", "2026-04-02")
self.assertEqual(result, "fallback result")
def test_empty_result_falls_back_to_next_vendor(self):
with patch("tradingagents.dataflows.interface.get_vendor", return_value="alpha_vantage,yfinance"), patch.dict(
"tradingagents.dataflows.interface.VENDOR_METHODS",
{
"get_company_news": {
"alpha_vantage": lambda *_args, **_kwargs: "No news found for AAPL",
"yfinance": lambda *_args, **_kwargs: "usable result",
}
},
clear=False,
):
result = route_to_vendor("get_company_news", "AAPL", "2026-04-01", "2026-04-02")
self.assertEqual(result, "usable result")
def test_malformed_payload_falls_back_to_next_vendor(self):
with patch("tradingagents.dataflows.interface.get_vendor", return_value="alpha_vantage,yfinance"), patch.dict(
"tradingagents.dataflows.interface.VENDOR_METHODS",
{
"get_company_news": {
"alpha_vantage": lambda *_args, **_kwargs: (_ for _ in ()).throw(VendorMalformedResponseError("bad payload")),
"yfinance": lambda *_args, **_kwargs: "usable result",
}
},
clear=False,
):
result = route_to_vendor("get_company_news", "AAPL", "2026-04-01", "2026-04-02")
self.assertEqual(result, "usable result")
def test_invalid_user_input_raises_without_fallback(self):
with self.assertRaises(VendorInputError):
route_to_vendor("get_company_news", "AAPL", "2026/04/01", "2026-04-02")
def test_social_sentiment_degrades_gracefully_when_primary_provider_missing(self):
with patch("tradingagents.dataflows.interface.get_vendor", return_value="naver,yfinance"), patch.dict(
"tradingagents.dataflows.interface.VENDOR_METHODS",
{
"get_social_sentiment": {
"naver": lambda *_args, **_kwargs: (_ for _ in ()).throw(VendorConfigurationError("missing naver key")),
"yfinance": lambda *_args, **_kwargs: "Dedicated social provider unavailable; using news-derived sentiment for AAPL.",
}
},
clear=False,
):
result = route_to_vendor("get_social_sentiment", "AAPL", "2026-04-01", "2026-04-02")
self.assertIn("Dedicated social provider unavailable", result)
self.assertIn("news-derived sentiment", result)
if __name__ == "__main__":
unittest.main()

View File

@ -1,4 +1,5 @@
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from tradingagents.agents.utils.agent_utils import (
build_instrument_context,
get_balance_sheet,
@ -8,26 +9,31 @@ from tradingagents.agents.utils.agent_utils import (
get_insider_transactions,
get_language_instruction,
)
from tradingagents.dataflows.config import get_config
def create_fundamentals_analyst(llm):
def fundamentals_analyst_node(state):
current_date = state["trade_date"]
instrument_context = build_instrument_context(state["company_of_interest"])
instrument_context = build_instrument_context(
state["company_of_interest"],
state.get("instrument_profile"),
)
tools = [
get_fundamentals,
get_balance_sheet,
get_cashflow,
get_income_statement,
get_insider_transactions,
]
system_message = (
"You are a researcher tasked with analyzing fundamental information over the past week about a company. Please write a comprehensive report of the company's fundamental information such as financial documents, company profile, basic company financials, and company financial history to gain a full view of the company's fundamental information to inform traders. Make sure to include as much detail as possible. Provide specific, actionable insights with supporting evidence to help traders make informed decisions."
+ " Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read."
+ " Use the available tools: `get_fundamentals` for comprehensive company analysis, `get_balance_sheet`, `get_cashflow`, and `get_income_statement` for specific financial statements."
+ get_language_instruction(),
"You are a fundamentals analyst focused on medium-term business quality and event risk. "
"Center the report on recent disclosures, earnings quality, guidance changes, capital structure, cash flow, margins, insider transactions, and any notable balance-sheet shifts. "
"Use `get_fundamentals(ticker, curr_date)` for the overview, `get_balance_sheet`, `get_cashflow`, and `get_income_statement` for statement detail, and `get_insider_transactions(ticker)` for insider activity. "
"Do not frame this as only a past-week exercise; emphasize the latest reported fundamentals and the most recent event-driven changes that matter for traders."
" End with a Markdown table summarizing the main fundamental strengths, weaknesses, and watch items."
+ get_language_instruction()
)
prompt = ChatPromptTemplate.from_messages(
@ -38,10 +44,9 @@ def create_fundamentals_analyst(llm):
" Use the provided tools to progress towards answering the question."
" If you are unable to fully answer, that's OK; another assistant with different tools"
" will help where you left off. Execute what you can to make progress."
" If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable,"
" prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop."
" Return the completed fundamentals report directly once you have enough evidence."
" You have access to the following tools: {tool_names}.\n{system_message}"
"For your reference, the current date is {current_date}. {instrument_context}",
" For your reference, the current date is {current_date}. {instrument_context}",
),
MessagesPlaceholder(variable_name="messages"),
]
@ -53,11 +58,9 @@ def create_fundamentals_analyst(llm):
prompt = prompt.partial(instrument_context=instrument_context)
chain = prompt | llm.bind_tools(tools)
result = chain.invoke(state["messages"])
report = ""
if len(result.tool_calls) == 0:
report = result.content

View File

@ -12,7 +12,10 @@ def create_market_analyst(llm):
def market_analyst_node(state):
current_date = state["trade_date"]
instrument_context = build_instrument_context(state["company_of_interest"])
instrument_context = build_instrument_context(
state["company_of_interest"],
state.get("instrument_profile"),
)
tools = [
get_stock_data,
@ -20,7 +23,7 @@ def create_market_analyst(llm):
]
system_message = (
"""You are a trading assistant tasked with analyzing financial markets. Your role is to select the **most relevant indicators** for a given market condition or trading strategy from the following list. The goal is to choose up to **8 indicators** that provide complementary insights without redundancy. Categories and each category's indicators are:
"""You are a trading assistant tasked with analyzing market regime first and indicators second. Start by classifying the regime as trending up, trending down, range-bound, high-volatility, or event-driven. Then select the **most relevant indicators** for that regime from the following list. The goal is to choose up to **8 indicators** that provide complementary insights without redundancy. Categories and each category's indicators are:
Moving Averages:
- close_50_sma: 50 SMA: A medium-term trend indicator. Usage: Identify trend direction and serve as dynamic support/resistance. Tips: It lags price; combine with faster indicators for timely signals.
@ -44,7 +47,7 @@ Volatility Indicators:
Volume-Based Indicators:
- vwma: VWMA: A moving average weighted by volume. Usage: Confirm trends by integrating price action with volume data. Tips: Watch for skewed results from volume spikes; use in combination with other volume analyses.
- Select indicators that provide diverse and complementary information. Avoid redundancy (e.g., do not select both rsi and stochrsi). Also briefly explain why they are suitable for the given market context. When you tool call, please use the exact name of the indicators provided above as they are defined parameters, otherwise your call will fail. Please make sure to call get_stock_data first to retrieve the CSV that is needed to generate indicators. Then use get_indicators with the specific indicator names. Write a very detailed and nuanced report of the trends you observe. Provide specific, actionable insights with supporting evidence to help traders make informed decisions."""
- Select indicators that provide diverse and complementary information. Avoid redundancy. Also briefly explain why they are suitable for the detected regime and connect them to volatility, liquidity, and event risk where possible. When you tool call, please use the exact name of the indicators provided above as they are defined parameters, otherwise your call will fail. Please make sure to call get_stock_data first to retrieve the CSV that is needed to generate indicators. Then use get_indicators with the specific indicator names. Write a detailed report grounded in observable features rather than vague narrative. Provide specific, actionable insights with supporting evidence to help traders make informed decisions."""
+ """ Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read."""
+ get_language_instruction()
)
@ -57,8 +60,7 @@ Volume-Based Indicators:
" Use the provided tools to progress towards answering the question."
" If you are unable to fully answer, that's OK; another assistant with different tools"
" will help where you left off. Execute what you can to make progress."
" If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable,"
" prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop."
" Return the completed market report directly once you have enough evidence."
" You have access to the following tools: {tool_names}.\n{system_message}"
"For your reference, the current date is {current_date}. {instrument_context}",
),

View File

@ -1,26 +1,37 @@
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from tradingagents.agents.utils.agent_utils import (
build_instrument_context,
get_global_news,
get_company_news,
get_disclosures,
get_language_instruction,
get_news,
get_macro_news,
)
from tradingagents.dataflows.config import get_config
def create_news_analyst(llm):
def news_analyst_node(state):
current_date = state.get("analysis_date") or state["trade_date"]
instrument_context = build_instrument_context(state["company_of_interest"])
instrument_context = build_instrument_context(
state["company_of_interest"],
state.get("instrument_profile"),
)
tools = [
get_news,
get_global_news,
get_company_news,
get_macro_news,
get_disclosures,
]
system_message = (
"You are a news researcher tasked with analyzing recent news and trends over the past week. Please write a comprehensive report of the current state of the world that is relevant for trading and macroeconomics. Use the available tools: get_news(query, start_date, end_date) for company-specific or targeted news searches, and get_global_news(curr_date, look_back_days, limit) for broader macroeconomic news. Provide specific, actionable insights with supporting evidence to help traders make informed decisions."
+ """ Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read."""
"You are a news and event analyst. "
"Build the report from three evidence blocks: company news, macro news, and disclosures. "
"Use `get_company_news(symbol, start_date, end_date)` for company-specific coverage, "
"`get_macro_news(curr_date, look_back_days, limit, region, language)` for broader market context, "
"and `get_disclosures(symbol, start_date, end_date)` for filing or disclosure events when available. "
"Do not describe unsupported tool signatures or imaginary search capabilities. "
"Present 3 to 5 key events with event type, source, why it matters, bullish implication, bearish implication, and confidence. "
"Finish with a concise Markdown table summarizing the evidence."
+ get_language_instruction()
)
@ -32,10 +43,9 @@ def create_news_analyst(llm):
" Use the provided tools to progress towards answering the question."
" If you are unable to fully answer, that's OK; another assistant with different tools"
" will help where you left off. Execute what you can to make progress."
" If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable,"
" prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop."
" Return the completed news report directly once you have enough evidence."
" You have access to the following tools: {tool_names}.\n{system_message}"
"For your reference, the current date is {current_date}. {instrument_context}",
" For your reference, the current date is {current_date}. {instrument_context}",
),
MessagesPlaceholder(variable_name="messages"),
]
@ -50,7 +60,6 @@ def create_news_analyst(llm):
result = chain.invoke(state["messages"])
report = ""
if len(result.tool_calls) == 0:
report = result.content

View File

@ -1,20 +1,33 @@
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from tradingagents.agents.utils.agent_utils import build_instrument_context, get_language_instruction, get_news
from tradingagents.dataflows.config import get_config
from tradingagents.agents.utils.agent_utils import (
build_instrument_context,
get_company_news,
get_language_instruction,
get_social_sentiment,
)
def create_social_media_analyst(llm):
def social_media_analyst_node(state):
current_date = state.get("analysis_date") or state["trade_date"]
instrument_context = build_instrument_context(state["company_of_interest"])
instrument_context = build_instrument_context(
state["company_of_interest"],
state.get("instrument_profile"),
)
tools = [
get_news,
get_social_sentiment,
get_company_news,
]
system_message = (
"You are a social media and company specific news researcher/analyst tasked with analyzing social media posts, recent company news, and public sentiment for a specific company over the past week. You will be given a company's name your objective is to write a comprehensive long report detailing your analysis, insights, and implications for traders and investors on this company's current state after looking at social media and what people are saying about that company, analyzing sentiment data of what people feel each day about the company, and looking at recent company news. Use the get_news(query, start_date, end_date) tool to search for company-specific news and social media discussions. Try to look at all sources possible from social media to sentiment to news. Provide specific, actionable insights with supporting evidence to help traders make informed decisions."
+ """ Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read."""
"You are a company sentiment analyst. "
"Your job is to assess public narrative, sentiment, and crowd positioning around the company without claiming direct social-media coverage unless a tool explicitly provides it. "
"Use `get_social_sentiment(symbol, start_date, end_date)` for dedicated or clearly labeled news-derived sentiment context, and `get_company_news(symbol, start_date, end_date)` for direct company-news evidence. "
"If the sentiment tool says a dedicated social provider is unavailable, explicitly state that you are working from news-derived sentiment instead of pretending you saw social posts. "
"Write a detailed report covering sentiment drivers, tone shifts, narrative concentration, what is improving, what is deteriorating, and the main trading implications."
" End with a Markdown table that summarizes key signals, evidence, and confidence."
+ get_language_instruction()
)
@ -26,10 +39,9 @@ def create_social_media_analyst(llm):
" Use the provided tools to progress towards answering the question."
" If you are unable to fully answer, that's OK; another assistant with different tools"
" will help where you left off. Execute what you can to make progress."
" If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable,"
" prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop."
" Return the completed sentiment report directly once you have enough evidence."
" You have access to the following tools: {tool_names}.\n{system_message}"
"For your reference, the current date is {current_date}. {instrument_context}",
" For your reference, the current date is {current_date}. {instrument_context}",
),
MessagesPlaceholder(variable_name="messages"),
]
@ -41,11 +53,9 @@ def create_social_media_analyst(llm):
prompt = prompt.partial(instrument_context=instrument_context)
chain = prompt | llm.bind_tools(tools)
result = chain.invoke(state["messages"])
report = ""
if len(result.tool_calls) == 0:
report = result.content

View File

@ -1,10 +1,17 @@
from tradingagents.agents.utils.agent_utils import build_instrument_context, get_language_instruction
from tradingagents.agents.utils.agent_utils import (
build_instrument_context,
get_language_instruction,
get_memory_matches,
)
from tradingagents.schemas import build_decision_output_instructions, ensure_structured_decision_json
def create_portfolio_manager(llm, memory):
def portfolio_manager_node(state) -> dict:
instrument_context = build_instrument_context(state["company_of_interest"])
instrument_context = build_instrument_context(
state["company_of_interest"],
state.get("instrument_profile"),
)
history = state["risk_debate_state"]["history"]
risk_debate_state = state["risk_debate_state"]
@ -16,48 +23,35 @@ def create_portfolio_manager(llm, memory):
trader_plan = state["trader_investment_plan"]
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
past_memories = memory.get_memories(curr_situation, n_matches=2)
past_memories = get_memory_matches(memory, curr_situation)
past_memory_str = ""
for i, rec in enumerate(past_memories, 1):
for rec in past_memories:
past_memory_str += rec["recommendation"] + "\n\n"
prompt = f"""As the Portfolio Manager, synthesize the risk analysts' debate and deliver the final trading decision.
{instrument_context}
---
Use the common decision schema and be explicit about rating, confidence, time horizon, entry logic, exit logic, position sizing, risk limits, catalysts, and invalidators.
NO_TRADE is allowed and should be used whenever the setup is not compelling enough to allocate capital.
**Rating Scale** (use exactly one):
- **Buy**: Strong conviction to enter or add to position
- **Overweight**: Favorable outlook, gradually increase exposure
- **Hold**: Maintain current position, no action needed
- **Underweight**: Reduce exposure, take partial profits
- **Sell**: Exit position or avoid entry
Context:
- Research Manager investment plan JSON: {research_plan}
- Trader execution plan JSON: {trader_plan}
- Lessons from past decisions: {past_memory_str or "No past reflections available."}
**Context:**
- Research Manager's investment plan: **{research_plan}**
- Trader's transaction proposal: **{trader_plan}**
- Lessons from past decisions: **{past_memory_str}**
**Required Output Structure:**
1. **Rating**: State one of Buy / Overweight / Hold / Underweight / Sell.
2. **Executive Summary**: A concise action plan covering entry strategy, position sizing, key risk levels, and time horizon.
3. **Investment Thesis**: Detailed reasoning anchored in the analysts' debate and past reflections.
---
**Risk Analysts Debate History:**
Risk Analysts Debate History:
{history}
---
Be decisive and ground every conclusion in specific evidence from the analysts.{get_language_instruction()}"""
Ground every conclusion in specific evidence from the analysts. {get_language_instruction()}
{build_decision_output_instructions("portfolio manager final decision")}"""
response = llm.invoke(prompt)
decision_json = ensure_structured_decision_json(response.content)
new_risk_debate_state = {
"judge_decision": response.content,
"judge_decision": decision_json,
"history": risk_debate_state["history"],
"aggressive_history": risk_debate_state["aggressive_history"],
"conservative_history": risk_debate_state["conservative_history"],
@ -71,7 +65,7 @@ Be decisive and ground every conclusion in specific evidence from the analysts.{
return {
"risk_debate_state": new_risk_debate_state,
"final_trade_decision": response.content,
"final_trade_decision": decision_json,
}
return portfolio_manager_node

View File

@ -1,58 +1,72 @@
from tradingagents.agents.utils.agent_utils import build_instrument_context
from tradingagents.agents.utils.agent_utils import build_instrument_context, get_memory_matches
from tradingagents.schemas import build_decision_output_instructions, ensure_structured_decision_json
def create_research_manager(llm, memory):
def research_manager_node(state) -> dict:
instrument_context = build_instrument_context(state["company_of_interest"])
instrument_context = build_instrument_context(
state["company_of_interest"],
state.get("instrument_profile"),
)
history = state["investment_debate_state"].get("history", "")
market_research_report = state["market_report"]
sentiment_report = state["sentiment_report"]
news_report = state["news_report"]
fundamentals_report = state["fundamentals_report"]
investment_debate_state = state["investment_debate_state"]
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
past_memories = memory.get_memories(curr_situation, n_matches=2)
past_memories = get_memory_matches(memory, curr_situation)
past_memory_str = ""
for i, rec in enumerate(past_memories, 1):
for rec in past_memories:
past_memory_str += rec["recommendation"] + "\n\n"
prompt = f"""As the portfolio manager and debate facilitator, your role is to critically evaluate this round of debate and make a definitive decision: align with the bear analyst, the bull analyst, or choose Hold only if it is strongly justified based on the arguments presented.
Summarize the key points from both sides concisely, focusing on the most compelling evidence or reasoning. Your recommendationBuy, Sell, or Holdmust be clear and actionable. Avoid defaulting to Hold simply because both sides have valid points; commit to a stance grounded in the debate's strongest arguments.
Additionally, develop a detailed investment plan for the trader. This should include:
Your Recommendation: A decisive stance supported by the most convincing arguments.
Rationale: An explanation of why these arguments lead to your conclusion.
Strategic Actions: Concrete steps for implementing the recommendation.
Take into account your past mistakes on similar situations. Use these insights to refine your decision-making and ensure you are learning and improving. Present your analysis conversationally, as if speaking naturally, without special formatting.
Here are your past reflections on mistakes:
\"{past_memory_str}\"
prompt = f"""As the research manager and evidence arbiter, critically evaluate the bull and bear debate and produce a structured investment view for the trader.
{instrument_context}
Here is the debate:
Your job:
- weigh the strongest bullish and bearish evidence
- reduce action bias; use NO_TRADE when the evidence is too weak or conflicted
- focus on evidence arbitration rather than rhetorical style
- make the catalysts and invalidators explicit
Use these objective reports for grounding:
Market Report:
{market_research_report}
Sentiment Report:
{sentiment_report}
News Report:
{news_report}
Fundamentals Report:
{fundamentals_report}
Debate History:
{history}"""
{history}
Lessons from past mistakes:
{past_memory_str or "No past reflections available."}
{build_decision_output_instructions("research manager investment plan")}"""
response = llm.invoke(prompt)
decision_json = ensure_structured_decision_json(response.content)
new_investment_debate_state = {
"judge_decision": response.content,
"judge_decision": decision_json,
"history": investment_debate_state.get("history", ""),
"bear_history": investment_debate_state.get("bear_history", ""),
"bull_history": investment_debate_state.get("bull_history", ""),
"current_response": response.content,
"current_response": decision_json,
"count": investment_debate_state["count"],
}
return {
"investment_debate_state": new_investment_debate_state,
"investment_plan": response.content,
"investment_plan": decision_json,
}
return research_manager_node

View File

@ -1,3 +1,4 @@
from tradingagents.agents.utils.agent_utils import get_memory_matches
def create_bear_researcher(llm, memory):
@ -13,7 +14,7 @@ def create_bear_researcher(llm, memory):
fundamentals_report = state["fundamentals_report"]
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
past_memories = memory.get_memories(curr_situation, n_matches=2)
past_memories = get_memory_matches(memory, curr_situation)
past_memory_str = ""
for i, rec in enumerate(past_memories, 1):

View File

@ -1,3 +1,4 @@
from tradingagents.agents.utils.agent_utils import get_memory_matches
def create_bull_researcher(llm, memory):
@ -13,7 +14,7 @@ def create_bull_researcher(llm, memory):
fundamentals_report = state["fundamentals_report"]
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
past_memories = memory.get_memories(curr_situation, n_matches=2)
past_memories = get_memory_matches(memory, curr_situation)
past_memory_str = ""
for i, rec in enumerate(past_memories, 1):

View File

@ -1,12 +1,13 @@
import functools
from tradingagents.agents.utils.agent_utils import build_instrument_context
from tradingagents.agents.utils.agent_utils import build_instrument_context, get_memory_matches
from tradingagents.schemas import build_decision_output_instructions, ensure_structured_decision_json
def create_trader(llm, memory):
def trader_node(state, name):
company_name = state["company_of_interest"]
instrument_context = build_instrument_context(company_name)
instrument_context = build_instrument_context(company_name, state.get("instrument_profile"))
investment_plan = state["investment_plan"]
market_research_report = state["market_report"]
sentiment_report = state["sentiment_report"]
@ -14,33 +15,45 @@ def create_trader(llm, memory):
fundamentals_report = state["fundamentals_report"]
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
past_memories = memory.get_memories(curr_situation, n_matches=2)
past_memories = get_memory_matches(memory, curr_situation)
past_memory_str = ""
if past_memories:
for i, rec in enumerate(past_memories, 1):
for rec in past_memories:
past_memory_str += rec["recommendation"] + "\n\n"
else:
past_memory_str = "No past memories found."
context = {
"role": "user",
"content": f"Based on a comprehensive analysis by a team of analysts, here is an investment plan tailored for {company_name}. {instrument_context} This plan incorporates insights from current technical market trends, macroeconomic indicators, and social media sentiment. Use this plan as a foundation for evaluating your next trading decision.\n\nProposed Investment Plan: {investment_plan}\n\nLeverage these insights to make an informed and strategic decision.",
"content": (
f"Based on a comprehensive analysis by a team of analysts, here is an investment plan tailored for {company_name}. "
f"{instrument_context} This plan incorporates insights from market trends, macro context, sentiment, news, and fundamentals. "
f"Use this plan as a foundation for your execution decision.\n\nProposed Investment Plan JSON: {investment_plan}\n\n"
"Leverage these insights to make an informed and strategic decision."
),
}
messages = [
{
"role": "system",
"content": f"""You are a trading agent analyzing market data to make investment decisions. Based on your analysis, provide a specific recommendation to buy, sell, or hold. End with a firm decision and always conclude your response with 'FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL**' to confirm your recommendation. Apply lessons from past decisions to strengthen your analysis. Here are reflections from similar situations you traded in and the lessons learned: {past_memory_str}""",
"content": (
"You are a trading agent analyzing market data to make execution-ready investment decisions. "
"Translate the research manager's view into a concrete trade recommendation with entry logic, exit logic, position sizing, risk limits, catalysts, and invalidators. "
"Use NO_TRADE when the setup is not actionable or lacks a favorable risk/reward. "
f"Apply lessons from similar situations: {past_memory_str} "
f"{build_decision_output_instructions('trader execution plan')}"
),
},
context,
]
result = llm.invoke(messages)
decision_json = ensure_structured_decision_json(result.content)
return {
"messages": [result],
"trader_investment_plan": result.content,
"trader_investment_plan": decision_json,
"sender": name,
}

View File

@ -44,7 +44,9 @@ class RiskDebateState(TypedDict):
class AgentState(MessagesState):
input_instrument: Annotated[str, "Original instrument input provided by the user"]
company_of_interest: Annotated[str, "Company that we are interested in trading"]
instrument_profile: Annotated[dict, "Normalized instrument metadata"]
trade_date: Annotated[str, "What date we are trading at"]
analysis_date: Annotated[str, "What date the full analysis is being generated on"]

View File

@ -15,10 +15,15 @@ from tradingagents.agents.utils.fundamental_data_tools import (
get_income_statement
)
from tradingagents.agents.utils.news_data_tools import (
get_company_news,
get_disclosures,
get_macro_news,
get_news,
get_insider_transactions,
get_global_news
get_global_news,
get_social_sentiment,
)
from tradingagents.agents.utils.instrument_resolver import InstrumentProfile
def get_language_instruction() -> str:
@ -111,14 +116,37 @@ def _normalize_localized_finance_terms(content: str, language: str) -> str:
return normalized
def build_instrument_context(ticker: str) -> str:
def build_instrument_context(
ticker: str,
instrument_profile: dict | InstrumentProfile | None = None,
) -> str:
"""Describe the exact instrument so agents preserve exchange-qualified tickers."""
profile = instrument_profile.to_dict() if isinstance(instrument_profile, InstrumentProfile) else instrument_profile
if not profile:
return (
f"The instrument to analyze is `{ticker}`. "
"Use this exact ticker in every tool call, report, and recommendation, "
"preserving any exchange suffix (e.g. `.TO`, `.L`, `.HK`, `.T`)."
)
display_name = profile.get("display_name") or ticker
primary_symbol = profile.get("primary_symbol") or ticker
exchange = profile.get("exchange") or "unknown exchange"
country = profile.get("country") or "unknown country"
timezone = profile.get("timezone") or "unknown timezone"
currency = profile.get("currency") or "unknown currency"
return (
f"The instrument to analyze is `{ticker}`. "
"Use this exact ticker in every tool call, report, and recommendation, "
"preserving any exchange suffix (e.g. `.TO`, `.L`, `.HK`, `.T`)."
f"The instrument to analyze is `{primary_symbol}` ({display_name}). "
f"It trades on {exchange} in {country}, with market timezone {timezone} and reporting currency {currency}. "
"Use the normalized primary symbol in every tool call, report, and recommendation, "
"preserving any exchange suffix."
)
def get_memory_matches(memory, current_situation: str, n_matches: int | None = None):
"""Retrieve memory matches while respecting configurable defaults."""
return memory.get_memories(current_situation, n_matches=n_matches)
def create_msg_delete():
def delete_messages(state):
"""Clear messages and add placeholder for Anthropic compatibility"""

View File

@ -0,0 +1,103 @@
from __future__ import annotations
import re
from dataclasses import dataclass
class InstrumentResolutionError(ValueError):
"""Raised when a user-provided instrument cannot be normalized."""
@dataclass(frozen=True)
class InstrumentProfile:
display_name: str
primary_symbol: str
exchange: str
country: str
timezone: str
currency: str
yahoo_symbol: str | None = None
krx_code: str | None = None
dart_corp_code: str | None = None
def to_dict(self) -> dict[str, str | None]:
return {
"display_name": self.display_name,
"primary_symbol": self.primary_symbol,
"exchange": self.exchange,
"country": self.country,
"timezone": self.timezone,
"currency": self.currency,
"yahoo_symbol": self.yahoo_symbol or self.primary_symbol,
"krx_code": self.krx_code,
"dart_corp_code": self.dart_corp_code,
}
_KRX_ALIAS_MAP = {
"삼성전자": ("삼성전자", "005930.KS", "005930"),
"SAMSUNG ELECTRONICS": ("삼성전자", "005930.KS", "005930"),
"005930": ("삼성전자", "005930.KS", "005930"),
"005930.KS": ("삼성전자", "005930.KS", "005930"),
"NAVER": ("NAVER", "035420.KS", "035420"),
"035420": ("NAVER", "035420.KS", "035420"),
"035420.KS": ("NAVER", "035420.KS", "035420"),
}
def is_krx_symbol(symbol: str) -> bool:
return bool(re.fullmatch(r"\d{6}\.(KS|KQ)", symbol.upper()))
def resolve_instrument(user_input: str) -> InstrumentProfile:
raw_value = (user_input or "").strip()
if not raw_value:
raise InstrumentResolutionError("Instrument input is empty.")
alias_key = raw_value.upper()
if raw_value in _KRX_ALIAS_MAP:
display_name, symbol, code = _KRX_ALIAS_MAP[raw_value]
return _build_krx_profile(display_name, symbol, code)
if alias_key in _KRX_ALIAS_MAP:
display_name, symbol, code = _KRX_ALIAS_MAP[alias_key]
return _build_krx_profile(display_name, symbol, code)
upper = raw_value.upper()
if is_krx_symbol(upper):
code = upper.split(".", 1)[0]
display_name = _KRX_ALIAS_MAP.get(upper, (code, upper, code))[0]
return _build_krx_profile(display_name, upper, code)
if re.fullmatch(r"\d{6}", raw_value):
symbol = f"{raw_value}.KS"
display_name = _KRX_ALIAS_MAP.get(raw_value, (raw_value, symbol, raw_value))[0]
return _build_krx_profile(display_name, symbol, raw_value)
if re.fullmatch(r"[A-Za-z][A-Za-z0-9.\-]{0,14}", raw_value):
return InstrumentProfile(
display_name=upper,
primary_symbol=upper,
exchange="US",
country="US",
timezone="US/Eastern",
currency="USD",
yahoo_symbol=upper,
)
raise InstrumentResolutionError(
f"Could not resolve instrument '{user_input}'. Pass an exchange-qualified ticker or a known company name/code."
)
def _build_krx_profile(display_name: str, primary_symbol: str, krx_code: str) -> InstrumentProfile:
exchange = "KOSDAQ" if primary_symbol.endswith(".KQ") else "KRX"
return InstrumentProfile(
display_name=display_name,
primary_symbol=primary_symbol,
exchange=exchange,
country="KR",
timezone="Asia/Seoul",
currency="KRW",
yahoo_symbol=primary_symbol,
krx_code=krx_code,
)

View File

@ -1,144 +1,120 @@
"""Financial situation memory using BM25 for lexical similarity matching.
"""Financial situation memory using hybrid BM25 plus regime-tag retrieval."""
Uses BM25 (Best Matching 25) algorithm for retrieval - no API calls,
no token limits, works offline with any LLM provider.
"""
from __future__ import annotations
from rank_bm25 import BM25Okapi
from typing import List, Tuple
from typing import Any, List, Tuple
import re
class FinancialSituationMemory:
"""Memory system for storing and retrieving financial situations using BM25."""
"""Memory system for storing and retrieving financial situations."""
def __init__(self, name: str, config: dict = None):
"""Initialize the memory system.
Args:
name: Name identifier for this memory instance
config: Configuration dict (kept for API compatibility, not used for BM25)
"""
def __init__(self, name: str, config: dict | None = None):
self.name = name
self.config = config or {}
self.documents: List[str] = []
self.recommendations: List[str] = []
self.metadata: List[dict[str, Any]] = []
self.bm25 = None
self.default_n_matches = int(self.config.get("memory_n_matches", 2))
def _tokenize(self, text: str) -> List[str]:
"""Tokenize text for BM25 indexing.
return re.findall(r"\b\w+\b", text.lower())
Simple whitespace + punctuation tokenization with lowercasing.
"""
# Lowercase and split on non-alphanumeric characters
tokens = re.findall(r'\b\w+\b', text.lower())
return tokens
def _extract_regime_tags(self, text: str) -> set[str]:
lowered = text.lower()
tags: set[str] = set()
keyword_groups = {
"volatility": ("volatility", "atr", "drawdown", "swing", "high-volatility"),
"trend_up": ("uptrend", "trending up", "breakout", "bullish", "momentum"),
"trend_down": ("downtrend", "trending down", "selloff", "bearish", "breakdown"),
"range_bound": ("range-bound", "sideways", "consolidation", "choppy"),
"rates": ("interest rate", "fed", "fomc", "yield", "monetary"),
"earnings": ("earnings", "guidance", "quarter", "revenue", "eps"),
"insider": ("insider", "buyback", "share issuance"),
"kr": ("krx", ".ks", ".kq", "korea", "한국", "", "krw"),
"us": ("nasdaq", "nyse", "usd", "federal reserve", "u.s.", "us/eastern"),
"sentiment": ("sentiment", "narrative", "social", "headline"),
"macro": ("inflation", "cpi", "gdp", "macro", "employment"),
}
for tag, keywords in keyword_groups.items():
if any(keyword in lowered for keyword in keywords):
tags.add(tag)
return tags
def _rebuild_index(self):
"""Rebuild the BM25 index after adding documents."""
if self.documents:
tokenized_docs = [self._tokenize(doc) for doc in self.documents]
self.bm25 = BM25Okapi(tokenized_docs)
else:
self.bm25 = None
def add_situations(self, situations_and_advice: List[Tuple[str, str]]):
"""Add financial situations and their corresponding advice.
def add_situations(self, situations_and_advice: List[Tuple]):
for item in situations_and_advice:
if len(item) == 2:
situation, recommendation = item
metadata = {}
elif len(item) == 3:
situation, recommendation, metadata = item
else:
raise ValueError("Each memory entry must be (situation, recommendation) or (situation, recommendation, metadata).")
Args:
situations_and_advice: List of tuples (situation, recommendation)
"""
for situation, recommendation in situations_and_advice:
self.documents.append(situation)
self.recommendations.append(recommendation)
combined_metadata = dict(metadata or {})
combined_metadata.setdefault("regime_tags", sorted(self._extract_regime_tags(str(situation))))
self.documents.append(str(situation))
self.recommendations.append(str(recommendation))
self.metadata.append(combined_metadata)
# Rebuild BM25 index with new documents
self._rebuild_index()
def get_memories(self, current_situation: str, n_matches: int = 1) -> List[dict]:
"""Find matching recommendations using BM25 similarity.
Args:
current_situation: The current financial situation to match against
n_matches: Number of top matches to return
Returns:
List of dicts with matched_situation, recommendation, and similarity_score
"""
def get_memories(
self,
current_situation: str,
n_matches: int | None = None,
metadata_filters: dict[str, Any] | None = None,
) -> List[dict]:
if not self.documents or self.bm25 is None:
return []
# Tokenize query
limit = n_matches if n_matches is not None else self.default_n_matches
query_tokens = self._tokenize(current_situation)
# Get BM25 scores for all documents
query_tags = self._extract_regime_tags(current_situation)
scores = self.bm25.get_scores(query_tokens)
max_score = max(scores) if max(scores) > 0 else 1
# Get top-n indices sorted by score (descending)
top_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:n_matches]
ranked_results = []
for idx, score in enumerate(scores):
metadata = self.metadata[idx] if idx < len(self.metadata) else {}
if metadata_filters:
if any(metadata.get(key) != value for key, value in metadata_filters.items()):
continue
normalized_bm25 = score / max_score if max_score > 0 else 0
doc_tags = set(metadata.get("regime_tags", []))
tag_score = len(query_tags & doc_tags) / len(query_tags | doc_tags) if (query_tags or doc_tags) else 0
hybrid_score = 0.75 * normalized_bm25 + 0.25 * tag_score
ranked_results.append((hybrid_score, normalized_bm25, tag_score, idx, metadata))
ranked_results.sort(key=lambda item: item[0], reverse=True)
# Build results
results = []
max_score = max(scores) if max(scores) > 0 else 1 # Normalize scores
for idx in top_indices:
# Normalize score to 0-1 range for consistency
normalized_score = scores[idx] / max_score if max_score > 0 else 0
results.append({
"matched_situation": self.documents[idx],
"recommendation": self.recommendations[idx],
"similarity_score": normalized_score,
})
for hybrid_score, normalized_bm25, tag_score, idx, metadata in ranked_results[:limit]:
results.append(
{
"matched_situation": self.documents[idx],
"recommendation": self.recommendations[idx],
"similarity_score": hybrid_score,
"bm25_score": normalized_bm25,
"tag_overlap_score": tag_score,
"metadata": metadata,
}
)
return results
def clear(self):
"""Clear all stored memories."""
self.documents = []
self.recommendations = []
self.metadata = []
self.bm25 = None
if __name__ == "__main__":
# Example usage
matcher = FinancialSituationMemory("test_memory")
# Example data
example_data = [
(
"High inflation rate with rising interest rates and declining consumer spending",
"Consider defensive sectors like consumer staples and utilities. Review fixed-income portfolio duration.",
),
(
"Tech sector showing high volatility with increasing institutional selling pressure",
"Reduce exposure to high-growth tech stocks. Look for value opportunities in established tech companies with strong cash flows.",
),
(
"Strong dollar affecting emerging markets with increasing forex volatility",
"Hedge currency exposure in international positions. Consider reducing allocation to emerging market debt.",
),
(
"Market showing signs of sector rotation with rising yields",
"Rebalance portfolio to maintain target allocations. Consider increasing exposure to sectors benefiting from higher rates.",
),
]
# Add the example situations and recommendations
matcher.add_situations(example_data)
# Example query
current_situation = """
Market showing increased volatility in tech sector, with institutional investors
reducing positions and rising interest rates affecting growth stock valuations
"""
try:
recommendations = matcher.get_memories(current_situation, n_matches=2)
for i, rec in enumerate(recommendations, 1):
print(f"\nMatch {i}:")
print(f"Similarity Score: {rec['similarity_score']:.2f}")
print(f"Matched Situation: {rec['matched_situation']}")
print(f"Recommendation: {rec['recommendation']}")
except Exception as e:
print(f"Error during recommendation: {str(e)}")

View File

@ -1,53 +1,104 @@
from langchain_core.tools import tool
from typing import Annotated
from langchain_core.tools import tool
from tradingagents.dataflows.interface import route_to_vendor
@tool
def get_news(
ticker: Annotated[str, "Ticker symbol"],
def get_company_news(
symbol: Annotated[str, "Exchange-qualified ticker symbol, such as AAPL or 005930.KS"],
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
) -> str:
"""
Retrieve news data for a given ticker symbol.
Uses the configured news_data vendor.
Args:
ticker (str): Ticker symbol
start_date (str): Start date in yyyy-mm-dd format
end_date (str): End date in yyyy-mm-dd format
Returns:
str: A formatted string containing news data
Retrieve company-specific news for a ticker symbol.
Uses the configured news_data vendor chain.
"""
return route_to_vendor("get_company_news", symbol, start_date, end_date)
@tool
def get_news(
ticker: Annotated[str, "Exchange-qualified ticker symbol, such as AAPL or 005930.KS"],
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
) -> str:
"""
Backward-compatible thin wrapper for company news.
This tool is equivalent to get_company_news(ticker, start_date, end_date).
"""
return route_to_vendor("get_news", ticker, start_date, end_date)
@tool
def get_macro_news(
curr_date: Annotated[str, "Current date in yyyy-mm-dd format"],
look_back_days: Annotated[int, "Number of calendar days to look back"] = 7,
limit: Annotated[int, "Maximum number of articles or macro items to return"] = 10,
region: Annotated[str | None, "Optional region hint such as US, KR, or GLOBAL"] = None,
language: Annotated[str | None, "Optional language hint such as en or ko"] = None,
) -> str:
"""
Retrieve macro and broader market context news.
Uses the configured macro_data vendor chain.
"""
return route_to_vendor(
"get_macro_news",
curr_date,
look_back_days,
limit,
region=region,
language=language,
)
@tool
def get_global_news(
curr_date: Annotated[str, "Current date in yyyy-mm-dd format"],
look_back_days: Annotated[int, "Number of days to look back"] = 7,
limit: Annotated[int, "Maximum number of articles to return"] = 5,
look_back_days: Annotated[int, "Number of calendar days to look back"] = 7,
limit: Annotated[int, "Maximum number of articles or macro items to return"] = 10,
) -> str:
"""
Retrieve global news data.
Uses the configured news_data vendor.
Args:
curr_date (str): Current date in yyyy-mm-dd format
look_back_days (int): Number of days to look back (default 7)
limit (int): Maximum number of articles to return (default 5)
Returns:
str: A formatted string containing global news data
Backward-compatible thin wrapper for macro news.
This tool is equivalent to get_macro_news(curr_date, look_back_days, limit).
"""
return route_to_vendor("get_global_news", curr_date, look_back_days, limit)
@tool
def get_insider_transactions(
ticker: Annotated[str, "ticker symbol"],
def get_disclosures(
symbol: Annotated[str, "Exchange-qualified ticker symbol, such as 005930.KS"],
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
) -> str:
"""
Retrieve insider transaction information about a company.
Uses the configured news_data vendor.
Args:
ticker (str): Ticker symbol of the company
Returns:
str: A report of insider transaction data
Retrieve company disclosures and filing events for a ticker symbol.
Uses the configured disclosure_data vendor chain.
"""
return route_to_vendor("get_disclosures", symbol, start_date, end_date)
@tool
def get_social_sentiment(
symbol: Annotated[str, "Exchange-qualified ticker symbol, such as AAPL or 005930.KS"],
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
) -> str:
"""
Retrieve social or public-narrative sentiment for a ticker symbol.
When a dedicated social vendor is unavailable, vendors may return a clearly labeled
news-derived sentiment summary instead of claiming direct social-media coverage.
"""
return route_to_vendor("get_social_sentiment", symbol, start_date, end_date)
@tool
def get_insider_transactions(
ticker: Annotated[str, "Exchange-qualified ticker symbol"],
) -> str:
"""
Retrieve insider transaction information for a company.
Uses the configured fundamental_data vendor chain unless overridden at the tool level.
"""
return route_to_vendor("get_insider_transactions", ticker)

View File

@ -2,4 +2,10 @@
from .alpha_vantage_stock import get_stock
from .alpha_vantage_indicator import get_indicator
from .alpha_vantage_fundamentals import get_fundamentals, get_balance_sheet, get_cashflow, get_income_statement
from .alpha_vantage_news import get_news, get_global_news, get_insider_transactions
from .alpha_vantage_news import (
get_company_news_alpha_vantage,
get_global_news,
get_insider_transactions,
get_macro_news_alpha_vantage,
get_news,
)

View File

@ -5,13 +5,17 @@ import json
from datetime import datetime
from io import StringIO
from .api_keys import get_api_key as get_documented_api_key
from .config import get_config
from .vendor_exceptions import VendorConfigurationError, VendorTransientError
API_BASE_URL = "https://www.alphavantage.co/query"
def get_api_key() -> str:
"""Retrieve the API key for Alpha Vantage from environment variables."""
api_key = os.getenv("ALPHA_VANTAGE_API_KEY")
api_key = os.getenv("ALPHA_VANTAGE_API_KEY") or get_documented_api_key("ALPHA_VANTAGE_API_KEY")
if not api_key:
raise ValueError("ALPHA_VANTAGE_API_KEY environment variable is not set.")
raise VendorConfigurationError("ALPHA_VANTAGE_API_KEY is not configured.")
return api_key
def format_datetime_for_api(date_input) -> str:
@ -63,8 +67,15 @@ def _make_api_request(function_name: str, params: dict) -> dict | str:
# Remove entitlement if it's None or empty
api_params.pop("entitlement", None)
response = requests.get(API_BASE_URL, params=api_params)
response.raise_for_status()
try:
response = requests.get(
API_BASE_URL,
params=api_params,
timeout=float(get_config().get("vendor_timeout", 15)),
)
response.raise_for_status()
except requests.RequestException as exc:
raise VendorTransientError(f"Alpha Vantage request failed: {exc}") from exc
response_text = response.text

View File

@ -1,55 +1,144 @@
from __future__ import annotations
import json
from datetime import datetime, timedelta
from .alpha_vantage_common import _make_api_request, format_datetime_for_api
from .news_models import NewsItem, dedupe_news_items, format_news_items_report, normalize_datetime
from .vendor_exceptions import VendorMalformedResponseError
def get_news(ticker, start_date, end_date) -> dict[str, str] | str:
"""Returns live and historical market news & sentiment data from premier news outlets worldwide.
Covers stocks, cryptocurrencies, forex, and topics like fiscal policy, mergers & acquisitions, IPOs.
def _parse_news_sentiment_response(response_text: str) -> list[dict]:
try:
payload = json.loads(response_text)
except json.JSONDecodeError as exc:
raise VendorMalformedResponseError("Alpha Vantage returned malformed NEWS_SENTIMENT payload.") from exc
Args:
ticker: Stock symbol for news articles.
start_date: Start date for news search.
end_date: End date for news search.
feed = payload.get("feed")
if feed is None:
raise VendorMalformedResponseError("Alpha Vantage NEWS_SENTIMENT payload did not include a feed.")
if not isinstance(feed, list):
raise VendorMalformedResponseError("Alpha Vantage NEWS_SENTIMENT feed must be a list.")
return feed
Returns:
Dictionary containing news sentiment data or JSON string.
"""
def normalize_alpha_vantage_article(article: dict, *, fallback_symbol: str | None = None) -> NewsItem:
raw_symbols = article.get("ticker_sentiment") or []
symbols = [
str(item.get("ticker", "")).upper()
for item in raw_symbols
if isinstance(item, dict) and str(item.get("ticker", "")).strip()
]
if fallback_symbol and fallback_symbol.upper() not in symbols:
symbols.append(fallback_symbol.upper())
topic_tags = [
str(item.get("topic", "")).strip()
for item in article.get("topics", [])
if isinstance(item, dict) and str(item.get("topic", "")).strip()
]
sentiment = article.get("overall_sentiment_score")
try:
sentiment_value = float(sentiment) if sentiment not in (None, "") else None
except (TypeError, ValueError):
sentiment_value = None
return NewsItem(
title=str(article.get("title", "No title")),
source=str(article.get("source", "Alpha Vantage")),
published_at=normalize_datetime(article.get("time_published")),
language=article.get("language"),
country=article.get("source_domain"),
symbols=symbols,
topic_tags=topic_tags,
sentiment=sentiment_value,
relevance=None,
reliability=None,
url=str(article.get("url", "")),
summary=str(article.get("summary", "")),
raw_vendor="alpha_vantage",
)
def fetch_company_news_alpha_vantage(ticker: str, start_date: str, end_date: str) -> list[NewsItem]:
params = {
"tickers": ticker,
"time_from": format_datetime_for_api(start_date),
"time_to": format_datetime_for_api(end_date),
"limit": "50",
}
response_text = _make_api_request("NEWS_SENTIMENT", params)
return dedupe_news_items(
[normalize_alpha_vantage_article(article, fallback_symbol=ticker) for article in _parse_news_sentiment_response(response_text)]
)
return _make_api_request("NEWS_SENTIMENT", params)
def get_global_news(curr_date, look_back_days: int = 7, limit: int = 50) -> dict[str, str] | str:
"""Returns global market news & sentiment data without ticker-specific filtering.
def get_company_news_alpha_vantage(ticker: str, start_date: str, end_date: str) -> str:
items = fetch_company_news_alpha_vantage(ticker, start_date, end_date)
if not items:
return f"No news found for {ticker} between {start_date} and {end_date}"
return format_news_items_report(
f"{ticker} Company News, from {start_date} to {end_date}",
items,
max_items=25,
)
Covers broad market topics like financial markets, economy, and more.
Args:
curr_date: Current date in yyyy-mm-dd format.
look_back_days: Number of days to look back (default 7).
limit: Maximum number of articles (default 50).
Returns:
Dictionary containing global news sentiment data or JSON string.
"""
from datetime import datetime, timedelta
# Calculate start date
def fetch_macro_news_alpha_vantage(
curr_date: str,
look_back_days: int = 7,
limit: int = 50,
region: str | None = None,
language: str | None = None,
) -> list[NewsItem]:
curr_dt = datetime.strptime(curr_date, "%Y-%m-%d")
start_dt = curr_dt - timedelta(days=look_back_days)
start_date = start_dt.strftime("%Y-%m-%d")
topics = "financial_markets,economy_macro,economy_monetary"
if region and region.upper() == "KR":
topics = "financial_markets,economy_macro"
params = {
"topics": "financial_markets,economy_macro,economy_monetary",
"time_from": format_datetime_for_api(start_date),
"topics": topics,
"time_from": format_datetime_for_api(start_dt.strftime("%Y-%m-%d")),
"time_to": format_datetime_for_api(curr_date),
"limit": str(limit),
}
if language:
params["sort"] = "LATEST"
return _make_api_request("NEWS_SENTIMENT", params)
response_text = _make_api_request("NEWS_SENTIMENT", params)
items = [
normalize_alpha_vantage_article(article)
for article in _parse_news_sentiment_response(response_text)
]
return dedupe_news_items(items)[:limit]
def get_macro_news_alpha_vantage(
curr_date: str,
look_back_days: int = 7,
limit: int = 50,
region: str | None = None,
language: str | None = None,
) -> str:
start_date = (datetime.strptime(curr_date, "%Y-%m-%d") - timedelta(days=look_back_days)).strftime("%Y-%m-%d")
items = fetch_macro_news_alpha_vantage(
curr_date,
look_back_days=look_back_days,
limit=limit,
region=region,
language=language,
)
if not items:
return f"No global news found for {curr_date}"
region_label = (region or "GLOBAL").upper()
return format_news_items_report(
f"{region_label} Macro News, from {start_date} to {curr_date}",
items,
max_items=limit,
)
def get_insider_transactions(symbol: str) -> dict[str, str] | str:
@ -68,4 +157,9 @@ def get_insider_transactions(symbol: str) -> dict[str, str] | str:
"symbol": symbol,
}
return _make_api_request("INSIDER_TRANSACTIONS", params)
return _make_api_request("INSIDER_TRANSACTIONS", params)
# Backward-compatible aliases
get_news = get_company_news_alpha_vantage
get_global_news = get_macro_news_alpha_vantage

View File

@ -0,0 +1,72 @@
from __future__ import annotations
import os
from functools import lru_cache
from pathlib import Path
_DOC_ENV_MAP = {
"ALPHA_VANTAGE_API_KEY": "Alpha Vantage",
"NAVER_CLIENT_ID": "Naver.Client ID",
"NAVER_CLIENT_SECRET": "Naver.Client Secret",
"OPENDART_API_KEY": "OpenDart",
}
def _get_api_keys_doc_path() -> Path:
return Path(__file__).resolve().parents[2] / "Docs" / "list_api_keys.md"
@lru_cache(maxsize=1)
def _load_documented_keys() -> dict[str, str]:
path = _get_api_keys_doc_path()
if not path.exists():
return {}
content = path.read_text(encoding="utf-8")
parsed: dict[str, str] = {}
current_section = None
for raw_line in content.splitlines():
line = raw_line.strip()
if not line:
continue
if line.startswith("Alpha Vantage:"):
_, value = line.split(":", 1)
parsed["ALPHA_VANTAGE_API_KEY"] = value.strip()
current_section = None
continue
if line.startswith("OpenDart:"):
_, value = line.split(":", 1)
parsed["OPENDART_API_KEY"] = value.strip()
current_section = None
continue
if line.endswith(":") and not line.startswith("-"):
current_section = line[:-1].strip()
continue
if line.startswith("-") and ":" in line and current_section == "Naver":
key, value = line[1:].split(":", 1)
normalized_key = key.strip().lower()
if normalized_key == "client id":
parsed["NAVER_CLIENT_ID"] = value.strip()
elif normalized_key == "client secret":
parsed["NAVER_CLIENT_SECRET"] = value.strip()
continue
return parsed
def get_api_key(env_name: str) -> str | None:
value = os.getenv(env_name)
if value:
return value.strip()
documented = _load_documented_keys().get(env_name)
if documented:
return documented.strip()
return None

View File

@ -1,5 +1,7 @@
from copy import deepcopy
from typing import Any, Dict, Optional
import tradingagents.default_config as default_config
from typing import Dict, Optional
# Use default config but allow it to be overridden
_config: Optional[Dict] = None
@ -9,22 +11,32 @@ def initialize_config():
"""Initialize the configuration with default values."""
global _config
if _config is None:
_config = default_config.DEFAULT_CONFIG.copy()
_config = deepcopy(default_config.DEFAULT_CONFIG)
def _deep_merge_dicts(base: Dict[str, Any], updates: Dict[str, Any]) -> Dict[str, Any]:
merged = deepcopy(base)
for key, value in updates.items():
if isinstance(value, dict) and isinstance(merged.get(key), dict):
merged[key] = _deep_merge_dicts(merged[key], value)
else:
merged[key] = deepcopy(value)
return merged
def set_config(config: Dict):
"""Update the configuration with custom values."""
global _config
if _config is None:
_config = default_config.DEFAULT_CONFIG.copy()
_config.update(config)
_config = deepcopy(default_config.DEFAULT_CONFIG)
_config = _deep_merge_dicts(_config, config)
def get_config() -> Dict:
"""Get the current configuration."""
if _config is None:
initialize_config()
return _config.copy()
return deepcopy(_config)
# Initialize with default config

View File

@ -0,0 +1,40 @@
from __future__ import annotations
import requests
from .api_keys import get_api_key
from .config import get_config
from .vendor_exceptions import VendorConfigurationError, VendorTransientError
_ECOS_API_BASE = "https://ecos.bok.or.kr/api"
def get_macro_news_ecos(
curr_date: str,
look_back_days: int = 7,
limit: int = 10,
region: str | None = None,
language: str | None = None,
) -> str:
api_key = get_api_key("ECOS_API_KEY")
if not api_key:
raise VendorConfigurationError("ECOS API key is not configured.")
series = get_config().get("ecos_series", [])
if not series:
raise VendorConfigurationError("ECOS series configuration is missing.")
try:
response = requests.get(
f"{_ECOS_API_BASE}/StatisticSearch/{api_key}/json/kr/1/{limit}",
timeout=float(get_config().get("vendor_timeout", 15)),
)
response.raise_for_status()
except requests.RequestException as exc:
raise VendorTransientError(f"ECOS request failed: {exc}") from exc
return (
"ECOS macro adapter is configured but requires project-specific series codes. "
"Provide `ecos_series` in config to enable Korean macro summaries."
)

View File

@ -1,45 +1,56 @@
from typing import Annotated
from __future__ import annotations
from datetime import datetime
from typing import Any
from tradingagents.agents.utils.instrument_resolver import resolve_instrument
# Import from vendor-specific modules
from .y_finance import (
get_YFin_data_online,
get_stock_stats_indicators_window,
get_fundamentals as get_yfinance_fundamentals,
get_balance_sheet as get_yfinance_balance_sheet,
get_cashflow as get_yfinance_cashflow,
get_income_statement as get_yfinance_income_statement,
get_insider_transactions as get_yfinance_insider_transactions,
)
from .yfinance_news import get_news_yfinance, get_global_news_yfinance
from .alpha_vantage import (
get_stock as get_alpha_vantage_stock,
get_indicator as get_alpha_vantage_indicator,
get_fundamentals as get_alpha_vantage_fundamentals,
get_balance_sheet as get_alpha_vantage_balance_sheet,
get_cashflow as get_alpha_vantage_cashflow,
get_company_news_alpha_vantage,
get_fundamentals as get_alpha_vantage_fundamentals,
get_income_statement as get_alpha_vantage_income_statement,
get_indicator as get_alpha_vantage_indicator,
get_insider_transactions as get_alpha_vantage_insider_transactions,
get_news as get_alpha_vantage_news,
get_global_news as get_alpha_vantage_global_news,
get_macro_news_alpha_vantage,
get_stock as get_alpha_vantage_stock,
)
from .alpha_vantage_common import AlphaVantageRateLimitError
# Configuration and routing logic
from .config import get_config
from .ecos import get_macro_news_ecos
from .naver_news import get_company_news_naver, get_social_sentiment_naver
from .opendart import get_disclosures_opendart
from .vendor_exceptions import (
VendorConfigurationError,
VendorInputError,
VendorMalformedResponseError,
VendorTransientError,
)
from .y_finance import (
get_YFin_data_online,
get_balance_sheet as get_yfinance_balance_sheet,
get_cashflow as get_yfinance_cashflow,
get_fundamentals as get_yfinance_fundamentals,
get_income_statement as get_yfinance_income_statement,
get_insider_transactions as get_yfinance_insider_transactions,
get_stock_stats_indicators_window,
)
from .yfinance_news import (
get_company_news_yfinance,
get_macro_news_yfinance,
get_social_sentiment_yfinance,
)
# Tools organized by category
TOOLS_CATEGORIES = {
"core_stock_apis": {
"description": "OHLCV stock price data",
"tools": [
"get_stock_data"
]
"tools": ["get_stock_data"],
},
"technical_indicators": {
"description": "Technical analysis indicators",
"tools": [
"get_indicators"
]
"tools": ["get_indicators"],
},
"fundamental_data": {
"description": "Company fundamentals",
@ -47,37 +58,45 @@ TOOLS_CATEGORIES = {
"get_fundamentals",
"get_balance_sheet",
"get_cashflow",
"get_income_statement"
]
"get_income_statement",
"get_insider_transactions",
],
},
"news_data": {
"description": "News and insider data",
"tools": [
"get_news",
"get_global_news",
"get_insider_transactions",
]
}
"description": "Company news feeds",
"tools": ["get_news", "get_company_news"],
},
"macro_data": {
"description": "Macro and market context feeds",
"tools": ["get_global_news", "get_macro_news"],
},
"disclosure_data": {
"description": "Corporate disclosures and filings",
"tools": ["get_disclosures"],
},
"social_data": {
"description": "Social and public narrative sentiment",
"tools": ["get_social_sentiment"],
},
}
VENDOR_LIST = [
"yfinance",
"alpha_vantage",
"yfinance",
"naver",
"opendart",
"ecos",
]
# Mapping of methods to their vendor-specific implementations
VENDOR_METHODS = {
# core_stock_apis
"get_stock_data": {
"alpha_vantage": get_alpha_vantage_stock,
"yfinance": get_YFin_data_online,
},
# technical_indicators
"get_indicators": {
"alpha_vantage": get_alpha_vantage_indicator,
"yfinance": get_stock_stats_indicators_window,
},
# fundamental_data
"get_fundamentals": {
"alpha_vantage": get_alpha_vantage_fundamentals,
"yfinance": get_yfinance_fundamentals,
@ -94,69 +113,240 @@ VENDOR_METHODS = {
"alpha_vantage": get_alpha_vantage_income_statement,
"yfinance": get_yfinance_income_statement,
},
# news_data
"get_news": {
"alpha_vantage": get_alpha_vantage_news,
"yfinance": get_news_yfinance,
"alpha_vantage": get_company_news_alpha_vantage,
"yfinance": get_company_news_yfinance,
"naver": get_company_news_naver,
},
"get_company_news": {
"alpha_vantage": get_company_news_alpha_vantage,
"yfinance": get_company_news_yfinance,
"naver": get_company_news_naver,
},
"get_global_news": {
"yfinance": get_global_news_yfinance,
"alpha_vantage": get_alpha_vantage_global_news,
"alpha_vantage": get_macro_news_alpha_vantage,
"yfinance": get_macro_news_yfinance,
"ecos": get_macro_news_ecos,
},
"get_macro_news": {
"alpha_vantage": get_macro_news_alpha_vantage,
"yfinance": get_macro_news_yfinance,
"ecos": get_macro_news_ecos,
},
"get_disclosures": {
"opendart": get_disclosures_opendart,
},
"get_insider_transactions": {
"alpha_vantage": get_alpha_vantage_insider_transactions,
"yfinance": get_yfinance_insider_transactions,
},
"get_social_sentiment": {
"naver": get_social_sentiment_naver,
"yfinance": get_social_sentiment_yfinance,
},
}
_SEMANTIC_EMPTY_MARKERS = (
"no news found",
"no global news found",
"no disclosures found",
"no insider transactions data found",
"no data found",
"no fundamentals data found",
"provider unavailable",
"no social provider",
"no social sentiment",
)
def get_category_for_method(method: str) -> str:
"""Get the category that contains the specified method."""
for category, info in TOOLS_CATEGORIES.items():
if method in info["tools"]:
return category
raise ValueError(f"Method '{method}' not found in any category")
def get_vendor(category: str, method: str = None) -> str:
"""Get the configured vendor for a data category or specific tool method.
Tool-level configuration takes precedence over category-level.
"""
config = get_config()
# Check tool-level configuration first (if method provided)
def get_vendor(category: str, method: str | None = None) -> str:
config = get_config()
if method:
tool_vendors = config.get("tool_vendors", {})
if method in tool_vendors:
return tool_vendors[method]
return config.get("data_vendors", {}).get(category, "yfinance")
def _normalize_vendor_chain(method: str, vendor_config: str) -> list[str]:
configured = [vendor.strip() for vendor in (vendor_config or "").split(",") if vendor.strip()]
if not configured:
raise ValueError(f"No vendors configured for '{method}'.")
available = VENDOR_METHODS.get(method, {})
invalid = [vendor for vendor in configured if vendor not in available]
if invalid:
invalid_list = ", ".join(sorted(invalid))
raise ValueError(f"Unsupported vendors for '{method}': {invalid_list}.")
chain = configured.copy()
for vendor in available:
if vendor not in chain:
chain.append(vendor)
return chain
def _prioritize_market_specific_vendors(method: str, vendor_chain: list[str], args: tuple[Any, ...], kwargs: dict[str, Any]) -> list[str]:
reordered = vendor_chain[:]
def promote(vendor_name: str) -> None:
if vendor_name in reordered:
reordered.remove(vendor_name)
reordered.insert(0, vendor_name)
try:
if method in {"get_news", "get_company_news", "get_social_sentiment", "get_disclosures"}:
symbol = kwargs.get("symbol") or kwargs.get("ticker") or (args[0] if args else None)
if isinstance(symbol, str):
profile = resolve_instrument(symbol)
if profile.country == "KR":
if method in {"get_news", "get_company_news", "get_social_sentiment"}:
promote("naver")
if method == "get_disclosures":
promote("opendart")
if method in {"get_global_news", "get_macro_news"}:
region = kwargs.get("region") or (args[3] if len(args) > 3 else None)
if isinstance(region, str) and region.upper() == "KR":
promote("ecos")
except Exception:
return reordered
return reordered
def _validate_date(value: str, field_name: str) -> None:
try:
datetime.strptime(value, "%Y-%m-%d")
except ValueError as exc:
raise VendorInputError(f"Field '{field_name}' must be in YYYY-MM-DD format.") from exc
def _validate_input_for_method(method: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> None:
def arg(index: int, name: str) -> Any:
if name in kwargs:
return kwargs[name]
return args[index] if len(args) > index else None
if method in {
"get_stock_data",
"get_indicators",
"get_fundamentals",
"get_balance_sheet",
"get_cashflow",
"get_income_statement",
"get_news",
"get_company_news",
"get_disclosures",
"get_insider_transactions",
"get_social_sentiment",
}:
symbol = arg(0, "symbol") or arg(0, "ticker")
if not isinstance(symbol, str) or not symbol.strip():
raise VendorInputError("A non-empty ticker or symbol is required.")
if method in {"get_stock_data", "get_news", "get_company_news", "get_disclosures", "get_social_sentiment"}:
_validate_date(str(arg(1, "start_date")), "start_date")
_validate_date(str(arg(2, "end_date")), "end_date")
if method in {"get_global_news", "get_macro_news"}:
_validate_date(str(arg(0, "curr_date")), "curr_date")
look_back_days = arg(1, "look_back_days")
limit = arg(2, "limit")
if look_back_days is not None and int(look_back_days) < 0:
raise VendorInputError("'look_back_days' must be non-negative.")
if limit is not None and int(limit) <= 0:
raise VendorInputError("'limit' must be positive.")
def should_fallback(result_or_exc: Any, method: str | None = None) -> bool:
config = get_config()
empty_result_fallback = bool(config.get("empty_result_fallback", True))
if isinstance(result_or_exc, VendorInputError):
return False
if isinstance(
result_or_exc,
(
AlphaVantageRateLimitError,
VendorConfigurationError,
VendorTransientError,
VendorMalformedResponseError,
),
):
return True
if isinstance(result_or_exc, Exception):
return True
if not empty_result_fallback:
return False
if result_or_exc is None:
return True
if isinstance(result_or_exc, (list, tuple, dict, set)) and len(result_or_exc) == 0:
return True
if isinstance(result_or_exc, str):
normalized = result_or_exc.strip().lower()
if not normalized:
return True
if normalized.startswith("error"):
return True
if any(marker in normalized for marker in _SEMANTIC_EMPTY_MARKERS):
return True
return False
# Fall back to category-level configuration
return config.get("data_vendors", {}).get(category, "default")
def route_to_vendor(method: str, *args, **kwargs):
"""Route method calls to appropriate vendor implementation with fallback support."""
category = get_category_for_method(method)
vendor_config = get_vendor(category, method)
primary_vendors = [v.strip() for v in vendor_config.split(',')]
if method not in VENDOR_METHODS:
raise ValueError(f"Method '{method}' not supported")
# Build fallback chain: primary vendors first, then remaining available vendors
all_available_vendors = list(VENDOR_METHODS[method].keys())
fallback_vendors = primary_vendors.copy()
for vendor in all_available_vendors:
if vendor not in fallback_vendors:
fallback_vendors.append(vendor)
_validate_input_for_method(method, args, kwargs)
for vendor in fallback_vendors:
if vendor not in VENDOR_METHODS[method]:
category = get_category_for_method(method)
vendor_chain = _normalize_vendor_chain(method, get_vendor(category, method))
vendor_chain = _prioritize_market_specific_vendors(method, vendor_chain, args, kwargs)
fallback_notes: list[str] = []
last_result = None
last_exception: Exception | None = None
for vendor in vendor_chain:
vendor_impl = VENDOR_METHODS[method].get(vendor)
if vendor_impl is None:
continue
vendor_impl = VENDOR_METHODS[method][vendor]
impl_func = vendor_impl[0] if isinstance(vendor_impl, list) else vendor_impl
try:
return impl_func(*args, **kwargs)
except AlphaVantageRateLimitError:
continue # Only rate limits trigger fallback
result = vendor_impl(*args, **kwargs)
if should_fallback(result, method):
last_result = result
fallback_notes.append(f"{vendor}: empty or unusable result")
continue
return result
except VendorInputError:
raise
except Exception as exc:
if should_fallback(exc, method):
last_exception = exc
fallback_notes.append(f"{vendor}: {exc}")
continue
raise
raise RuntimeError(f"No available vendor for '{method}'")
if last_result is not None:
return last_result
if last_exception is not None:
note = " | ".join(fallback_notes)
raise RuntimeError(f"No available vendor for '{method}'. Fallback attempts: {note}") from last_exception
raise RuntimeError(f"No available vendor for '{method}'.")

View File

@ -0,0 +1,29 @@
from __future__ import annotations
import requests
from .api_keys import get_api_key
from .config import get_config
from .vendor_exceptions import VendorConfigurationError, VendorTransientError
_KRX_API_BASE = "https://openapi.krx.co.kr"
def call_krx_open_api(api_path: str, params: dict[str, str] | None = None) -> dict:
api_key = get_api_key("KRX_API_KEY")
if not api_key:
raise VendorConfigurationError("KRX Open API key is not configured.")
try:
response = requests.get(
f"{_KRX_API_BASE.rstrip('/')}/{api_path.lstrip('/')}",
params=params or {},
headers={"AUTH_KEY": api_key},
timeout=float(get_config().get("vendor_timeout", 15)),
)
response.raise_for_status()
except requests.RequestException as exc:
raise VendorTransientError(f"KRX Open API request failed: {exc}") from exc
return response.json()

View File

@ -0,0 +1,114 @@
from __future__ import annotations
import html
import re
from datetime import datetime, timedelta
from email.utils import parsedate_to_datetime
import requests
from tradingagents.agents.utils.instrument_resolver import resolve_instrument
from .api_keys import get_api_key
from .config import get_config
from .news_models import NewsItem, dedupe_news_items, filter_news_items_by_date, format_news_items_report
from .vendor_exceptions import VendorConfigurationError, VendorMalformedResponseError, VendorTransientError
_NAVER_NEWS_ENDPOINT = "https://openapi.naver.com/v1/search/news.json"
def _strip_html(text: str) -> str:
return re.sub(r"<[^>]+>", "", html.unescape(text or "")).strip()
def _get_headers() -> dict[str, str]:
client_id = get_api_key("NAVER_CLIENT_ID")
client_secret = get_api_key("NAVER_CLIENT_SECRET")
if not client_id or not client_secret:
raise VendorConfigurationError("Naver News credentials are not configured.")
return {
"X-Naver-Client-Id": client_id,
"X-Naver-Client-Secret": client_secret,
}
def normalize_naver_article(article: dict, *, fallback_symbol: str) -> NewsItem:
published_at = None
if article.get("pubDate"):
try:
published_at = parsedate_to_datetime(article["pubDate"])
except (TypeError, ValueError, IndexError):
published_at = None
return NewsItem(
title=_strip_html(article.get("title", "No title")),
source="Naver News",
published_at=published_at,
language="ko",
country="KR",
symbols=[fallback_symbol.upper()],
topic_tags=[],
sentiment=None,
relevance=None,
reliability=None,
url=article.get("originallink") or article.get("link") or "",
summary=_strip_html(article.get("description", "")),
raw_vendor="naver",
)
def fetch_company_news_naver(symbol: str, start_date: str, end_date: str, display: int = 20) -> list[NewsItem]:
profile = resolve_instrument(symbol)
search_query = profile.display_name if profile.country == "KR" else symbol
try:
response = requests.get(
_NAVER_NEWS_ENDPOINT,
headers=_get_headers(),
params={"query": search_query, "display": display, "sort": "date"},
timeout=float(get_config().get("vendor_timeout", 15)),
)
response.raise_for_status()
except requests.RequestException as exc:
raise VendorTransientError(f"Naver News request failed: {exc}") from exc
payload = response.json()
items = payload.get("items")
if not isinstance(items, list):
raise VendorMalformedResponseError("Naver News payload did not include an items list.")
start_dt = datetime.strptime(start_date, "%Y-%m-%d")
end_dt = datetime.strptime(end_date, "%Y-%m-%d") + timedelta(days=1)
normalized = dedupe_news_items(
[normalize_naver_article(article, fallback_symbol=profile.primary_symbol) for article in items]
)
return filter_news_items_by_date(normalized, start_date=start_dt, end_date=end_dt)
def get_company_news_naver(symbol: str, start_date: str, end_date: str) -> str:
items = fetch_company_news_naver(symbol, start_date, end_date)
if not items:
return f"No news found for {symbol} between {start_date} and {end_date}"
return format_news_items_report(
f"{symbol} Company News, from {start_date} to {end_date}",
items,
max_items=15,
)
def get_social_sentiment_naver(symbol: str, start_date: str, end_date: str) -> str:
items = fetch_company_news_naver(symbol, start_date, end_date, display=10)
if not items:
return (
f"Dedicated social provider unavailable; Naver company-news sentiment was unavailable for {symbol} "
f"between {start_date} and {end_date}."
)
lines = [
f"Dedicated social provider unavailable; using Korean news-derived public narrative for {symbol} from {start_date} to {end_date}.",
"",
]
for item in items[:10]:
stamp = item.published_at.strftime("%Y-%m-%d") if item.published_at else "undated"
lines.append(f"- {stamp}: {item.title}")
if item.summary:
lines.append(f" Narrative: {item.summary}")
return "\n".join(lines)

View File

@ -0,0 +1,157 @@
from __future__ import annotations
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Iterable
@dataclass(frozen=True)
class NewsItem:
title: str
source: str
published_at: datetime | None
language: str | None = None
country: str | None = None
symbols: list[str] = field(default_factory=list)
topic_tags: list[str] = field(default_factory=list)
sentiment: float | None = None
relevance: float | None = None
reliability: float | None = None
url: str = ""
summary: str = ""
raw_vendor: str = ""
@dataclass(frozen=True)
class DisclosureItem:
title: str
source: str
published_at: datetime | None
url: str
summary: str
symbol: str
raw_vendor: str
def normalize_datetime(value: datetime | str | int | float | None) -> datetime | None:
if value in (None, ""):
return None
if isinstance(value, datetime):
if value.tzinfo is None:
return value.replace(tzinfo=timezone.utc)
return value
if isinstance(value, (int, float)):
try:
return datetime.fromtimestamp(value, tz=timezone.utc)
except (OverflowError, OSError, ValueError):
return None
if isinstance(value, str):
text = value.strip()
if not text:
return None
try:
return datetime.fromisoformat(text.replace("Z", "+00:00"))
except ValueError:
try:
return datetime.strptime(text, "%Y%m%dT%H%M").replace(tzinfo=timezone.utc)
except ValueError:
pass
try:
return datetime.strptime(text, "%Y%m%dT%H%M%S").replace(tzinfo=timezone.utc)
except ValueError:
pass
try:
return datetime.fromtimestamp(float(text), tz=timezone.utc)
except (OverflowError, OSError, ValueError):
return None
return None
def dedupe_news_items(items: Iterable[NewsItem]) -> list[NewsItem]:
deduped: list[NewsItem] = []
seen: set[str] = set()
for item in items:
identity = build_news_identity(item)
if identity in seen:
continue
seen.add(identity)
deduped.append(item)
return deduped
def build_news_identity(item: NewsItem) -> str:
if item.url:
return item.url.strip()
stamp = item.published_at.isoformat() if item.published_at else ""
return f"{item.source.strip()}::{item.title.strip()}::{stamp}"
def filter_news_items_by_date(
items: Iterable[NewsItem],
*,
start_date: datetime | None = None,
end_date: datetime | None = None,
) -> list[NewsItem]:
filtered: list[NewsItem] = []
for item in items:
published_at = item.published_at
if published_at is None:
filtered.append(item)
continue
naive_published = published_at.astimezone(timezone.utc).replace(tzinfo=None)
if start_date and naive_published < start_date:
continue
if end_date and naive_published > end_date:
continue
filtered.append(item)
return filtered
def format_news_items_report(
heading: str,
items: Iterable[NewsItem],
*,
max_items: int = 10,
) -> str:
selected = list(items)[:max_items]
if not selected:
return f"No news found for {heading}"
lines = [f"## {heading}", ""]
for item in selected:
date_prefix = ""
if item.published_at:
date_prefix = f"[{item.published_at.strftime('%Y-%m-%d')}] "
lines.append(f"### {date_prefix}{item.title} (source: {item.source})")
if item.summary:
lines.append(item.summary)
if item.sentiment is not None:
lines.append(f"Sentiment score: {item.sentiment:.2f}")
if item.url:
lines.append(f"Link: {item.url}")
lines.append("")
return "\n".join(lines).strip()
def format_disclosure_items_report(
heading: str,
items: Iterable[DisclosureItem],
*,
max_items: int = 10,
) -> str:
selected = list(items)[:max_items]
if not selected:
return f"No disclosures found for {heading}"
lines = [f"## {heading}", ""]
for item in selected:
date_prefix = ""
if item.published_at:
date_prefix = f"[{item.published_at.strftime('%Y-%m-%d')}] "
lines.append(f"### {date_prefix}{item.title} (source: {item.source})")
if item.summary:
lines.append(item.summary)
if item.url:
lines.append(f"Link: {item.url}")
lines.append("")
return "\n".join(lines).strip()

View File

@ -0,0 +1,147 @@
from __future__ import annotations
import json
from datetime import datetime
from pathlib import Path
from xml.etree import ElementTree as ET
from zipfile import ZipFile
import io
import requests
from tradingagents.agents.utils.instrument_resolver import is_krx_symbol, resolve_instrument
from .api_keys import get_api_key
from .config import get_config
from .news_models import DisclosureItem, format_disclosure_items_report
from .vendor_exceptions import VendorConfigurationError, VendorMalformedResponseError, VendorTransientError
_OPENDART_API_BASE = "https://opendart.fss.or.kr/api"
def _get_opendart_key() -> str:
api_key = get_api_key("OPENDART_API_KEY")
if not api_key:
raise VendorConfigurationError("OpenDART API key is not configured.")
return api_key
def _corp_code_cache_path() -> Path:
data_cache_dir = Path(get_config().get("data_cache_dir", Path(__file__).resolve().parent / "data_cache"))
data_cache_dir.mkdir(parents=True, exist_ok=True)
return data_cache_dir / "opendart_corp_codes.json"
def _download_corp_code_map() -> dict[str, str]:
try:
response = requests.get(
f"{_OPENDART_API_BASE}/corpCode.xml",
params={"crtfc_key": _get_opendart_key()},
timeout=float(get_config().get("vendor_timeout", 15)),
)
response.raise_for_status()
except requests.RequestException as exc:
raise VendorTransientError(f"OpenDART corpCode download failed: {exc}") from exc
with ZipFile(io.BytesIO(response.content)) as zipped:
xml_name = zipped.namelist()[0]
xml_bytes = zipped.read(xml_name)
root = ET.fromstring(xml_bytes)
corp_codes: dict[str, str] = {}
for item in root.findall("list"):
stock_code = (item.findtext("stock_code") or "").strip()
corp_code = (item.findtext("corp_code") or "").strip()
if stock_code and corp_code:
corp_codes[stock_code] = corp_code
_corp_code_cache_path().write_text(json.dumps(corp_codes, ensure_ascii=False), encoding="utf-8")
return corp_codes
def _load_corp_code_map() -> dict[str, str]:
cache_path = _corp_code_cache_path()
if cache_path.exists():
try:
return json.loads(cache_path.read_text(encoding="utf-8"))
except json.JSONDecodeError:
pass
return _download_corp_code_map()
def _resolve_corp_code(symbol: str) -> tuple[str | None, str | None]:
profile = resolve_instrument(symbol)
if profile.country != "KR" and not is_krx_symbol(profile.primary_symbol):
return None, None
stock_code = profile.krx_code or profile.primary_symbol.split(".", 1)[0]
corp_code = _load_corp_code_map().get(stock_code)
return corp_code, stock_code
def fetch_disclosures_opendart(symbol: str, start_date: str, end_date: str, *, page_count: int = 10) -> list[DisclosureItem]:
corp_code, stock_code = _resolve_corp_code(symbol)
if not corp_code or not stock_code:
return []
params = {
"crtfc_key": _get_opendart_key(),
"corp_code": corp_code,
"bgn_de": start_date.replace("-", ""),
"end_de": end_date.replace("-", ""),
"page_count": str(page_count),
}
try:
response = requests.get(
f"{_OPENDART_API_BASE}/list.json",
params=params,
timeout=float(get_config().get("vendor_timeout", 15)),
)
response.raise_for_status()
except requests.RequestException as exc:
raise VendorTransientError(f"OpenDART disclosure request failed: {exc}") from exc
payload = response.json()
status = payload.get("status")
if status in {"013", "020"}:
return []
if status != "000":
message = payload.get("message", "Unknown OpenDART error")
raise VendorMalformedResponseError(f"OpenDART returned status {status}: {message}")
disclosure_items = payload.get("list")
if not isinstance(disclosure_items, list):
raise VendorMalformedResponseError("OpenDART list.json payload did not include a disclosure list.")
result: list[DisclosureItem] = []
for item in disclosure_items:
receipt_no = str(item.get("rcept_no", ""))
receipt_dt = item.get("rcept_dt")
published_at = None
if receipt_dt:
published_at = datetime.strptime(receipt_dt, "%Y%m%d")
report_name = str(item.get("report_nm", "Disclosure"))
corp_name = str(item.get("corp_name", stock_code))
result.append(
DisclosureItem(
title=f"{corp_name}: {report_name}",
source="OpenDART",
published_at=published_at,
url=f"https://dart.fss.or.kr/dsaf001/main.do?rcpNo={receipt_no}" if receipt_no else "",
summary=f"Filer: {item.get('flr_nm', '')} | Receipt no: {receipt_no}",
symbol=symbol,
raw_vendor="opendart",
)
)
return result
def get_disclosures_opendart(symbol: str, start_date: str, end_date: str) -> str:
items = fetch_disclosures_opendart(symbol, start_date, end_date)
if not items:
return f"No disclosures found for {symbol} between {start_date} and {end_date}"
return format_disclosure_items_report(
f"{symbol} Disclosures, from {start_date} to {end_date}",
items,
max_items=10,
)

View File

@ -0,0 +1,14 @@
class VendorInputError(ValueError):
"""Raised when the user input is invalid and should not trigger vendor fallback."""
class VendorConfigurationError(RuntimeError):
"""Raised when a vendor is unavailable because credentials or config are missing."""
class VendorTransientError(RuntimeError):
"""Raised when a vendor hit a transient/network/service issue."""
class VendorMalformedResponseError(RuntimeError):
"""Raised when a vendor response is structurally invalid."""

View File

@ -1,101 +1,99 @@
"""yfinance-based news data fetching functions."""
"""yfinance-based news, macro, and sentiment helpers."""
from __future__ import annotations
from datetime import datetime, timezone
from dateutil.relativedelta import relativedelta
import yfinance as yf
from .news_models import (
NewsItem,
dedupe_news_items,
filter_news_items_by_date,
format_news_items_report,
normalize_datetime,
)
from .stockstats_utils import yf_retry
_TICKER_NEWS_FETCH_COUNTS = (20, 50, 100)
_MAX_FILTERED_TICKER_ARTICLES = 25
_GLOBAL_QUERY_PRESETS = {
"US": [
"stock market economy",
"Federal Reserve interest rates",
"inflation economic outlook",
"global markets trading",
],
"KR": [
"한국 증시",
"한국은행 기준금리",
"원달러 환율",
"반도체 수출",
],
"GLOBAL": [
"stock market economy",
"global markets trading",
"economy monetary policy",
"inflation growth outlook",
],
}
def _parse_pub_date(raw_value) -> datetime | None:
"""Normalize yfinance pub date values into a timezone-aware datetime."""
if raw_value in (None, ""):
return None
if isinstance(raw_value, datetime):
return raw_value
if isinstance(raw_value, (int, float)):
try:
return datetime.fromtimestamp(raw_value, tz=timezone.utc)
except (OverflowError, OSError, ValueError):
return None
if isinstance(raw_value, str):
normalized = raw_value.strip()
if not normalized:
return None
try:
return datetime.fromisoformat(normalized.replace("Z", "+00:00"))
except ValueError:
try:
return datetime.fromtimestamp(float(normalized), tz=timezone.utc)
except (OverflowError, OSError, ValueError):
return None
return None
def _extract_article_data(article: dict) -> dict:
"""Extract article data from yfinance news format (handles nested 'content' structure)."""
# Handle nested content structure
def _extract_article_fields(article: dict) -> dict:
"""Extract article data from yfinance news format."""
if "content" in article:
content = article["content"]
title = content.get("title", "No title")
summary = content.get("summary", "")
provider = content.get("provider", {})
publisher = provider.get("displayName", "Unknown")
# Get URL from canonicalUrl or clickThroughUrl
url_obj = content.get("canonicalUrl") or content.get("clickThroughUrl") or {}
link = url_obj.get("url", "")
# Get publish date
pub_date = _parse_pub_date(content.get("pubDate", ""))
return {
"title": title,
"summary": summary,
"publisher": publisher,
"link": link,
"pub_date": pub_date,
}
else:
# Fallback for flat structure
return {
"title": article.get("title", "No title"),
"summary": article.get("summary", ""),
"publisher": article.get("publisher", "Unknown"),
"link": article.get("link", ""),
"pub_date": _parse_pub_date(article.get("providerPublishTime")),
"title": content.get("title", "No title"),
"summary": content.get("summary", ""),
"publisher": provider.get("displayName", "Unknown"),
"link": url_obj.get("url", ""),
"pub_date": normalize_datetime(content.get("pubDate")),
"raw_symbols": content.get("relatedTickers") or [],
}
return {
"title": article.get("title", "No title"),
"summary": article.get("summary", ""),
"publisher": article.get("publisher", "Unknown"),
"link": article.get("link", ""),
"pub_date": normalize_datetime(article.get("providerPublishTime")),
"raw_symbols": article.get("relatedTickers") or [],
}
def _article_identity(article: dict) -> str:
"""Return a stable identity key for deduplicating news articles."""
link = article.get("link", "").strip()
if link:
return link
title = article.get("title", "").strip()
publisher = article.get("publisher", "").strip()
pub_date = article.get("pub_date")
stamp = pub_date.isoformat() if isinstance(pub_date, datetime) else ""
return f"{publisher}::{title}::{stamp}"
def normalize_yfinance_article(article: dict, *, fallback_symbol: str | None = None, country: str | None = None) -> NewsItem:
data = _extract_article_fields(article)
symbols = [str(symbol).upper() for symbol in data["raw_symbols"] if str(symbol).strip()]
if fallback_symbol and fallback_symbol.upper() not in symbols:
symbols.append(fallback_symbol.upper())
return NewsItem(
title=data["title"],
source=data["publisher"],
published_at=data["pub_date"],
language=None,
country=country,
symbols=symbols,
topic_tags=[],
sentiment=None,
relevance=None,
reliability=None,
url=data["link"],
summary=data["summary"],
raw_vendor="yfinance",
)
def _collect_ticker_news(
ticker: str,
start_dt: datetime,
) -> tuple[list[dict], datetime | None, datetime | None]:
) -> tuple[list[NewsItem], datetime | None, datetime | None]:
"""Fetch increasingly larger ticker feeds until the requested window is covered."""
collected: list[dict] = []
seen: set[str] = set()
collected: list[NewsItem] = []
oldest_pub_date = None
newest_pub_date = None
@ -104,15 +102,13 @@ def _collect_ticker_news(
if not news:
continue
for article in news:
data = _extract_article_data(article)
identity = _article_identity(data)
if identity in seen:
continue
seen.add(identity)
collected.append(data)
batch = dedupe_news_items(
[normalize_yfinance_article(article, fallback_symbol=ticker) for article in news]
)
pub_date = data.get("pub_date")
for item in batch:
collected.append(item)
pub_date = item.published_at
if pub_date:
if newest_pub_date is None or pub_date > newest_pub_date:
newest_pub_date = pub_date
@ -124,15 +120,15 @@ def _collect_ticker_news(
if len(news) < count:
break
collected = dedupe_news_items(collected)
collected.sort(
key=lambda article: article["pub_date"].timestamp() if article.get("pub_date") else float("-inf"),
key=lambda article: article.published_at.timestamp() if article.published_at else float("-inf"),
reverse=True,
)
return collected, oldest_pub_date, newest_pub_date
def _format_coverage_note(oldest_pub_date: datetime | None, newest_pub_date: datetime | None) -> str:
"""Describe the yfinance coverage window when no article matches the requested range."""
if oldest_pub_date and newest_pub_date:
return (
"; the current yfinance ticker feed only covered "
@ -145,152 +141,137 @@ def _format_coverage_note(oldest_pub_date: datetime | None, newest_pub_date: dat
return ""
def get_news_yfinance(
def fetch_company_news_yfinance(
ticker: str,
start_date: str,
end_date: str,
) -> tuple[list[NewsItem], datetime | None, datetime | None]:
start_dt = datetime.strptime(start_date, "%Y-%m-%d")
end_dt = datetime.strptime(end_date, "%Y-%m-%d") + relativedelta(days=1)
articles, oldest_pub_date, newest_pub_date = _collect_ticker_news(ticker, start_dt)
filtered = filter_news_items_by_date(articles, start_date=start_dt, end_date=end_dt)
return filtered[:_MAX_FILTERED_TICKER_ARTICLES], oldest_pub_date, newest_pub_date
def get_company_news_yfinance(
ticker: str,
start_date: str,
end_date: str,
) -> str:
"""
Retrieve news for a specific stock ticker using yfinance.
Args:
ticker: Stock ticker symbol (e.g., "AAPL")
start_date: Start date in yyyy-mm-dd format
end_date: End date in yyyy-mm-dd format
Returns:
Formatted string containing news articles
"""
try:
start_dt = datetime.strptime(start_date, "%Y-%m-%d")
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
articles, oldest_pub_date, newest_pub_date = _collect_ticker_news(ticker, start_dt)
if not articles:
return f"No news found for {ticker}"
news_str = ""
filtered_count = 0
for data in articles:
# Filter by date if publish time is available
if data["pub_date"]:
pub_date_naive = data["pub_date"].replace(tzinfo=None)
if not (start_dt <= pub_date_naive <= end_dt + relativedelta(days=1)):
continue
date_prefix = ""
if data["pub_date"]:
date_prefix = f"[{data['pub_date'].strftime('%Y-%m-%d')}] "
news_str += f"### {date_prefix}{data['title']} (source: {data['publisher']})\n"
if data["summary"]:
news_str += f"{data['summary']}\n"
if data["link"]:
news_str += f"Link: {data['link']}\n"
news_str += "\n"
filtered_count += 1
if filtered_count >= _MAX_FILTERED_TICKER_ARTICLES:
break
if filtered_count == 0:
filtered, oldest_pub_date, newest_pub_date = fetch_company_news_yfinance(ticker, start_date, end_date)
if not filtered:
coverage_note = _format_coverage_note(oldest_pub_date, newest_pub_date)
return f"No news found for {ticker} between {start_date} and {end_date}{coverage_note}"
return f"## {ticker} News, from {start_date} to {end_date}:\n\n{news_str}"
except Exception as e:
return f"Error fetching news for {ticker}: {str(e)}"
return format_news_items_report(
f"{ticker} Company News, from {start_date} to {end_date}",
filtered,
max_items=_MAX_FILTERED_TICKER_ARTICLES,
)
except Exception as exc:
return f"Error fetching news for {ticker}: {exc}"
def get_global_news_yfinance(
def _get_query_preset(region: str | None) -> list[str]:
if not region:
return _GLOBAL_QUERY_PRESETS["GLOBAL"]
return _GLOBAL_QUERY_PRESETS.get(region.upper(), _GLOBAL_QUERY_PRESETS["GLOBAL"])
def fetch_macro_news_yfinance(
curr_date: str,
look_back_days: int = 7,
limit: int = 10,
) -> str:
"""
Retrieve global/macro economic news using yfinance Search.
region: str | None = None,
language: str | None = None,
) -> list[NewsItem]:
curr_dt = datetime.strptime(curr_date, "%Y-%m-%d")
start_dt = curr_dt - relativedelta(days=look_back_days)
country = (region or "GLOBAL").upper()
Args:
curr_date: Current date in yyyy-mm-dd format
look_back_days: Number of days to look back
limit: Maximum number of articles to return
Returns:
Formatted string containing global news articles
"""
# Search queries for macro/global news
search_queries = [
"stock market economy",
"Federal Reserve interest rates",
"inflation economic outlook",
"global markets trading",
]
all_news = []
seen_titles = set()
try:
for query in search_queries:
search = yf_retry(lambda q=query: yf.Search(
query=q,
all_news: list[NewsItem] = []
for query in _get_query_preset(region):
search = yf_retry(
lambda q=query: yf.Search(
query=q if not language else f"{q} {language}",
news_count=limit,
enable_fuzzy_query=True,
))
)
)
search_news = getattr(search, "news", None) or []
batch = [normalize_yfinance_article(article, country=country) for article in search_news]
all_news.extend(batch)
if len(all_news) >= limit * len(_get_query_preset(region)):
break
if search.news:
for article in search.news:
# Handle both flat and nested structures
if "content" in article:
data = _extract_article_data(article)
title = data["title"]
else:
title = article.get("title", "")
filtered = []
for item in dedupe_news_items(all_news):
if item.published_at:
published = item.published_at.replace(tzinfo=None)
if published < start_dt or published > curr_dt + relativedelta(days=1):
continue
filtered.append(item)
# Deduplicate by title
if title and title not in seen_titles:
seen_titles.add(title)
all_news.append(article)
filtered.sort(
key=lambda article: article.published_at.timestamp() if article.published_at else float("-inf"),
reverse=True,
)
return filtered[:limit]
if len(all_news) >= limit:
break
if not all_news:
def get_macro_news_yfinance(
curr_date: str,
look_back_days: int = 7,
limit: int = 10,
region: str | None = None,
language: str | None = None,
) -> str:
try:
items = fetch_macro_news_yfinance(
curr_date,
look_back_days=look_back_days,
limit=limit,
region=region,
language=language,
)
if not items:
return f"No global news found for {curr_date}"
start_date = (datetime.strptime(curr_date, "%Y-%m-%d") - relativedelta(days=look_back_days)).strftime("%Y-%m-%d")
region_label = (region or "GLOBAL").upper()
return format_news_items_report(
f"{region_label} Macro News, from {start_date} to {curr_date}",
items,
max_items=limit,
)
except Exception as exc:
return f"Error fetching global news: {exc}"
# Calculate date range
curr_dt = datetime.strptime(curr_date, "%Y-%m-%d")
start_dt = curr_dt - relativedelta(days=look_back_days)
start_date = start_dt.strftime("%Y-%m-%d")
news_str = ""
for article in all_news[:limit]:
# Handle both flat and nested structures
if "content" in article:
data = _extract_article_data(article)
# Skip articles published after curr_date (look-ahead guard)
if data.get("pub_date"):
pub_naive = data["pub_date"].replace(tzinfo=None) if hasattr(data["pub_date"], "replace") else data["pub_date"]
if pub_naive > curr_dt + relativedelta(days=1):
continue
title = data["title"]
publisher = data["publisher"]
link = data["link"]
summary = data["summary"]
else:
title = article.get("title", "No title")
publisher = article.get("publisher", "Unknown")
link = article.get("link", "")
summary = ""
def get_social_sentiment_yfinance(
symbol: str,
start_date: str,
end_date: str,
) -> str:
articles, _, _ = fetch_company_news_yfinance(symbol, start_date, end_date)
if not articles:
return (
f"Dedicated social provider unavailable; no news-derived sentiment was found for {symbol} "
f"between {start_date} and {end_date}."
)
news_str += f"### {title} (source: {publisher})\n"
if summary:
news_str += f"{summary}\n"
if link:
news_str += f"Link: {link}\n"
news_str += "\n"
report_lines = [
f"Dedicated social provider unavailable; using news-derived sentiment for {symbol} from {start_date} to {end_date}.",
"Use this as public-narrative context rather than a literal social-media feed.",
"",
]
for item in articles[:10]:
date_prefix = item.published_at.strftime("%Y-%m-%d") if item.published_at else "undated"
summary = item.summary or "No summary available."
report_lines.append(f"- {date_prefix}: {item.title} ({item.source})")
report_lines.append(f" Narrative: {summary}")
return "\n".join(report_lines)
return f"## Global Market News, from {start_date} to {curr_date}:\n\n{news_str}"
except Exception as e:
return f"Error fetching global news: {str(e)}"
# Backward-compatible aliases
get_news_yfinance = get_company_news_yfinance
get_global_news_yfinance = get_macro_news_yfinance

View File

@ -29,19 +29,31 @@ DEFAULT_CONFIG = {
# Internal agent debate stays in English for reasoning quality
"output_language": "English",
# Debate and discussion settings
"max_debate_rounds": 1,
"max_risk_discuss_rounds": 1,
"max_debate_rounds": 2,
"max_risk_discuss_rounds": 2,
"max_recur_limit": 100,
"market_country": "US",
"timezone": "US/Eastern",
"enable_no_trade": True,
"vendor_timeout": 15,
"empty_result_fallback": True,
"memory_n_matches": 3,
# Data vendor configuration
# Category-level configuration (default for all tools in category)
"data_vendors": {
"core_stock_apis": "yfinance", # Options: alpha_vantage, yfinance
"technical_indicators": "yfinance", # Options: alpha_vantage, yfinance
"fundamental_data": "yfinance", # Options: alpha_vantage, yfinance
"news_data": "yfinance", # Options: alpha_vantage, yfinance
"news_data": "alpha_vantage,yfinance", # Options: alpha_vantage, yfinance, naver
"macro_data": "alpha_vantage,yfinance", # Options: alpha_vantage, yfinance, ecos
"disclosure_data": "opendart", # Options: opendart
"social_data": "yfinance", # Options: yfinance, naver
},
# Tool-level configuration (takes precedence over category-level)
"tool_vendors": {
# Example: "get_stock_data": "alpha_vantage", # Override category default
# Example: "get_company_news": "naver,yfinance", # Override category default
# Example: "get_macro_news": "ecos,alpha_vantage,yfinance",
# Example: "get_stock_data": "alpha_vantage",
},
"api_keys_path": str(Path(__file__).resolve().parents[1] / "Docs" / "list_api_keys.md"),
}

View File

@ -0,0 +1 @@
"""Evaluation helpers for TradingAgents."""

View File

@ -0,0 +1,155 @@
from __future__ import annotations
import argparse
import json
from collections import defaultdict
from copy import deepcopy
from datetime import datetime, timedelta
from pathlib import Path
from typing import Iterable
import pandas as pd
import yfinance as yf
from tradingagents.default_config import DEFAULT_CONFIG
from tradingagents.graph.trading_graph import TradingAgentsGraph
RATING_TO_EXPOSURE = {
"BUY": 1.0,
"OVERWEIGHT": 0.5,
"HOLD": 0.0,
"UNDERWEIGHT": -0.5,
"SELL": -1.0,
"NO_TRADE": 0.0,
}
def _fetch_forward_return(symbol: str, trade_date: str, holding_period: int) -> float | None:
start_dt = datetime.strptime(trade_date, "%Y-%m-%d")
end_dt = start_dt + timedelta(days=max(holding_period * 3, 10))
history = yf.Ticker(symbol).history(start=trade_date, end=end_dt.strftime("%Y-%m-%d"))
if history.empty or "Close" not in history:
return None
closes = history["Close"].dropna()
if len(closes) < 2:
return None
entry_price = float(closes.iloc[0])
exit_index = min(holding_period, len(closes) - 1)
exit_price = float(closes.iloc[exit_index])
return (exit_price / entry_price) - 1.0
def _compute_max_drawdown(returns: Iterable[float]) -> float:
cumulative = pd.Series(list(returns)).fillna(0.0).add(1.0).cumprod()
running_max = cumulative.cummax()
drawdown = (cumulative / running_max) - 1.0
return float(drawdown.min()) if not drawdown.empty else 0.0
def run_walk_forward_evaluation(
symbols: list[str],
trade_dates: list[str],
*,
holding_period: int = 5,
benchmark_symbol: str = "SPY",
graph_config: dict | None = None,
selected_analysts: list[str] | None = None,
enable_reflection: bool = False,
) -> dict:
config = deepcopy(graph_config or DEFAULT_CONFIG)
graph = TradingAgentsGraph(
config=config,
selected_analysts=selected_analysts or ["market", "social", "news", "fundamentals"],
)
records: list[dict] = []
previous_exposure = 0.0
for trade_date in trade_dates:
benchmark_return = _fetch_forward_return(benchmark_symbol, trade_date, holding_period)
for symbol in symbols:
final_state, rating = graph.propagate(symbol, trade_date)
asset_return = _fetch_forward_return(final_state["company_of_interest"], trade_date, holding_period)
if asset_return is None:
continue
exposure = RATING_TO_EXPOSURE.get(rating, 0.0)
strategy_return = exposure * asset_return
turnover = abs(exposure - previous_exposure)
previous_exposure = exposure
if enable_reflection:
graph.reflect_and_remember(strategy_return)
country = (final_state.get("instrument_profile") or {}).get("country", "UNKNOWN")
records.append(
{
"symbol": final_state["company_of_interest"],
"input_instrument": final_state.get("input_instrument", symbol),
"country": country,
"trade_date": trade_date,
"rating": rating,
"asset_return": asset_return,
"strategy_return": strategy_return,
"benchmark_return": benchmark_return,
"excess_return": None if benchmark_return is None else strategy_return - benchmark_return,
"turnover": turnover,
}
)
if not records:
return {"records": [], "metrics": {}}
df = pd.DataFrame(records)
bucket_metrics = df.groupby("rating")["asset_return"].mean().to_dict()
region_metrics = (
df.groupby("country")["strategy_return"]
.agg(["mean", "count"])
.rename(columns={"mean": "avg_strategy_return"})
.to_dict(orient="index")
)
metrics = {
"hit_rate": float((df["strategy_return"] > 0).mean()),
"forward_return_by_rating_bucket": bucket_metrics,
"turnover": float(df["turnover"].mean()),
"max_drawdown": _compute_max_drawdown(df["strategy_return"].tolist()),
"benchmark_excess_return": float(df["excess_return"].dropna().mean()) if df["excess_return"].notna().any() else None,
"abstain_frequency": float((df["rating"] == "NO_TRADE").mean()),
"region_split_metrics": region_metrics,
}
return {"records": records, "metrics": metrics}
def main():
parser = argparse.ArgumentParser(description="Run a simple walk-forward evaluation for TradingAgents.")
parser.add_argument("--symbols", nargs="+", required=True, help="Instrument inputs, such as AAPL or 005930")
parser.add_argument("--trade-dates", nargs="+", required=True, help="Trade dates in YYYY-MM-DD format")
parser.add_argument("--holding-period", type=int, default=5, help="Forward holding period in trading days")
parser.add_argument("--benchmark", default="SPY", help="Benchmark ticker for excess-return comparison")
parser.add_argument("--enable-reflection", action="store_true", help="Call reflect_and_remember after each evaluated trade")
parser.add_argument("--output", default=None, help="Optional JSON output path")
args = parser.parse_args()
result = run_walk_forward_evaluation(
symbols=args.symbols,
trade_dates=args.trade_dates,
holding_period=args.holding_period,
benchmark_symbol=args.benchmark,
enable_reflection=args.enable_reflection,
)
rendered = json.dumps(result, indent=2, ensure_ascii=False)
if args.output:
output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
output_path.write_text(rendered, encoding="utf-8")
else:
print(rendered)
if __name__ == "__main__":
main()

View File

@ -1,17 +1,19 @@
# TradingAgents/graph/propagation.py
from typing import Dict, Any, List, Optional
from typing import Any, Dict, List, Optional
from tradingagents.agents.utils.agent_states import (
AgentState,
InvestDebateState,
RiskDebateState,
)
from tradingagents.agents.utils.instrument_resolver import resolve_instrument
class Propagator:
"""Handles state initialization and propagation through the graph."""
def __init__(self, max_recur_limit=100):
def __init__(self, max_recur_limit: int = 100):
"""Initialize with configuration parameters."""
self.max_recur_limit = max_recur_limit
@ -22,9 +24,12 @@ class Propagator:
analysis_date: str | None = None,
) -> Dict[str, Any]:
"""Create the initial state for the agent graph."""
instrument_profile = resolve_instrument(company_name)
return {
"messages": [("human", company_name)],
"company_of_interest": company_name,
"messages": [("human", instrument_profile.primary_symbol)],
"input_instrument": company_name,
"company_of_interest": instrument_profile.primary_symbol,
"instrument_profile": instrument_profile.to_dict(),
"trade_date": str(trade_date),
"analysis_date": str(analysis_date or trade_date),
"investment_debate_state": InvestDebateState(

View File

@ -2,6 +2,8 @@
from typing import Any, Dict
from tradingagents.schemas import parse_structured_decision
class Reflector:
"""Handles reflection on decisions and updating memory."""
@ -93,28 +95,43 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur
"""Reflect on trader's decision and update memory."""
situation = self._extract_current_situation(current_state)
trader_decision = current_state["trader_investment_plan"]
metadata = {"returns_losses": returns_losses, "component": "trader"}
try:
metadata["rating"] = parse_structured_decision(trader_decision).rating.value
except Exception:
pass
result = self._reflect_on_component(
"TRADER", trader_decision, situation, returns_losses
)
trader_memory.add_situations([(situation, result)])
trader_memory.add_situations([(situation, result, metadata)])
def reflect_invest_judge(self, current_state, returns_losses, invest_judge_memory):
"""Reflect on investment judge's decision and update memory."""
situation = self._extract_current_situation(current_state)
judge_decision = current_state["investment_debate_state"]["judge_decision"]
metadata = {"returns_losses": returns_losses, "component": "research_manager"}
try:
metadata["rating"] = parse_structured_decision(judge_decision).rating.value
except Exception:
pass
result = self._reflect_on_component(
"INVEST JUDGE", judge_decision, situation, returns_losses
)
invest_judge_memory.add_situations([(situation, result)])
invest_judge_memory.add_situations([(situation, result, metadata)])
def reflect_portfolio_manager(self, current_state, returns_losses, portfolio_manager_memory):
"""Reflect on portfolio manager's decision and update memory."""
situation = self._extract_current_situation(current_state)
judge_decision = current_state["risk_debate_state"]["judge_decision"]
metadata = {"returns_losses": returns_losses, "component": "portfolio_manager"}
try:
metadata["rating"] = parse_structured_decision(judge_decision).rating.value
except Exception:
pass
result = self._reflect_on_component(
"PORTFOLIO MANAGER", judge_decision, situation, returns_losses
)
portfolio_manager_memory.add_situations([(situation, result)])
portfolio_manager_memory.add_situations([(situation, result, metadata)])

View File

@ -1,33 +1,14 @@
# TradingAgents/graph/signal_processing.py
from typing import Any
from tradingagents.schemas import parse_structured_decision
class SignalProcessor:
"""Processes trading signals to extract actionable decisions."""
"""Processes structured trading signals deterministically."""
def __init__(self, quick_thinking_llm: Any):
"""Initialize with an LLM for processing."""
def __init__(self, quick_thinking_llm):
self.quick_thinking_llm = quick_thinking_llm
def process_signal(self, full_signal: str) -> str:
"""
Process a full trading signal to extract the core decision.
Args:
full_signal: Complete trading signal text
Returns:
Extracted rating (BUY, OVERWEIGHT, HOLD, UNDERWEIGHT, or SELL)
"""
messages = [
(
"system",
"You are an efficient assistant that extracts the trading decision from analyst reports. "
"Extract the rating as exactly one of: BUY, OVERWEIGHT, HOLD, UNDERWEIGHT, SELL. "
"Output only the single rating word, nothing else.",
),
("human", full_signal),
]
return self.quick_thinking_llm.invoke(messages).content
decision = parse_structured_decision(full_signal)
return decision.rating.value

View File

@ -19,6 +19,7 @@ from tradingagents.agents.utils.agent_states import (
RiskDebateState,
)
from tradingagents.dataflows.config import set_config
from tradingagents.schemas import StructuredDecisionValidationError, parse_structured_decision
# Import the new abstract tool methods from agent_utils
from tradingagents.agents.utils.agent_utils import (
@ -28,9 +29,13 @@ from tradingagents.agents.utils.agent_utils import (
get_balance_sheet,
get_cashflow,
get_income_statement,
get_company_news,
get_disclosures,
get_macro_news,
get_news,
get_insider_transactions,
get_global_news,
get_social_sentiment,
get_output_language,
rewrite_in_output_language,
)
@ -123,7 +128,7 @@ class TradingAgentsGraph:
self.conditional_logic,
)
self.propagator = Propagator()
self.propagator = Propagator(self.config["max_recur_limit"])
self.reflector = Reflector(self.quick_thinking_llm)
self.signal_processor = SignalProcessor(self.quick_thinking_llm)
@ -179,15 +184,17 @@ class TradingAgentsGraph:
),
"social": ToolNode(
[
# News tools for social media analysis
get_news,
# Dedicated or news-derived sentiment tools
get_social_sentiment,
get_company_news,
]
),
"news": ToolNode(
[
# News and insider information
get_news,
get_global_news,
# News, macro, and disclosure information
get_company_news,
get_macro_news,
get_disclosures,
get_insider_transactions,
]
),
@ -198,6 +205,7 @@ class TradingAgentsGraph:
get_balance_sheet,
get_cashflow,
get_income_statement,
get_insider_transactions,
]
),
}
@ -205,12 +213,11 @@ class TradingAgentsGraph:
def propagate(self, company_name, trade_date, analysis_date=None):
"""Run the trading agents graph for a company on a specific date."""
self.ticker = company_name
# Initialize state
init_agent_state = self.propagator.create_initial_state(
company_name, trade_date, analysis_date=analysis_date
)
self.ticker = init_agent_state["company_of_interest"]
args = self.propagator.get_graph_args()
if self.debug:
@ -243,7 +250,9 @@ class TradingAgentsGraph:
def _log_state(self, trade_date, final_state):
"""Log the final state to a JSON file."""
self.log_states_dict[str(trade_date)] = {
"input_instrument": final_state.get("input_instrument", final_state["company_of_interest"]),
"company_of_interest": final_state["company_of_interest"],
"instrument_profile": final_state.get("instrument_profile", {}),
"trade_date": final_state["trade_date"],
"analysis_date": final_state.get("analysis_date", final_state["trade_date"]),
"market_report": final_state["market_report"],
@ -311,6 +320,17 @@ class TradingAgentsGraph:
localized = dict(final_state)
def maybe_localize(content: str, *, content_type: str) -> str:
try:
parse_structured_decision(content)
return content
except StructuredDecisionValidationError:
return rewrite_in_output_language(
self.quick_thinking_llm,
content,
content_type=content_type,
)
for field_name, content_type in (
("market_report", "market analyst report"),
("sentiment_report", "social sentiment report"),
@ -320,8 +340,7 @@ class TradingAgentsGraph:
("trader_investment_plan", "trader plan"),
("final_trade_decision", "portfolio manager final decision"),
):
localized[field_name] = rewrite_in_output_language(
self.quick_thinking_llm,
localized[field_name] = maybe_localize(
localized.get(field_name, ""),
content_type=content_type,
)
@ -334,8 +353,7 @@ class TradingAgentsGraph:
("current_response", "investment debate latest response"),
("judge_decision", "research manager decision"),
):
investment_debate[field_name] = rewrite_in_output_language(
self.quick_thinking_llm,
investment_debate[field_name] = maybe_localize(
investment_debate.get(field_name, ""),
content_type=content_type,
)
@ -352,8 +370,7 @@ class TradingAgentsGraph:
("current_neutral_response", "neutral risk analyst latest response"),
("judge_decision", "portfolio manager decision"),
):
risk_debate[field_name] = rewrite_in_output_language(
self.quick_thinking_llm,
risk_debate[field_name] = maybe_localize(
risk_debate.get(field_name, ""),
content_type=content_type,
)

View File

@ -11,6 +11,7 @@ from zoneinfo import ZoneInfo
import yfinance as yf
from tradingagents.agents.utils.instrument_resolver import resolve_instrument
from cli.stats_handler import StatsCallbackHandler
from tradingagents.default_config import DEFAULT_CONFIG
from tradingagents.graph.trading_graph import TradingAgentsGraph
@ -119,6 +120,7 @@ def resolve_trade_date(
ticker: str,
config: ScheduledAnalysisConfig,
) -> str:
normalized_symbol = resolve_instrument(ticker).primary_symbol
mode = config.run.trade_date_mode
if mode == "explicit" and config.run.explicit_trade_date:
return config.run.explicit_trade_date
@ -129,14 +131,14 @@ def resolve_trade_date(
if mode == "previous_business_day":
return _previous_business_day(now.date()).isoformat()
history = yf.Ticker(ticker).history(
history = yf.Ticker(normalized_symbol).history(
period=f"{config.run.latest_market_data_lookback_days}d",
interval="1d",
auto_adjust=False,
)
if history.empty:
raise RuntimeError(
f"Could not resolve the latest available trade date for {ticker}; yfinance returned no rows."
f"Could not resolve the latest available trade date for {ticker} ({normalized_symbol}); yfinance returned no rows."
)
last_index = history.index[-1]

View File

@ -0,0 +1,19 @@
from .decision import (
DecisionRating,
StructuredDecision,
StructuredDecisionValidationError,
TimeHorizon,
build_decision_output_instructions,
ensure_structured_decision_json,
parse_structured_decision,
)
__all__ = [
"DecisionRating",
"StructuredDecision",
"StructuredDecisionValidationError",
"TimeHorizon",
"build_decision_output_instructions",
"ensure_structured_decision_json",
"parse_structured_decision",
]

View File

@ -0,0 +1,182 @@
from __future__ import annotations
import json
from dataclasses import dataclass
from enum import Enum
from json import JSONDecodeError
from typing import Any, Mapping
class StructuredDecisionValidationError(ValueError):
"""Raised when a decision payload does not match the required schema."""
class DecisionRating(str, Enum):
BUY = "BUY"
OVERWEIGHT = "OVERWEIGHT"
HOLD = "HOLD"
UNDERWEIGHT = "UNDERWEIGHT"
SELL = "SELL"
NO_TRADE = "NO_TRADE"
class TimeHorizon(str, Enum):
SHORT = "short"
MEDIUM = "medium"
LONG = "long"
@dataclass(frozen=True)
class StructuredDecision:
rating: DecisionRating
confidence: float
time_horizon: TimeHorizon
entry_logic: str
exit_logic: str
position_sizing: str
risk_limits: str
catalysts: tuple[str, ...]
invalidators: tuple[str, ...]
def to_dict(self) -> dict[str, Any]:
return {
"rating": self.rating.value,
"confidence": self.confidence,
"time_horizon": self.time_horizon.value,
"entry_logic": self.entry_logic,
"exit_logic": self.exit_logic,
"position_sizing": self.position_sizing,
"risk_limits": self.risk_limits,
"catalysts": list(self.catalysts),
"invalidators": list(self.invalidators),
}
def to_json(self, *, indent: int = 2) -> str:
return json.dumps(self.to_dict(), indent=indent, ensure_ascii=False)
def build_decision_output_instructions(context: str) -> str:
return (
f"Return only one valid JSON object for the {context}. "
"Do not wrap it in markdown fences. "
"The schema is: "
'{"rating":"BUY | OVERWEIGHT | HOLD | UNDERWEIGHT | SELL | NO_TRADE",'
'"confidence":0.0,'
'"time_horizon":"short | medium | long",'
'"entry_logic":"...",'
'"exit_logic":"...",'
'"position_sizing":"...",'
'"risk_limits":"...",'
'"catalysts":["..."],'
'"invalidators":["..."]}. '
"Use an uppercase rating, confidence between 0 and 1 inclusive, and concise but specific strings."
)
def _extract_json_object(payload: str | Mapping[str, Any]) -> Mapping[str, Any]:
if isinstance(payload, Mapping):
return payload
if not isinstance(payload, str) or not payload.strip():
raise StructuredDecisionValidationError("Decision payload must be a non-empty JSON string or mapping.")
text = payload.strip()
if text.startswith("```"):
lines = text.splitlines()
if len(lines) >= 3 and lines[0].startswith("```") and lines[-1].startswith("```"):
text = "\n".join(lines[1:-1]).strip()
try:
parsed = json.loads(text)
if isinstance(parsed, Mapping):
return parsed
except JSONDecodeError:
pass
decoder = json.JSONDecoder()
for index, char in enumerate(text):
if char != "{":
continue
try:
parsed, _ = decoder.raw_decode(text[index:])
except JSONDecodeError:
continue
if isinstance(parsed, Mapping):
return parsed
raise StructuredDecisionValidationError("Could not locate a valid JSON object in the decision payload.")
def _require_string(data: Mapping[str, Any], field_name: str) -> str:
value = data.get(field_name)
if not isinstance(value, str) or not value.strip():
raise StructuredDecisionValidationError(f"Field '{field_name}' must be a non-empty string.")
return value.strip()
def _require_string_list(data: Mapping[str, Any], field_name: str) -> tuple[str, ...]:
value = data.get(field_name)
if not isinstance(value, list):
raise StructuredDecisionValidationError(f"Field '{field_name}' must be a list of strings.")
normalized: list[str] = []
for item in value:
if not isinstance(item, str) or not item.strip():
raise StructuredDecisionValidationError(
f"Field '{field_name}' must contain only non-empty strings."
)
normalized.append(item.strip())
return tuple(normalized)
def parse_structured_decision(payload: str | Mapping[str, Any]) -> StructuredDecision:
data = _extract_json_object(payload)
missing_fields = {
"rating",
"confidence",
"time_horizon",
"entry_logic",
"exit_logic",
"position_sizing",
"risk_limits",
"catalysts",
"invalidators",
} - set(data.keys())
if missing_fields:
missing = ", ".join(sorted(missing_fields))
raise StructuredDecisionValidationError(f"Decision payload is missing required fields: {missing}.")
try:
rating = DecisionRating(str(data["rating"]).strip().upper())
except ValueError as exc:
raise StructuredDecisionValidationError(f"Unsupported rating: {data.get('rating')!r}.") from exc
try:
confidence = float(data["confidence"])
except (TypeError, ValueError) as exc:
raise StructuredDecisionValidationError("Field 'confidence' must be numeric.") from exc
if not 0.0 <= confidence <= 1.0:
raise StructuredDecisionValidationError("Field 'confidence' must be between 0 and 1 inclusive.")
try:
time_horizon = TimeHorizon(str(data["time_horizon"]).strip().lower())
except ValueError as exc:
raise StructuredDecisionValidationError(
f"Unsupported time horizon: {data.get('time_horizon')!r}."
) from exc
return StructuredDecision(
rating=rating,
confidence=confidence,
time_horizon=time_horizon,
entry_logic=_require_string(data, "entry_logic"),
exit_logic=_require_string(data, "exit_logic"),
position_sizing=_require_string(data, "position_sizing"),
risk_limits=_require_string(data, "risk_limits"),
catalysts=_require_string_list(data, "catalysts"),
invalidators=_require_string_list(data, "invalidators"),
)
def ensure_structured_decision_json(payload: str | Mapping[str, Any]) -> str:
return parse_structured_decision(payload).to_json()