Refine Chinese debate output and data routing
This commit is contained in:
parent
4641c03340
commit
9d9a159d37
|
|
@ -4,3 +4,8 @@ GOOGLE_API_KEY=
|
|||
ANTHROPIC_API_KEY=
|
||||
XAI_API_KEY=
|
||||
OPENROUTER_API_KEY=
|
||||
|
||||
# Data Providers
|
||||
ALPHA_VANTAGE_API_KEY=
|
||||
TUSHARE_TOKEN=
|
||||
BRAVE_SEARCH_API_KEY=
|
||||
|
|
|
|||
|
|
@ -217,3 +217,7 @@ __marimo__/
|
|||
|
||||
# Cache
|
||||
**/data_cache/
|
||||
|
||||
# Local generated outputs
|
||||
results/
|
||||
reports/
|
||||
|
|
|
|||
|
|
@ -128,7 +128,8 @@ export GOOGLE_API_KEY=... # Google (Gemini)
|
|||
export ANTHROPIC_API_KEY=... # Anthropic (Claude)
|
||||
export XAI_API_KEY=... # xAI (Grok)
|
||||
export OPENROUTER_API_KEY=... # OpenRouter
|
||||
export ALPHA_VANTAGE_API_KEY=... # Alpha Vantage
|
||||
export TUSHARE_TOKEN=... # Tushare (A-share / HK / US price and fundamentals)
|
||||
export BRAVE_SEARCH_API_KEY=... # Brave Search (news search)
|
||||
```
|
||||
|
||||
For local models, configure Ollama with `llm_provider: "ollama"` in your config.
|
||||
|
|
|
|||
220
cli/main.py
220
cli/main.py
|
|
@ -3,6 +3,7 @@ import datetime
|
|||
import typer
|
||||
from pathlib import Path
|
||||
from functools import wraps
|
||||
import re
|
||||
from rich.console import Console
|
||||
from dotenv import load_dotenv
|
||||
|
||||
|
|
@ -29,6 +30,13 @@ from cli.models import AnalystType
|
|||
from cli.utils import *
|
||||
from cli.announcements import fetch_announcements, display_announcements
|
||||
from cli.stats_handler import StatsCallbackHandler
|
||||
from tradingagents.agents.utils.agent_utils import (
|
||||
extract_feedback_snapshot,
|
||||
get_output_language,
|
||||
is_feedback_snapshot_inferred,
|
||||
localize_role_name,
|
||||
strip_feedback_snapshot,
|
||||
)
|
||||
|
||||
console = Console()
|
||||
|
||||
|
|
@ -229,6 +237,136 @@ class MessageBuffer:
|
|||
message_buffer = MessageBuffer()
|
||||
|
||||
|
||||
RESEARCH_SPEAKER_ALIASES = {
|
||||
"Bull Researcher": ("Bull Researcher", "Bull Analyst", "多头分析师"),
|
||||
"Bear Researcher": ("Bear Researcher", "Bear Analyst", "空头分析师"),
|
||||
}
|
||||
|
||||
RISK_SPEAKER_ALIASES = {
|
||||
"Aggressive Analyst": ("Aggressive Analyst", "激进分析师"),
|
||||
"Conservative Analyst": ("Conservative Analyst", "保守分析师"),
|
||||
"Neutral Analyst": ("Neutral Analyst", "中性分析师"),
|
||||
}
|
||||
|
||||
DISPLAY_ROLE_NAMES = {
|
||||
"Bull Researcher": "多头分析师",
|
||||
"Bear Researcher": "空头分析师",
|
||||
"Research Manager": "研究经理",
|
||||
"Aggressive Analyst": "激进分析师",
|
||||
"Conservative Analyst": "保守分析师",
|
||||
"Neutral Analyst": "中性分析师",
|
||||
"Portfolio Manager": "投资组合经理",
|
||||
}
|
||||
|
||||
|
||||
def _build_speaker_pattern(speaker_aliases: dict[str, tuple[str, ...]]) -> re.Pattern[str]:
|
||||
aliases = []
|
||||
for names in speaker_aliases.values():
|
||||
aliases.extend(names)
|
||||
escaped = sorted((re.escape(name) for name in aliases), key=len, reverse=True)
|
||||
return re.compile(rf"(?m)^\s*({'|'.join(escaped)})\s*[::]\s*")
|
||||
|
||||
|
||||
def _split_history_into_turns(history: str, speaker_aliases: dict[str, tuple[str, ...]]) -> list[str]:
|
||||
if not history or not history.strip():
|
||||
return []
|
||||
|
||||
pattern = _build_speaker_pattern(speaker_aliases)
|
||||
matches = list(pattern.finditer(history))
|
||||
if not matches:
|
||||
return [history.strip()]
|
||||
|
||||
turns = []
|
||||
for idx, match in enumerate(matches):
|
||||
start = match.start()
|
||||
end = matches[idx + 1].start() if idx + 1 < len(matches) else len(history)
|
||||
turn = history[start:end].strip()
|
||||
if turn:
|
||||
turns.append(turn)
|
||||
return turns
|
||||
|
||||
|
||||
def _format_grouped_rounds(
|
||||
histories: dict[str, str],
|
||||
speaker_aliases: dict[str, tuple[str, ...]],
|
||||
manager_title: Optional[str] = None,
|
||||
manager_content: str = "",
|
||||
) -> str:
|
||||
output_language = get_output_language().strip().lower()
|
||||
is_chinese = output_language in {"chinese", "中文", "zh", "zh-cn", "zh-hans"}
|
||||
turns_by_speaker = {
|
||||
speaker: _split_history_into_turns(histories.get(speaker, ""), speaker_aliases)
|
||||
for speaker in speaker_aliases
|
||||
}
|
||||
max_rounds = max((len(turns) for turns in turns_by_speaker.values()), default=0)
|
||||
|
||||
parts = []
|
||||
for round_index in range(max_rounds):
|
||||
round_parts = []
|
||||
for speaker, turns in turns_by_speaker.items():
|
||||
if round_index < len(turns):
|
||||
turn = turns[round_index]
|
||||
argument_body = strip_feedback_snapshot(turn)
|
||||
snapshot = extract_feedback_snapshot(turn)
|
||||
speaker_title = DISPLAY_ROLE_NAMES.get(speaker, speaker) if is_chinese else speaker
|
||||
speaker_parts = [f"#### {speaker_title}"]
|
||||
if argument_body:
|
||||
speaker_parts.append(argument_body)
|
||||
if snapshot:
|
||||
inferred_snapshot = is_feedback_snapshot_inferred(turn)
|
||||
if is_chinese:
|
||||
snapshot_title = "自动复盘" if inferred_snapshot else "本轮复盘"
|
||||
else:
|
||||
snapshot_title = "Auto Review" if inferred_snapshot else "Round Review"
|
||||
speaker_parts.append(f"##### {snapshot_title}\n" + snapshot)
|
||||
round_parts.append("\n\n".join(speaker_parts))
|
||||
if round_parts:
|
||||
round_title = f"第 {round_index + 1} 轮" if is_chinese else f"Round {round_index + 1}"
|
||||
parts.append(f"### {round_title}\n\n" + "\n\n".join(round_parts))
|
||||
|
||||
if manager_title and manager_content and manager_content.strip():
|
||||
parts.append(f"### {manager_title}\n{manager_content.strip()}")
|
||||
|
||||
return "\n\n".join(parts).strip()
|
||||
|
||||
|
||||
def format_research_team_history(debate_state: dict) -> str:
|
||||
output_language = get_output_language().strip().lower()
|
||||
manager_title = (
|
||||
"研究经理结论"
|
||||
if output_language in {"chinese", "中文", "zh", "zh-cn", "zh-hans"}
|
||||
else "Research Manager Decision"
|
||||
)
|
||||
return _format_grouped_rounds(
|
||||
{
|
||||
"Bull Researcher": debate_state.get("bull_history", ""),
|
||||
"Bear Researcher": debate_state.get("bear_history", ""),
|
||||
},
|
||||
RESEARCH_SPEAKER_ALIASES,
|
||||
manager_title=manager_title,
|
||||
manager_content=debate_state.get("judge_decision", ""),
|
||||
)
|
||||
|
||||
|
||||
def format_risk_management_history(risk_state: dict) -> str:
|
||||
output_language = get_output_language().strip().lower()
|
||||
manager_title = (
|
||||
"投资组合经理结论"
|
||||
if output_language in {"chinese", "中文", "zh", "zh-cn", "zh-hans"}
|
||||
else "Portfolio Manager Decision"
|
||||
)
|
||||
return _format_grouped_rounds(
|
||||
{
|
||||
"Aggressive Analyst": risk_state.get("aggressive_history", ""),
|
||||
"Conservative Analyst": risk_state.get("conservative_history", ""),
|
||||
"Neutral Analyst": risk_state.get("neutral_history", ""),
|
||||
},
|
||||
RISK_SPEAKER_ALIASES,
|
||||
manager_title=manager_title,
|
||||
manager_content=risk_state.get("judge_decision", ""),
|
||||
)
|
||||
|
||||
|
||||
def create_layout():
|
||||
layout = Layout()
|
||||
layout.split_column(
|
||||
|
|
@ -679,9 +817,14 @@ def save_report_to_disk(final_state, ticker: str, save_path: Path):
|
|||
if debate.get("judge_decision"):
|
||||
research_dir.mkdir(exist_ok=True)
|
||||
(research_dir / "manager.md").write_text(debate["judge_decision"])
|
||||
research_parts.append(("Research Manager", debate["judge_decision"]))
|
||||
formatted_research = format_research_team_history(debate)
|
||||
if formatted_research:
|
||||
research_dir.mkdir(exist_ok=True)
|
||||
(research_dir / "rounds.md").write_text(formatted_research)
|
||||
if research_parts:
|
||||
content = "\n\n".join(f"### {name}\n{text}" for name, text in research_parts)
|
||||
content = formatted_research or "\n\n".join(
|
||||
f"### {name}\n{text}" for name, text in research_parts
|
||||
)
|
||||
sections.append(f"## II. Research Team Decision\n\n{content}")
|
||||
|
||||
# 3. Trading
|
||||
|
|
@ -708,8 +851,14 @@ def save_report_to_disk(final_state, ticker: str, save_path: Path):
|
|||
risk_dir.mkdir(exist_ok=True)
|
||||
(risk_dir / "neutral.md").write_text(risk["neutral_history"])
|
||||
risk_parts.append(("Neutral Analyst", risk["neutral_history"]))
|
||||
formatted_risk = format_risk_management_history(risk)
|
||||
if formatted_risk:
|
||||
risk_dir.mkdir(exist_ok=True)
|
||||
(risk_dir / "rounds.md").write_text(formatted_risk)
|
||||
if risk_parts:
|
||||
content = "\n\n".join(f"### {name}\n{text}" for name, text in risk_parts)
|
||||
content = formatted_risk or "\n\n".join(
|
||||
f"### {name}\n{text}" for name, text in risk_parts
|
||||
)
|
||||
sections.append(f"## IV. Risk Management Team Decision\n\n{content}")
|
||||
|
||||
# 5. Portfolio Manager
|
||||
|
|
@ -748,17 +897,17 @@ def display_complete_report(final_state):
|
|||
# II. Research Team Reports
|
||||
if final_state.get("investment_debate_state"):
|
||||
debate = final_state["investment_debate_state"]
|
||||
research = []
|
||||
if debate.get("bull_history"):
|
||||
research.append(("Bull Researcher", debate["bull_history"]))
|
||||
if debate.get("bear_history"):
|
||||
research.append(("Bear Researcher", debate["bear_history"]))
|
||||
if debate.get("judge_decision"):
|
||||
research.append(("Research Manager", debate["judge_decision"]))
|
||||
if research:
|
||||
formatted_research = format_research_team_history(debate)
|
||||
if formatted_research:
|
||||
console.print(Panel("[bold]II. Research Team Decision[/bold]", border_style="magenta"))
|
||||
for title, content in research:
|
||||
console.print(Panel(Markdown(content), title=title, border_style="blue", padding=(1, 2)))
|
||||
console.print(
|
||||
Panel(
|
||||
Markdown(formatted_research),
|
||||
title="Research Team",
|
||||
border_style="blue",
|
||||
padding=(1, 2),
|
||||
)
|
||||
)
|
||||
|
||||
# III. Trading Team
|
||||
if final_state.get("trader_investment_plan"):
|
||||
|
|
@ -768,17 +917,17 @@ def display_complete_report(final_state):
|
|||
# IV. Risk Management Team
|
||||
if final_state.get("risk_debate_state"):
|
||||
risk = final_state["risk_debate_state"]
|
||||
risk_reports = []
|
||||
if risk.get("aggressive_history"):
|
||||
risk_reports.append(("Aggressive Analyst", risk["aggressive_history"]))
|
||||
if risk.get("conservative_history"):
|
||||
risk_reports.append(("Conservative Analyst", risk["conservative_history"]))
|
||||
if risk.get("neutral_history"):
|
||||
risk_reports.append(("Neutral Analyst", risk["neutral_history"]))
|
||||
if risk_reports:
|
||||
formatted_risk = format_risk_management_history(risk)
|
||||
if formatted_risk:
|
||||
console.print(Panel("[bold]IV. Risk Management Team Decision[/bold]", border_style="red"))
|
||||
for title, content in risk_reports:
|
||||
console.print(Panel(Markdown(content), title=title, border_style="blue", padding=(1, 2)))
|
||||
console.print(
|
||||
Panel(
|
||||
Markdown(formatted_risk),
|
||||
title="Risk Management Team",
|
||||
border_style="blue",
|
||||
padding=(1, 2),
|
||||
)
|
||||
)
|
||||
|
||||
# V. Portfolio Manager Decision
|
||||
if risk.get("judge_decision"):
|
||||
|
|
@ -1084,22 +1233,16 @@ def run_analysis():
|
|||
bull_hist = debate_state.get("bull_history", "").strip()
|
||||
bear_hist = debate_state.get("bear_history", "").strip()
|
||||
judge = debate_state.get("judge_decision", "").strip()
|
||||
formatted_research = format_research_team_history(debate_state)
|
||||
|
||||
# Only update status when there's actual content
|
||||
if bull_hist or bear_hist:
|
||||
update_research_team_status("in_progress")
|
||||
if bull_hist:
|
||||
if formatted_research:
|
||||
message_buffer.update_report_section(
|
||||
"investment_plan", f"### Bull Researcher Analysis\n{bull_hist}"
|
||||
)
|
||||
if bear_hist:
|
||||
message_buffer.update_report_section(
|
||||
"investment_plan", f"### Bear Researcher Analysis\n{bear_hist}"
|
||||
"investment_plan", formatted_research
|
||||
)
|
||||
if judge:
|
||||
message_buffer.update_report_section(
|
||||
"investment_plan", f"### Research Manager Decision\n{judge}"
|
||||
)
|
||||
update_research_team_status("completed")
|
||||
message_buffer.update_agent_status("Trader", "in_progress")
|
||||
|
||||
|
|
@ -1119,31 +1262,24 @@ def run_analysis():
|
|||
con_hist = risk_state.get("conservative_history", "").strip()
|
||||
neu_hist = risk_state.get("neutral_history", "").strip()
|
||||
judge = risk_state.get("judge_decision", "").strip()
|
||||
formatted_risk = format_risk_management_history(risk_state)
|
||||
|
||||
if agg_hist:
|
||||
if message_buffer.agent_status.get("Aggressive Analyst") != "completed":
|
||||
message_buffer.update_agent_status("Aggressive Analyst", "in_progress")
|
||||
message_buffer.update_report_section(
|
||||
"final_trade_decision", f"### Aggressive Analyst Analysis\n{agg_hist}"
|
||||
)
|
||||
if con_hist:
|
||||
if message_buffer.agent_status.get("Conservative Analyst") != "completed":
|
||||
message_buffer.update_agent_status("Conservative Analyst", "in_progress")
|
||||
message_buffer.update_report_section(
|
||||
"final_trade_decision", f"### Conservative Analyst Analysis\n{con_hist}"
|
||||
)
|
||||
if neu_hist:
|
||||
if message_buffer.agent_status.get("Neutral Analyst") != "completed":
|
||||
message_buffer.update_agent_status("Neutral Analyst", "in_progress")
|
||||
if formatted_risk:
|
||||
message_buffer.update_report_section(
|
||||
"final_trade_decision", f"### Neutral Analyst Analysis\n{neu_hist}"
|
||||
"final_trade_decision", formatted_risk
|
||||
)
|
||||
if judge:
|
||||
if message_buffer.agent_status.get("Portfolio Manager") != "completed":
|
||||
message_buffer.update_agent_status("Portfolio Manager", "in_progress")
|
||||
message_buffer.update_report_section(
|
||||
"final_trade_decision", f"### Portfolio Manager Decision\n{judge}"
|
||||
)
|
||||
message_buffer.update_agent_status("Aggressive Analyst", "completed")
|
||||
message_buffer.update_agent_status("Conservative Analyst", "completed")
|
||||
message_buffer.update_agent_status("Neutral Analyst", "completed")
|
||||
|
|
|
|||
22
cli/utils.py
22
cli/utils.py
|
|
@ -189,21 +189,21 @@ def select_deep_thinking_agent(provider) -> str:
|
|||
|
||||
def select_llm_provider() -> tuple[str, str]:
|
||||
"""Select the OpenAI api url using interactive selection."""
|
||||
# Define OpenAI api options with their corresponding endpoints
|
||||
# Define provider options as (display_name, provider_key, endpoint)
|
||||
BASE_URLS = [
|
||||
("OpenAI", "https://api.openai.com/v1"),
|
||||
("Google", "https://generativelanguage.googleapis.com/v1"),
|
||||
("Anthropic", "https://api.anthropic.com/"),
|
||||
("xAI", "https://api.x.ai/v1"),
|
||||
("Openrouter", "https://openrouter.ai/api/v1"),
|
||||
("Ollama", "http://localhost:11434/v1"),
|
||||
("OpenAI", "openai", "https://api.openai.com/v1"),
|
||||
("Google", "google", "https://generativelanguage.googleapis.com/v1"),
|
||||
("Anthropic", "anthropic", "https://api.anthropic.com/"),
|
||||
("xAI", "xai", "https://api.x.ai/v1"),
|
||||
("Openrouter", "openrouter", "https://openrouter.ai/api/v1"),
|
||||
("Ollama / llama.cpp", "ollama", "http://localhost:4000/v1"),
|
||||
]
|
||||
|
||||
choice = questionary.select(
|
||||
"Select your LLM Provider:",
|
||||
choices=[
|
||||
questionary.Choice(display, value=(display, value))
|
||||
for display, value in BASE_URLS
|
||||
questionary.Choice(display, value=(provider_key, endpoint, display))
|
||||
for display, provider_key, endpoint in BASE_URLS
|
||||
],
|
||||
instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
|
||||
style=questionary.Style(
|
||||
|
|
@ -219,10 +219,10 @@ def select_llm_provider() -> tuple[str, str]:
|
|||
console.print("\n[red]no OpenAI backend selected. Exiting...[/red]")
|
||||
exit(1)
|
||||
|
||||
display_name, url = choice
|
||||
provider_key, url, display_name = choice
|
||||
print(f"You selected: {display_name}\tURL: {url}")
|
||||
|
||||
return display_name, url
|
||||
return provider_key, url
|
||||
|
||||
|
||||
def ask_openai_reasoning_effort() -> str:
|
||||
|
|
|
|||
24
main.py
24
main.py
|
|
@ -11,13 +11,27 @@ config = DEFAULT_CONFIG.copy()
|
|||
config["deep_think_llm"] = "gpt-5.4-mini" # Use a different model
|
||||
config["quick_think_llm"] = "gpt-5.4-mini" # Use a different model
|
||||
config["max_debate_rounds"] = 1 # Increase debate rounds
|
||||
# Example for local OpenAI-compatible llama.cpp server:
|
||||
# config["llm_provider"] = "ollama"
|
||||
# config["backend_url"] = "http://localhost:4000/v1"
|
||||
|
||||
# Configure data vendors (default uses yfinance, no extra API keys needed)
|
||||
# Configure data vendors
|
||||
config["data_vendors"] = {
|
||||
"core_stock_apis": "yfinance", # Options: alpha_vantage, yfinance
|
||||
"technical_indicators": "yfinance", # Options: alpha_vantage, yfinance
|
||||
"fundamental_data": "yfinance", # Options: alpha_vantage, yfinance
|
||||
"news_data": "yfinance", # Options: alpha_vantage, yfinance
|
||||
"core_stock_apis": "tushare,yfinance", # Options: tushare, yfinance
|
||||
"technical_indicators": "tushare,yfinance", # Options: tushare, yfinance
|
||||
"fundamental_data": "tushare,yfinance", # Options: tushare, yfinance
|
||||
"news_data": "opencli,brave,yfinance", # Options: opencli, brave, yfinance
|
||||
}
|
||||
config["tool_vendors"] = {
|
||||
"get_stock_data": "tushare",
|
||||
"get_indicators": "tushare",
|
||||
"get_fundamentals": "tushare",
|
||||
"get_balance_sheet": "tushare",
|
||||
"get_cashflow": "tushare",
|
||||
"get_income_statement": "tushare",
|
||||
"get_news": "opencli",
|
||||
"get_global_news": "opencli",
|
||||
"get_insider_transactions": "tushare,yfinance",
|
||||
}
|
||||
|
||||
# Initialize with custom config
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ dependencies = [
|
|||
"tqdm>=4.67.1",
|
||||
"typing-extensions>=4.14.0",
|
||||
"yfinance>=0.2.63",
|
||||
"tushare>=1.4.21",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,23 @@
|
|||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
from cli.utils import select_llm_provider
|
||||
|
||||
|
||||
class CliProviderSelectionTests(unittest.TestCase):
|
||||
@patch("cli.utils.questionary.select")
|
||||
def test_select_llm_provider_returns_internal_provider_key(self, mock_select):
|
||||
mock_select.return_value.ask.return_value = (
|
||||
"ollama",
|
||||
"http://localhost:4000/v1",
|
||||
"Ollama / llama.cpp",
|
||||
)
|
||||
|
||||
provider, url = select_llm_provider()
|
||||
|
||||
self.assertEqual(provider, "ollama")
|
||||
self.assertEqual(url, "http://localhost:4000/v1")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
@ -0,0 +1,140 @@
|
|||
import unittest
|
||||
|
||||
from cli.main import format_research_team_history, format_risk_management_history
|
||||
from tradingagents.dataflows.config import get_config, set_config
|
||||
from tradingagents.default_config import DEFAULT_CONFIG
|
||||
|
||||
|
||||
class CliRoundFormattingTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.original_config = get_config().copy()
|
||||
cfg = DEFAULT_CONFIG.copy()
|
||||
cfg["output_language"] = "Chinese"
|
||||
set_config(cfg)
|
||||
|
||||
def tearDown(self):
|
||||
set_config(self.original_config)
|
||||
|
||||
def test_research_team_history_is_grouped_by_round(self):
|
||||
debate_state = {
|
||||
"bull_history": (
|
||||
"多头分析师: 第一轮多头观点\n"
|
||||
"反馈快照:\n"
|
||||
"- 当前观点: 买入\n"
|
||||
"- 发生了什么变化: 强化多头\n"
|
||||
"- 为什么变化: 金价走强\n"
|
||||
"- 关键反驳: 估值担忧可控\n"
|
||||
"- 下一轮教训: 跟踪量价\n"
|
||||
"多头分析师: 第二轮多头补充\n"
|
||||
"反馈快照:\n"
|
||||
"- 当前观点: 强烈买入\n"
|
||||
"- 发生了什么变化: 更激进\n"
|
||||
"- 为什么变化: 避险升级\n"
|
||||
"- 关键反驳: 回撤是买点\n"
|
||||
"- 下一轮教训: 盯并购兑现"
|
||||
),
|
||||
"bear_history": (
|
||||
"空头分析师: 第一轮空头观点\n"
|
||||
"反馈快照:\n"
|
||||
"- 当前观点: 持有\n"
|
||||
"- 发生了什么变化: 维持谨慎\n"
|
||||
"- 为什么变化: 估值偏高\n"
|
||||
"- 关键反驳: 上涨已透支\n"
|
||||
"- 下一轮教训: 看现金流\n"
|
||||
"空头分析师: 第二轮空头反驳\n"
|
||||
"反馈快照:\n"
|
||||
"- 当前观点: 减持\n"
|
||||
"- 发生了什么变化: 转向更谨慎\n"
|
||||
"- 为什么变化: 风险升高\n"
|
||||
"- 关键反驳: 高位放量\n"
|
||||
"- 下一轮教训: 盯库存"
|
||||
),
|
||||
"judge_decision": "研究经理: 最终结论",
|
||||
}
|
||||
|
||||
formatted = format_research_team_history(debate_state)
|
||||
|
||||
self.assertIn("### 第 1 轮", formatted)
|
||||
self.assertIn("#### 多头分析师\n\n多头分析师: 第一轮多头观点", formatted)
|
||||
self.assertIn("##### 本轮复盘\n反馈快照:\n- 当前观点: 买入", formatted)
|
||||
self.assertIn("- 发生了什么变化: 强化多头", formatted)
|
||||
self.assertIn("- 下一轮教训: 跟踪量价", formatted)
|
||||
self.assertIn("#### 空头分析师\n\n空头分析师: 第一轮空头观点", formatted)
|
||||
self.assertIn("### 第 2 轮", formatted)
|
||||
self.assertIn("#### 多头分析师\n\n多头分析师: 第二轮多头补充", formatted)
|
||||
self.assertIn("- 发生了什么变化: 更激进", formatted)
|
||||
self.assertIn("- 下一轮教训: 盯并购兑现", formatted)
|
||||
self.assertIn("#### 空头分析师\n\n空头分析师: 第二轮空头反驳", formatted)
|
||||
self.assertIn("- 发生了什么变化: 转向更谨慎", formatted)
|
||||
self.assertIn("- 下一轮教训: 盯库存", formatted)
|
||||
self.assertTrue(formatted.endswith("### 研究经理结论\n研究经理: 最终结论"))
|
||||
|
||||
def test_risk_management_history_supports_english_prefixes(self):
|
||||
risk_state = {
|
||||
"aggressive_history": (
|
||||
"Aggressive Analyst: Round 1 aggressive case\n"
|
||||
"FEEDBACK SNAPSHOT:\n"
|
||||
"- Current thesis: Sell\n"
|
||||
"- What changed: More defensive\n"
|
||||
"- Why it changed: Momentum broke\n"
|
||||
"- Key rebuttal: Upside is capped\n"
|
||||
"- Lesson for next round: Watch liquidity\n"
|
||||
"Aggressive Analyst: Round 2 aggressive follow-up"
|
||||
),
|
||||
"conservative_history": (
|
||||
"Conservative Analyst: Round 1 conservative case\n"
|
||||
"FEEDBACK SNAPSHOT:\n"
|
||||
"- Current thesis: Hold\n"
|
||||
"- What changed: Stayed cautious\n"
|
||||
"- Why it changed: Valuation rich\n"
|
||||
"- Key rebuttal: Do not chase\n"
|
||||
"- Lesson for next round: Check earnings"
|
||||
),
|
||||
"neutral_history": (
|
||||
"Neutral Analyst: Round 1 neutral case\n"
|
||||
"FEEDBACK SNAPSHOT:\n"
|
||||
"- Current thesis: Hold\n"
|
||||
"- What changed: Balanced both sides\n"
|
||||
"- Why it changed: Conflicting signals\n"
|
||||
"- Key rebuttal: Need confirmation\n"
|
||||
"- Lesson for next round: Wait for breakout"
|
||||
),
|
||||
"judge_decision": "Portfolio Manager: Final allocation",
|
||||
}
|
||||
|
||||
formatted = format_risk_management_history(risk_state)
|
||||
|
||||
self.assertIn("### 第 1 轮", formatted)
|
||||
self.assertIn("#### 激进分析师\n\nAggressive Analyst: Round 1 aggressive case", formatted)
|
||||
self.assertIn("##### 本轮复盘\nFEEDBACK SNAPSHOT:\n- Current thesis: Sell", formatted)
|
||||
self.assertIn("- What changed: More defensive", formatted)
|
||||
self.assertIn("- Lesson for next round: Watch liquidity", formatted)
|
||||
self.assertIn("#### 保守分析师\n\nConservative Analyst: Round 1 conservative case", formatted)
|
||||
self.assertIn("#### 中性分析师\n\nNeutral Analyst: Round 1 neutral case", formatted)
|
||||
self.assertIn("### 第 2 轮", formatted)
|
||||
self.assertIn("#### 激进分析师\n\nAggressive Analyst: Round 2 aggressive follow-up", formatted)
|
||||
self.assertIn("### 投资组合经理结论\nPortfolio Manager: Final allocation", formatted)
|
||||
|
||||
def test_inferred_snapshot_uses_auto_review_title(self):
|
||||
debate_state = {
|
||||
"bull_history": (
|
||||
"多头分析师: 本轮新增了对库存风险的反驳,并强调需要继续跟踪金价与并购进度。\n"
|
||||
"反馈快照:\n"
|
||||
"- 当前观点: 暂无。\n"
|
||||
"- 发生了什么变化: 未明确说明。\n"
|
||||
"- 为什么变化: 未明确说明。\n"
|
||||
"- 关键反驳: 未明确说明。\n"
|
||||
"- 下一轮教训: 未明确说明。"
|
||||
),
|
||||
"bear_history": "",
|
||||
"judge_decision": "",
|
||||
}
|
||||
|
||||
formatted = format_research_team_history(debate_state)
|
||||
|
||||
self.assertIn("##### 自动复盘", formatted)
|
||||
self.assertNotIn("##### 本轮复盘", formatted)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
@ -0,0 +1,33 @@
|
|||
import unittest
|
||||
|
||||
from tradingagents.graph.conditional_logic import ConditionalLogic
|
||||
|
||||
|
||||
class ConditionalLogicLocalizationTests(unittest.TestCase):
|
||||
def test_should_continue_debate_uses_latest_speaker_not_localized_response_prefix(self):
|
||||
logic = ConditionalLogic(max_debate_rounds=1, max_risk_discuss_rounds=1)
|
||||
state = {
|
||||
"investment_debate_state": {
|
||||
"count": 1,
|
||||
"latest_speaker": "Bull Analyst",
|
||||
"current_response": "多头分析师: 这是中文前缀,不应影响路由",
|
||||
}
|
||||
}
|
||||
|
||||
self.assertEqual(logic.should_continue_debate(state), "Bear Researcher")
|
||||
|
||||
def test_should_continue_debate_still_returns_research_manager_when_rounds_complete(self):
|
||||
logic = ConditionalLogic(max_debate_rounds=1, max_risk_discuss_rounds=1)
|
||||
state = {
|
||||
"investment_debate_state": {
|
||||
"count": 2,
|
||||
"latest_speaker": "Bull Analyst",
|
||||
"current_response": "多头分析师: 已完成一轮",
|
||||
}
|
||||
}
|
||||
|
||||
self.assertEqual(logic.should_continue_debate(state), "Research Manager")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
@ -0,0 +1,193 @@
|
|||
import copy
|
||||
import unittest
|
||||
|
||||
from tradingagents.default_config import DEFAULT_CONFIG
|
||||
from tradingagents.dataflows.config import get_config, set_config
|
||||
from tradingagents.agents.utils.agent_utils import (
|
||||
build_debate_brief,
|
||||
extract_feedback_snapshot,
|
||||
get_snapshot_template,
|
||||
strip_feedback_snapshot,
|
||||
truncate_for_prompt,
|
||||
)
|
||||
from tradingagents.agents.utils.memory import FinancialSituationMemory
|
||||
|
||||
|
||||
class ContextMemoryOptimizationTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.original_config = copy.deepcopy(get_config())
|
||||
|
||||
def tearDown(self):
|
||||
set_config(self.original_config)
|
||||
|
||||
def test_truncate_for_prompt_uses_config_limit(self):
|
||||
cfg = copy.deepcopy(DEFAULT_CONFIG)
|
||||
cfg["report_context_char_limit"] = 10
|
||||
set_config(cfg)
|
||||
|
||||
text = "abcdefghijklmnopqrstuvwxyz"
|
||||
truncated = truncate_for_prompt(text)
|
||||
|
||||
self.assertIn("Content trimmed", truncated)
|
||||
self.assertTrue(truncated.endswith("qrstuvwxyz"))
|
||||
|
||||
def test_memory_similarity_threshold_filters_irrelevant_matches(self):
|
||||
situations = [
|
||||
("apple earnings beat with margin expansion", "prefer bullish setups"),
|
||||
("oil demand collapse and weak refinery margins", "reduce cyclical exposure"),
|
||||
]
|
||||
strict_memory = FinancialSituationMemory(
|
||||
"strict_memory",
|
||||
config={"memory_min_similarity": 0.99},
|
||||
)
|
||||
strict_memory.add_situations(situations)
|
||||
|
||||
unrelated = strict_memory.get_memories("football world cup final highlights", n_matches=2)
|
||||
self.assertEqual(unrelated, [])
|
||||
|
||||
related = strict_memory.get_memories("apple margin expansion after earnings", n_matches=2)
|
||||
self.assertEqual(related, [])
|
||||
|
||||
relaxed_memory = FinancialSituationMemory(
|
||||
"relaxed_memory",
|
||||
config={"memory_min_similarity": 0.0},
|
||||
)
|
||||
relaxed_memory.add_situations(situations)
|
||||
related = relaxed_memory.get_memories("apple margin expansion after earnings", n_matches=2)
|
||||
self.assertGreaterEqual(len(related), 1)
|
||||
|
||||
def test_feedback_snapshot_helpers(self):
|
||||
response = (
|
||||
"Argument body here.\n\n"
|
||||
"FEEDBACK SNAPSHOT:\n"
|
||||
"- Current thesis: Bull case improved.\n"
|
||||
"- What changed: Margin outlook improved.\n"
|
||||
"- Why it changed: Earnings beat.\n"
|
||||
"- Key rebuttal: Bear margin fears are weaker.\n"
|
||||
"- Lesson for next round: Track valuation risk."
|
||||
)
|
||||
|
||||
snapshot = extract_feedback_snapshot(response)
|
||||
body = strip_feedback_snapshot(response)
|
||||
brief = build_debate_brief(
|
||||
{
|
||||
"Bull Analyst": snapshot,
|
||||
"Bear Analyst": "FEEDBACK SNAPSHOT:\n- Current thesis: Bear case unchanged.",
|
||||
},
|
||||
latest_speaker="Bull Analyst",
|
||||
)
|
||||
|
||||
self.assertIn("Current thesis", snapshot)
|
||||
self.assertEqual(body, "Argument body here.")
|
||||
self.assertIn("Latest update came from: Bull Analyst", brief)
|
||||
self.assertIn("Bull Analyst latest snapshot", brief)
|
||||
|
||||
def test_build_debate_brief_localizes_summary_phrases_for_chinese(self):
|
||||
cfg = copy.deepcopy(DEFAULT_CONFIG)
|
||||
cfg["output_language"] = "Chinese"
|
||||
set_config(cfg)
|
||||
|
||||
brief = build_debate_brief(
|
||||
{
|
||||
"Bull Analyst": "反馈快照:\n- 当前观点: 多头增强。",
|
||||
"Bear Analyst": "反馈快照:\n- 当前观点: 空头不变。",
|
||||
},
|
||||
latest_speaker="Bull Analyst",
|
||||
)
|
||||
|
||||
self.assertIn("最新更新来自: 多头分析师", brief)
|
||||
self.assertIn("多头分析师 最新快照", brief)
|
||||
|
||||
def test_feedback_snapshot_helpers_support_chinese_template(self):
|
||||
cfg = copy.deepcopy(DEFAULT_CONFIG)
|
||||
cfg["output_language"] = "Chinese"
|
||||
set_config(cfg)
|
||||
|
||||
self.assertIn("反馈快照", get_snapshot_template())
|
||||
|
||||
response = (
|
||||
"论证正文。\n\n"
|
||||
"反馈快照:\n"
|
||||
"- 当前观点: 多头逻辑增强。\n"
|
||||
"- 发生了什么变化: 利润率预期改善。\n"
|
||||
"- 为什么变化: 财报超预期。\n"
|
||||
"- 关键反驳: 空头对利润率的担忧减弱。\n"
|
||||
"- 下一轮教训: 继续跟踪估值风险。"
|
||||
)
|
||||
|
||||
snapshot = extract_feedback_snapshot(response)
|
||||
body = strip_feedback_snapshot(response)
|
||||
|
||||
self.assertIn("当前观点", snapshot)
|
||||
self.assertEqual(body, "论证正文。")
|
||||
|
||||
def test_feedback_snapshot_infers_substantive_chinese_content_when_placeholder_block_is_used(self):
|
||||
cfg = copy.deepcopy(DEFAULT_CONFIG)
|
||||
cfg["output_language"] = "Chinese"
|
||||
set_config(cfg)
|
||||
|
||||
response = (
|
||||
"论证正文。本轮新增了对库存风险的反驳,并强调需要继续跟踪金价与并购进度。\n\n"
|
||||
"反馈快照:\n"
|
||||
"- 当前观点: 暂无。\n"
|
||||
"- 发生了什么变化: 未明确说明。\n"
|
||||
"- 为什么变化: 未明确说明。\n"
|
||||
"- 关键反驳: 未明确说明。\n"
|
||||
"- 下一轮教训: 未明确说明。"
|
||||
)
|
||||
|
||||
snapshot = extract_feedback_snapshot(response)
|
||||
|
||||
self.assertNotIn("未明确说明", snapshot)
|
||||
self.assertNotIn("暂无", snapshot)
|
||||
self.assertIn("- 当前观点: 持有", snapshot)
|
||||
self.assertIn("库存风险", snapshot)
|
||||
self.assertIn("继续跟踪金价与并购进度", snapshot)
|
||||
|
||||
def test_feedback_snapshot_fills_empty_fields_with_inferred_content(self):
|
||||
cfg = copy.deepcopy(DEFAULT_CONFIG)
|
||||
cfg["output_language"] = "Chinese"
|
||||
set_config(cfg)
|
||||
|
||||
response = (
|
||||
"论证正文。本轮转向更谨慎,核心原因是估值偏高且需要继续跟踪成交量是否萎缩。\n\n"
|
||||
"反馈快照:\n"
|
||||
"- 当前观点:\n"
|
||||
"- 发生了什么变化:\n"
|
||||
"- 为什么变化: 估值偏高。\n"
|
||||
"- 关键反驳:\n"
|
||||
"- 下一轮教训:"
|
||||
)
|
||||
|
||||
snapshot = extract_feedback_snapshot(response)
|
||||
|
||||
self.assertIn("- 当前观点: 持有", snapshot)
|
||||
self.assertIn("- 发生了什么变化:", snapshot)
|
||||
self.assertIn("- 为什么变化: 估值偏高。", snapshot)
|
||||
self.assertIn("- 关键反驳:", snapshot)
|
||||
self.assertIn("- 下一轮教训:", snapshot)
|
||||
self.assertNotIn("- 当前观点: \n", snapshot)
|
||||
self.assertNotIn("- 关键反驳: \n", snapshot)
|
||||
|
||||
def test_feedback_snapshot_detects_explicit_chinese_rating_terms(self):
|
||||
cfg = copy.deepcopy(DEFAULT_CONFIG)
|
||||
cfg["output_language"] = "Chinese"
|
||||
set_config(cfg)
|
||||
|
||||
response = (
|
||||
"我们维持减持观点,建议分批止盈,等待更好的风险收益比。\n\n"
|
||||
"反馈快照:\n"
|
||||
"- 当前观点:\n"
|
||||
"- 发生了什么变化:\n"
|
||||
"- 为什么变化:\n"
|
||||
"- 关键反驳:\n"
|
||||
"- 下一轮教训:"
|
||||
)
|
||||
|
||||
snapshot = extract_feedback_snapshot(response)
|
||||
|
||||
self.assertIn("- 当前观点: 减持", snapshot)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
@ -0,0 +1,244 @@
|
|||
import copy
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
from tradingagents.default_config import DEFAULT_CONFIG
|
||||
from tradingagents.dataflows.config import get_config, set_config
|
||||
from tradingagents.dataflows.exceptions import DataVendorUnavailable
|
||||
from tradingagents.dataflows.interface import VENDOR_LIST, VENDOR_METHODS, route_to_vendor
|
||||
|
||||
|
||||
class DataVendorRoutingTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.original_config = copy.deepcopy(get_config())
|
||||
|
||||
def tearDown(self):
|
||||
set_config(self.original_config)
|
||||
|
||||
def _base_config(self):
|
||||
cfg = copy.deepcopy(DEFAULT_CONFIG)
|
||||
cfg["tool_vendors"] = {}
|
||||
return cfg
|
||||
|
||||
def test_fallback_when_primary_vendor_unavailable(self):
|
||||
cfg = self._base_config()
|
||||
cfg["data_vendors"]["core_stock_apis"] = "tushare,yfinance"
|
||||
set_config(cfg)
|
||||
|
||||
def _primary(*_args, **_kwargs):
|
||||
raise DataVendorUnavailable("tushare unavailable")
|
||||
|
||||
def _fallback(*_args, **_kwargs):
|
||||
return "fallback-ok"
|
||||
|
||||
with patch.dict(
|
||||
VENDOR_METHODS,
|
||||
{
|
||||
"get_stock_data": {
|
||||
"tushare": _primary,
|
||||
"yfinance": _fallback,
|
||||
}
|
||||
},
|
||||
clear=False,
|
||||
):
|
||||
result = route_to_vendor("get_stock_data", "000001.SZ", "2024-01-01", "2024-01-02")
|
||||
|
||||
self.assertEqual(result, "fallback-ok")
|
||||
|
||||
def test_tool_level_vendor_overrides_category_vendor(self):
|
||||
cfg = self._base_config()
|
||||
cfg["data_vendors"]["news_data"] = "yfinance"
|
||||
cfg["tool_vendors"] = {"get_news": "opencli"}
|
||||
set_config(cfg)
|
||||
|
||||
def _opencli(*_args, **_kwargs):
|
||||
return "opencli-news"
|
||||
|
||||
def _yfinance(*_args, **_kwargs):
|
||||
return "yfinance-news"
|
||||
|
||||
with patch.dict(
|
||||
VENDOR_METHODS,
|
||||
{
|
||||
"get_news": {
|
||||
"opencli": _opencli,
|
||||
"yfinance": _yfinance,
|
||||
}
|
||||
},
|
||||
clear=False,
|
||||
):
|
||||
result = route_to_vendor("get_news", "AAPL", "2024-01-01", "2024-01-02")
|
||||
|
||||
self.assertEqual(result, "opencli-news")
|
||||
|
||||
def test_global_news_is_pinned_to_opencli(self):
|
||||
cfg = self._base_config()
|
||||
cfg["tool_vendors"] = {"get_global_news": "opencli"}
|
||||
set_config(cfg)
|
||||
|
||||
def _opencli(*_args, **_kwargs):
|
||||
return "opencli-global"
|
||||
|
||||
def _fallback(*_args, **_kwargs):
|
||||
return "fallback-global"
|
||||
|
||||
with patch.dict(
|
||||
VENDOR_METHODS,
|
||||
{
|
||||
"get_global_news": {
|
||||
"opencli": _opencli,
|
||||
"yfinance": _fallback,
|
||||
}
|
||||
},
|
||||
clear=False,
|
||||
):
|
||||
result = route_to_vendor("get_global_news", "2024-01-02", 7, 5)
|
||||
|
||||
self.assertEqual(result, "opencli-global")
|
||||
|
||||
def test_price_and_fundamentals_are_hard_pinned_to_tushare(self):
|
||||
cfg = self._base_config()
|
||||
cfg["data_vendors"]["core_stock_apis"] = "yfinance"
|
||||
cfg["data_vendors"]["technical_indicators"] = "yfinance"
|
||||
cfg["data_vendors"]["fundamental_data"] = "yfinance"
|
||||
cfg["tool_vendors"] = {
|
||||
"get_stock_data": "tushare",
|
||||
"get_indicators": "tushare",
|
||||
"get_fundamentals": "tushare",
|
||||
"get_balance_sheet": "tushare",
|
||||
"get_cashflow": "tushare",
|
||||
"get_income_statement": "tushare",
|
||||
}
|
||||
set_config(cfg)
|
||||
|
||||
touched = []
|
||||
|
||||
def _record(name):
|
||||
def _inner(*_args, **_kwargs):
|
||||
touched.append(name)
|
||||
return name
|
||||
return _inner
|
||||
|
||||
with patch.dict(
|
||||
VENDOR_METHODS,
|
||||
{
|
||||
"get_stock_data": {"tushare": _record("stock_tushare"), "yfinance": _record("stock_yf")},
|
||||
"get_indicators": {"tushare": _record("ind_tushare"), "yfinance": _record("ind_yf")},
|
||||
"get_fundamentals": {"tushare": _record("fund_tushare"), "yfinance": _record("fund_yf")},
|
||||
"get_balance_sheet": {"tushare": _record("bs_tushare"), "yfinance": _record("bs_yf")},
|
||||
"get_cashflow": {"tushare": _record("cf_tushare"), "yfinance": _record("cf_yf")},
|
||||
"get_income_statement": {"tushare": _record("is_tushare"), "yfinance": _record("is_yf")},
|
||||
},
|
||||
clear=False,
|
||||
):
|
||||
self.assertEqual(route_to_vendor("get_stock_data", "000001.SZ", "2024-01-01", "2024-01-02"), "stock_tushare")
|
||||
self.assertEqual(route_to_vendor("get_indicators", "000001.SZ", "macd", "2024-01-02", 30), "ind_tushare")
|
||||
self.assertEqual(route_to_vendor("get_fundamentals", "000001.SZ", "2024-01-02"), "fund_tushare")
|
||||
self.assertEqual(route_to_vendor("get_balance_sheet", "000001.SZ", "quarterly", "2024-01-02"), "bs_tushare")
|
||||
self.assertEqual(route_to_vendor("get_cashflow", "000001.SZ", "quarterly", "2024-01-02"), "cf_tushare")
|
||||
self.assertEqual(route_to_vendor("get_income_statement", "000001.SZ", "quarterly", "2024-01-02"), "is_tushare")
|
||||
|
||||
self.assertEqual(
|
||||
touched,
|
||||
[
|
||||
"stock_tushare",
|
||||
"ind_tushare",
|
||||
"fund_tushare",
|
||||
"bs_tushare",
|
||||
"cf_tushare",
|
||||
"is_tushare",
|
||||
],
|
||||
)
|
||||
|
||||
def test_unsupported_market_returns_explicit_tushare_error(self):
|
||||
cfg = self._base_config()
|
||||
cfg["tool_vendors"] = {"get_stock_data": "tushare"}
|
||||
set_config(cfg)
|
||||
|
||||
def _unsupported(*_args, **_kwargs):
|
||||
raise DataVendorUnavailable(
|
||||
"Tushare currently supports A-share, Hong Kong, and US tickers only, got '7203.T'."
|
||||
)
|
||||
|
||||
with patch.dict(
|
||||
VENDOR_METHODS,
|
||||
{
|
||||
"get_stock_data": {"tushare": _unsupported},
|
||||
},
|
||||
clear=False,
|
||||
):
|
||||
with self.assertRaises(RuntimeError) as ctx:
|
||||
route_to_vendor("get_stock_data", "7203.T", "2024-01-01", "2024-01-02")
|
||||
|
||||
self.assertIn("A-share, Hong Kong, and US tickers only", str(ctx.exception))
|
||||
|
||||
def test_alpha_vantage_is_not_an_available_vendor(self):
|
||||
self.assertNotIn("alpha_vantage", VENDOR_LIST)
|
||||
|
||||
for vendor_map in VENDOR_METHODS.values():
|
||||
self.assertNotIn("alpha_vantage", vendor_map)
|
||||
|
||||
def test_a_share_insider_transactions_prefers_tushare(self):
|
||||
cfg = self._base_config()
|
||||
cfg["data_vendors"]["news_data"] = "opencli,brave,yfinance"
|
||||
cfg["tool_vendors"] = {"get_insider_transactions": "tushare,yfinance"}
|
||||
set_config(cfg)
|
||||
|
||||
touched = []
|
||||
|
||||
def _tushare(*_args, **_kwargs):
|
||||
touched.append("tushare")
|
||||
return [{"insider": "a-share"}]
|
||||
|
||||
def _yfinance(*_args, **_kwargs):
|
||||
touched.append("yfinance")
|
||||
return [{"insider": "example"}]
|
||||
|
||||
with patch.dict(
|
||||
VENDOR_METHODS,
|
||||
{
|
||||
"get_insider_transactions": {
|
||||
"tushare": _tushare,
|
||||
"yfinance": _yfinance,
|
||||
}
|
||||
},
|
||||
clear=False,
|
||||
):
|
||||
result = route_to_vendor("get_insider_transactions", "002155.SZ")
|
||||
|
||||
self.assertEqual(touched, ["tushare"])
|
||||
self.assertEqual(result, [{"insider": "a-share"}])
|
||||
|
||||
def test_non_a_share_insider_transactions_fall_back_to_yfinance(self):
|
||||
cfg = self._base_config()
|
||||
cfg["tool_vendors"] = {"get_insider_transactions": "tushare,yfinance"}
|
||||
set_config(cfg)
|
||||
|
||||
touched = []
|
||||
|
||||
def _tushare(*_args, **_kwargs):
|
||||
touched.append("tushare")
|
||||
raise DataVendorUnavailable("A-share only")
|
||||
|
||||
def _yfinance(*_args, **_kwargs):
|
||||
touched.append("yfinance")
|
||||
return [{"insider": "fallback"}]
|
||||
|
||||
with patch.dict(
|
||||
VENDOR_METHODS,
|
||||
{
|
||||
"get_insider_transactions": {
|
||||
"tushare": _tushare,
|
||||
"yfinance": _yfinance,
|
||||
}
|
||||
},
|
||||
clear=False,
|
||||
):
|
||||
result = route_to_vendor("get_insider_transactions", "AAPL")
|
||||
|
||||
self.assertEqual(touched, ["tushare", "yfinance"])
|
||||
self.assertEqual(result, [{"insider": "fallback"}])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
@ -2,7 +2,7 @@ import unittest
|
|||
import warnings
|
||||
|
||||
from tradingagents.llm_clients.base_client import BaseLLMClient
|
||||
from tradingagents.llm_clients.model_catalog import get_known_models
|
||||
from tradingagents.llm_clients.model_catalog import get_known_models, get_model_options
|
||||
from tradingagents.llm_clients.validators import validate_model
|
||||
|
||||
|
||||
|
|
@ -20,6 +20,19 @@ class DummyLLMClient(BaseLLMClient):
|
|||
|
||||
|
||||
class ModelValidationTests(unittest.TestCase):
|
||||
def test_local_llamacpp_models_are_exposed_in_cli_catalog(self):
|
||||
quick_models = [value for _, value in get_model_options("ollama", "quick")]
|
||||
deep_models = [value for _, value in get_model_options("ollama", "deep")]
|
||||
|
||||
for model in (
|
||||
"Qwen3.5-27B",
|
||||
"Qwen3.5-35B-3A",
|
||||
"Qwen3.5-35B-A3B",
|
||||
"Qwen3.5-122B",
|
||||
):
|
||||
with self.subTest(model=model):
|
||||
self.assertIn(model, quick_models + deep_models)
|
||||
|
||||
def test_cli_catalog_models_are_all_validator_approved(self):
|
||||
for provider, models in get_known_models().items():
|
||||
if provider in ("ollama", "openrouter"):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,23 @@
|
|||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
from tradingagents.llm_clients.openai_client import OpenAIClient
|
||||
|
||||
|
||||
class OpenAICompatibleBaseUrlTests(unittest.TestCase):
|
||||
@patch("tradingagents.llm_clients.openai_client.NormalizedChatOpenAI")
|
||||
def test_ollama_provider_respects_explicit_base_url(self, mock_chat):
|
||||
client = OpenAIClient(
|
||||
"qwen3:latest",
|
||||
base_url="http://localhost:4000/v1",
|
||||
provider="ollama",
|
||||
)
|
||||
client.get_llm()
|
||||
|
||||
kwargs = mock_chat.call_args[1]
|
||||
self.assertEqual(kwargs["base_url"], "http://localhost:4000/v1")
|
||||
self.assertEqual(kwargs["api_key"], "ollama")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
@ -0,0 +1,163 @@
|
|||
import subprocess
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
from tradingagents.dataflows.exceptions import DataVendorUnavailable
|
||||
from tradingagents.dataflows.opencli_news import _resolve_company_aliases, get_global_news, get_news
|
||||
|
||||
|
||||
class OpenCliNewsTests(unittest.TestCase):
|
||||
@patch("tradingagents.dataflows.tushare._get_pro_client")
|
||||
@patch("tradingagents.dataflows.tushare._classify_market", return_value="a_share")
|
||||
@patch("tradingagents.dataflows.tushare._normalize_ts_code", return_value="002155.SZ")
|
||||
def test_resolve_company_aliases_prefers_tushare_name(
|
||||
self,
|
||||
_mock_normalize,
|
||||
_mock_market,
|
||||
mock_pro_client,
|
||||
):
|
||||
class _BasicFrame:
|
||||
empty = False
|
||||
|
||||
class _Row(dict):
|
||||
def get(self, key, default=None):
|
||||
return super().get(key, default)
|
||||
|
||||
@property
|
||||
def iloc(self):
|
||||
class _ILoc:
|
||||
def __getitem__(_self, _idx):
|
||||
return _BasicFrame._Row({"name": "金博股份", "fullname": "湖南金博碳素股份有限公司"})
|
||||
|
||||
return _ILoc()
|
||||
|
||||
mock_pro_client.return_value.stock_basic.return_value = _BasicFrame()
|
||||
|
||||
aliases = _resolve_company_aliases("002155.SZ")
|
||||
|
||||
self.assertEqual(aliases[0], "金博股份")
|
||||
self.assertIn("湖南金博碳素股份有限公司", aliases)
|
||||
self.assertIn("湖南金博碳素", aliases)
|
||||
self.assertIn("002155.SZ", aliases)
|
||||
|
||||
@patch(
|
||||
"tradingagents.dataflows.opencli_news._resolve_company_aliases",
|
||||
return_value=["金博股份", "002155.SZ", "002155"],
|
||||
)
|
||||
@patch("tradingagents.dataflows.opencli_news.shutil.which", return_value="/usr/bin/opencli-rs")
|
||||
@patch("tradingagents.dataflows.opencli_news.subprocess.run")
|
||||
def test_get_news_aggregates_multiple_sources(self, mock_run, _mock_which, _mock_aliases):
|
||||
def _dispatch(cmd, **_kwargs):
|
||||
if cmd[1:3] == ["xueqiu", "search"]:
|
||||
return subprocess.CompletedProcess(
|
||||
args=[],
|
||||
returncode=0,
|
||||
stdout='[{"name":"金博股份","symbol":"002155"}]',
|
||||
stderr="",
|
||||
)
|
||||
if cmd[1:3] == ["weibo", "search"]:
|
||||
return subprocess.CompletedProcess(
|
||||
args=[],
|
||||
returncode=0,
|
||||
stdout='[{"text":"金博股份讨论热度上升","url":"https://example.com/weibo"}]',
|
||||
stderr="",
|
||||
)
|
||||
if cmd[1:3] == ["xiaohongshu", "search"]:
|
||||
return subprocess.CompletedProcess(
|
||||
args=[],
|
||||
returncode=0,
|
||||
stdout='[{"title":"金博股份观察","url":"https://example.com/xhs"}]',
|
||||
stderr="",
|
||||
)
|
||||
if cmd[1:3] == ["sinafinance", "news"]:
|
||||
return subprocess.CompletedProcess(
|
||||
args=[],
|
||||
returncode=0,
|
||||
stdout='[{"content":"金博股份公告带动碳基材料板块走强","time":"2026-04-01 10:00:00","views":"5万"}]',
|
||||
stderr="",
|
||||
)
|
||||
if cmd[1:3] == ["google", "news"]:
|
||||
return subprocess.CompletedProcess(
|
||||
args=[],
|
||||
returncode=0,
|
||||
stdout='[{"title":"金博股份 headline","source":"CNBC","date":"2026-04-01","url":"https://example.com/news"}]',
|
||||
stderr="",
|
||||
)
|
||||
if cmd[1:3] == ["google", "search"]:
|
||||
return subprocess.CompletedProcess(
|
||||
args=[],
|
||||
returncode=0,
|
||||
stdout='[{"title":"百度一下,你就知道 - 金博股份","url":"https://example.com/search"}]',
|
||||
stderr="",
|
||||
)
|
||||
raise AssertionError(f"Unexpected command: {cmd}")
|
||||
|
||||
mock_run.side_effect = _dispatch
|
||||
|
||||
result = get_news("NVDA", "2026-03-25", "2026-04-01")
|
||||
|
||||
self.assertIn("Xueqiu Search", result)
|
||||
self.assertIn("Weibo Search", result)
|
||||
self.assertIn("Xiaohongshu Search", result)
|
||||
self.assertIn("Sina Finance A-Share Flash", result)
|
||||
self.assertIn("Google News", result)
|
||||
self.assertIn("Google Search (ZH)", result)
|
||||
self.assertIn("金博股份 headline", result)
|
||||
first_call = mock_run.call_args_list[0].args[0]
|
||||
self.assertEqual(first_call[0:4], ["/usr/bin/opencli-rs", "xueqiu", "search", "金博股份"])
|
||||
|
||||
@patch("tradingagents.dataflows.opencli_news.shutil.which", return_value="/usr/bin/opencli-rs")
|
||||
@patch("tradingagents.dataflows.opencli_news.subprocess.run")
|
||||
def test_get_global_news_aggregates_market_sources(self, mock_run, _mock_which):
|
||||
mock_run.side_effect = [
|
||||
subprocess.CompletedProcess(
|
||||
args=[],
|
||||
returncode=0,
|
||||
stdout='[{"title":"Macro headline","source":"Reuters","date":"2026-04-01","url":"https://example.com/google"}]',
|
||||
stderr="",
|
||||
),
|
||||
subprocess.CompletedProcess(
|
||||
args=[],
|
||||
returncode=0,
|
||||
stdout='[{"content":"Flash item","time":"2026-04-01 15:50:00","views":"10万"}]',
|
||||
stderr="",
|
||||
),
|
||||
subprocess.CompletedProcess(
|
||||
args=[],
|
||||
returncode=0,
|
||||
stdout='[{"text":"Hot Xueqiu post","author":"alice","likes":12,"url":"https://example.com/xq"}]',
|
||||
stderr="",
|
||||
),
|
||||
subprocess.CompletedProcess(
|
||||
args=[],
|
||||
returncode=0,
|
||||
stdout='[{"word":"Top Weibo topic","category":"财经","hot_value":12345,"url":"https://example.com/wb"}]',
|
||||
stderr="",
|
||||
),
|
||||
]
|
||||
|
||||
result = get_global_news("2026-04-01", 7, 5)
|
||||
|
||||
self.assertIn("Google News Top Stories", result)
|
||||
self.assertIn("Sina Finance Flash News", result)
|
||||
self.assertIn("Xueqiu Hot Discussions", result)
|
||||
self.assertIn("Weibo Hot Topics", result)
|
||||
|
||||
@patch("tradingagents.dataflows.opencli_news.shutil.which", return_value="/usr/bin/opencli-rs")
|
||||
@patch("tradingagents.dataflows.opencli_news.subprocess.run")
|
||||
def test_opencli_failures_surface_in_no_results_message(self, mock_run, _mock_which):
|
||||
mock_run.return_value = subprocess.CompletedProcess(
|
||||
args=[],
|
||||
returncode=1,
|
||||
stdout="",
|
||||
stderr="browser disconnected",
|
||||
)
|
||||
|
||||
result = get_news("NVDA", "2026-03-25", "2026-04-01")
|
||||
|
||||
self.assertIn("No relevant news found via opencli-rs", result)
|
||||
self.assertIn("browser disconnected", result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
@ -0,0 +1,173 @@
|
|||
import copy
|
||||
import unittest
|
||||
|
||||
from tradingagents.agents.managers.portfolio_manager import create_portfolio_manager
|
||||
from tradingagents.agents.managers.research_manager import create_research_manager
|
||||
from tradingagents.agents.researchers.bear_researcher import create_bear_researcher
|
||||
from tradingagents.agents.researchers.bull_researcher import create_bull_researcher
|
||||
from tradingagents.agents.risk_mgmt.aggressive_debator import create_aggressive_debator
|
||||
from tradingagents.agents.risk_mgmt.conservative_debator import create_conservative_debator
|
||||
from tradingagents.agents.risk_mgmt.neutral_debator import create_neutral_debator
|
||||
from tradingagents.agents.trader.trader import create_trader
|
||||
from tradingagents.agents.utils.agent_utils import (
|
||||
get_collaboration_stop_instruction,
|
||||
normalize_chinese_role_terms,
|
||||
)
|
||||
from tradingagents.dataflows.config import get_config, set_config
|
||||
from tradingagents.default_config import DEFAULT_CONFIG
|
||||
|
||||
|
||||
class _FakeResponse:
|
||||
def __init__(self, content="ok"):
|
||||
self.content = content
|
||||
|
||||
|
||||
class _CapturingLLM:
|
||||
def __init__(self):
|
||||
self.calls = []
|
||||
|
||||
def invoke(self, prompt):
|
||||
self.calls.append(prompt)
|
||||
return _FakeResponse("测试输出\n反馈快照:\n- 当前观点: x\n- 发生了什么变化: y\n- 为什么变化: z\n- 关键反驳: r\n- 下一轮教训: l")
|
||||
|
||||
|
||||
class _EmptyMemory:
|
||||
def get_memories(self, *_args, **_kwargs):
|
||||
return []
|
||||
|
||||
|
||||
class OutputLanguagePropagationTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.original_config = copy.deepcopy(get_config())
|
||||
cfg = copy.deepcopy(DEFAULT_CONFIG)
|
||||
cfg["output_language"] = "Chinese"
|
||||
set_config(cfg)
|
||||
|
||||
self.base_state = {
|
||||
"company_of_interest": "002155.SZ",
|
||||
"investment_plan": "Plan",
|
||||
"market_report": "Market report",
|
||||
"sentiment_report": "Sentiment report",
|
||||
"news_report": "News report",
|
||||
"fundamentals_report": "Fundamentals report",
|
||||
"trader_investment_plan": "Trader plan",
|
||||
"risk_debate_state": {
|
||||
"history": "",
|
||||
"aggressive_history": "",
|
||||
"conservative_history": "",
|
||||
"neutral_history": "",
|
||||
"latest_speaker": "",
|
||||
"current_aggressive_response": "",
|
||||
"current_conservative_response": "",
|
||||
"current_neutral_response": "",
|
||||
"aggressive_snapshot": "",
|
||||
"conservative_snapshot": "",
|
||||
"neutral_snapshot": "",
|
||||
"debate_brief": "",
|
||||
"judge_decision": "",
|
||||
"count": 0,
|
||||
},
|
||||
"investment_debate_state": {
|
||||
"history": "",
|
||||
"bear_history": "",
|
||||
"bull_history": "",
|
||||
"current_response": "",
|
||||
"bull_snapshot": "",
|
||||
"bear_snapshot": "",
|
||||
"debate_brief": "",
|
||||
"latest_speaker": "",
|
||||
"count": 0,
|
||||
},
|
||||
}
|
||||
|
||||
def tearDown(self):
|
||||
set_config(self.original_config)
|
||||
|
||||
def test_trader_prompt_respects_output_language(self):
|
||||
llm = _CapturingLLM()
|
||||
node = create_trader(llm, _EmptyMemory())
|
||||
node(self.base_state)
|
||||
|
||||
system_prompt = llm.calls[0][0]["content"]
|
||||
self.assertIn("Write your entire response in Chinese.", system_prompt)
|
||||
|
||||
def test_research_manager_prompt_respects_output_language(self):
|
||||
llm = _CapturingLLM()
|
||||
node = create_research_manager(llm, _EmptyMemory())
|
||||
node(self.base_state)
|
||||
|
||||
prompt = llm.calls[0]
|
||||
self.assertIn("Write your entire response in Chinese.", prompt)
|
||||
self.assertIn("多头分析师", prompt)
|
||||
self.assertIn("空头分析师", prompt)
|
||||
self.assertIn("买入", prompt)
|
||||
self.assertIn("持有", prompt)
|
||||
|
||||
def test_bull_bear_researcher_prompts_require_chinese_body_and_chinese_final_title(self):
|
||||
for factory in (create_bull_researcher, create_bear_researcher):
|
||||
llm = _CapturingLLM()
|
||||
node = factory(llm, _EmptyMemory())
|
||||
node(self.base_state)
|
||||
|
||||
prompt = llm.calls[0]
|
||||
self.assertIn("written entirely in Chinese", prompt)
|
||||
self.assertIn("最终交易建议", prompt)
|
||||
self.assertIn("Write your entire response in Chinese.", prompt)
|
||||
self.assertIn("Do not use variants like", prompt)
|
||||
self.assertIn("牛派分析师", prompt)
|
||||
self.assertIn("熊派分析师", prompt)
|
||||
self.assertIn("禁止填写“未明确说明”“暂无”“同上”“无变化”", prompt)
|
||||
|
||||
def test_normalize_chinese_role_terms_replaces_bull_bear_variants(self):
|
||||
text = "我是熊派分析师,也不同意牛派分析师和熊派投资者的说法。"
|
||||
normalized = normalize_chinese_role_terms(text)
|
||||
|
||||
self.assertNotIn("熊派分析师", normalized)
|
||||
self.assertNotIn("牛派分析师", normalized)
|
||||
self.assertNotIn("熊派投资者", normalized)
|
||||
self.assertIn("空头分析师", normalized)
|
||||
self.assertIn("多头分析师", normalized)
|
||||
self.assertIn("空头投资者", normalized)
|
||||
|
||||
def test_risk_team_prompts_respect_output_language(self):
|
||||
for factory in (
|
||||
create_aggressive_debator,
|
||||
create_conservative_debator,
|
||||
create_neutral_debator,
|
||||
):
|
||||
llm = _CapturingLLM()
|
||||
node = factory(llm)
|
||||
node(self.base_state)
|
||||
|
||||
prompt = llm.calls[0]
|
||||
self.assertIn("Write your entire response in Chinese.", prompt)
|
||||
self.assertIn("反馈快照", prompt)
|
||||
self.assertIn("禁止填写“未明确说明”“暂无”“同上”“无变化”", prompt)
|
||||
|
||||
def test_portfolio_manager_prompt_respects_output_language(self):
|
||||
llm = _CapturingLLM()
|
||||
node = create_portfolio_manager(llm, _EmptyMemory())
|
||||
node(self.base_state)
|
||||
|
||||
prompt = llm.calls[0]
|
||||
self.assertIn("Write your entire response in Chinese.", prompt)
|
||||
self.assertIn("反馈快照", prompt)
|
||||
self.assertIn("激进分析师", prompt)
|
||||
self.assertIn("保守分析师", prompt)
|
||||
self.assertIn("中性分析师", prompt)
|
||||
self.assertIn("评级体系", prompt)
|
||||
self.assertIn("买入", prompt)
|
||||
self.assertIn("增持", prompt)
|
||||
self.assertIn("持有", prompt)
|
||||
self.assertIn("减持", prompt)
|
||||
self.assertIn("卖出", prompt)
|
||||
self.assertIn("禁止填写“未明确说明”“暂无”“同上”“无变化”", prompt)
|
||||
|
||||
def test_collaboration_stop_instruction_prefers_chinese_display(self):
|
||||
instruction = get_collaboration_stop_instruction()
|
||||
self.assertIn("最终交易建议: **买入/持有/卖出**", instruction)
|
||||
self.assertIn("FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL**", instruction)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
@ -0,0 +1,31 @@
|
|||
import unittest
|
||||
|
||||
from tradingagents.graph.signal_processing import SignalProcessor
|
||||
|
||||
|
||||
class _UnusedLLM:
|
||||
def invoke(self, _messages):
|
||||
raise AssertionError("LLM fallback should not be used for normalized ratings")
|
||||
|
||||
|
||||
class SignalProcessingLocalizationTests(unittest.TestCase):
|
||||
def test_normalizes_chinese_final_proposal_markers(self):
|
||||
processor = SignalProcessor(_UnusedLLM())
|
||||
|
||||
self.assertEqual(processor.process_signal("最终交易建议: **买入**"), "BUY")
|
||||
self.assertEqual(processor.process_signal("最终交易建议: **增持**"), "OVERWEIGHT")
|
||||
self.assertEqual(processor.process_signal("最终交易建议: **持有**"), "HOLD")
|
||||
self.assertEqual(processor.process_signal("最终交易建议: **减持**"), "UNDERWEIGHT")
|
||||
self.assertEqual(processor.process_signal("最终交易建议: **卖出**"), "SELL")
|
||||
|
||||
def test_normalizes_english_internal_markers(self):
|
||||
processor = SignalProcessor(_UnusedLLM())
|
||||
|
||||
self.assertEqual(
|
||||
processor.process_signal("FINAL TRANSACTION PROPOSAL: **BUY**"),
|
||||
"BUY",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
@ -5,6 +5,7 @@ from tradingagents.agents.utils.agent_utils import (
|
|||
build_instrument_context,
|
||||
get_balance_sheet,
|
||||
get_cashflow,
|
||||
get_collaboration_stop_instruction,
|
||||
get_fundamentals,
|
||||
get_income_statement,
|
||||
get_insider_transactions,
|
||||
|
|
@ -26,24 +27,25 @@ def create_fundamentals_analyst(llm):
|
|||
]
|
||||
|
||||
system_message = (
|
||||
"You are a researcher tasked with analyzing fundamental information over the past week about a company. Please write a comprehensive report of the company's fundamental information such as financial documents, company profile, basic company financials, and company financial history to gain a full view of the company's fundamental information to inform traders. Make sure to include as much detail as possible. Provide specific, actionable insights with supporting evidence to help traders make informed decisions."
|
||||
"You are a researcher tasked with analyzing fundamental information over the past week about a company. Your fundamental analysis must rely only on the structured accounting and financial statement data returned by the fundamentals tools. Do not use news flow, sentiment, rumors, or macro headlines to justify accounting conclusions. Please write a comprehensive report of the company's fundamental information such as financial documents, company profile, basic company financials, and company financial history to gain a full view of the company's fundamental information to inform traders. Make sure to include as much detail as possible. Provide specific, actionable insights with supporting evidence to help traders make informed decisions."
|
||||
+ " Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read."
|
||||
+ " Use the available tools: `get_fundamentals` for comprehensive company analysis, `get_balance_sheet`, `get_cashflow`, and `get_income_statement` for specific financial statements."
|
||||
+ get_language_instruction(),
|
||||
+ get_language_instruction()
|
||||
)
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
(
|
||||
"system",
|
||||
"You are a helpful AI assistant, collaborating with other assistants."
|
||||
" Use the provided tools to progress towards answering the question."
|
||||
" If you are unable to fully answer, that's OK; another assistant with different tools"
|
||||
" will help where you left off. Execute what you can to make progress."
|
||||
" If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable,"
|
||||
" prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop."
|
||||
" You have access to the following tools: {tool_names}.\n{system_message}"
|
||||
"For your reference, the current date is {current_date}. {instrument_context}",
|
||||
(
|
||||
"You are a helpful AI assistant, collaborating with other assistants."
|
||||
" Use the provided tools to progress towards answering the question."
|
||||
" If you are unable to fully answer, that's OK; another assistant with different tools"
|
||||
" will help where you left off. Execute what you can to make progress."
|
||||
+ get_collaboration_stop_instruction()
|
||||
+ " You have access to the following tools: {tool_names}.\n{system_message}"
|
||||
+ "For your reference, the current date is {current_date}. {instrument_context}"
|
||||
),
|
||||
),
|
||||
MessagesPlaceholder(variable_name="messages"),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import time
|
|||
import json
|
||||
from tradingagents.agents.utils.agent_utils import (
|
||||
build_instrument_context,
|
||||
get_collaboration_stop_instruction,
|
||||
get_indicators,
|
||||
get_language_instruction,
|
||||
get_stock_data,
|
||||
|
|
@ -22,7 +23,7 @@ def create_market_analyst(llm):
|
|||
]
|
||||
|
||||
system_message = (
|
||||
"""You are a trading assistant tasked with analyzing financial markets. Your role is to select the **most relevant indicators** for a given market condition or trading strategy from the following list. The goal is to choose up to **8 indicators** that provide complementary insights without redundancy. Categories and each category's indicators are:
|
||||
"""You are a trading assistant tasked with analyzing financial markets. Your price and technical analysis must rely only on the data returned by `get_stock_data` and `get_indicators`. Do not infer price action, volume behavior, momentum, or technical signals from company news, macro news, or social sentiment. Your role is to select the **most relevant indicators** for a given market condition or trading strategy from the following list. The goal is to choose up to **8 indicators** that provide complementary insights without redundancy. Categories and each category's indicators are:
|
||||
|
||||
Moving Averages:
|
||||
- close_50_sma: 50 SMA: A medium-term trend indicator. Usage: Identify trend direction and serve as dynamic support/resistance. Tips: It lags price; combine with faster indicators for timely signals.
|
||||
|
|
@ -55,14 +56,15 @@ Volume-Based Indicators:
|
|||
[
|
||||
(
|
||||
"system",
|
||||
"You are a helpful AI assistant, collaborating with other assistants."
|
||||
" Use the provided tools to progress towards answering the question."
|
||||
" If you are unable to fully answer, that's OK; another assistant with different tools"
|
||||
" will help where you left off. Execute what you can to make progress."
|
||||
" If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable,"
|
||||
" prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop."
|
||||
" You have access to the following tools: {tool_names}.\n{system_message}"
|
||||
"For your reference, the current date is {current_date}. {instrument_context}",
|
||||
(
|
||||
"You are a helpful AI assistant, collaborating with other assistants."
|
||||
" Use the provided tools to progress towards answering the question."
|
||||
" If you are unable to fully answer, that's OK; another assistant with different tools"
|
||||
" will help where you left off. Execute what you can to make progress."
|
||||
+ get_collaboration_stop_instruction()
|
||||
+ " You have access to the following tools: {tool_names}.\n{system_message}"
|
||||
+ "For your reference, the current date is {current_date}. {instrument_context}"
|
||||
),
|
||||
),
|
||||
MessagesPlaceholder(variable_name="messages"),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import time
|
|||
import json
|
||||
from tradingagents.agents.utils.agent_utils import (
|
||||
build_instrument_context,
|
||||
get_collaboration_stop_instruction,
|
||||
get_global_news,
|
||||
get_language_instruction,
|
||||
get_news,
|
||||
|
|
@ -30,14 +31,15 @@ def create_news_analyst(llm):
|
|||
[
|
||||
(
|
||||
"system",
|
||||
"You are a helpful AI assistant, collaborating with other assistants."
|
||||
" Use the provided tools to progress towards answering the question."
|
||||
" If you are unable to fully answer, that's OK; another assistant with different tools"
|
||||
" will help where you left off. Execute what you can to make progress."
|
||||
" If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable,"
|
||||
" prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop."
|
||||
" You have access to the following tools: {tool_names}.\n{system_message}"
|
||||
"For your reference, the current date is {current_date}. {instrument_context}",
|
||||
(
|
||||
"You are a helpful AI assistant, collaborating with other assistants."
|
||||
" Use the provided tools to progress towards answering the question."
|
||||
" If you are unable to fully answer, that's OK; another assistant with different tools"
|
||||
" will help where you left off. Execute what you can to make progress."
|
||||
+ get_collaboration_stop_instruction()
|
||||
+ " You have access to the following tools: {tool_names}.\n{system_message}"
|
||||
+ "For your reference, the current date is {current_date}. {instrument_context}"
|
||||
),
|
||||
),
|
||||
MessagesPlaceholder(variable_name="messages"),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,7 +1,12 @@
|
|||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
import time
|
||||
import json
|
||||
from tradingagents.agents.utils.agent_utils import build_instrument_context, get_language_instruction, get_news
|
||||
from tradingagents.agents.utils.agent_utils import (
|
||||
build_instrument_context,
|
||||
get_collaboration_stop_instruction,
|
||||
get_language_instruction,
|
||||
get_news,
|
||||
)
|
||||
from tradingagents.dataflows.config import get_config
|
||||
|
||||
|
||||
|
|
@ -24,14 +29,15 @@ def create_social_media_analyst(llm):
|
|||
[
|
||||
(
|
||||
"system",
|
||||
"You are a helpful AI assistant, collaborating with other assistants."
|
||||
" Use the provided tools to progress towards answering the question."
|
||||
" If you are unable to fully answer, that's OK; another assistant with different tools"
|
||||
" will help where you left off. Execute what you can to make progress."
|
||||
" If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable,"
|
||||
" prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop."
|
||||
" You have access to the following tools: {tool_names}.\n{system_message}"
|
||||
"For your reference, the current date is {current_date}. {instrument_context}",
|
||||
(
|
||||
"You are a helpful AI assistant, collaborating with other assistants."
|
||||
" Use the provided tools to progress towards answering the question."
|
||||
" If you are unable to fully answer, that's OK; another assistant with different tools"
|
||||
" will help where you left off. Execute what you can to make progress."
|
||||
+ get_collaboration_stop_instruction()
|
||||
+ " You have access to the following tools: {tool_names}.\n{system_message}"
|
||||
+ "For your reference, the current date is {current_date}. {instrument_context}"
|
||||
),
|
||||
),
|
||||
MessagesPlaceholder(variable_name="messages"),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,18 +1,33 @@
|
|||
from tradingagents.agents.utils.agent_utils import build_instrument_context, get_language_instruction
|
||||
from tradingagents.agents.utils.agent_utils import (
|
||||
build_debate_brief,
|
||||
build_instrument_context,
|
||||
extract_feedback_snapshot,
|
||||
get_language_instruction,
|
||||
get_localized_final_proposal_instruction,
|
||||
get_localized_rating_scale,
|
||||
get_snapshot_template,
|
||||
get_snapshot_writing_instruction,
|
||||
localize_label,
|
||||
localize_rating_term,
|
||||
localize_role_name,
|
||||
truncate_for_prompt,
|
||||
)
|
||||
|
||||
|
||||
def create_portfolio_manager(llm, memory):
|
||||
def portfolio_manager_node(state) -> dict:
|
||||
|
||||
instrument_context = build_instrument_context(state["company_of_interest"])
|
||||
|
||||
history = state["risk_debate_state"]["history"]
|
||||
risk_debate_state = state["risk_debate_state"]
|
||||
market_research_report = state["market_report"]
|
||||
news_report = state["news_report"]
|
||||
fundamentals_report = state["fundamentals_report"]
|
||||
sentiment_report = state["sentiment_report"]
|
||||
trader_plan = state["investment_plan"]
|
||||
market_research_report = truncate_for_prompt(state["market_report"])
|
||||
news_report = truncate_for_prompt(state["news_report"])
|
||||
fundamentals_report = truncate_for_prompt(state["fundamentals_report"])
|
||||
sentiment_report = truncate_for_prompt(state["sentiment_report"])
|
||||
trader_plan = truncate_for_prompt(state["investment_plan"])
|
||||
aggressive_snapshot = risk_debate_state.get("aggressive_snapshot", "")
|
||||
conservative_snapshot = risk_debate_state.get("conservative_snapshot", "")
|
||||
neutral_snapshot = risk_debate_state.get("neutral_snapshot", "")
|
||||
debate_brief = risk_debate_state.get("debate_brief", "")
|
||||
|
||||
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
|
||||
past_memories = memory.get_memories(curr_situation, n_matches=2)
|
||||
|
|
@ -27,32 +42,49 @@ def create_portfolio_manager(llm, memory):
|
|||
|
||||
---
|
||||
|
||||
**Rating Scale** (use exactly one):
|
||||
- **Buy**: Strong conviction to enter or add to position
|
||||
- **Overweight**: Favorable outlook, gradually increase exposure
|
||||
- **Hold**: Maintain current position, no action needed
|
||||
- **Underweight**: Reduce exposure, take partial profits
|
||||
- **Sell**: Exit position or avoid entry
|
||||
{get_localized_rating_scale()}
|
||||
|
||||
**Context:**
|
||||
- Trader's proposed plan: **{trader_plan}**
|
||||
- Lessons from past decisions: **{past_memory_str}**
|
||||
|
||||
**Required Output Structure:**
|
||||
1. **Rating**: State one of Buy / Overweight / Hold / Underweight / Sell.
|
||||
2. **Executive Summary**: A concise action plan covering entry strategy, position sizing, key risk levels, and time horizon.
|
||||
3. **Investment Thesis**: Detailed reasoning anchored in the analysts' debate and past reflections.
|
||||
1. **{localize_label("Rating", "评级")}**: State one of {localize_rating_term("Buy")} / {localize_rating_term("Overweight")} / {localize_rating_term("Hold")} / {localize_rating_term("Underweight")} / {localize_rating_term("Sell")}.
|
||||
2. **{localize_label("Executive Summary", "执行摘要")}**: A concise action plan covering entry strategy, position sizing, key risk levels, and time horizon.
|
||||
3. **{localize_label("Investment Thesis", "投资逻辑")}**: Detailed reasoning anchored in the analysts' debate and past reflections.
|
||||
|
||||
---
|
||||
|
||||
**Risk Analysts Debate History:**
|
||||
{history}
|
||||
**{localize_label("Rolling Risk Debate Brief", "滚动风险辩论摘要")}:**
|
||||
{debate_brief}
|
||||
|
||||
**{localize_label("Aggressive Snapshot", f"{localize_role_name('Aggressive Analyst')} 最新快照")}:**
|
||||
{aggressive_snapshot}
|
||||
|
||||
**{localize_label("Conservative Snapshot", f"{localize_role_name('Conservative Analyst')} 最新快照")}:**
|
||||
{conservative_snapshot}
|
||||
|
||||
**{localize_label("Neutral Snapshot", f"{localize_role_name('Neutral Analyst')} 最新快照")}:**
|
||||
{neutral_snapshot}
|
||||
|
||||
---
|
||||
|
||||
Be decisive and ground every conclusion in specific evidence from the analysts.{get_language_instruction()}"""
|
||||
Be decisive and ground every conclusion in specific evidence from the analysts. {get_localized_final_proposal_instruction()}
|
||||
Append a feedback block in this exact format:
|
||||
{get_snapshot_template()}
|
||||
{get_snapshot_writing_instruction()}{get_language_instruction()}"""
|
||||
|
||||
response = llm.invoke(prompt)
|
||||
judge_snapshot = extract_feedback_snapshot(response.content)
|
||||
updated_brief = build_debate_brief(
|
||||
{
|
||||
"Aggressive Analyst": aggressive_snapshot,
|
||||
"Conservative Analyst": conservative_snapshot,
|
||||
"Neutral Analyst": neutral_snapshot,
|
||||
"Portfolio Manager": judge_snapshot,
|
||||
},
|
||||
latest_speaker="Portfolio Manager",
|
||||
)
|
||||
|
||||
new_risk_debate_state = {
|
||||
"judge_decision": response.content,
|
||||
|
|
@ -60,10 +92,14 @@ Be decisive and ground every conclusion in specific evidence from the analysts.{
|
|||
"aggressive_history": risk_debate_state["aggressive_history"],
|
||||
"conservative_history": risk_debate_state["conservative_history"],
|
||||
"neutral_history": risk_debate_state["neutral_history"],
|
||||
"debate_brief": updated_brief,
|
||||
"latest_speaker": "Judge",
|
||||
"current_aggressive_response": risk_debate_state["current_aggressive_response"],
|
||||
"current_conservative_response": risk_debate_state["current_conservative_response"],
|
||||
"current_neutral_response": risk_debate_state["current_neutral_response"],
|
||||
"aggressive_snapshot": aggressive_snapshot,
|
||||
"conservative_snapshot": conservative_snapshot,
|
||||
"neutral_snapshot": neutral_snapshot,
|
||||
"count": risk_debate_state["count"],
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,19 +1,32 @@
|
|||
import time
|
||||
import json
|
||||
|
||||
from tradingagents.agents.utils.agent_utils import build_instrument_context
|
||||
from tradingagents.agents.utils.agent_utils import (
|
||||
build_debate_brief,
|
||||
build_instrument_context,
|
||||
extract_feedback_snapshot,
|
||||
get_language_instruction,
|
||||
get_snapshot_template,
|
||||
get_snapshot_writing_instruction,
|
||||
localize_label,
|
||||
localize_rating_term,
|
||||
localize_role_name,
|
||||
truncate_for_prompt,
|
||||
)
|
||||
|
||||
|
||||
def create_research_manager(llm, memory):
|
||||
def research_manager_node(state) -> dict:
|
||||
instrument_context = build_instrument_context(state["company_of_interest"])
|
||||
history = state["investment_debate_state"].get("history", "")
|
||||
market_research_report = state["market_report"]
|
||||
sentiment_report = state["sentiment_report"]
|
||||
news_report = state["news_report"]
|
||||
fundamentals_report = state["fundamentals_report"]
|
||||
market_research_report = truncate_for_prompt(state["market_report"])
|
||||
sentiment_report = truncate_for_prompt(state["sentiment_report"])
|
||||
news_report = truncate_for_prompt(state["news_report"])
|
||||
fundamentals_report = truncate_for_prompt(state["fundamentals_report"])
|
||||
|
||||
investment_debate_state = state["investment_debate_state"]
|
||||
bull_snapshot = investment_debate_state.get("bull_snapshot", "")
|
||||
bear_snapshot = investment_debate_state.get("bear_snapshot", "")
|
||||
debate_brief = investment_debate_state.get("debate_brief", "")
|
||||
|
||||
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
|
||||
past_memories = memory.get_memories(curr_situation, n_matches=2)
|
||||
|
|
@ -22,9 +35,9 @@ def create_research_manager(llm, memory):
|
|||
for i, rec in enumerate(past_memories, 1):
|
||||
past_memory_str += rec["recommendation"] + "\n\n"
|
||||
|
||||
prompt = f"""As the portfolio manager and debate facilitator, your role is to critically evaluate this round of debate and make a definitive decision: align with the bear analyst, the bull analyst, or choose Hold only if it is strongly justified based on the arguments presented.
|
||||
prompt = f"""As the portfolio manager and debate facilitator, your role is to critically evaluate this round of debate and make a definitive decision: align with the {localize_role_name("Bear Analyst")}, the {localize_role_name("Bull Analyst")}, or choose {localize_rating_term("Hold")} only if it is strongly justified based on the arguments presented.
|
||||
|
||||
Summarize the key points from both sides concisely, focusing on the most compelling evidence or reasoning. Your recommendation—Buy, Sell, or Hold—must be clear and actionable. Avoid defaulting to Hold simply because both sides have valid points; commit to a stance grounded in the debate's strongest arguments.
|
||||
Summarize the key points from both sides concisely, focusing on the most compelling evidence or reasoning. Your recommendation—{localize_rating_term("Buy")}, {localize_rating_term("Sell")}, or {localize_rating_term("Hold")}—must be clear and actionable. Avoid defaulting to {localize_rating_term("Hold")} simply because both sides have valid points; commit to a stance grounded in the debate's strongest arguments.
|
||||
|
||||
Additionally, develop a detailed investment plan for the trader. This should include:
|
||||
|
||||
|
|
@ -32,16 +45,35 @@ Your Recommendation: A decisive stance supported by the most convincing argument
|
|||
Rationale: An explanation of why these arguments lead to your conclusion.
|
||||
Strategic Actions: Concrete steps for implementing the recommendation.
|
||||
Take into account your past mistakes on similar situations. Use these insights to refine your decision-making and ensure you are learning and improving. Present your analysis conversationally, as if speaking naturally, without special formatting.
|
||||
After your analysis, append a feedback block in this exact format:
|
||||
{get_snapshot_template()}
|
||||
{get_snapshot_writing_instruction()}
|
||||
|
||||
Here are your past reflections on mistakes:
|
||||
\"{past_memory_str}\"
|
||||
|
||||
{instrument_context}
|
||||
|
||||
Here is the debate:
|
||||
Debate History:
|
||||
{history}"""
|
||||
Here is the latest debate context:
|
||||
{localize_label("Rolling debate brief:", "滚动辩论摘要:")}
|
||||
{debate_brief}
|
||||
|
||||
{localize_label("Bull latest snapshot:", f"{localize_role_name('Bull Analyst')} 最新快照:")}
|
||||
{bull_snapshot}
|
||||
|
||||
{localize_label("Bear latest snapshot:", f"{localize_role_name('Bear Analyst')} 最新快照:")}
|
||||
{bear_snapshot}{get_language_instruction()}
|
||||
"""
|
||||
response = llm.invoke(prompt)
|
||||
judge_snapshot = extract_feedback_snapshot(response.content)
|
||||
updated_brief = build_debate_brief(
|
||||
{
|
||||
"Bull Analyst": bull_snapshot,
|
||||
"Bear Analyst": bear_snapshot,
|
||||
"Research Manager": judge_snapshot,
|
||||
},
|
||||
latest_speaker="Research Manager",
|
||||
)
|
||||
|
||||
new_investment_debate_state = {
|
||||
"judge_decision": response.content,
|
||||
|
|
@ -49,6 +81,10 @@ Debate History:
|
|||
"bear_history": investment_debate_state.get("bear_history", ""),
|
||||
"bull_history": investment_debate_state.get("bull_history", ""),
|
||||
"current_response": response.content,
|
||||
"bull_snapshot": bull_snapshot,
|
||||
"bear_snapshot": bear_snapshot,
|
||||
"debate_brief": updated_brief,
|
||||
"latest_speaker": "Research Manager",
|
||||
"count": investment_debate_state["count"],
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,19 +1,32 @@
|
|||
from langchain_core.messages import AIMessage
|
||||
import time
|
||||
import json
|
||||
from tradingagents.agents.utils.agent_utils import (
|
||||
build_debate_brief,
|
||||
extract_feedback_snapshot,
|
||||
get_language_instruction,
|
||||
get_localized_final_proposal_instruction,
|
||||
get_snapshot_template,
|
||||
get_snapshot_writing_instruction,
|
||||
localize_role_name,
|
||||
normalize_chinese_role_terms,
|
||||
strip_feedback_snapshot,
|
||||
truncate_for_prompt,
|
||||
)
|
||||
|
||||
|
||||
def create_bear_researcher(llm, memory):
|
||||
def bear_node(state) -> dict:
|
||||
investment_debate_state = state["investment_debate_state"]
|
||||
history = investment_debate_state.get("history", "")
|
||||
bear_history = investment_debate_state.get("bear_history", "")
|
||||
|
||||
current_response = investment_debate_state.get("current_response", "")
|
||||
market_research_report = state["market_report"]
|
||||
sentiment_report = state["sentiment_report"]
|
||||
news_report = state["news_report"]
|
||||
fundamentals_report = state["fundamentals_report"]
|
||||
bull_snapshot = investment_debate_state.get("bull_snapshot", "")
|
||||
bear_snapshot = investment_debate_state.get("bear_snapshot", "")
|
||||
debate_brief = investment_debate_state.get("debate_brief", "")
|
||||
market_research_report = truncate_for_prompt(state["market_report"])
|
||||
sentiment_report = truncate_for_prompt(state["sentiment_report"])
|
||||
news_report = truncate_for_prompt(state["news_report"])
|
||||
fundamentals_report = truncate_for_prompt(state["fundamentals_report"])
|
||||
|
||||
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
|
||||
past_memories = memory.get_memories(curr_situation, n_matches=2)
|
||||
|
|
@ -38,21 +51,42 @@ Market research report: {market_research_report}
|
|||
Social media sentiment report: {sentiment_report}
|
||||
Latest world affairs news: {news_report}
|
||||
Company fundamentals report: {fundamentals_report}
|
||||
Conversation history of the debate: {history}
|
||||
Last bull argument: {current_response}
|
||||
Rolling debate brief: {debate_brief}
|
||||
Your latest feedback snapshot: {bear_snapshot}
|
||||
Latest bull feedback snapshot: {bull_snapshot}
|
||||
Last bull argument body: {current_response}
|
||||
Reflections from similar situations and lessons learned: {past_memory_str}
|
||||
Use this information to deliver a compelling bear argument, refute the bull's claims, and engage in a dynamic debate that demonstrates the risks and weaknesses of investing in the stock. You must also address reflections and learn from lessons and mistakes you made in the past.
|
||||
When writing in Chinese, use the exact role names "{localize_role_name('Bear Analyst')}" and "{localize_role_name('Bull Analyst')}". Do not use variants like "熊派分析师" or "牛派分析师".
|
||||
Your main argument body must be written entirely in Chinese. {get_localized_final_proposal_instruction()}
|
||||
After your normal argument, append an exact block using this template:
|
||||
{get_snapshot_template()}
|
||||
{get_snapshot_writing_instruction()}{get_language_instruction()}
|
||||
"""
|
||||
|
||||
response = llm.invoke(prompt)
|
||||
|
||||
argument = f"Bear Analyst: {response.content}"
|
||||
raw_content = normalize_chinese_role_terms(response.content)
|
||||
argument_body = strip_feedback_snapshot(raw_content)
|
||||
argument = f"{localize_role_name('Bear Analyst')}: {argument_body}"
|
||||
new_bear_snapshot = extract_feedback_snapshot(raw_content)
|
||||
new_debate_brief = build_debate_brief(
|
||||
{
|
||||
"Bull Analyst": bull_snapshot,
|
||||
"Bear Analyst": new_bear_snapshot,
|
||||
},
|
||||
latest_speaker="Bear Analyst",
|
||||
)
|
||||
|
||||
new_investment_debate_state = {
|
||||
"history": history + "\n" + argument,
|
||||
"history": investment_debate_state.get("history", "") + "\n" + argument,
|
||||
"bear_history": bear_history + "\n" + argument,
|
||||
"bull_history": investment_debate_state.get("bull_history", ""),
|
||||
"current_response": argument,
|
||||
"bull_snapshot": bull_snapshot,
|
||||
"bear_snapshot": new_bear_snapshot,
|
||||
"debate_brief": new_debate_brief,
|
||||
"latest_speaker": "Bear Analyst",
|
||||
"judge_decision": investment_debate_state.get("judge_decision", ""),
|
||||
"count": investment_debate_state["count"] + 1,
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,19 +1,32 @@
|
|||
from langchain_core.messages import AIMessage
|
||||
import time
|
||||
import json
|
||||
from tradingagents.agents.utils.agent_utils import (
|
||||
build_debate_brief,
|
||||
extract_feedback_snapshot,
|
||||
get_language_instruction,
|
||||
get_localized_final_proposal_instruction,
|
||||
get_snapshot_template,
|
||||
get_snapshot_writing_instruction,
|
||||
localize_role_name,
|
||||
normalize_chinese_role_terms,
|
||||
strip_feedback_snapshot,
|
||||
truncate_for_prompt,
|
||||
)
|
||||
|
||||
|
||||
def create_bull_researcher(llm, memory):
|
||||
def bull_node(state) -> dict:
|
||||
investment_debate_state = state["investment_debate_state"]
|
||||
history = investment_debate_state.get("history", "")
|
||||
bull_history = investment_debate_state.get("bull_history", "")
|
||||
|
||||
current_response = investment_debate_state.get("current_response", "")
|
||||
market_research_report = state["market_report"]
|
||||
sentiment_report = state["sentiment_report"]
|
||||
news_report = state["news_report"]
|
||||
fundamentals_report = state["fundamentals_report"]
|
||||
bull_snapshot = investment_debate_state.get("bull_snapshot", "")
|
||||
bear_snapshot = investment_debate_state.get("bear_snapshot", "")
|
||||
debate_brief = investment_debate_state.get("debate_brief", "")
|
||||
market_research_report = truncate_for_prompt(state["market_report"])
|
||||
sentiment_report = truncate_for_prompt(state["sentiment_report"])
|
||||
news_report = truncate_for_prompt(state["news_report"])
|
||||
fundamentals_report = truncate_for_prompt(state["fundamentals_report"])
|
||||
|
||||
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
|
||||
past_memories = memory.get_memories(curr_situation, n_matches=2)
|
||||
|
|
@ -36,21 +49,42 @@ Market research report: {market_research_report}
|
|||
Social media sentiment report: {sentiment_report}
|
||||
Latest world affairs news: {news_report}
|
||||
Company fundamentals report: {fundamentals_report}
|
||||
Conversation history of the debate: {history}
|
||||
Last bear argument: {current_response}
|
||||
Rolling debate brief: {debate_brief}
|
||||
Your latest feedback snapshot: {bull_snapshot}
|
||||
Latest bear feedback snapshot: {bear_snapshot}
|
||||
Last bear argument body: {current_response}
|
||||
Reflections from similar situations and lessons learned: {past_memory_str}
|
||||
Use this information to deliver a compelling bull argument, refute the bear's concerns, and engage in a dynamic debate that demonstrates the strengths of the bull position. You must also address reflections and learn from lessons and mistakes you made in the past.
|
||||
When writing in Chinese, use the exact role names "{localize_role_name('Bull Analyst')}" and "{localize_role_name('Bear Analyst')}". Do not use variants like "牛派分析师" or "熊派分析师".
|
||||
Your main argument body must be written entirely in Chinese. {get_localized_final_proposal_instruction()}
|
||||
After your normal argument, append an exact block using this template:
|
||||
{get_snapshot_template()}
|
||||
{get_snapshot_writing_instruction()}{get_language_instruction()}
|
||||
"""
|
||||
|
||||
response = llm.invoke(prompt)
|
||||
|
||||
argument = f"Bull Analyst: {response.content}"
|
||||
raw_content = normalize_chinese_role_terms(response.content)
|
||||
argument_body = strip_feedback_snapshot(raw_content)
|
||||
argument = f"{localize_role_name('Bull Analyst')}: {argument_body}"
|
||||
new_bull_snapshot = extract_feedback_snapshot(raw_content)
|
||||
new_debate_brief = build_debate_brief(
|
||||
{
|
||||
"Bull Analyst": new_bull_snapshot,
|
||||
"Bear Analyst": bear_snapshot,
|
||||
},
|
||||
latest_speaker="Bull Analyst",
|
||||
)
|
||||
|
||||
new_investment_debate_state = {
|
||||
"history": history + "\n" + argument,
|
||||
"history": investment_debate_state.get("history", "") + "\n" + argument,
|
||||
"bull_history": bull_history + "\n" + argument,
|
||||
"bear_history": investment_debate_state.get("bear_history", ""),
|
||||
"current_response": argument,
|
||||
"bull_snapshot": new_bull_snapshot,
|
||||
"bear_snapshot": bear_snapshot,
|
||||
"debate_brief": new_debate_brief,
|
||||
"latest_speaker": "Bull Analyst",
|
||||
"judge_decision": investment_debate_state.get("judge_decision", ""),
|
||||
"count": investment_debate_state["count"] + 1,
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,22 +1,34 @@
|
|||
import time
|
||||
import json
|
||||
from tradingagents.agents.utils.agent_utils import (
|
||||
build_debate_brief,
|
||||
extract_feedback_snapshot,
|
||||
get_language_instruction,
|
||||
get_snapshot_template,
|
||||
get_snapshot_writing_instruction,
|
||||
localize_role_name,
|
||||
strip_feedback_snapshot,
|
||||
truncate_for_prompt,
|
||||
)
|
||||
|
||||
|
||||
def create_aggressive_debator(llm):
|
||||
def aggressive_node(state) -> dict:
|
||||
risk_debate_state = state["risk_debate_state"]
|
||||
history = risk_debate_state.get("history", "")
|
||||
aggressive_history = risk_debate_state.get("aggressive_history", "")
|
||||
|
||||
current_conservative_response = risk_debate_state.get("current_conservative_response", "")
|
||||
current_neutral_response = risk_debate_state.get("current_neutral_response", "")
|
||||
aggressive_snapshot = risk_debate_state.get("aggressive_snapshot", "")
|
||||
conservative_snapshot = risk_debate_state.get("conservative_snapshot", "")
|
||||
neutral_snapshot = risk_debate_state.get("neutral_snapshot", "")
|
||||
debate_brief = risk_debate_state.get("debate_brief", "")
|
||||
|
||||
market_research_report = state["market_report"]
|
||||
sentiment_report = state["sentiment_report"]
|
||||
news_report = state["news_report"]
|
||||
fundamentals_report = state["fundamentals_report"]
|
||||
market_research_report = truncate_for_prompt(state["market_report"])
|
||||
sentiment_report = truncate_for_prompt(state["sentiment_report"])
|
||||
news_report = truncate_for_prompt(state["news_report"])
|
||||
fundamentals_report = truncate_for_prompt(state["fundamentals_report"])
|
||||
|
||||
trader_decision = state["trader_investment_plan"]
|
||||
trader_decision = truncate_for_prompt(state["trader_investment_plan"])
|
||||
|
||||
prompt = f"""As the Aggressive Risk Analyst, your role is to actively champion high-reward, high-risk opportunities, emphasizing bold strategies and competitive advantages. When evaluating the trader's decision or plan, focus intently on the potential upside, growth potential, and innovative benefits—even when these come with elevated risk. Use the provided market data and sentiment analysis to strengthen your arguments and challenge the opposing views. Specifically, respond directly to each point made by the conservative and neutral analysts, countering with data-driven rebuttals and persuasive reasoning. Highlight where their caution might miss critical opportunities or where their assumptions may be overly conservative. Here is the trader's decision:
|
||||
|
||||
|
|
@ -28,16 +40,35 @@ Market Research Report: {market_research_report}
|
|||
Social Media Sentiment Report: {sentiment_report}
|
||||
Latest World Affairs Report: {news_report}
|
||||
Company Fundamentals Report: {fundamentals_report}
|
||||
Here is the current conversation history: {history} Here are the last arguments from the conservative analyst: {current_conservative_response} Here are the last arguments from the neutral analyst: {current_neutral_response}. If there are no responses from the other viewpoints yet, present your own argument based on the available data.
|
||||
Rolling risk debate brief: {debate_brief}
|
||||
Your latest feedback snapshot: {aggressive_snapshot}
|
||||
Latest conservative feedback snapshot: {conservative_snapshot}
|
||||
Latest neutral feedback snapshot: {neutral_snapshot}
|
||||
Last conservative argument body: {current_conservative_response}
|
||||
Last neutral argument body: {current_neutral_response}
|
||||
If there are no responses from the other viewpoints yet, present your own argument based on the available data.
|
||||
|
||||
Engage actively by addressing any specific concerns raised, refuting the weaknesses in their logic, and asserting the benefits of risk-taking to outpace market norms. Maintain a focus on debating and persuading, not just presenting data. Challenge each counterpoint to underscore why a high-risk approach is optimal. Output conversationally as if you are speaking without any special formatting."""
|
||||
Engage actively by addressing any specific concerns raised, refuting the weaknesses in their logic, and asserting the benefits of risk-taking to outpace market norms. Maintain a focus on debating and persuading, not just presenting data. Challenge each counterpoint to underscore why a high-risk approach is optimal. Output conversationally as if you are speaking without any special formatting.
|
||||
After your normal argument, append an exact block using this template:
|
||||
{get_snapshot_template()}
|
||||
{get_snapshot_writing_instruction()}{get_language_instruction()}"""
|
||||
|
||||
response = llm.invoke(prompt)
|
||||
|
||||
argument = f"Aggressive Analyst: {response.content}"
|
||||
raw_content = response.content
|
||||
argument_body = strip_feedback_snapshot(raw_content)
|
||||
argument = f"{localize_role_name('Aggressive Analyst')}: {argument_body}"
|
||||
new_aggressive_snapshot = extract_feedback_snapshot(raw_content)
|
||||
new_debate_brief = build_debate_brief(
|
||||
{
|
||||
"Aggressive Analyst": new_aggressive_snapshot,
|
||||
"Conservative Analyst": conservative_snapshot,
|
||||
"Neutral Analyst": neutral_snapshot,
|
||||
},
|
||||
latest_speaker="Aggressive Analyst",
|
||||
)
|
||||
|
||||
new_risk_debate_state = {
|
||||
"history": history + "\n" + argument,
|
||||
"history": risk_debate_state.get("history", "") + "\n" + argument,
|
||||
"aggressive_history": aggressive_history + "\n" + argument,
|
||||
"conservative_history": risk_debate_state.get("conservative_history", ""),
|
||||
"neutral_history": risk_debate_state.get("neutral_history", ""),
|
||||
|
|
@ -47,6 +78,11 @@ Engage actively by addressing any specific concerns raised, refuting the weaknes
|
|||
"current_neutral_response": risk_debate_state.get(
|
||||
"current_neutral_response", ""
|
||||
),
|
||||
"aggressive_snapshot": new_aggressive_snapshot,
|
||||
"conservative_snapshot": conservative_snapshot,
|
||||
"neutral_snapshot": neutral_snapshot,
|
||||
"debate_brief": new_debate_brief,
|
||||
"judge_decision": risk_debate_state.get("judge_decision", ""),
|
||||
"count": risk_debate_state["count"] + 1,
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,23 +1,35 @@
|
|||
from langchain_core.messages import AIMessage
|
||||
import time
|
||||
import json
|
||||
from tradingagents.agents.utils.agent_utils import (
|
||||
build_debate_brief,
|
||||
extract_feedback_snapshot,
|
||||
get_language_instruction,
|
||||
get_snapshot_template,
|
||||
get_snapshot_writing_instruction,
|
||||
localize_role_name,
|
||||
strip_feedback_snapshot,
|
||||
truncate_for_prompt,
|
||||
)
|
||||
|
||||
|
||||
def create_conservative_debator(llm):
|
||||
def conservative_node(state) -> dict:
|
||||
risk_debate_state = state["risk_debate_state"]
|
||||
history = risk_debate_state.get("history", "")
|
||||
conservative_history = risk_debate_state.get("conservative_history", "")
|
||||
|
||||
current_aggressive_response = risk_debate_state.get("current_aggressive_response", "")
|
||||
current_neutral_response = risk_debate_state.get("current_neutral_response", "")
|
||||
aggressive_snapshot = risk_debate_state.get("aggressive_snapshot", "")
|
||||
conservative_snapshot = risk_debate_state.get("conservative_snapshot", "")
|
||||
neutral_snapshot = risk_debate_state.get("neutral_snapshot", "")
|
||||
debate_brief = risk_debate_state.get("debate_brief", "")
|
||||
|
||||
market_research_report = state["market_report"]
|
||||
sentiment_report = state["sentiment_report"]
|
||||
news_report = state["news_report"]
|
||||
fundamentals_report = state["fundamentals_report"]
|
||||
market_research_report = truncate_for_prompt(state["market_report"])
|
||||
sentiment_report = truncate_for_prompt(state["sentiment_report"])
|
||||
news_report = truncate_for_prompt(state["news_report"])
|
||||
fundamentals_report = truncate_for_prompt(state["fundamentals_report"])
|
||||
|
||||
trader_decision = state["trader_investment_plan"]
|
||||
trader_decision = truncate_for_prompt(state["trader_investment_plan"])
|
||||
|
||||
prompt = f"""As the Conservative Risk Analyst, your primary objective is to protect assets, minimize volatility, and ensure steady, reliable growth. You prioritize stability, security, and risk mitigation, carefully assessing potential losses, economic downturns, and market volatility. When evaluating the trader's decision or plan, critically examine high-risk elements, pointing out where the decision may expose the firm to undue risk and where more cautious alternatives could secure long-term gains. Here is the trader's decision:
|
||||
|
||||
|
|
@ -29,16 +41,35 @@ Market Research Report: {market_research_report}
|
|||
Social Media Sentiment Report: {sentiment_report}
|
||||
Latest World Affairs Report: {news_report}
|
||||
Company Fundamentals Report: {fundamentals_report}
|
||||
Here is the current conversation history: {history} Here is the last response from the aggressive analyst: {current_aggressive_response} Here is the last response from the neutral analyst: {current_neutral_response}. If there are no responses from the other viewpoints yet, present your own argument based on the available data.
|
||||
Rolling risk debate brief: {debate_brief}
|
||||
Your latest feedback snapshot: {conservative_snapshot}
|
||||
Latest aggressive feedback snapshot: {aggressive_snapshot}
|
||||
Latest neutral feedback snapshot: {neutral_snapshot}
|
||||
Last aggressive argument body: {current_aggressive_response}
|
||||
Last neutral argument body: {current_neutral_response}
|
||||
If there are no responses from the other viewpoints yet, present your own argument based on the available data.
|
||||
|
||||
Engage by questioning their optimism and emphasizing the potential downsides they may have overlooked. Address each of their counterpoints to showcase why a conservative stance is ultimately the safest path for the firm's assets. Focus on debating and critiquing their arguments to demonstrate the strength of a low-risk strategy over their approaches. Output conversationally as if you are speaking without any special formatting."""
|
||||
Engage by questioning their optimism and emphasizing the potential downsides they may have overlooked. Address each of their counterpoints to showcase why a conservative stance is ultimately the safest path for the firm's assets. Focus on debating and critiquing their arguments to demonstrate the strength of a low-risk strategy over their approaches. Output conversationally as if you are speaking without any special formatting.
|
||||
After your normal argument, append an exact block using this template:
|
||||
{get_snapshot_template()}
|
||||
{get_snapshot_writing_instruction()}{get_language_instruction()}"""
|
||||
|
||||
response = llm.invoke(prompt)
|
||||
|
||||
argument = f"Conservative Analyst: {response.content}"
|
||||
raw_content = response.content
|
||||
argument_body = strip_feedback_snapshot(raw_content)
|
||||
argument = f"{localize_role_name('Conservative Analyst')}: {argument_body}"
|
||||
new_conservative_snapshot = extract_feedback_snapshot(raw_content)
|
||||
new_debate_brief = build_debate_brief(
|
||||
{
|
||||
"Aggressive Analyst": aggressive_snapshot,
|
||||
"Conservative Analyst": new_conservative_snapshot,
|
||||
"Neutral Analyst": neutral_snapshot,
|
||||
},
|
||||
latest_speaker="Conservative Analyst",
|
||||
)
|
||||
|
||||
new_risk_debate_state = {
|
||||
"history": history + "\n" + argument,
|
||||
"history": risk_debate_state.get("history", "") + "\n" + argument,
|
||||
"aggressive_history": risk_debate_state.get("aggressive_history", ""),
|
||||
"conservative_history": conservative_history + "\n" + argument,
|
||||
"neutral_history": risk_debate_state.get("neutral_history", ""),
|
||||
|
|
@ -50,6 +81,11 @@ Engage by questioning their optimism and emphasizing the potential downsides the
|
|||
"current_neutral_response": risk_debate_state.get(
|
||||
"current_neutral_response", ""
|
||||
),
|
||||
"aggressive_snapshot": aggressive_snapshot,
|
||||
"conservative_snapshot": new_conservative_snapshot,
|
||||
"neutral_snapshot": neutral_snapshot,
|
||||
"debate_brief": new_debate_brief,
|
||||
"judge_decision": risk_debate_state.get("judge_decision", ""),
|
||||
"count": risk_debate_state["count"] + 1,
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,22 +1,34 @@
|
|||
import time
|
||||
import json
|
||||
from tradingagents.agents.utils.agent_utils import (
|
||||
build_debate_brief,
|
||||
extract_feedback_snapshot,
|
||||
get_language_instruction,
|
||||
get_snapshot_template,
|
||||
get_snapshot_writing_instruction,
|
||||
localize_role_name,
|
||||
strip_feedback_snapshot,
|
||||
truncate_for_prompt,
|
||||
)
|
||||
|
||||
|
||||
def create_neutral_debator(llm):
|
||||
def neutral_node(state) -> dict:
|
||||
risk_debate_state = state["risk_debate_state"]
|
||||
history = risk_debate_state.get("history", "")
|
||||
neutral_history = risk_debate_state.get("neutral_history", "")
|
||||
|
||||
current_aggressive_response = risk_debate_state.get("current_aggressive_response", "")
|
||||
current_conservative_response = risk_debate_state.get("current_conservative_response", "")
|
||||
aggressive_snapshot = risk_debate_state.get("aggressive_snapshot", "")
|
||||
conservative_snapshot = risk_debate_state.get("conservative_snapshot", "")
|
||||
neutral_snapshot = risk_debate_state.get("neutral_snapshot", "")
|
||||
debate_brief = risk_debate_state.get("debate_brief", "")
|
||||
|
||||
market_research_report = state["market_report"]
|
||||
sentiment_report = state["sentiment_report"]
|
||||
news_report = state["news_report"]
|
||||
fundamentals_report = state["fundamentals_report"]
|
||||
market_research_report = truncate_for_prompt(state["market_report"])
|
||||
sentiment_report = truncate_for_prompt(state["sentiment_report"])
|
||||
news_report = truncate_for_prompt(state["news_report"])
|
||||
fundamentals_report = truncate_for_prompt(state["fundamentals_report"])
|
||||
|
||||
trader_decision = state["trader_investment_plan"]
|
||||
trader_decision = truncate_for_prompt(state["trader_investment_plan"])
|
||||
|
||||
prompt = f"""As the Neutral Risk Analyst, your role is to provide a balanced perspective, weighing both the potential benefits and risks of the trader's decision or plan. You prioritize a well-rounded approach, evaluating the upsides and downsides while factoring in broader market trends, potential economic shifts, and diversification strategies.Here is the trader's decision:
|
||||
|
||||
|
|
@ -28,16 +40,35 @@ Market Research Report: {market_research_report}
|
|||
Social Media Sentiment Report: {sentiment_report}
|
||||
Latest World Affairs Report: {news_report}
|
||||
Company Fundamentals Report: {fundamentals_report}
|
||||
Here is the current conversation history: {history} Here is the last response from the aggressive analyst: {current_aggressive_response} Here is the last response from the conservative analyst: {current_conservative_response}. If there are no responses from the other viewpoints yet, present your own argument based on the available data.
|
||||
Rolling risk debate brief: {debate_brief}
|
||||
Your latest feedback snapshot: {neutral_snapshot}
|
||||
Latest aggressive feedback snapshot: {aggressive_snapshot}
|
||||
Latest conservative feedback snapshot: {conservative_snapshot}
|
||||
Last aggressive argument body: {current_aggressive_response}
|
||||
Last conservative argument body: {current_conservative_response}
|
||||
If there are no responses from the other viewpoints yet, present your own argument based on the available data.
|
||||
|
||||
Engage actively by analyzing both sides critically, addressing weaknesses in the aggressive and conservative arguments to advocate for a more balanced approach. Challenge each of their points to illustrate why a moderate risk strategy might offer the best of both worlds, providing growth potential while safeguarding against extreme volatility. Focus on debating rather than simply presenting data, aiming to show that a balanced view can lead to the most reliable outcomes. Output conversationally as if you are speaking without any special formatting."""
|
||||
Engage actively by analyzing both sides critically, addressing weaknesses in the aggressive and conservative arguments to advocate for a more balanced approach. Challenge each of their points to illustrate why a moderate risk strategy might offer the best of both worlds, providing growth potential while safeguarding against extreme volatility. Focus on debating rather than simply presenting data, aiming to show that a balanced view can lead to the most reliable outcomes. Output conversationally as if you are speaking without any special formatting.
|
||||
After your normal argument, append an exact block using this template:
|
||||
{get_snapshot_template()}
|
||||
{get_snapshot_writing_instruction()}{get_language_instruction()}"""
|
||||
|
||||
response = llm.invoke(prompt)
|
||||
|
||||
argument = f"Neutral Analyst: {response.content}"
|
||||
raw_content = response.content
|
||||
argument_body = strip_feedback_snapshot(raw_content)
|
||||
argument = f"{localize_role_name('Neutral Analyst')}: {argument_body}"
|
||||
new_neutral_snapshot = extract_feedback_snapshot(raw_content)
|
||||
new_debate_brief = build_debate_brief(
|
||||
{
|
||||
"Aggressive Analyst": aggressive_snapshot,
|
||||
"Conservative Analyst": conservative_snapshot,
|
||||
"Neutral Analyst": new_neutral_snapshot,
|
||||
},
|
||||
latest_speaker="Neutral Analyst",
|
||||
)
|
||||
|
||||
new_risk_debate_state = {
|
||||
"history": history + "\n" + argument,
|
||||
"history": risk_debate_state.get("history", "") + "\n" + argument,
|
||||
"aggressive_history": risk_debate_state.get("aggressive_history", ""),
|
||||
"conservative_history": risk_debate_state.get("conservative_history", ""),
|
||||
"neutral_history": neutral_history + "\n" + argument,
|
||||
|
|
@ -47,6 +78,11 @@ Engage actively by analyzing both sides critically, addressing weaknesses in the
|
|||
),
|
||||
"current_conservative_response": risk_debate_state.get("current_conservative_response", ""),
|
||||
"current_neutral_response": argument,
|
||||
"aggressive_snapshot": aggressive_snapshot,
|
||||
"conservative_snapshot": conservative_snapshot,
|
||||
"neutral_snapshot": new_neutral_snapshot,
|
||||
"debate_brief": new_debate_brief,
|
||||
"judge_decision": risk_debate_state.get("judge_decision", ""),
|
||||
"count": risk_debate_state["count"] + 1,
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -2,18 +2,23 @@ import functools
|
|||
import time
|
||||
import json
|
||||
|
||||
from tradingagents.agents.utils.agent_utils import build_instrument_context
|
||||
from tradingagents.agents.utils.agent_utils import (
|
||||
build_instrument_context,
|
||||
get_language_instruction,
|
||||
get_localized_final_proposal_instruction,
|
||||
truncate_for_prompt,
|
||||
)
|
||||
|
||||
|
||||
def create_trader(llm, memory):
|
||||
def trader_node(state, name):
|
||||
company_name = state["company_of_interest"]
|
||||
instrument_context = build_instrument_context(company_name)
|
||||
investment_plan = state["investment_plan"]
|
||||
market_research_report = state["market_report"]
|
||||
sentiment_report = state["sentiment_report"]
|
||||
news_report = state["news_report"]
|
||||
fundamentals_report = state["fundamentals_report"]
|
||||
investment_plan = truncate_for_prompt(state["investment_plan"])
|
||||
market_research_report = truncate_for_prompt(state["market_report"])
|
||||
sentiment_report = truncate_for_prompt(state["sentiment_report"])
|
||||
news_report = truncate_for_prompt(state["news_report"])
|
||||
fundamentals_report = truncate_for_prompt(state["fundamentals_report"])
|
||||
|
||||
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
|
||||
past_memories = memory.get_memories(curr_situation, n_matches=2)
|
||||
|
|
@ -33,7 +38,7 @@ def create_trader(llm, memory):
|
|||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"""You are a trading agent analyzing market data to make investment decisions. Based on your analysis, provide a specific recommendation to buy, sell, or hold. End with a firm decision and always conclude your response with 'FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL**' to confirm your recommendation. Apply lessons from past decisions to strengthen your analysis. Here are reflections from similar situations you traded in and the lessons learned: {past_memory_str}""",
|
||||
"content": f"""You are a trading agent analyzing market data to make investment decisions. Based on your analysis, provide a specific recommendation to buy, sell, or hold. {get_localized_final_proposal_instruction()} Apply lessons from past decisions to strengthen your analysis. Here are reflections from similar situations you traded in and the lessons learned: {past_memory_str}{get_language_instruction()}""",
|
||||
},
|
||||
context,
|
||||
]
|
||||
|
|
|
|||
|
|
@ -17,6 +17,10 @@ class InvestDebateState(TypedDict):
|
|||
] # Bullish Conversation history
|
||||
history: Annotated[str, "Conversation history"] # Conversation history
|
||||
current_response: Annotated[str, "Latest response"] # Last response
|
||||
bull_snapshot: Annotated[str, "Latest bull feedback snapshot"]
|
||||
bear_snapshot: Annotated[str, "Latest bear feedback snapshot"]
|
||||
debate_brief: Annotated[str, "Compact latest debate brief"]
|
||||
latest_speaker: Annotated[str, "Speaker that updated the brief last"]
|
||||
judge_decision: Annotated[str, "Final judge decision"] # Last response
|
||||
count: Annotated[int, "Length of the current conversation"] # Conversation length
|
||||
|
||||
|
|
@ -33,6 +37,7 @@ class RiskDebateState(TypedDict):
|
|||
str, "Neutral Agent's Conversation history"
|
||||
] # Conversation history
|
||||
history: Annotated[str, "Conversation history"] # Conversation history
|
||||
debate_brief: Annotated[str, "Compact latest risk debate brief"]
|
||||
latest_speaker: Annotated[str, "Analyst that spoke last"]
|
||||
current_aggressive_response: Annotated[
|
||||
str, "Latest response by the aggressive analyst"
|
||||
|
|
@ -43,6 +48,9 @@ class RiskDebateState(TypedDict):
|
|||
current_neutral_response: Annotated[
|
||||
str, "Latest response by the neutral analyst"
|
||||
] # Last response
|
||||
aggressive_snapshot: Annotated[str, "Latest aggressive feedback snapshot"]
|
||||
conservative_snapshot: Annotated[str, "Latest conservative feedback snapshot"]
|
||||
neutral_snapshot: Annotated[str, "Latest neutral feedback snapshot"]
|
||||
judge_decision: Annotated[str, "Judge's decision"]
|
||||
count: Annotated[int, "Length of the current conversation"] # Conversation length
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from langchain_core.messages import HumanMessage, RemoveMessage
|
||||
import re
|
||||
|
||||
# Import tools from separate utility files
|
||||
from tradingagents.agents.utils.core_stock_tools import (
|
||||
|
|
@ -34,6 +35,450 @@ def get_language_instruction() -> str:
|
|||
return f" Write your entire response in {lang}."
|
||||
|
||||
|
||||
def get_output_language() -> str:
|
||||
from tradingagents.dataflows.config import get_config
|
||||
|
||||
return get_config().get("output_language", "English")
|
||||
|
||||
|
||||
def _is_chinese_output() -> bool:
|
||||
return get_output_language().strip().lower() in {"chinese", "中文", "zh", "zh-cn", "zh-hans"}
|
||||
|
||||
|
||||
def truncate_for_prompt(
|
||||
text: str,
|
||||
limit_key: str = "report_context_char_limit",
|
||||
default_limit: int = 16000,
|
||||
) -> str:
|
||||
"""Trim long context text to keep prompts within a stable token budget."""
|
||||
if not text:
|
||||
return ""
|
||||
|
||||
from tradingagents.dataflows.config import get_config
|
||||
|
||||
cfg = get_config()
|
||||
limit = int(cfg.get(limit_key, default_limit) or default_limit)
|
||||
if limit <= 0 or len(text) <= limit:
|
||||
return text
|
||||
|
||||
omitted = len(text) - limit
|
||||
return f"[Content trimmed, omitted {omitted} characters]\n{text[-limit:]}"
|
||||
|
||||
|
||||
def get_snapshot_template() -> str:
|
||||
if _is_chinese_output():
|
||||
return """反馈快照:
|
||||
- 当前观点:
|
||||
- 发生了什么变化:
|
||||
- 为什么变化:
|
||||
- 关键反驳:
|
||||
- 下一轮教训:"""
|
||||
|
||||
return """FEEDBACK SNAPSHOT:
|
||||
- Current thesis:
|
||||
- What changed:
|
||||
- Why it changed:
|
||||
- Key rebuttal:
|
||||
- Lesson for next round:"""
|
||||
|
||||
|
||||
def get_snapshot_writing_instruction() -> str:
|
||||
if _is_chinese_output():
|
||||
return (
|
||||
"反馈快照中的每一项都必须填写具体内容,直接总结本轮新增观点、证据、反驳和下一轮要验证的点。"
|
||||
"禁止填写“未明确说明”“暂无”“同上”“无变化”这类占位语。"
|
||||
)
|
||||
return (
|
||||
"Every field in the feedback snapshot must contain concrete content grounded in this round's argument, "
|
||||
"including what changed, why it changed, the key rebuttal, and what to verify next round. "
|
||||
"Do not use placeholders like 'not specified', 'none', 'same as above', or 'no change'."
|
||||
)
|
||||
|
||||
|
||||
def localize_label(english: str, chinese: str) -> str:
|
||||
return chinese if _is_chinese_output() else english
|
||||
|
||||
|
||||
def localize_role_name(role: str) -> str:
|
||||
mapping = {
|
||||
"Bull Analyst": "多头分析师",
|
||||
"Bear Analyst": "空头分析师",
|
||||
"Aggressive Analyst": "激进分析师",
|
||||
"Conservative Analyst": "保守分析师",
|
||||
"Neutral Analyst": "中性分析师",
|
||||
"Portfolio Manager": "投资组合经理",
|
||||
"Research Manager": "研究经理",
|
||||
"Trader": "交易员",
|
||||
"Judge": "裁决者",
|
||||
}
|
||||
return mapping.get(role, role) if _is_chinese_output() else role
|
||||
|
||||
|
||||
def normalize_chinese_role_terms(text: str) -> str:
|
||||
"""Normalize user-facing Chinese role terms to a single preferred wording."""
|
||||
if not text:
|
||||
return ""
|
||||
|
||||
replacements = {
|
||||
"熊派分析师": "空头分析师",
|
||||
"熊派投资者": "空头投资者",
|
||||
"熊观点": "空头观点",
|
||||
"熊派": "空头",
|
||||
"牛派分析师": "多头分析师",
|
||||
"牛派投资者": "多头投资者",
|
||||
"牛观点": "多头观点",
|
||||
"牛派": "多头",
|
||||
}
|
||||
normalized = text
|
||||
for src, dst in replacements.items():
|
||||
normalized = normalized.replace(src, dst)
|
||||
return normalized
|
||||
|
||||
|
||||
def localize_rating_term(term: str) -> str:
|
||||
mapping = {
|
||||
"Buy": "买入",
|
||||
"Overweight": "增持",
|
||||
"Hold": "持有",
|
||||
"Underweight": "减持",
|
||||
"Sell": "卖出",
|
||||
"BUY": "买入",
|
||||
"HOLD": "持有",
|
||||
"SELL": "卖出",
|
||||
}
|
||||
return mapping.get(term, term) if _is_chinese_output() else term
|
||||
|
||||
|
||||
def get_localized_rating_scale() -> str:
|
||||
if _is_chinese_output():
|
||||
return """**评级体系**(只能选择一个):
|
||||
- **买入**: 对开仓或加仓有很强信心
|
||||
- **增持**: 前景偏积极,建议逐步提高仓位
|
||||
- **持有**: 维持当前仓位,暂不动作
|
||||
- **减持**: 降低敞口,分批止盈或收缩仓位
|
||||
- **卖出**: 退出仓位或避免入场"""
|
||||
|
||||
return """**Rating Scale** (use exactly one):
|
||||
- **Buy**: Strong conviction to enter or add to position
|
||||
- **Overweight**: Favorable outlook, gradually increase exposure
|
||||
- **Hold**: Maintain current position, no action needed
|
||||
- **Underweight**: Reduce exposure, take partial profits
|
||||
- **Sell**: Exit position or avoid entry"""
|
||||
|
||||
|
||||
def get_localized_final_proposal_instruction() -> str:
|
||||
if _is_chinese_output():
|
||||
return (
|
||||
"End with a firm decision and present the user-facing conclusion as "
|
||||
"'最终交易建议: **买入/持有/卖出**'. For machine compatibility, you may optionally append a separate final line "
|
||||
"using the internal token 'FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL**' only when needed."
|
||||
)
|
||||
return (
|
||||
"End with a firm decision and always conclude your response with "
|
||||
"'FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL**' to confirm your recommendation."
|
||||
)
|
||||
|
||||
|
||||
def get_collaboration_stop_instruction() -> str:
|
||||
if _is_chinese_output():
|
||||
return (
|
||||
" If you or another assistant has already reached a final deliverable, prefer the user-facing line "
|
||||
"'最终交易建议: **买入/持有/卖出**'. Only when a machine-readable stop signal is necessary, append an extra final line "
|
||||
"with the internal token 'FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL**'."
|
||||
)
|
||||
return (
|
||||
" If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable,"
|
||||
" prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop."
|
||||
)
|
||||
|
||||
|
||||
SNAPSHOT_MARKERS = ("FEEDBACK SNAPSHOT:", "反馈快照:")
|
||||
SNAPSHOT_TEMPLATE = get_snapshot_template()
|
||||
|
||||
|
||||
def _condense_excerpt(text: str, limit: int = 120) -> str:
|
||||
compact = re.sub(r"\s+", " ", text).strip()
|
||||
if len(compact) <= limit:
|
||||
return compact
|
||||
return compact[: limit - 3].rstrip() + "..."
|
||||
|
||||
|
||||
def _get_rating_patterns() -> list[tuple[str, tuple[str, ...]]]:
|
||||
return [
|
||||
("买入", ("最终交易建议: **买入**", "评级: 买入", "建议买入", "维持买入", "转为买入", "买入")),
|
||||
("增持", ("最终交易建议: **增持**", "评级: 增持", "建议增持", "维持增持", "转为增持", "增持")),
|
||||
("持有", ("最终交易建议: **持有**", "评级: 持有", "建议持有", "维持持有", "转为持有", "持有")),
|
||||
("减持", ("最终交易建议: **减持**", "评级: 减持", "建议减持", "维持减持", "转为减持", "减持")),
|
||||
("卖出", ("最终交易建议: **卖出**", "评级: 卖出", "建议卖出", "维持卖出", "转为卖出", "卖出")),
|
||||
]
|
||||
|
||||
|
||||
def _detect_chinese_rating(text: str) -> str:
|
||||
content = normalize_chinese_role_terms(text or "")
|
||||
if not content.strip():
|
||||
return "持有"
|
||||
|
||||
for rating, patterns in _get_rating_patterns():
|
||||
for pattern in patterns:
|
||||
if pattern in content:
|
||||
return rating
|
||||
|
||||
heuristic_patterns = [
|
||||
("卖出", ("清仓", "退出", "避免入场", "止损离场", "果断卖出")),
|
||||
("减持", ("降低仓位", "分批止盈", "降低敞口", "部分卖出", "先减仓")),
|
||||
("增持", ("加仓", "提高仓位", "逢低布局", "继续增持", "扩大仓位")),
|
||||
("买入", ("买入机会", "积极布局", "值得买入", "坚定看多", "继续买入")),
|
||||
("持有", ("继续观察", "暂不动作", "维持仓位", "等待确认", "持仓观望")),
|
||||
]
|
||||
for rating, patterns in heuristic_patterns:
|
||||
if any(pattern in content for pattern in patterns):
|
||||
return rating
|
||||
|
||||
return "持有"
|
||||
|
||||
|
||||
def _detect_english_rating(text: str) -> str:
|
||||
content = (text or "").lower()
|
||||
if not content.strip():
|
||||
return "HOLD"
|
||||
|
||||
explicit_patterns = [
|
||||
("SELL", ("final transaction proposal: **sell**", "rating: sell", "recommend sell")),
|
||||
("UNDERWEIGHT", ("final transaction proposal: **underweight**", "rating: underweight", "recommend underweight")),
|
||||
("HOLD", ("final transaction proposal: **hold**", "rating: hold", "recommend hold")),
|
||||
("OVERWEIGHT", ("final transaction proposal: **overweight**", "rating: overweight", "recommend overweight")),
|
||||
("BUY", ("final transaction proposal: **buy**", "rating: buy", "recommend buy")),
|
||||
]
|
||||
for rating, patterns in explicit_patterns:
|
||||
if any(pattern in content for pattern in patterns):
|
||||
return rating
|
||||
|
||||
heuristic_patterns = [
|
||||
("SELL", ("exit position", "avoid entry", "sell the stock", "close the position")),
|
||||
("UNDERWEIGHT", ("reduce exposure", "trim the position", "take partial profits")),
|
||||
("OVERWEIGHT", ("add to position", "increase exposure", "build the position")),
|
||||
("BUY", ("buy the stock", "enter the position", "strong upside")),
|
||||
("HOLD", ("maintain the position", "wait for confirmation", "stay on hold")),
|
||||
]
|
||||
for rating, patterns in heuristic_patterns:
|
||||
if any(pattern in content for pattern in patterns):
|
||||
return rating
|
||||
|
||||
return "HOLD"
|
||||
|
||||
|
||||
def _extract_sentences(text: str) -> list[str]:
|
||||
compact = re.sub(r"\s+", " ", text).strip()
|
||||
if not compact:
|
||||
return []
|
||||
parts = re.split(r"(?<=[。!?!?\.])\s+|\n+", compact)
|
||||
return [part.strip() for part in parts if part.strip()]
|
||||
|
||||
|
||||
def _contains_placeholder_snapshot(snapshot: str) -> bool:
|
||||
placeholders = (
|
||||
"未明确说明",
|
||||
"暂无",
|
||||
"同上",
|
||||
"无变化",
|
||||
"Not explicitly stated",
|
||||
"None yet",
|
||||
"same as above",
|
||||
"no change",
|
||||
)
|
||||
return any(token in snapshot for token in placeholders)
|
||||
|
||||
|
||||
def _snapshot_field_labels() -> list[str]:
|
||||
if _is_chinese_output():
|
||||
return ["当前观点", "发生了什么变化", "为什么变化", "关键反驳", "下一轮教训"]
|
||||
return [
|
||||
"Current thesis",
|
||||
"What changed",
|
||||
"Why it changed",
|
||||
"Key rebuttal",
|
||||
"Lesson for next round",
|
||||
]
|
||||
|
||||
|
||||
def _snapshot_field_aliases() -> dict[str, tuple[str, ...]]:
|
||||
return {
|
||||
"current_thesis": ("当前观点", "Current thesis"),
|
||||
"what_changed": ("发生了什么变化", "What changed"),
|
||||
"why_changed": ("为什么变化", "Why it changed"),
|
||||
"key_rebuttal": ("关键反驳", "Key rebuttal"),
|
||||
"lesson_next_round": ("下一轮教训", "Lesson for next round"),
|
||||
}
|
||||
|
||||
|
||||
def _parse_snapshot_fields(snapshot: str) -> dict[str, str]:
|
||||
fields = {key: "" for key in _snapshot_field_aliases()}
|
||||
if not snapshot:
|
||||
return fields
|
||||
|
||||
for line in snapshot.splitlines():
|
||||
stripped = line.strip()
|
||||
matched = False
|
||||
for field_key, aliases in _snapshot_field_aliases().items():
|
||||
for label in aliases:
|
||||
prefix = f"- {label}:"
|
||||
if stripped.startswith(prefix):
|
||||
fields[field_key] = stripped[len(prefix):].strip()
|
||||
matched = True
|
||||
break
|
||||
if matched:
|
||||
break
|
||||
return fields
|
||||
|
||||
|
||||
def _snapshot_has_missing_fields(snapshot: str) -> bool:
|
||||
fields = _parse_snapshot_fields(snapshot)
|
||||
for value in fields.values():
|
||||
normalized = value.strip()
|
||||
if not normalized:
|
||||
return True
|
||||
if normalized in {"。", ".", "...", "……", "-", "--"}:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _merge_snapshot_with_inferred(snapshot: str, inferred_snapshot: str) -> str:
|
||||
explicit = _parse_snapshot_fields(snapshot)
|
||||
inferred = _parse_snapshot_fields(inferred_snapshot)
|
||||
|
||||
lines = [SNAPSHOT_MARKERS[1] if _is_chinese_output() else SNAPSHOT_MARKERS[0]]
|
||||
display_labels = _snapshot_field_labels()
|
||||
for field_key, label in zip(_snapshot_field_aliases().keys(), display_labels):
|
||||
value = explicit.get(field_key, "").strip()
|
||||
if not value or value in {"。", ".", "...", "……", "-", "--"}:
|
||||
value = inferred.get(field_key, "").strip()
|
||||
lines.append(f"- {label}: {value}")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def is_feedback_snapshot_inferred(text: str) -> bool:
|
||||
"""Return True when the displayed snapshot will be inferred from the body."""
|
||||
if not text or not text.strip():
|
||||
return True
|
||||
|
||||
for marker in SNAPSHOT_MARKERS:
|
||||
idx = text.rfind(marker)
|
||||
if idx != -1:
|
||||
snapshot = text[idx:].strip()
|
||||
return _contains_placeholder_snapshot(snapshot) or _snapshot_has_missing_fields(snapshot)
|
||||
return True
|
||||
|
||||
|
||||
def _infer_feedback_snapshot_from_body(text: str) -> str:
|
||||
body = normalize_chinese_role_terms(strip_feedback_snapshot(text))
|
||||
sentences = _extract_sentences(body)
|
||||
first = _condense_excerpt(sentences[0], 120) if sentences else _condense_excerpt(body, 120)
|
||||
second = _condense_excerpt(sentences[1], 120) if len(sentences) > 1 else first
|
||||
third = _condense_excerpt(sentences[2], 120) if len(sentences) > 2 else second
|
||||
|
||||
if _is_chinese_output():
|
||||
rating = _detect_chinese_rating(text)
|
||||
current = rating
|
||||
changed = (
|
||||
second if len(sentences) > 1 else f"本轮围绕“{rating}”补充了更明确的交易依据、风险边界和执行条件。"
|
||||
)
|
||||
why = (
|
||||
third if len(sentences) > 2 else f"变化来自本轮新增的数据证据、市场信号和对手论点带来的判断修正。"
|
||||
)
|
||||
rebuttal_source = next(
|
||||
(s for s in sentences if any(word in s for word in ("但", "然而", "不过", "反驳", "忽略", "高估"))),
|
||||
second or f"本轮核心反驳集中在对手忽略了影响“{rating}”判断的关键数据或风险约束。",
|
||||
)
|
||||
lesson_source = next(
|
||||
(s for s in sentences if any(word in s for word in ("需要", "继续", "监控", "跟踪", "等待", "验证", "警惕"))),
|
||||
f"下一轮需要继续验证支持“{rating}”结论的关键数据、风险触发条件和执行时点。",
|
||||
)
|
||||
return (
|
||||
"反馈快照:\n"
|
||||
f"- 当前观点: {current}\n"
|
||||
f"- 发生了什么变化: {changed}\n"
|
||||
f"- 为什么变化: {why}\n"
|
||||
f"- 关键反驳: {rebuttal_source}\n"
|
||||
f"- 下一轮教训: {lesson_source}"
|
||||
)
|
||||
|
||||
rating = _detect_english_rating(text)
|
||||
current = rating
|
||||
changed = second if len(sentences) > 1 else f"This round added clearer trading evidence, risk boundaries, and execution conditions behind the {rating} call."
|
||||
why = third if len(sentences) > 2 else "The update came from new evidence, market signals, and adjustments prompted by the opponent's latest claims."
|
||||
rebuttal_source = next(
|
||||
(s for s in sentences if any(word in s.lower() for word in ("but", "however", "rebut", "weakness", "risk", "miss"))),
|
||||
second or f"The key rebuttal is that the opposing case missed the main evidence or risk controls behind the {rating} stance.",
|
||||
)
|
||||
lesson_source = next(
|
||||
(s for s in sentences if any(word in s.lower() for word in ("monitor", "watch", "verify", "track", "wait", "risk"))),
|
||||
f"Next round should verify the core data assumptions, risk triggers, and timing conditions behind the {rating} stance.",
|
||||
)
|
||||
return (
|
||||
"FEEDBACK SNAPSHOT:\n"
|
||||
f"- Current thesis: {current}\n"
|
||||
f"- What changed: {changed}\n"
|
||||
f"- Why it changed: {why}\n"
|
||||
f"- Key rebuttal: {rebuttal_source}\n"
|
||||
f"- Lesson for next round: {lesson_source}"
|
||||
)
|
||||
|
||||
|
||||
def extract_feedback_snapshot(text: str) -> str:
|
||||
"""Extract the structured feedback snapshot block from an agent response."""
|
||||
if not text:
|
||||
if _is_chinese_output():
|
||||
return "反馈快照:\n- 当前观点: 暂无。\n- 发生了什么变化: 暂无。\n- 为什么变化: 暂无。\n- 关键反驳: 暂无。\n- 下一轮教训: 暂无。"
|
||||
return "FEEDBACK SNAPSHOT:\n- Current thesis: None yet.\n- What changed: None yet.\n- Why it changed: None yet.\n- Key rebuttal: None yet.\n- Lesson for next round: None yet."
|
||||
|
||||
for marker in SNAPSHOT_MARKERS:
|
||||
idx = text.rfind(marker)
|
||||
if idx != -1:
|
||||
snapshot = text[idx:].strip()
|
||||
if _contains_placeholder_snapshot(snapshot):
|
||||
return _infer_feedback_snapshot_from_body(text)
|
||||
normalized_snapshot = normalize_chinese_role_terms(snapshot)
|
||||
if _snapshot_has_missing_fields(normalized_snapshot):
|
||||
inferred_snapshot = _infer_feedback_snapshot_from_body(text)
|
||||
return _merge_snapshot_with_inferred(normalized_snapshot, inferred_snapshot)
|
||||
return normalized_snapshot
|
||||
|
||||
return _infer_feedback_snapshot_from_body(text)
|
||||
|
||||
|
||||
def strip_feedback_snapshot(text: str) -> str:
|
||||
"""Remove the feedback snapshot block from the visible argument body."""
|
||||
if not text:
|
||||
return ""
|
||||
best_idx = -1
|
||||
for marker in SNAPSHOT_MARKERS:
|
||||
idx = text.rfind(marker)
|
||||
if idx > best_idx:
|
||||
best_idx = idx
|
||||
if best_idx == -1:
|
||||
return text.strip()
|
||||
return text[:best_idx].strip()
|
||||
|
||||
|
||||
def build_debate_brief(snapshots: dict[str, str], latest_speaker: str = "") -> str:
|
||||
"""Build a compact cross-agent brief from the latest structured snapshots."""
|
||||
sections = []
|
||||
if latest_speaker:
|
||||
if _is_chinese_output():
|
||||
sections.append(f"最新更新来自: {localize_role_name(latest_speaker)}")
|
||||
else:
|
||||
sections.append(f"Latest update came from: {latest_speaker}")
|
||||
|
||||
for speaker, snapshot in snapshots.items():
|
||||
if snapshot:
|
||||
if _is_chinese_output():
|
||||
sections.append(f"{localize_role_name(speaker)} 最新快照:\n{snapshot}")
|
||||
else:
|
||||
sections.append(f"{speaker} latest snapshot:\n{snapshot}")
|
||||
|
||||
return "\n\n".join(sections).strip()
|
||||
|
||||
|
||||
def build_instrument_context(ticker: str) -> str:
|
||||
"""Describe the exact instrument so agents preserve exchange-qualified tickers."""
|
||||
return (
|
||||
|
|
|
|||
|
|
@ -23,6 +23,8 @@ class FinancialSituationMemory:
|
|||
self.documents: List[str] = []
|
||||
self.recommendations: List[str] = []
|
||||
self.bm25 = None
|
||||
cfg = config or {}
|
||||
self.min_similarity = float(cfg.get("memory_min_similarity", 0.15))
|
||||
|
||||
def _tokenize(self, text: str) -> List[str]:
|
||||
"""Tokenize text for BM25 indexing.
|
||||
|
|
@ -78,15 +80,19 @@ class FinancialSituationMemory:
|
|||
|
||||
# Build results
|
||||
results = []
|
||||
max_score = max(scores) if max(scores) > 0 else 1 # Normalize scores
|
||||
query_token_set = set(query_tokens)
|
||||
|
||||
for idx in top_indices:
|
||||
# Normalize score to 0-1 range for consistency
|
||||
normalized_score = scores[idx] / max_score if max_score > 0 else 0
|
||||
doc_token_set = set(self._tokenize(self.documents[idx]))
|
||||
overlap_count = len(query_token_set & doc_token_set)
|
||||
overlap_ratio = overlap_count / max(1, len(query_token_set))
|
||||
|
||||
if overlap_ratio < self.min_similarity:
|
||||
continue
|
||||
results.append({
|
||||
"matched_situation": self.documents[idx],
|
||||
"recommendation": self.recommendations[idx],
|
||||
"similarity_score": normalized_score,
|
||||
"similarity_score": overlap_ratio,
|
||||
})
|
||||
|
||||
return results
|
||||
|
|
|
|||
|
|
@ -5,13 +5,15 @@ import json
|
|||
from datetime import datetime
|
||||
from io import StringIO
|
||||
|
||||
from .exceptions import DataVendorUnavailable
|
||||
|
||||
API_BASE_URL = "https://www.alphavantage.co/query"
|
||||
|
||||
def get_api_key() -> str:
|
||||
"""Retrieve the API key for Alpha Vantage from environment variables."""
|
||||
api_key = os.getenv("ALPHA_VANTAGE_API_KEY")
|
||||
if not api_key:
|
||||
raise ValueError("ALPHA_VANTAGE_API_KEY environment variable is not set.")
|
||||
raise DataVendorUnavailable("ALPHA_VANTAGE_API_KEY environment variable is not set.")
|
||||
return api_key
|
||||
|
||||
def format_datetime_for_api(date_input) -> str:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,130 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import requests
|
||||
|
||||
from .exceptions import DataVendorUnavailable
|
||||
|
||||
|
||||
BRAVE_SEARCH_ENDPOINT = "https://api.search.brave.com/res/v1/web/search"
|
||||
REQUEST_TIMEOUT = 12
|
||||
|
||||
|
||||
def _parse_date(date_str: str) -> datetime:
|
||||
return datetime.strptime(date_str, "%Y-%m-%d")
|
||||
|
||||
|
||||
def _get_api_key() -> str:
|
||||
api_key = os.getenv("BRAVE_SEARCH_API_KEY") or os.getenv("BRAVE_API_KEY")
|
||||
if not api_key:
|
||||
raise DataVendorUnavailable(
|
||||
"BRAVE_SEARCH_API_KEY is not set. Configure it or use fallback vendor."
|
||||
)
|
||||
return api_key
|
||||
|
||||
|
||||
def _freshness_from_days(days: int) -> str:
|
||||
if days <= 1:
|
||||
return "pd"
|
||||
if days <= 7:
|
||||
return "pw"
|
||||
if days <= 31:
|
||||
return "pm"
|
||||
return "py"
|
||||
|
||||
|
||||
def _search_brave(query: str, count: int, freshness: str) -> list[dict]:
|
||||
headers = {
|
||||
"Accept": "application/json",
|
||||
"X-Subscription-Token": _get_api_key(),
|
||||
}
|
||||
params = {
|
||||
"q": query,
|
||||
"count": max(1, min(count, 20)),
|
||||
"freshness": freshness,
|
||||
"search_lang": "en",
|
||||
"country": "US",
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.get(
|
||||
BRAVE_SEARCH_ENDPOINT,
|
||||
headers=headers,
|
||||
params=params,
|
||||
timeout=REQUEST_TIMEOUT,
|
||||
)
|
||||
response.raise_for_status()
|
||||
except requests.RequestException as exc:
|
||||
raise DataVendorUnavailable(f"Brave Search request failed: {exc}") from exc
|
||||
|
||||
payload = response.json()
|
||||
return payload.get("web", {}).get("results", [])
|
||||
|
||||
|
||||
def _format_news_block(title: str, start_date: str, end_date: str, results: list[dict]) -> str:
|
||||
if not results:
|
||||
return f"No news found for {title} between {start_date} and {end_date}."
|
||||
|
||||
blocks = []
|
||||
for item in results:
|
||||
headline = item.get("title") or "No title"
|
||||
description = item.get("description") or ""
|
||||
url = item.get("url") or ""
|
||||
source = item.get("profile", {}).get("name") or "Unknown"
|
||||
age = item.get("age") or ""
|
||||
|
||||
text = f"### {headline} (source: {source})"
|
||||
if age:
|
||||
text += f"\nPublished: {age}"
|
||||
if description:
|
||||
text += f"\n{description}"
|
||||
if url:
|
||||
text += f"\nLink: {url}"
|
||||
blocks.append(text)
|
||||
|
||||
return f"## {title}, from {start_date} to {end_date}:\n\n" + "\n\n".join(blocks)
|
||||
|
||||
|
||||
def get_news(ticker: str, start_date: str, end_date: str) -> str:
|
||||
start_dt = _parse_date(start_date)
|
||||
end_dt = _parse_date(end_date)
|
||||
day_window = max(1, (end_dt - start_dt).days)
|
||||
freshness = _freshness_from_days(day_window)
|
||||
|
||||
query = f"{ticker} stock news earnings guidance sentiment"
|
||||
results = _search_brave(query=query, count=20, freshness=freshness)
|
||||
return _format_news_block(f"{ticker} News", start_date, end_date, results)
|
||||
|
||||
|
||||
def get_global_news(curr_date: str, look_back_days: int = 7, limit: int = 10) -> str:
|
||||
end_dt = _parse_date(curr_date)
|
||||
start_dt = end_dt - timedelta(days=look_back_days)
|
||||
start_date = start_dt.strftime("%Y-%m-%d")
|
||||
|
||||
freshness = _freshness_from_days(max(1, look_back_days))
|
||||
queries = [
|
||||
"US stock market macro news",
|
||||
"Federal Reserve rates inflation outlook",
|
||||
"global markets risk sentiment",
|
||||
"equity market volatility earnings outlook",
|
||||
]
|
||||
|
||||
merged = []
|
||||
seen_urls = set()
|
||||
per_query = max(3, min(limit, 8))
|
||||
|
||||
for query in queries:
|
||||
for item in _search_brave(query=query, count=per_query, freshness=freshness):
|
||||
url = item.get("url")
|
||||
if not url or url in seen_urls:
|
||||
continue
|
||||
seen_urls.add(url)
|
||||
merged.append(item)
|
||||
if len(merged) >= limit:
|
||||
break
|
||||
if len(merged) >= limit:
|
||||
break
|
||||
|
||||
return _format_news_block("Global Market News", start_date, curr_date, merged)
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
class DataVendorUnavailable(Exception):
|
||||
"""Raised when a vendor cannot serve a request and fallback should be attempted."""
|
||||
|
|
@ -11,18 +11,18 @@ from .y_finance import (
|
|||
get_insider_transactions as get_yfinance_insider_transactions,
|
||||
)
|
||||
from .yfinance_news import get_news_yfinance, get_global_news_yfinance
|
||||
from .alpha_vantage import (
|
||||
get_stock as get_alpha_vantage_stock,
|
||||
get_indicator as get_alpha_vantage_indicator,
|
||||
get_fundamentals as get_alpha_vantage_fundamentals,
|
||||
get_balance_sheet as get_alpha_vantage_balance_sheet,
|
||||
get_cashflow as get_alpha_vantage_cashflow,
|
||||
get_income_statement as get_alpha_vantage_income_statement,
|
||||
get_insider_transactions as get_alpha_vantage_insider_transactions,
|
||||
get_news as get_alpha_vantage_news,
|
||||
get_global_news as get_alpha_vantage_global_news,
|
||||
from .brave_news import get_news as get_brave_news, get_global_news as get_brave_global_news
|
||||
from .opencli_news import get_news as get_opencli_news, get_global_news as get_opencli_global_news
|
||||
from .tushare import (
|
||||
get_stock as get_tushare_stock,
|
||||
get_indicator as get_tushare_indicator,
|
||||
get_fundamentals as get_tushare_fundamentals,
|
||||
get_balance_sheet as get_tushare_balance_sheet,
|
||||
get_cashflow as get_tushare_cashflow,
|
||||
get_income_statement as get_tushare_income_statement,
|
||||
get_insider_transactions as get_tushare_insider_transactions,
|
||||
)
|
||||
from .alpha_vantage_common import AlphaVantageRateLimitError
|
||||
from .exceptions import DataVendorUnavailable
|
||||
|
||||
# Configuration and routing logic
|
||||
from .config import get_config
|
||||
|
|
@ -62,49 +62,53 @@ TOOLS_CATEGORIES = {
|
|||
|
||||
VENDOR_LIST = [
|
||||
"yfinance",
|
||||
"alpha_vantage",
|
||||
"tushare",
|
||||
"brave",
|
||||
"opencli",
|
||||
]
|
||||
|
||||
# Mapping of methods to their vendor-specific implementations
|
||||
VENDOR_METHODS = {
|
||||
# core_stock_apis
|
||||
"get_stock_data": {
|
||||
"alpha_vantage": get_alpha_vantage_stock,
|
||||
"tushare": get_tushare_stock,
|
||||
"yfinance": get_YFin_data_online,
|
||||
},
|
||||
# technical_indicators
|
||||
"get_indicators": {
|
||||
"alpha_vantage": get_alpha_vantage_indicator,
|
||||
"tushare": get_tushare_indicator,
|
||||
"yfinance": get_stock_stats_indicators_window,
|
||||
},
|
||||
# fundamental_data
|
||||
"get_fundamentals": {
|
||||
"alpha_vantage": get_alpha_vantage_fundamentals,
|
||||
"tushare": get_tushare_fundamentals,
|
||||
"yfinance": get_yfinance_fundamentals,
|
||||
},
|
||||
"get_balance_sheet": {
|
||||
"alpha_vantage": get_alpha_vantage_balance_sheet,
|
||||
"tushare": get_tushare_balance_sheet,
|
||||
"yfinance": get_yfinance_balance_sheet,
|
||||
},
|
||||
"get_cashflow": {
|
||||
"alpha_vantage": get_alpha_vantage_cashflow,
|
||||
"tushare": get_tushare_cashflow,
|
||||
"yfinance": get_yfinance_cashflow,
|
||||
},
|
||||
"get_income_statement": {
|
||||
"alpha_vantage": get_alpha_vantage_income_statement,
|
||||
"tushare": get_tushare_income_statement,
|
||||
"yfinance": get_yfinance_income_statement,
|
||||
},
|
||||
# news_data
|
||||
"get_news": {
|
||||
"alpha_vantage": get_alpha_vantage_news,
|
||||
"opencli": get_opencli_news,
|
||||
"brave": get_brave_news,
|
||||
"yfinance": get_news_yfinance,
|
||||
},
|
||||
"get_global_news": {
|
||||
"opencli": get_opencli_global_news,
|
||||
"brave": get_brave_global_news,
|
||||
"yfinance": get_global_news_yfinance,
|
||||
"alpha_vantage": get_alpha_vantage_global_news,
|
||||
},
|
||||
"get_insider_transactions": {
|
||||
"alpha_vantage": get_alpha_vantage_insider_transactions,
|
||||
"tushare": get_tushare_insider_transactions,
|
||||
"yfinance": get_yfinance_insider_transactions,
|
||||
},
|
||||
}
|
||||
|
|
@ -136,6 +140,7 @@ def route_to_vendor(method: str, *args, **kwargs):
|
|||
category = get_category_for_method(method)
|
||||
vendor_config = get_vendor(category, method)
|
||||
primary_vendors = [v.strip() for v in vendor_config.split(',')]
|
||||
last_error = None
|
||||
|
||||
if method not in VENDOR_METHODS:
|
||||
raise ValueError(f"Method '{method}' not supported")
|
||||
|
|
@ -156,7 +161,11 @@ def route_to_vendor(method: str, *args, **kwargs):
|
|||
|
||||
try:
|
||||
return impl_func(*args, **kwargs)
|
||||
except AlphaVantageRateLimitError:
|
||||
continue # Only rate limits trigger fallback
|
||||
except DataVendorUnavailable as exc:
|
||||
last_error = exc
|
||||
continue # Try next vendor in fallback chain
|
||||
|
||||
raise RuntimeError(f"No available vendor for '{method}'")
|
||||
if last_error is not None:
|
||||
raise RuntimeError(str(last_error)) from last_error
|
||||
|
||||
raise RuntimeError(f"No available vendor for '{method}'")
|
||||
|
|
|
|||
|
|
@ -0,0 +1,438 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import shutil
|
||||
import subprocess
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from .exceptions import DataVendorUnavailable
|
||||
|
||||
|
||||
def _parse_date(date_str: str) -> datetime:
|
||||
return datetime.strptime(date_str, "%Y-%m-%d")
|
||||
|
||||
|
||||
def _ensure_opencli() -> str:
|
||||
binary = shutil.which("opencli-rs")
|
||||
if not binary:
|
||||
raise DataVendorUnavailable("opencli-rs is not installed or not on PATH.")
|
||||
return binary
|
||||
|
||||
|
||||
def _run_opencli(args: list[str]) -> list[dict]:
|
||||
binary = _ensure_opencli()
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[binary, *args],
|
||||
check=False,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=60,
|
||||
)
|
||||
except (OSError, subprocess.SubprocessError) as exc:
|
||||
raise DataVendorUnavailable(f"opencli-rs execution failed: {exc}") from exc
|
||||
|
||||
if result.returncode != 0:
|
||||
stderr = (result.stderr or result.stdout or "").strip()
|
||||
raise DataVendorUnavailable(f"opencli-rs command failed: {stderr}")
|
||||
|
||||
try:
|
||||
payload = json.loads(result.stdout)
|
||||
except json.JSONDecodeError as exc:
|
||||
raise DataVendorUnavailable("opencli-rs returned non-JSON output.") from exc
|
||||
|
||||
if not isinstance(payload, list):
|
||||
raise DataVendorUnavailable("opencli-rs returned an unexpected payload format.")
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
def _safe_run_opencli(args: list[str]) -> tuple[list[dict], str | None]:
|
||||
try:
|
||||
return _run_opencli(args), None
|
||||
except DataVendorUnavailable as exc:
|
||||
return [], str(exc)
|
||||
|
||||
|
||||
def _format_block(title: str, records: list[str]) -> str:
|
||||
if not records:
|
||||
return f"### {title}\nNo results."
|
||||
return f"### {title}\n" + "\n\n".join(records)
|
||||
|
||||
|
||||
def _dedupe_records(items: list[dict], keys: tuple[str, ...]) -> list[dict]:
|
||||
seen: set[str] = set()
|
||||
output: list[dict] = []
|
||||
for item in items:
|
||||
identity = " | ".join(str(item.get(key, "")).strip() for key in keys).strip()
|
||||
if not identity or identity in seen:
|
||||
continue
|
||||
seen.add(identity)
|
||||
output.append(item)
|
||||
return output
|
||||
|
||||
|
||||
def _clean_symbol(symbol: str) -> str:
|
||||
return symbol.strip().upper()
|
||||
|
||||
|
||||
def _symbol_without_suffix(symbol: str) -> str:
|
||||
clean = _clean_symbol(symbol)
|
||||
return clean.split(".", 1)[0]
|
||||
|
||||
|
||||
def _resolve_company_aliases(ticker: str) -> list[str]:
|
||||
aliases: list[str] = []
|
||||
|
||||
try:
|
||||
from .tushare import _classify_market, _get_pro_client, _normalize_ts_code
|
||||
|
||||
ts_code = _normalize_ts_code(ticker)
|
||||
market = _classify_market(ts_code)
|
||||
pro = _get_pro_client()
|
||||
|
||||
if market == "a_share":
|
||||
basic = pro.stock_basic(ts_code=ts_code, fields="ts_code,name")
|
||||
elif market == "hk":
|
||||
basic = pro.hk_basic(ts_code=ts_code)
|
||||
else:
|
||||
basic = pro.us_basic(ts_code=ts_code)
|
||||
|
||||
if basic is not None and not basic.empty:
|
||||
row = basic.iloc[0]
|
||||
for field in ("name", "fullname", "enname"):
|
||||
value = row.get(field)
|
||||
if value:
|
||||
aliases.append(str(value).strip())
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
aliases.extend([_clean_symbol(ticker), _symbol_without_suffix(ticker)])
|
||||
|
||||
expanded_aliases: list[str] = []
|
||||
for alias in aliases:
|
||||
alias = alias.strip()
|
||||
if not alias:
|
||||
continue
|
||||
expanded_aliases.append(alias)
|
||||
if alias.endswith("股份有限公司"):
|
||||
short_alias = alias[: -len("股份有限公司")].strip()
|
||||
if short_alias:
|
||||
expanded_aliases.append(short_alias)
|
||||
if alias.endswith("有限公司"):
|
||||
short_alias = alias[: -len("有限公司")].strip()
|
||||
if short_alias:
|
||||
expanded_aliases.append(short_alias)
|
||||
|
||||
seen: set[str] = set()
|
||||
result: list[str] = []
|
||||
for alias in expanded_aliases:
|
||||
if alias not in seen:
|
||||
seen.add(alias)
|
||||
result.append(alias)
|
||||
return result
|
||||
|
||||
|
||||
def _build_google_queries(ticker: str) -> list[str]:
|
||||
aliases = _resolve_company_aliases(ticker)
|
||||
queries: list[str] = []
|
||||
for alias in aliases:
|
||||
queries.append(f"{alias} stock")
|
||||
queries.append(alias)
|
||||
return queries
|
||||
|
||||
|
||||
def _collect_google_news(ticker: str, limit: int = 8) -> tuple[list[dict], list[str]]:
|
||||
items: list[dict] = []
|
||||
errors: list[str] = []
|
||||
|
||||
for query in _build_google_queries(ticker):
|
||||
payload, error = _safe_run_opencli(
|
||||
["google", "news", query, "--limit", str(limit), "--format", "json"]
|
||||
)
|
||||
if error:
|
||||
errors.append(f"{query}: {error}")
|
||||
continue
|
||||
items.extend(payload)
|
||||
if len(_dedupe_records(items, ("url", "title"))) >= limit:
|
||||
break
|
||||
|
||||
return _dedupe_records(items, ("url", "title"))[:limit], errors
|
||||
|
||||
|
||||
def _collect_google_search_results(ticker: str, limit: int = 8) -> tuple[list[dict], list[str]]:
|
||||
items: list[dict] = []
|
||||
errors: list[str] = []
|
||||
|
||||
for query in _build_google_queries(ticker):
|
||||
payload, error = _safe_run_opencli(
|
||||
["google", "search", query, "--lang", "zh", "--limit", str(limit), "--format", "json"]
|
||||
)
|
||||
if error:
|
||||
errors.append(f"{query}: {error}")
|
||||
continue
|
||||
items.extend(payload)
|
||||
if len(_dedupe_records(items, ("url", "title"))) >= limit:
|
||||
break
|
||||
|
||||
return _dedupe_records(items, ("url", "title"))[:limit], errors
|
||||
|
||||
|
||||
def _collect_xueqiu_results(ticker: str, limit: int = 8) -> tuple[list[dict], list[str]]:
|
||||
items: list[dict] = []
|
||||
errors: list[str] = []
|
||||
|
||||
for keyword in _resolve_company_aliases(ticker):
|
||||
payload, error = _safe_run_opencli(
|
||||
["xueqiu", "search", keyword, "--limit", str(limit), "--format", "json"]
|
||||
)
|
||||
if error:
|
||||
errors.append(f"{keyword}: {error}")
|
||||
continue
|
||||
items.extend(payload)
|
||||
if len(_dedupe_records(items, ("symbol", "name"))) >= limit:
|
||||
break
|
||||
|
||||
return _dedupe_records(items, ("symbol", "name"))[:limit], errors
|
||||
|
||||
|
||||
def _collect_weibo_results(ticker: str, limit: int = 8) -> tuple[list[dict], list[str]]:
|
||||
items: list[dict] = []
|
||||
errors: list[str] = []
|
||||
|
||||
for keyword in _resolve_company_aliases(ticker):
|
||||
payload, error = _safe_run_opencli(
|
||||
["weibo", "search", keyword, "--limit", str(limit), "--format", "json"]
|
||||
)
|
||||
if error:
|
||||
errors.append(f"{keyword}: {error}")
|
||||
continue
|
||||
items.extend(payload)
|
||||
if len(_dedupe_records(items, ("url", "text", "word"))) >= limit:
|
||||
break
|
||||
|
||||
return _dedupe_records(items, ("url", "text", "word"))[:limit], errors
|
||||
|
||||
|
||||
def _collect_xiaohongshu_results(ticker: str, limit: int = 8) -> tuple[list[dict], list[str]]:
|
||||
items: list[dict] = []
|
||||
errors: list[str] = []
|
||||
|
||||
for keyword in _resolve_company_aliases(ticker):
|
||||
payload, error = _safe_run_opencli(
|
||||
["xiaohongshu", "search", keyword, "--limit", str(limit), "--format", "json"]
|
||||
)
|
||||
if error:
|
||||
errors.append(f"{keyword}: {error}")
|
||||
continue
|
||||
items.extend(payload)
|
||||
if len(_dedupe_records(items, ("id", "note_id", "url", "title"))) >= limit:
|
||||
break
|
||||
|
||||
return _dedupe_records(items, ("id", "note_id", "url", "title"))[:limit], errors
|
||||
|
||||
|
||||
def _collect_sinafinance_results(ticker: str, limit: int = 8) -> tuple[list[dict], list[str]]:
|
||||
aliases = _resolve_company_aliases(ticker)
|
||||
payload, error = _safe_run_opencli(
|
||||
["sinafinance", "news", "--type", "1", "--limit", "50", "--format", "json"]
|
||||
)
|
||||
if error:
|
||||
return [], [error]
|
||||
|
||||
filtered: list[dict] = []
|
||||
for item in payload:
|
||||
haystack = " ".join(
|
||||
str(item.get(field, "")).strip()
|
||||
for field in ("content", "title", "symbol", "name")
|
||||
)
|
||||
if any(alias and alias in haystack for alias in aliases):
|
||||
filtered.append(item)
|
||||
|
||||
return _dedupe_records(filtered, ("time", "content", "title"))[:limit], []
|
||||
|
||||
|
||||
def get_news(ticker: str, start_date: str, end_date: str) -> str:
|
||||
_parse_date(start_date)
|
||||
_parse_date(end_date)
|
||||
|
||||
sections: list[str] = []
|
||||
errors: list[str] = []
|
||||
|
||||
xueqiu_items, xueqiu_errors = _collect_xueqiu_results(ticker, limit=6)
|
||||
errors.extend(xueqiu_errors)
|
||||
if xueqiu_items:
|
||||
sections.append(
|
||||
_format_block(
|
||||
"Xueqiu Search",
|
||||
[
|
||||
(
|
||||
f"- {item.get('name', item.get('symbol', 'Unknown'))} "
|
||||
f"(symbol: {item.get('symbol', 'Unknown')})"
|
||||
)
|
||||
for item in xueqiu_items
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
weibo_items, weibo_errors = _collect_weibo_results(ticker, limit=6)
|
||||
errors.extend(weibo_errors)
|
||||
if weibo_items:
|
||||
sections.append(
|
||||
_format_block(
|
||||
"Weibo Search",
|
||||
[
|
||||
(
|
||||
f"- {item.get('text', item.get('word', 'No text'))}\n"
|
||||
f" Link: {item.get('url', '')}"
|
||||
)
|
||||
for item in weibo_items
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
xiaohongshu_items, xiaohongshu_errors = _collect_xiaohongshu_results(ticker, limit=6)
|
||||
errors.extend(xiaohongshu_errors)
|
||||
if xiaohongshu_items:
|
||||
sections.append(
|
||||
_format_block(
|
||||
"Xiaohongshu Search",
|
||||
[
|
||||
(
|
||||
f"- {item.get('title', item.get('desc', 'No title'))}\n"
|
||||
f" Link: {item.get('url', '')}"
|
||||
)
|
||||
for item in xiaohongshu_items
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
sina_items, sina_errors = _collect_sinafinance_results(ticker, limit=6)
|
||||
errors.extend(sina_errors)
|
||||
if sina_items:
|
||||
sections.append(
|
||||
_format_block(
|
||||
"Sina Finance A-Share Flash",
|
||||
[
|
||||
(
|
||||
f"- {item.get('content', item.get('title', 'No content'))} "
|
||||
f"(time: {item.get('time', 'Unknown')}, views: {item.get('views', 'Unknown')})"
|
||||
)
|
||||
for item in sina_items
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
google_news_items, google_news_errors = _collect_google_news(ticker, limit=6)
|
||||
errors.extend(google_news_errors)
|
||||
if google_news_items:
|
||||
sections.append(
|
||||
_format_block(
|
||||
"Google News",
|
||||
[
|
||||
(
|
||||
f"- {item.get('title', 'No title')} "
|
||||
f"(source: {item.get('source', 'Unknown')}, date: {item.get('date', 'Unknown')})\n"
|
||||
f" Link: {item.get('url', '')}"
|
||||
)
|
||||
for item in google_news_items
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
google_search_items, google_search_errors = _collect_google_search_results(ticker, limit=6)
|
||||
errors.extend(google_search_errors)
|
||||
if google_search_items:
|
||||
sections.append(
|
||||
_format_block(
|
||||
"Google Search (ZH)",
|
||||
[
|
||||
(
|
||||
f"- {item.get('title', 'No title')}\n"
|
||||
f" Link: {item.get('url', '')}"
|
||||
)
|
||||
for item in google_search_items
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
if not sections:
|
||||
aliases = ", ".join(_resolve_company_aliases(ticker))
|
||||
detail = (
|
||||
f"No relevant news found via opencli-rs for {ticker} "
|
||||
f"between {start_date} and {end_date}. "
|
||||
f"Queries tried: {aliases or ticker}."
|
||||
)
|
||||
if errors:
|
||||
detail += f" Source errors: {'; '.join(errors[:3])}."
|
||||
return detail
|
||||
|
||||
return f"## {ticker} News and Social Signals, from {start_date} to {end_date}:\n\n" + "\n\n".join(sections)
|
||||
|
||||
|
||||
def get_global_news(curr_date: str, look_back_days: int = 7, limit: int = 10) -> str:
|
||||
end_dt = _parse_date(curr_date)
|
||||
start_date = (end_dt - timedelta(days=look_back_days)).strftime("%Y-%m-%d")
|
||||
|
||||
sections = []
|
||||
|
||||
google_items = _run_opencli(["google", "news", "--limit", str(limit), "--format", "json"])
|
||||
sections.append(
|
||||
_format_block(
|
||||
"Google News Top Stories",
|
||||
[
|
||||
(
|
||||
f"- {item.get('title', 'No title')} "
|
||||
f"(source: {item.get('source', 'Unknown')}, date: {item.get('date', 'Unknown')})\n"
|
||||
f" Link: {item.get('url', '')}"
|
||||
)
|
||||
for item in google_items[:limit]
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
sina_items = _run_opencli(["sinafinance", "news", "--limit", str(limit), "--format", "json"])
|
||||
sections.append(
|
||||
_format_block(
|
||||
"Sina Finance Flash News",
|
||||
[
|
||||
(
|
||||
f"- {item.get('content', 'No content')} "
|
||||
f"(time: {item.get('time', 'Unknown')}, views: {item.get('views', 'Unknown')})"
|
||||
)
|
||||
for item in sina_items[:limit]
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
xueqiu_hot = _run_opencli(["xueqiu", "hot", "--limit", str(min(limit, 8)), "--format", "json"])
|
||||
sections.append(
|
||||
_format_block(
|
||||
"Xueqiu Hot Discussions",
|
||||
[
|
||||
(
|
||||
f"- {item.get('text', 'No text')} "
|
||||
f"(author: {item.get('author', 'Unknown')}, likes: {item.get('likes', 'Unknown')})\n"
|
||||
f" Link: {item.get('url', '')}"
|
||||
)
|
||||
for item in xueqiu_hot[:limit]
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
weibo_hot = _run_opencli(["weibo", "hot", "--limit", str(min(limit, 8)), "--format", "json"])
|
||||
sections.append(
|
||||
_format_block(
|
||||
"Weibo Hot Topics",
|
||||
[
|
||||
(
|
||||
f"- {item.get('word', 'No topic')} "
|
||||
f"(category: {item.get('category', 'Unknown')}, heat: {item.get('hot_value', 'Unknown')})\n"
|
||||
f" Link: {item.get('url', '')}"
|
||||
)
|
||||
for item in weibo_hot[:limit]
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
return f"## Global Market News and Social Signals, from {start_date} to {curr_date}:\n\n" + "\n\n".join(sections)
|
||||
|
|
@ -0,0 +1,575 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
from functools import lru_cache
|
||||
from typing import Callable
|
||||
|
||||
import pandas as pd
|
||||
from stockstats import wrap
|
||||
|
||||
from .exceptions import DataVendorUnavailable
|
||||
|
||||
|
||||
_SUPPORTED_EXCHANGES = {"SH", "SZ", "BJ", "HK"}
|
||||
_SUFFIX_MAP = {
|
||||
"SH": "SH",
|
||||
"SS": "SH",
|
||||
"SSE": "SH",
|
||||
"SZ": "SZ",
|
||||
"SZSE": "SZ",
|
||||
"BJ": "BJ",
|
||||
"BSE": "BJ",
|
||||
"HK": "HK",
|
||||
"HKG": "HK",
|
||||
"SEHK": "HK",
|
||||
}
|
||||
|
||||
_A_SHARE_EXCHANGES = {"SH", "SZ", "BJ"}
|
||||
|
||||
|
||||
def _parse_date(date_str: str) -> datetime:
|
||||
return datetime.strptime(date_str, "%Y-%m-%d")
|
||||
|
||||
|
||||
def _to_api_date(date_str: str) -> str:
|
||||
return _parse_date(date_str).strftime("%Y%m%d")
|
||||
|
||||
|
||||
def _classify_market(ts_code: str) -> str:
|
||||
if "." in ts_code:
|
||||
suffix = ts_code.rsplit(".", 1)[1]
|
||||
if suffix in _A_SHARE_EXCHANGES:
|
||||
return "a_share"
|
||||
if suffix == "HK":
|
||||
return "hk"
|
||||
return "us"
|
||||
|
||||
|
||||
def _normalize_ts_code(symbol: str) -> str:
|
||||
raw = symbol.strip().upper()
|
||||
|
||||
if "." in raw:
|
||||
code, suffix = raw.split(".", 1)
|
||||
suffix = _SUFFIX_MAP.get(suffix, suffix)
|
||||
if suffix in _A_SHARE_EXCHANGES and code.isdigit():
|
||||
return f"{code.zfill(6)}.{suffix}"
|
||||
if suffix == "HK" and code.isdigit():
|
||||
return f"{code.zfill(5)}.HK"
|
||||
raise DataVendorUnavailable(
|
||||
f"Tushare currently supports A-share, Hong Kong, and US tickers only, got '{symbol}'."
|
||||
)
|
||||
|
||||
if raw.isdigit() and len(raw) <= 6:
|
||||
code = raw.zfill(6)
|
||||
if code.startswith(("6", "9", "5")):
|
||||
return f"{code}.SH"
|
||||
if code.startswith(("0", "2", "3")):
|
||||
return f"{code}.SZ"
|
||||
if code.startswith(("4", "8")):
|
||||
return f"{code}.BJ"
|
||||
return f"{raw.zfill(5)}.HK"
|
||||
|
||||
if raw.replace("-", "").isalnum():
|
||||
return raw
|
||||
|
||||
raise DataVendorUnavailable(
|
||||
f"Cannot map ticker '{symbol}' to a supported Tushare market automatically."
|
||||
)
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _get_pro_client():
|
||||
token = (
|
||||
os.getenv("TUSHARE_TOKEN")
|
||||
or os.getenv("TUSHARE_API_TOKEN")
|
||||
or os.getenv("TS_TOKEN")
|
||||
)
|
||||
if not token:
|
||||
raise DataVendorUnavailable(
|
||||
"TUSHARE_TOKEN is not set. Configure token or use fallback vendor."
|
||||
)
|
||||
|
||||
try:
|
||||
import tushare as ts
|
||||
except ImportError as exc:
|
||||
raise DataVendorUnavailable(
|
||||
"tushare package is not installed. Install it to enable tushare vendor."
|
||||
) from exc
|
||||
|
||||
try:
|
||||
ts.set_token(token)
|
||||
return ts.pro_api(token)
|
||||
except Exception as exc:
|
||||
raise DataVendorUnavailable(f"Failed to initialize tushare client: {exc}") from exc
|
||||
|
||||
|
||||
def _to_csv_with_header(df: pd.DataFrame, title: str) -> str:
|
||||
if df is None or df.empty:
|
||||
return f"No {title.lower()} data found."
|
||||
|
||||
header = f"# {title}\n"
|
||||
header += f"# Total records: {len(df)}\n"
|
||||
header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
|
||||
return header + df.to_csv(index=False)
|
||||
|
||||
|
||||
def _filter_statement(df: pd.DataFrame, freq: str, curr_date: str | None) -> pd.DataFrame:
|
||||
if df is None or df.empty:
|
||||
return df
|
||||
|
||||
output = df.copy()
|
||||
|
||||
if curr_date and "end_date" in output.columns:
|
||||
cutoff = _to_api_date(curr_date)
|
||||
output = output[output["end_date"].astype(str) <= cutoff]
|
||||
|
||||
if freq.lower() == "annual" and "end_date" in output.columns:
|
||||
output = output[output["end_date"].astype(str).str.endswith("1231")]
|
||||
|
||||
sort_col = "end_date" if "end_date" in output.columns else output.columns[0]
|
||||
output = output.sort_values(sort_col, ascending=False).head(8)
|
||||
return output
|
||||
|
||||
|
||||
def _fetch_price_data(pro, ts_code: str, start_api: str, end_api: str) -> pd.DataFrame:
|
||||
market = _classify_market(ts_code)
|
||||
if market == "a_share":
|
||||
return pro.daily(ts_code=ts_code, start_date=start_api, end_date=end_api)
|
||||
if market == "hk":
|
||||
return pro.hk_daily(ts_code=ts_code, start_date=start_api, end_date=end_api)
|
||||
return pro.us_daily(ts_code=ts_code, start_date=start_api, end_date=end_api)
|
||||
|
||||
|
||||
def get_stock(symbol: str, start_date: str, end_date: str) -> str:
|
||||
pro = _get_pro_client()
|
||||
ts_code = _normalize_ts_code(symbol)
|
||||
|
||||
start_api = _to_api_date(start_date)
|
||||
end_api = _to_api_date(end_date)
|
||||
|
||||
data = _fetch_price_data(pro, ts_code, start_api, end_api)
|
||||
if data is None or data.empty:
|
||||
return f"No stock data found for '{ts_code}' between {start_date} and {end_date}."
|
||||
|
||||
rename_map = {
|
||||
"trade_date": "Date",
|
||||
"open": "Open",
|
||||
"high": "High",
|
||||
"low": "Low",
|
||||
"close": "Close",
|
||||
"vol": "Volume",
|
||||
"amount": "Amount",
|
||||
"pct_chg": "PctChg",
|
||||
"pre_close": "PrevClose",
|
||||
"change": "Change",
|
||||
}
|
||||
|
||||
output = data.rename(columns=rename_map)
|
||||
if "Date" in output.columns:
|
||||
output["Date"] = pd.to_datetime(output["Date"], format="%Y%m%d").dt.strftime(
|
||||
"%Y-%m-%d"
|
||||
)
|
||||
output = output.sort_values("Date", ascending=True)
|
||||
|
||||
preferred_cols = [
|
||||
"Date",
|
||||
"Open",
|
||||
"High",
|
||||
"Low",
|
||||
"Close",
|
||||
"PrevClose",
|
||||
"Change",
|
||||
"PctChg",
|
||||
"Volume",
|
||||
"Amount",
|
||||
]
|
||||
existing_cols = [c for c in preferred_cols if c in output.columns]
|
||||
output = output[existing_cols]
|
||||
|
||||
return _to_csv_with_header(
|
||||
output,
|
||||
f"Tushare stock data for {ts_code} from {start_date} to {end_date}",
|
||||
)
|
||||
|
||||
|
||||
def _load_price_frame(symbol: str, curr_date: str, look_back_days: int = 260) -> pd.DataFrame:
|
||||
pro = _get_pro_client()
|
||||
ts_code = _normalize_ts_code(symbol)
|
||||
end_dt = _parse_date(curr_date)
|
||||
start_dt = end_dt - timedelta(days=look_back_days)
|
||||
data = _fetch_price_data(
|
||||
pro,
|
||||
ts_code,
|
||||
start_dt.strftime("%Y%m%d"),
|
||||
end_dt.strftime("%Y%m%d"),
|
||||
)
|
||||
if data is None or data.empty:
|
||||
raise DataVendorUnavailable(
|
||||
f"No tushare price data found for '{ts_code}' before {curr_date}."
|
||||
)
|
||||
|
||||
df = data.rename(
|
||||
columns={
|
||||
"trade_date": "Date",
|
||||
"open": "Open",
|
||||
"high": "High",
|
||||
"low": "Low",
|
||||
"close": "Close",
|
||||
"vol": "Volume",
|
||||
}
|
||||
).copy()
|
||||
df["Date"] = pd.to_datetime(df["Date"], format="%Y%m%d")
|
||||
df = df.sort_values("Date", ascending=True)
|
||||
return df[["Date", "Open", "High", "Low", "Close", "Volume"]]
|
||||
|
||||
|
||||
def get_indicator(
|
||||
symbol: str,
|
||||
indicator: str,
|
||||
curr_date: str,
|
||||
look_back_days: int,
|
||||
) -> str:
|
||||
descriptions = {
|
||||
"close_50_sma": "50 SMA: A medium-term trend indicator. Usage: Identify trend direction and serve as dynamic support/resistance. Tips: It lags price; combine with faster indicators for timely signals.",
|
||||
"close_200_sma": "200 SMA: A long-term trend benchmark. Usage: Confirm overall market trend and identify golden/death cross setups. Tips: It reacts slowly; best for strategic trend confirmation rather than frequent trading entries.",
|
||||
"close_10_ema": "10 EMA: A responsive short-term average. Usage: Capture quick shifts in momentum and potential entry points. Tips: Prone to noise in choppy markets; use alongside longer averages for filtering false signals.",
|
||||
"macd": "MACD: Computes momentum via differences of EMAs. Usage: Look for crossovers and divergence as signals of trend changes. Tips: Confirm with other indicators in low-volatility or sideways markets.",
|
||||
"macds": "MACD Signal: An EMA smoothing of the MACD line. Usage: Use crossovers with the MACD line to trigger trades. Tips: Should be part of a broader strategy to avoid false positives.",
|
||||
"macdh": "MACD Histogram: Shows the gap between the MACD line and its signal. Usage: Visualize momentum strength and spot divergence early. Tips: Can be volatile; complement with additional filters in fast-moving markets.",
|
||||
"rsi": "RSI: Measures momentum to flag overbought/oversold conditions. Usage: Apply 70/30 thresholds and watch for divergence to signal reversals. Tips: In strong trends, RSI may remain extreme; always cross-check with trend analysis.",
|
||||
"boll": "Bollinger Middle: A 20 SMA serving as the basis for Bollinger Bands. Usage: Acts as a dynamic benchmark for price movement. Tips: Combine with the upper and lower bands to effectively spot breakouts or reversals.",
|
||||
"boll_ub": "Bollinger Upper Band: Typically 2 standard deviations above the middle line. Usage: Signals potential overbought conditions and breakout zones. Tips: Confirm signals with other tools; prices may ride the band in strong trends.",
|
||||
"boll_lb": "Bollinger Lower Band: Typically 2 standard deviations below the middle line. Usage: Indicates potential oversold conditions. Tips: Use additional analysis to avoid false reversal signals.",
|
||||
"atr": "ATR: Averages true range to measure volatility. Usage: Set stop-loss levels and adjust position sizes based on current market volatility. Tips: It's a reactive measure, so use it as part of a broader risk management strategy.",
|
||||
"vwma": "VWMA: A moving average weighted by volume. Usage: Confirm trends by integrating price action with volume data. Tips: Watch for skewed results from volume spikes; use in combination with other volume analyses.",
|
||||
"mfi": "MFI: Uses both price and volume to measure buying and selling pressure. Usage: Identify overbought (>80) or oversold (<20) conditions and confirm trends or reversals.",
|
||||
}
|
||||
if indicator not in descriptions:
|
||||
raise ValueError(
|
||||
f"Indicator {indicator} is not supported. Please choose from: {list(descriptions.keys())}"
|
||||
)
|
||||
|
||||
current_dt = _parse_date(curr_date)
|
||||
start_dt = current_dt - timedelta(days=look_back_days)
|
||||
stats_df = wrap(_load_price_frame(symbol, curr_date))
|
||||
stats_df["Date"] = stats_df["Date"].dt.strftime("%Y-%m-%d")
|
||||
stats_df[indicator]
|
||||
|
||||
lines = []
|
||||
probe_dt = current_dt
|
||||
while probe_dt >= start_dt:
|
||||
date_str = probe_dt.strftime("%Y-%m-%d")
|
||||
row = stats_df[stats_df["Date"] == date_str]
|
||||
if row.empty:
|
||||
lines.append(f"{date_str}: N/A: Not a trading day (weekend or holiday)")
|
||||
else:
|
||||
value = row.iloc[0][indicator]
|
||||
if pd.isna(value):
|
||||
lines.append(f"{date_str}: N/A")
|
||||
else:
|
||||
lines.append(f"{date_str}: {value}")
|
||||
probe_dt -= timedelta(days=1)
|
||||
|
||||
return (
|
||||
f"## {indicator} values from {start_dt.strftime('%Y-%m-%d')} to {curr_date}:\n\n"
|
||||
+ "\n".join(lines)
|
||||
+ "\n\n"
|
||||
+ descriptions[indicator]
|
||||
)
|
||||
|
||||
|
||||
def get_fundamentals(ticker: str, curr_date: str | None = None) -> str:
|
||||
pro = _get_pro_client()
|
||||
ts_code = _normalize_ts_code(ticker)
|
||||
market = _classify_market(ts_code)
|
||||
|
||||
if curr_date:
|
||||
curr_dt = _parse_date(curr_date)
|
||||
else:
|
||||
curr_dt = datetime.now()
|
||||
curr_date = curr_dt.strftime("%Y-%m-%d")
|
||||
|
||||
end_api = curr_dt.strftime("%Y%m%d")
|
||||
start_api_40d = (curr_dt - timedelta(days=40)).strftime("%Y%m%d")
|
||||
start_api_400d = (curr_dt - timedelta(days=400)).strftime("%Y%m%d")
|
||||
|
||||
if market == "a_share":
|
||||
basic = pro.stock_basic(
|
||||
ts_code=ts_code,
|
||||
fields="ts_code,symbol,name,area,industry,market,list_date,list_status",
|
||||
)
|
||||
latest_price = pro.daily_basic(
|
||||
ts_code=ts_code,
|
||||
start_date=start_api_40d,
|
||||
end_date=end_api,
|
||||
)
|
||||
fina_indicator = pro.fina_indicator(
|
||||
ts_code=ts_code,
|
||||
start_date=start_api_400d,
|
||||
end_date=end_api,
|
||||
)
|
||||
elif market == "hk":
|
||||
basic = pro.hk_basic(ts_code=ts_code)
|
||||
latest_price = pro.hk_daily(ts_code=ts_code, start_date=start_api_40d, end_date=end_api)
|
||||
fina_indicator = None
|
||||
else:
|
||||
basic = pro.us_basic(ts_code=ts_code)
|
||||
latest_price = pro.us_daily(ts_code=ts_code, start_date=start_api_40d, end_date=end_api)
|
||||
fina_indicator = None
|
||||
|
||||
lines = [
|
||||
f"Ticker: {ts_code}",
|
||||
f"Market: {market}",
|
||||
f"Reference date: {curr_date}",
|
||||
]
|
||||
|
||||
if basic is not None and not basic.empty:
|
||||
row = basic.iloc[0]
|
||||
if market == "a_share":
|
||||
field_map = {
|
||||
"name": "Name",
|
||||
"area": "Area",
|
||||
"industry": "Industry",
|
||||
"market": "Market",
|
||||
"list_date": "List Date",
|
||||
"list_status": "List Status",
|
||||
}
|
||||
elif market == "hk":
|
||||
field_map = {
|
||||
"name": "Name",
|
||||
"fullname": "Full Name",
|
||||
"enname": "English Name",
|
||||
"market": "Market",
|
||||
"curr_type": "Currency",
|
||||
"list_date": "List Date",
|
||||
"list_status": "List Status",
|
||||
}
|
||||
else:
|
||||
field_map = {
|
||||
"name": "Name",
|
||||
"enname": "English Name",
|
||||
"classify": "Classify",
|
||||
"list_date": "List Date",
|
||||
"delist_date": "Delist Date",
|
||||
}
|
||||
for field, label in field_map.items():
|
||||
value = row.get(field)
|
||||
if pd.notna(value):
|
||||
lines.append(f"{label}: {value}")
|
||||
|
||||
if latest_price is not None and not latest_price.empty:
|
||||
row = latest_price.sort_values("trade_date", ascending=False).iloc[0]
|
||||
if market == "a_share":
|
||||
field_map = {
|
||||
"trade_date": "Latest Trade Date",
|
||||
"close": "Close",
|
||||
"turnover_rate": "Turnover Rate",
|
||||
"pe": "PE",
|
||||
"pb": "PB",
|
||||
"ps": "PS",
|
||||
"dv_ratio": "Dividend Yield Ratio",
|
||||
"total_mv": "Total Market Value",
|
||||
"circ_mv": "Circulating Market Value",
|
||||
}
|
||||
else:
|
||||
field_map = {
|
||||
"trade_date": "Latest Trade Date",
|
||||
"close": "Close",
|
||||
"open": "Open",
|
||||
"high": "High",
|
||||
"low": "Low",
|
||||
"pre_close": "Prev Close",
|
||||
"change": "Change",
|
||||
"pct_chg": "Pct Change",
|
||||
"vol": "Volume",
|
||||
"amount": "Amount",
|
||||
}
|
||||
for field, label in field_map.items():
|
||||
value = row.get(field)
|
||||
if pd.notna(value):
|
||||
lines.append(f"{label}: {value}")
|
||||
|
||||
if fina_indicator is not None and not fina_indicator.empty:
|
||||
row = fina_indicator.sort_values("end_date", ascending=False).iloc[0]
|
||||
field_map = {
|
||||
"end_date": "Latest Financial Period",
|
||||
"roe": "ROE",
|
||||
"roa": "ROA",
|
||||
"grossprofit_margin": "Gross Margin",
|
||||
"netprofit_margin": "Net Margin",
|
||||
"debt_to_assets": "Debt to Assets",
|
||||
"ocf_to_or": "OCF to Revenue",
|
||||
}
|
||||
for field, label in field_map.items():
|
||||
value = row.get(field)
|
||||
if pd.notna(value):
|
||||
lines.append(f"{label}: {value}")
|
||||
elif market == "hk":
|
||||
income = pro.hk_income(ts_code=ts_code, end_date=end_api)
|
||||
if income is not None and not income.empty:
|
||||
latest_end = income["end_date"].astype(str).max()
|
||||
lines.append(f"Latest Financial Period: {latest_end}")
|
||||
sample = income[income["end_date"].astype(str) == latest_end].head(12)
|
||||
for _, rec in sample.iterrows():
|
||||
lines.append(f"{rec.get('ind_name')}: {rec.get('ind_value')}")
|
||||
else:
|
||||
income = pro.us_income(ts_code=ts_code, end_date=end_api)
|
||||
if income is not None and not income.empty:
|
||||
latest_end = income["end_date"].astype(str).max()
|
||||
lines.append(f"Latest Financial Period: {latest_end}")
|
||||
sample = income[income["end_date"].astype(str) == latest_end].head(12)
|
||||
for _, rec in sample.iterrows():
|
||||
lines.append(f"{rec.get('ind_name')}: {rec.get('ind_value')}")
|
||||
|
||||
header = f"# Tushare fundamentals for {ts_code}\n"
|
||||
header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
|
||||
return header + "\n".join(lines)
|
||||
|
||||
|
||||
def _statement_common(
|
||||
ticker: str,
|
||||
freq: str,
|
||||
curr_date: str | None,
|
||||
fetcher: Callable,
|
||||
title: str,
|
||||
) -> str:
|
||||
pro = _get_pro_client()
|
||||
ts_code = _normalize_ts_code(ticker)
|
||||
market = _classify_market(ts_code)
|
||||
data = fetcher(pro, ts_code, market)
|
||||
filtered = _filter_statement(data, freq, curr_date)
|
||||
return _to_csv_with_header(filtered, f"Tushare {title} for {ts_code} ({freq})")
|
||||
|
||||
|
||||
def get_balance_sheet(
|
||||
ticker: str,
|
||||
freq: str = "quarterly",
|
||||
curr_date: str | None = None,
|
||||
) -> str:
|
||||
return _statement_common(
|
||||
ticker,
|
||||
freq,
|
||||
curr_date,
|
||||
lambda pro, ts_code, market: (
|
||||
pro.balancesheet(ts_code=ts_code)
|
||||
if market == "a_share"
|
||||
else pro.hk_balancesheet(ts_code=ts_code)
|
||||
if market == "hk"
|
||||
else pro.us_balancesheet(ts_code=ts_code)
|
||||
),
|
||||
"balance sheet",
|
||||
)
|
||||
|
||||
|
||||
def get_cashflow(
|
||||
ticker: str,
|
||||
freq: str = "quarterly",
|
||||
curr_date: str | None = None,
|
||||
) -> str:
|
||||
return _statement_common(
|
||||
ticker,
|
||||
freq,
|
||||
curr_date,
|
||||
lambda pro, ts_code, market: (
|
||||
pro.cashflow(ts_code=ts_code)
|
||||
if market == "a_share"
|
||||
else pro.hk_cashflow(ts_code=ts_code)
|
||||
if market == "hk"
|
||||
else pro.us_cashflow(ts_code=ts_code)
|
||||
),
|
||||
"cashflow",
|
||||
)
|
||||
|
||||
|
||||
def get_income_statement(
|
||||
ticker: str,
|
||||
freq: str = "quarterly",
|
||||
curr_date: str | None = None,
|
||||
) -> str:
|
||||
return _statement_common(
|
||||
ticker,
|
||||
freq,
|
||||
curr_date,
|
||||
lambda pro, ts_code, market: (
|
||||
pro.income(ts_code=ts_code)
|
||||
if market == "a_share"
|
||||
else pro.hk_income(ts_code=ts_code)
|
||||
if market == "hk"
|
||||
else pro.us_income(ts_code=ts_code)
|
||||
),
|
||||
"income statement",
|
||||
)
|
||||
|
||||
|
||||
def get_insider_transactions(ticker: str) -> str:
|
||||
pro = _get_pro_client()
|
||||
ts_code = _normalize_ts_code(ticker)
|
||||
market = _classify_market(ts_code)
|
||||
|
||||
if market != "a_share":
|
||||
raise DataVendorUnavailable(
|
||||
f"Tushare insider transactions currently support A-share tickers only, got '{ts_code}'."
|
||||
)
|
||||
|
||||
end_dt = datetime.now()
|
||||
start_dt = end_dt - timedelta(days=365)
|
||||
|
||||
try:
|
||||
data = pro.stk_holdertrade(
|
||||
ts_code=ts_code,
|
||||
start_date=start_dt.strftime("%Y%m%d"),
|
||||
end_date=end_dt.strftime("%Y%m%d"),
|
||||
)
|
||||
except Exception as exc:
|
||||
raise DataVendorUnavailable(
|
||||
f"Failed to retrieve tushare insider transactions for '{ts_code}': {exc}"
|
||||
) from exc
|
||||
|
||||
if data is None or data.empty:
|
||||
return f"No tushare insider transactions found for '{ts_code}'."
|
||||
|
||||
output = data.rename(
|
||||
columns={
|
||||
"ann_date": "AnnouncementDate",
|
||||
"holder_name": "HolderName",
|
||||
"holder_type": "HolderType",
|
||||
"in_de": "Direction",
|
||||
"change_vol": "ChangeVolume",
|
||||
"change_ratio": "ChangeRatio",
|
||||
"after_share": "AfterShareholding",
|
||||
"after_ratio": "AfterRatio",
|
||||
"avg_price": "AveragePrice",
|
||||
"total_share": "TotalShareholding",
|
||||
"begin_date": "StartDate",
|
||||
"close_date": "EndDate",
|
||||
}
|
||||
).copy()
|
||||
|
||||
for col in ("AnnouncementDate", "StartDate", "EndDate"):
|
||||
if col in output.columns:
|
||||
output[col] = pd.to_datetime(
|
||||
output[col], format="%Y%m%d", errors="coerce"
|
||||
).dt.strftime("%Y-%m-%d")
|
||||
|
||||
preferred_cols = [
|
||||
"AnnouncementDate",
|
||||
"HolderName",
|
||||
"HolderType",
|
||||
"Direction",
|
||||
"ChangeVolume",
|
||||
"ChangeRatio",
|
||||
"AfterShareholding",
|
||||
"AfterRatio",
|
||||
"AveragePrice",
|
||||
"TotalShareholding",
|
||||
"StartDate",
|
||||
"EndDate",
|
||||
]
|
||||
existing_cols = [col for col in preferred_cols if col in output.columns]
|
||||
if existing_cols:
|
||||
output = output[existing_cols]
|
||||
|
||||
sort_col = "AnnouncementDate" if "AnnouncementDate" in output.columns else output.columns[0]
|
||||
output = output.sort_values(sort_col, ascending=False)
|
||||
return _to_csv_with_header(output, f"Tushare insider transactions for {ts_code}")
|
||||
|
|
@ -2,6 +2,7 @@ from typing import Annotated
|
|||
from datetime import datetime
|
||||
from dateutil.relativedelta import relativedelta
|
||||
import yfinance as yf
|
||||
import pandas as pd
|
||||
import os
|
||||
from .stockstats_utils import StockstatsUtils, _clean_dataframe, yf_retry, load_ohlcv, filter_financials_by_date
|
||||
|
||||
|
|
@ -418,4 +419,4 @@ def get_insider_transactions(
|
|||
return header + csv_string
|
||||
|
||||
except Exception as e:
|
||||
return f"Error retrieving insider transactions for {ticker}: {str(e)}"
|
||||
return f"Error retrieving insider transactions for {ticker}: {str(e)}"
|
||||
|
|
|
|||
|
|
@ -23,16 +23,27 @@ DEFAULT_CONFIG = {
|
|||
"max_debate_rounds": 1,
|
||||
"max_risk_discuss_rounds": 1,
|
||||
"max_recur_limit": 100,
|
||||
"report_context_char_limit": 16000,
|
||||
"debate_history_char_limit": 12000,
|
||||
"memory_min_similarity": 0.15,
|
||||
# Data vendor configuration
|
||||
# Category-level configuration (default for all tools in category)
|
||||
"data_vendors": {
|
||||
"core_stock_apis": "yfinance", # Options: alpha_vantage, yfinance
|
||||
"technical_indicators": "yfinance", # Options: alpha_vantage, yfinance
|
||||
"fundamental_data": "yfinance", # Options: alpha_vantage, yfinance
|
||||
"news_data": "yfinance", # Options: alpha_vantage, yfinance
|
||||
"core_stock_apis": "tushare,yfinance", # Options: tushare, yfinance
|
||||
"technical_indicators": "tushare,yfinance", # Options: tushare, yfinance
|
||||
"fundamental_data": "tushare,yfinance", # Options: tushare, yfinance
|
||||
"news_data": "opencli,brave,yfinance", # Options: opencli, brave, yfinance
|
||||
},
|
||||
# Tool-level configuration (takes precedence over category-level)
|
||||
"tool_vendors": {
|
||||
# Example: "get_stock_data": "alpha_vantage", # Override category default
|
||||
"get_stock_data": "tushare",
|
||||
"get_indicators": "tushare",
|
||||
"get_fundamentals": "tushare",
|
||||
"get_balance_sheet": "tushare",
|
||||
"get_cashflow": "tushare",
|
||||
"get_income_statement": "tushare",
|
||||
"get_news": "opencli",
|
||||
"get_global_news": "opencli",
|
||||
"get_insider_transactions": "tushare,yfinance",
|
||||
},
|
||||
}
|
||||
|
|
|
|||
|
|
@ -50,7 +50,8 @@ class ConditionalLogic:
|
|||
state["investment_debate_state"]["count"] >= 2 * self.max_debate_rounds
|
||||
): # 3 rounds of back-and-forth between 2 agents
|
||||
return "Research Manager"
|
||||
if state["investment_debate_state"]["current_response"].startswith("Bull"):
|
||||
latest_speaker = state["investment_debate_state"].get("latest_speaker", "")
|
||||
if latest_speaker.startswith("Bull"):
|
||||
return "Bear Researcher"
|
||||
return "Bull Researcher"
|
||||
|
||||
|
|
|
|||
|
|
@ -29,6 +29,10 @@ class Propagator:
|
|||
"bear_history": "",
|
||||
"history": "",
|
||||
"current_response": "",
|
||||
"bull_snapshot": "",
|
||||
"bear_snapshot": "",
|
||||
"debate_brief": "",
|
||||
"latest_speaker": "",
|
||||
"judge_decision": "",
|
||||
"count": 0,
|
||||
}
|
||||
|
|
@ -39,10 +43,14 @@ class Propagator:
|
|||
"conservative_history": "",
|
||||
"neutral_history": "",
|
||||
"history": "",
|
||||
"debate_brief": "",
|
||||
"latest_speaker": "",
|
||||
"current_aggressive_response": "",
|
||||
"current_conservative_response": "",
|
||||
"current_neutral_response": "",
|
||||
"aggressive_snapshot": "",
|
||||
"conservative_snapshot": "",
|
||||
"neutral_snapshot": "",
|
||||
"judge_decision": "",
|
||||
"count": 0,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -20,10 +20,16 @@ class SignalProcessor:
|
|||
Returns:
|
||||
Extracted rating (BUY, OVERWEIGHT, HOLD, UNDERWEIGHT, or SELL)
|
||||
"""
|
||||
normalized = self._normalize_known_rating(full_signal)
|
||||
if normalized:
|
||||
return normalized
|
||||
|
||||
messages = [
|
||||
(
|
||||
"system",
|
||||
"You are an efficient assistant that extracts the trading decision from analyst reports. "
|
||||
"The report may express the final recommendation in English or Chinese. "
|
||||
"Map Chinese ratings as follows: 买入=BUY, 增持=OVERWEIGHT, 持有=HOLD, 减持=UNDERWEIGHT, 卖出=SELL. "
|
||||
"Extract the rating as exactly one of: BUY, OVERWEIGHT, HOLD, UNDERWEIGHT, SELL. "
|
||||
"Output only the single rating word, nothing else.",
|
||||
),
|
||||
|
|
@ -31,3 +37,30 @@ class SignalProcessor:
|
|||
]
|
||||
|
||||
return self.quick_thinking_llm.invoke(messages).content
|
||||
|
||||
@staticmethod
|
||||
def _normalize_known_rating(full_signal: str) -> str | None:
|
||||
text = (full_signal or "").upper()
|
||||
english_markers = {
|
||||
"FINAL TRANSACTION PROPOSAL: **BUY**": "BUY",
|
||||
"FINAL TRANSACTION PROPOSAL: **OVERWEIGHT**": "OVERWEIGHT",
|
||||
"FINAL TRANSACTION PROPOSAL: **HOLD**": "HOLD",
|
||||
"FINAL TRANSACTION PROPOSAL: **UNDERWEIGHT**": "UNDERWEIGHT",
|
||||
"FINAL TRANSACTION PROPOSAL: **SELL**": "SELL",
|
||||
}
|
||||
for marker, rating in english_markers.items():
|
||||
if marker in text:
|
||||
return rating
|
||||
|
||||
chinese_markers = {
|
||||
"最终交易建议: **买入**": "BUY",
|
||||
"最终交易建议: **增持**": "OVERWEIGHT",
|
||||
"最终交易建议: **持有**": "HOLD",
|
||||
"最终交易建议: **减持**": "UNDERWEIGHT",
|
||||
"最终交易建议: **卖出**": "SELL",
|
||||
}
|
||||
for marker, rating in chinese_markers.items():
|
||||
if marker in full_signal:
|
||||
return rating
|
||||
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -121,7 +121,9 @@ class TradingAgentsGraph:
|
|||
self.conditional_logic,
|
||||
)
|
||||
|
||||
self.propagator = Propagator()
|
||||
self.propagator = Propagator(
|
||||
max_recur_limit=self.config.get("max_recur_limit", 100)
|
||||
)
|
||||
self.reflector = Reflector(self.quick_thinking_llm)
|
||||
self.signal_processor = SignalProcessor(self.quick_thinking_llm)
|
||||
|
||||
|
|
|
|||
|
|
@ -75,14 +75,16 @@ MODEL_OPTIONS: ProviderModeOptions = {
|
|||
},
|
||||
"ollama": {
|
||||
"quick": [
|
||||
("Qwen3:latest (8B, local)", "qwen3:latest"),
|
||||
("GPT-OSS:latest (20B, local)", "gpt-oss:latest"),
|
||||
("GLM-4.7-Flash:latest (30B, local)", "glm-4.7-flash:latest"),
|
||||
("Qwen3.5-27B (llama.cpp local)", "Qwen3.5-27B"),
|
||||
("Qwen3.5-35B-3A (llama.cpp local)", "Qwen3.5-35B-3A"),
|
||||
("Qwen3.5-35B-A3B (llama.cpp local)", "Qwen3.5-35B-A3B"),
|
||||
("Qwen3.5-122B (llama.cpp local)", "Qwen3.5-122B"),
|
||||
],
|
||||
"deep": [
|
||||
("GLM-4.7-Flash:latest (30B, local)", "glm-4.7-flash:latest"),
|
||||
("GPT-OSS:latest (20B, local)", "gpt-oss:latest"),
|
||||
("Qwen3:latest (8B, local)", "qwen3:latest"),
|
||||
("Qwen3.5-122B (llama.cpp local)", "Qwen3.5-122B"),
|
||||
("Qwen3.5-35B-A3B (llama.cpp local)", "Qwen3.5-35B-A3B"),
|
||||
("Qwen3.5-35B-3A (llama.cpp local)", "Qwen3.5-35B-3A"),
|
||||
("Qwen3.5-27B (llama.cpp local)", "Qwen3.5-27B"),
|
||||
],
|
||||
},
|
||||
}
|
||||
|
|
|
|||
|
|
@ -58,8 +58,8 @@ class OpenAIClient(BaseLLMClient):
|
|||
|
||||
# Provider-specific base URL and auth
|
||||
if self.provider in _PROVIDER_CONFIG:
|
||||
base_url, api_key_env = _PROVIDER_CONFIG[self.provider]
|
||||
llm_kwargs["base_url"] = base_url
|
||||
default_base_url, api_key_env = _PROVIDER_CONFIG[self.provider]
|
||||
llm_kwargs["base_url"] = self.base_url or default_base_url
|
||||
if api_key_env:
|
||||
api_key = os.environ.get(api_key_env)
|
||||
if api_key:
|
||||
|
|
|
|||
Loading…
Reference in New Issue