feat: optimize `get_lang` and `get_prompts` functions in localization module
This commit is contained in:
parent
53884b11a8
commit
219d7e4824
134
cli/main.py
134
cli/main.py
|
|
@ -32,7 +32,7 @@ lang = get_lang()
|
|||
|
||||
app = typer.Typer(
|
||||
name="TradingAgents",
|
||||
help=lang["welcome"] + ": " + lang["framework_subtitle"],
|
||||
help=get_lang("welcome") + ": " + get_lang("framework_subtitle"),
|
||||
add_completion=True, # Enable shell completion
|
||||
)
|
||||
|
||||
|
|
@ -46,22 +46,22 @@ class MessageBuffer:
|
|||
self.final_report = None # Store the complete final report
|
||||
self.agent_status = {
|
||||
# Analyst Team
|
||||
lang["market_analyst"]: lang["pending"],
|
||||
lang["social_analyst"]: lang["pending"],
|
||||
lang["news_analyst"]: lang["pending"],
|
||||
lang["fundamentals_analyst"]: lang["pending"],
|
||||
get_lang("market_analyst"): get_lang("pending"),
|
||||
get_lang("social_analyst"): get_lang("pending"),
|
||||
get_lang("news_analyst"): get_lang("pending"),
|
||||
get_lang("fundamentals_analyst"): get_lang("pending"),
|
||||
# Research Team
|
||||
lang["bull_researcher"]: lang["pending"],
|
||||
lang["bear_researcher"]: lang["pending"],
|
||||
lang["research_manager"]: lang["pending"],
|
||||
get_lang("bull_researcher"): get_lang("pending"),
|
||||
get_lang("bear_researcher"): get_lang("pending"),
|
||||
get_lang("research_manager"): get_lang("pending"),
|
||||
# Trading Team
|
||||
lang["trader"]: lang["pending"],
|
||||
get_lang("trader"): get_lang("pending"),
|
||||
# Risk Management Team
|
||||
lang["risky_analyst"]: lang["pending"],
|
||||
lang["neutral_analyst"]: lang["pending"],
|
||||
lang["safe_analyst"]: lang["pending"],
|
||||
get_lang("risky_analyst"): get_lang("pending"),
|
||||
get_lang("neutral_analyst"): get_lang("pending"),
|
||||
get_lang("safe_analyst"): get_lang("pending"),
|
||||
# Portfolio Management Team
|
||||
lang["portfolio_manager"]: lang["pending"],
|
||||
get_lang("portfolio_manager"): get_lang("pending"),
|
||||
}
|
||||
self.current_agent = None
|
||||
self.report_sections = {
|
||||
|
|
@ -106,13 +106,13 @@ class MessageBuffer:
|
|||
if latest_section and latest_content:
|
||||
# Format the current section for display
|
||||
section_titles = {
|
||||
"market_report": lang["market_report"],
|
||||
"sentiment_report": lang["sentiment_report"],
|
||||
"news_report": lang["news_report"],
|
||||
"fundamentals_report": lang["fundamentals_report"],
|
||||
"investment_plan": lang["investment_plan"],
|
||||
"trader_investment_plan": lang["trader_investment_plan"],
|
||||
"final_trade_decision": lang["final_trade_decision"],
|
||||
"market_report": get_lang("market_report"),
|
||||
"sentiment_report": get_lang("sentiment_report"),
|
||||
"news_report": get_lang("news_report"),
|
||||
"fundamentals_report": get_lang("fundamentals_report"),
|
||||
"investment_plan": get_lang("investment_plan"),
|
||||
"trader_investment_plan": get_lang("trader_investment_plan"),
|
||||
"final_trade_decision": get_lang("final_trade_decision"),
|
||||
}
|
||||
self.current_report = (
|
||||
f"### {section_titles[latest_section]}\n{latest_content}"
|
||||
|
|
@ -134,37 +134,37 @@ class MessageBuffer:
|
|||
"fundamentals_report",
|
||||
]
|
||||
):
|
||||
report_parts.append(f"## {lang['analyst_team']}")
|
||||
report_parts.append(f"## {get_lang("analyst_team")}")
|
||||
if self.report_sections["market_report"]:
|
||||
report_parts.append(
|
||||
f"### {lang['market_report']}\n{self.report_sections['market_report']}"
|
||||
f"### {get_lang("market_report")}\n{self.report_sections['market_report']}"
|
||||
)
|
||||
if self.report_sections["sentiment_report"]:
|
||||
report_parts.append(
|
||||
f"### {lang['sentiment_report']}\n{self.report_sections['sentiment_report']}"
|
||||
f"### {get_lang("sentiment_report")}\n{self.report_sections['sentiment_report']}"
|
||||
)
|
||||
if self.report_sections["news_report"]:
|
||||
report_parts.append(
|
||||
f"### {lang['news_report']}\n{self.report_sections['news_report']}"
|
||||
f"### {get_lang("news_report")}\n{self.report_sections['news_report']}"
|
||||
)
|
||||
if self.report_sections["fundamentals_report"]:
|
||||
report_parts.append(
|
||||
f"### {lang['fundamentals_report']}\n{self.report_sections['fundamentals_report']}"
|
||||
f"### {get_lang("fundamentals_report")}\n{self.report_sections['fundamentals_report']}"
|
||||
)
|
||||
|
||||
# Research Team Reports
|
||||
if self.report_sections["investment_plan"]:
|
||||
report_parts.append(f"## {lang['research_team']}")
|
||||
report_parts.append(f"## {get_lang("research_team")}")
|
||||
report_parts.append(f"{self.report_sections['investment_plan']}")
|
||||
|
||||
# Trading Team Reports
|
||||
if self.report_sections["trader_investment_plan"]:
|
||||
report_parts.append(f"## {lang['trading_team']}")
|
||||
report_parts.append(f"## {get_lang("trading_team")}")
|
||||
report_parts.append(f"{self.report_sections['trader_investment_plan']}")
|
||||
|
||||
# Portfolio Management Decision
|
||||
if self.report_sections["final_trade_decision"]:
|
||||
report_parts.append(f"## {lang['portfolio_management']}")
|
||||
report_parts.append(f"## {get_lang("portfolio_management")}")
|
||||
report_parts.append(f"{self.report_sections['final_trade_decision']}")
|
||||
|
||||
self.final_report = "\n\n".join(report_parts) if report_parts else None
|
||||
|
|
@ -193,9 +193,9 @@ def update_display(layout, spinner_text=None):
|
|||
# Header with welcome message
|
||||
layout["header"].update(
|
||||
Panel(
|
||||
f"[bold green]{lang['welcome']}[/bold green]\n"
|
||||
f"[bold green]{get_lang("welcome")}[/bold green]\n"
|
||||
"[dim]© [Tauric Research](https://github.com/TauricResearch)[/dim]",
|
||||
title=lang["welcome"],
|
||||
title=get_lang("welcome"),
|
||||
border_style="green",
|
||||
padding=(1, 2),
|
||||
expand=True,
|
||||
|
|
@ -212,22 +212,22 @@ def update_display(layout, spinner_text=None):
|
|||
padding=(0, 2), # Add horizontal padding
|
||||
expand=True, # Make table expand to fill available space
|
||||
)
|
||||
progress_table.add_column(lang["team"], style="cyan", justify="center", width=20)
|
||||
progress_table.add_column(lang["agent"], style="green", justify="center", width=20)
|
||||
progress_table.add_column(lang["status"], style="yellow", justify="center", width=20)
|
||||
progress_table.add_column(get_lang("team"), style="cyan", justify="center", width=20)
|
||||
progress_table.add_column(get_lang("agent"), style="green", justify="center", width=20)
|
||||
progress_table.add_column(get_lang("status"), style="yellow", justify="center", width=20)
|
||||
|
||||
# Group agents by team
|
||||
teams = {
|
||||
lang["analyst_team"]: [
|
||||
lang["market_analyst"],
|
||||
lang["social_analyst"],
|
||||
lang["news_analyst"],
|
||||
lang["fundamentals_analyst"],
|
||||
get_lang("analyst_team"): [
|
||||
get_lang("market_analyst"),
|
||||
get_lang("social_analyst"),
|
||||
get_lang("news_analyst"),
|
||||
get_lang("fundamentals_analyst"),
|
||||
],
|
||||
lang["research_team"]: [lang["bull_researcher"], lang["bear_researcher"], lang["research_manager"]],
|
||||
lang["trading_team"]: [lang["trader"]],
|
||||
lang["risk_management"]: [lang["risky_analyst"], lang["neutral_analyst"], lang["safe_analyst"]],
|
||||
lang["portfolio_management"]: [lang["portfolio_manager"]],
|
||||
get_lang("research_team"): [get_lang("bull_researcher"), get_lang("bear_researcher"), get_lang("research_manager")],
|
||||
get_lang("trading_team"): [get_lang("trader")],
|
||||
get_lang("risk_management"): [get_lang("risky_analyst"), get_lang("neutral_analyst"), get_lang("safe_analyst")],
|
||||
get_lang("portfolio_management"): [get_lang("portfolio_manager")],
|
||||
}
|
||||
|
||||
for team, agents in teams.items():
|
||||
|
|
@ -236,14 +236,14 @@ def update_display(layout, spinner_text=None):
|
|||
status = message_buffer.agent_status[first_agent]
|
||||
if status == "in_progress":
|
||||
spinner = Spinner(
|
||||
"dots", text=f"[blue]{lang['in_progress']}[/blue]", style="bold cyan"
|
||||
"dots", text=f"[blue]{get_lang("in_progress")}[/blue]", style="bold cyan"
|
||||
)
|
||||
status_cell = spinner
|
||||
else:
|
||||
status_color = {
|
||||
lang["pending"]: "yellow",
|
||||
lang["completed"]: "green",
|
||||
lang["error"]: "red",
|
||||
get_lang("pending"): "yellow",
|
||||
get_lang("completed"): "green",
|
||||
get_lang("error"): "red",
|
||||
}.get(status, "white")
|
||||
status_cell = f"[{status_color}]{status}[/{status_color}]"
|
||||
progress_table.add_row(team, first_agent, status_cell)
|
||||
|
|
@ -253,14 +253,14 @@ def update_display(layout, spinner_text=None):
|
|||
status = message_buffer.agent_status[agent]
|
||||
if status == "in_progress":
|
||||
spinner = Spinner(
|
||||
"dots", text=f"[blue]{lang['in_progress']}[/blue]", style="bold cyan"
|
||||
"dots", text=f"[blue]{get_lang("in_progress")}[/blue]", style="bold cyan"
|
||||
)
|
||||
status_cell = spinner
|
||||
else:
|
||||
status_color = {
|
||||
lang["pending"]: "yellow",
|
||||
lang["completed"]: "green",
|
||||
lang["error"]: "red",
|
||||
get_lang("pending"): "yellow",
|
||||
get_lang("completed"): "green",
|
||||
get_lang("error"): "red",
|
||||
}.get(status, "white")
|
||||
status_cell = f"[{status_color}]{status}[/{status_color}]"
|
||||
progress_table.add_row("", agent, status_cell)
|
||||
|
|
@ -269,7 +269,7 @@ def update_display(layout, spinner_text=None):
|
|||
progress_table.add_row("─" * 20, "─" * 20, "─" * 20, style="dim")
|
||||
|
||||
layout["progress"].update(
|
||||
Panel(progress_table, title=lang["progress"], border_style="cyan", padding=(1, 2))
|
||||
Panel(progress_table, title=get_lang("progress"), border_style="cyan", padding=(1, 2))
|
||||
)
|
||||
|
||||
# Messages panel showing recent messages and tool calls
|
||||
|
|
@ -284,7 +284,7 @@ def update_display(layout, spinner_text=None):
|
|||
)
|
||||
messages_table.add_column("Time", style="cyan", width=8, justify="center")
|
||||
messages_table.add_column("Type", style="green", width=10, justify="center")
|
||||
messages_table.add_column(lang["content"] if "content" in lang else "Content", style="white", no_wrap=False, ratio=1)
|
||||
messages_table.add_column(get_lang("content") if "content" in lang else "Content", style="white", no_wrap=False, ratio=1)
|
||||
|
||||
# Combine tool calls and messages
|
||||
all_messages = []
|
||||
|
|
@ -346,7 +346,7 @@ def update_display(layout, spinner_text=None):
|
|||
)
|
||||
|
||||
layout["messages"].update(
|
||||
Panel(messages_table, title=lang["messages_tools"], border_style="blue")
|
||||
Panel(messages_table, title=get_lang("messages_tools"), border_style="blue")
|
||||
)
|
||||
|
||||
# Analysis panel showing current report
|
||||
|
|
@ -354,7 +354,7 @@ def update_display(layout, spinner_text=None):
|
|||
layout["analysis"].update(
|
||||
Panel(
|
||||
Markdown(message_buffer.current_report),
|
||||
title=lang["current_report"],
|
||||
title=get_lang("current_report"),
|
||||
border_style="green",
|
||||
padding=(1, 2),
|
||||
)
|
||||
|
|
@ -363,7 +363,7 @@ def update_display(layout, spinner_text=None):
|
|||
layout["analysis"].update(
|
||||
Panel(
|
||||
"[italic]Waiting for analysis report...[/italic]",
|
||||
title=lang["current_report"],
|
||||
title=get_lang("current_report"),
|
||||
border_style="green",
|
||||
padding=(1, 2),
|
||||
)
|
||||
|
|
@ -395,9 +395,9 @@ def get_user_selections():
|
|||
|
||||
# Create welcome box content
|
||||
welcome_content = f"{welcome_ascii}\n"
|
||||
welcome_content += f"[bold green]TradingAgents: {lang['framework_subtitle']} - CLI[/bold green]\n\n"
|
||||
welcome_content += f"[bold]{lang['workflow_steps_title']}[/bold]\n"
|
||||
welcome_content += lang['workflow_steps']
|
||||
welcome_content += f"[bold green]TradingAgents: {get_lang("framework_subtitle")} - CLI[/bold green]\n\n"
|
||||
welcome_content += f"[bold]{get_lang("workflow_steps_title")}[/bold]\n"
|
||||
welcome_content += get_lang("workflow_steps")
|
||||
welcome_content += (
|
||||
"[dim]Built by [Tauric Research](https://github.com/TauricResearch)[/dim]"
|
||||
)
|
||||
|
|
@ -407,8 +407,8 @@ def get_user_selections():
|
|||
welcome_content,
|
||||
border_style="green",
|
||||
padding=(1, 2),
|
||||
title= lang["welcome"],
|
||||
subtitle=lang["framework_subtitle"],
|
||||
title= get_lang("welcome"),
|
||||
subtitle=get_lang("framework_subtitle"),
|
||||
)
|
||||
console.print(Align.center(welcome_box))
|
||||
console.print() # Add a blank line after the welcome box
|
||||
|
|
@ -424,7 +424,7 @@ def get_user_selections():
|
|||
# Step 1: Ticker symbol
|
||||
console.print(
|
||||
create_question_box(
|
||||
lang["step1_title"], lang["step1_prompt"], lang["default_ticker"]
|
||||
get_lang("step1_title"), get_lang("step1_prompt"), get_lang("default_ticker")
|
||||
)
|
||||
)
|
||||
selected_ticker = get_ticker()
|
||||
|
|
@ -433,8 +433,8 @@ def get_user_selections():
|
|||
default_date = datetime.datetime.now().strftime("%Y-%m-%d")
|
||||
console.print(
|
||||
create_question_box(
|
||||
lang["step2_title"],
|
||||
lang["step2_prompt"],
|
||||
get_lang("step2_title"),
|
||||
get_lang("step2_prompt"),
|
||||
default_date,
|
||||
)
|
||||
)
|
||||
|
|
@ -443,7 +443,7 @@ def get_user_selections():
|
|||
# Step 3: Select analysts
|
||||
console.print(
|
||||
create_question_box(
|
||||
lang["step3_title"], lang["step3_prompt"]
|
||||
get_lang("step3_title"), get_lang("step3_prompt")
|
||||
)
|
||||
)
|
||||
selected_analysts = select_analysts()
|
||||
|
|
@ -454,7 +454,7 @@ def get_user_selections():
|
|||
# Step 4: Research depth
|
||||
console.print(
|
||||
create_question_box(
|
||||
lang["step4_title"], lang["step4_prompt"]
|
||||
get_lang("step4_title"), get_lang("step4_prompt")
|
||||
)
|
||||
)
|
||||
selected_research_depth = select_research_depth()
|
||||
|
|
@ -462,7 +462,7 @@ def get_user_selections():
|
|||
# Step 5: OpenAI backend
|
||||
console.print(
|
||||
create_question_box(
|
||||
lang["step5_title"], lang["step5_prompt"]
|
||||
get_lang("step5_title"), get_lang("step5_prompt")
|
||||
)
|
||||
)
|
||||
selected_llm_provider, backend_url = select_llm_provider()
|
||||
|
|
@ -470,7 +470,7 @@ def get_user_selections():
|
|||
# Step 6: Thinking agents
|
||||
console.print(
|
||||
create_question_box(
|
||||
lang["step6_title"], lang["step6_prompt"]
|
||||
get_lang("step6_title"), get_lang("step6_prompt")
|
||||
)
|
||||
)
|
||||
selected_shallow_thinker = select_shallow_thinking_agent(selected_llm_provider)
|
||||
|
|
|
|||
24
cli/utils.py
24
cli/utils.py
|
|
@ -17,8 +17,8 @@ ANALYST_ORDER = [
|
|||
def get_ticker() -> str:
|
||||
"""Prompt the user to enter a ticker symbol."""
|
||||
ticker = questionary.text(
|
||||
lang["step1_prompt"],
|
||||
validate=lambda x: len(x.strip()) > 0 or lang["ticker_validate"],
|
||||
get_lang("step1_prompt"),
|
||||
validate=lambda x: len(x.strip()) > 0 or get_lang("ticker_validate"),
|
||||
style=questionary.Style(
|
||||
[
|
||||
("text", "fg:green"),
|
||||
|
|
@ -49,8 +49,8 @@ def get_analysis_date() -> str:
|
|||
return False
|
||||
|
||||
date = questionary.text(
|
||||
lang["step2_prompt"],
|
||||
validate=lambda x: validate_date(x.strip()) or lang["date_validate"],
|
||||
get_lang("step2_prompt"),
|
||||
validate=lambda x: validate_date(x.strip()) or get_lang("date_validate"),
|
||||
style=questionary.Style(
|
||||
[
|
||||
("text", "fg:green"),
|
||||
|
|
@ -69,12 +69,12 @@ def get_analysis_date() -> str:
|
|||
def select_analysts() -> List[AnalystType]:
|
||||
"""Select analysts using an interactive checkbox."""
|
||||
choices = questionary.checkbox(
|
||||
lang["step3_prompt"],
|
||||
get_lang("step3_prompt"),
|
||||
choices=[
|
||||
questionary.Choice(lang.get(display.replace(" ", "_").lower(), display), value=value) for display, value in ANALYST_ORDER
|
||||
],
|
||||
instruction=lang["analyst_instruction"],
|
||||
validate=lambda x: len(x) > 0 or lang["analyst_validate"],
|
||||
instruction=get_lang("analyst_instruction"),
|
||||
validate=lambda x: len(x) > 0 or get_lang("analyst_validate"),
|
||||
style=questionary.Style(
|
||||
[
|
||||
("checkbox-selected", "fg:green"),
|
||||
|
|
@ -97,17 +97,17 @@ def select_research_depth() -> int:
|
|||
|
||||
# Define research depth options with their corresponding values
|
||||
DEPTH_OPTIONS = [
|
||||
(lang["depth_shallow"], 1),
|
||||
(lang["depth_medium"], 3),
|
||||
(lang["depth_deep"], 5),
|
||||
(get_lang("depth_shallow"), 1),
|
||||
(get_lang("depth_medium"), 3),
|
||||
(get_lang("depth_deep"), 5),
|
||||
]
|
||||
|
||||
choice = questionary.select(
|
||||
lang["step4_prompt"],
|
||||
get_lang("step4_prompt"),
|
||||
choices=[
|
||||
questionary.Choice(display, value=value) for display, value in DEPTH_OPTIONS
|
||||
],
|
||||
instruction=lang["depth_instruction"],
|
||||
instruction=get_lang("depth_instruction"),
|
||||
style=questionary.Style(
|
||||
[
|
||||
("selected", "fg:yellow noinherit"),
|
||||
|
|
|
|||
|
|
@ -3,8 +3,6 @@ import time
|
|||
import json
|
||||
from tradingagents.i18n import get_prompts
|
||||
|
||||
prompts = get_prompts()
|
||||
|
||||
def create_fundamentals_analyst(llm, toolkit):
|
||||
def fundamentals_analyst_node(state):
|
||||
current_date = state["trade_date"]
|
||||
|
|
@ -23,14 +21,14 @@ def create_fundamentals_analyst(llm, toolkit):
|
|||
]
|
||||
|
||||
system_message = (
|
||||
prompts["analysts"]["fundamentals_analyst"]["system_message"]
|
||||
get_prompts("analysts", "fundamentals_analyst", "system_message")
|
||||
)
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
(
|
||||
"system",
|
||||
prompts["analysts"]["template"]
|
||||
get_prompts("analysts", "template")
|
||||
),
|
||||
MessagesPlaceholder(variable_name="messages"),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -3,8 +3,6 @@ import time
|
|||
import json
|
||||
from tradingagents.i18n import get_prompts
|
||||
|
||||
prompts = get_prompts()
|
||||
|
||||
def create_market_analyst(llm, toolkit):
|
||||
|
||||
def market_analyst_node(state):
|
||||
|
|
@ -24,14 +22,14 @@ def create_market_analyst(llm, toolkit):
|
|||
]
|
||||
|
||||
system_message = (
|
||||
prompts["analysts"]["market_analyst"]["system_message"]
|
||||
get_prompts("analysts", "market_analyst", "system_message")
|
||||
)
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
(
|
||||
"system",
|
||||
prompts["analysts"]["template"]
|
||||
get_prompts("analysts", "template")
|
||||
),
|
||||
MessagesPlaceholder(variable_name="messages"),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -3,8 +3,6 @@ import time
|
|||
import json
|
||||
from tradingagents.i18n import get_prompts
|
||||
|
||||
prompts = get_prompts()
|
||||
|
||||
def create_news_analyst(llm, toolkit):
|
||||
def news_analyst_node(state):
|
||||
current_date = state["trade_date"]
|
||||
|
|
@ -20,14 +18,14 @@ def create_news_analyst(llm, toolkit):
|
|||
]
|
||||
|
||||
system_message = (
|
||||
prompts["analysts"]["news_analyst"]["system_message"]
|
||||
get_prompts("analysts", "news_analyst", "system_message")
|
||||
)
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
(
|
||||
"system",
|
||||
prompts["analysts"]["template"]
|
||||
get_prompts("analysts", "template")
|
||||
),
|
||||
MessagesPlaceholder(variable_name="messages"),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -3,8 +3,6 @@ import time
|
|||
import json
|
||||
from tradingagents.i18n import get_prompts
|
||||
|
||||
prompts = get_prompts()
|
||||
|
||||
def create_social_media_analyst(llm, toolkit):
|
||||
def social_media_analyst_node(state):
|
||||
current_date = state["trade_date"]
|
||||
|
|
@ -19,14 +17,14 @@ def create_social_media_analyst(llm, toolkit):
|
|||
]
|
||||
|
||||
system_message = (
|
||||
prompts["analysts"]["social_media_analyst"]["system_message"]
|
||||
get_prompts("analysts", "social_media_analyst", "system_message")
|
||||
)
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
(
|
||||
"system",
|
||||
prompts["analysts"]["template"]
|
||||
get_prompts("analysts", "template")
|
||||
),
|
||||
MessagesPlaceholder(variable_name="messages"),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -2,8 +2,6 @@ import time
|
|||
import json
|
||||
from tradingagents.i18n import get_prompts
|
||||
|
||||
prompts = get_prompts()
|
||||
|
||||
def create_research_manager(llm, memory):
|
||||
def research_manager_node(state) -> dict:
|
||||
history = state["investment_debate_state"].get("history", "")
|
||||
|
|
@ -21,7 +19,7 @@ def create_research_manager(llm, memory):
|
|||
for i, rec in enumerate(past_memories, 1):
|
||||
past_memory_str += rec["recommendation"] + "\n\n"
|
||||
|
||||
prompt = prompts["managers"]["research_manager"] \
|
||||
prompt = get_prompts("managers", "research_manager") \
|
||||
.replace("{past_memory_str}", past_memory_str) \
|
||||
.replace("{history}", history)
|
||||
response = llm.invoke(prompt)
|
||||
|
|
|
|||
|
|
@ -2,8 +2,6 @@ import time
|
|||
import json
|
||||
from tradingagents.i18n import get_prompts
|
||||
|
||||
prompts = get_prompts()
|
||||
|
||||
def create_risk_manager(llm, memory):
|
||||
def risk_manager_node(state) -> dict:
|
||||
|
||||
|
|
@ -24,7 +22,7 @@ def create_risk_manager(llm, memory):
|
|||
for i, rec in enumerate(past_memories, 1):
|
||||
past_memory_str += rec["recommendation"] + "\n\n"
|
||||
|
||||
prompt = prompts["managers"]["risk_manager"] \
|
||||
prompt = get_prompts("managers", "risk_manager") \
|
||||
.replace("{trader_plan}", trader_plan) \
|
||||
.replace("{past_memory_str}", past_memory_str) \
|
||||
.replace("{history}", history)
|
||||
|
|
|
|||
|
|
@ -3,8 +3,6 @@ import time
|
|||
import json
|
||||
from tradingagents.i18n import get_prompts
|
||||
|
||||
prompts = get_prompts()
|
||||
|
||||
def create_bear_researcher(llm, memory):
|
||||
def bear_node(state) -> dict:
|
||||
investment_debate_state = state["investment_debate_state"]
|
||||
|
|
@ -24,7 +22,7 @@ def create_bear_researcher(llm, memory):
|
|||
for i, rec in enumerate(past_memories, 1):
|
||||
past_memory_str += rec["recommendation"] + "\n\n"
|
||||
|
||||
prompt = prompts["researchers"]["bear_researcher"] \
|
||||
prompt = get_prompts("researchers", "bear_researcher") \
|
||||
.replace("{market_research_report}", market_research_report) \
|
||||
.replace("{sentiment_report}", sentiment_report) \
|
||||
.replace("{news_report}", news_report) \
|
||||
|
|
|
|||
|
|
@ -3,8 +3,6 @@ import time
|
|||
import json
|
||||
from tradingagents.i18n import get_prompts
|
||||
|
||||
prompts = get_prompts()
|
||||
|
||||
def create_bull_researcher(llm, memory):
|
||||
def bull_node(state) -> dict:
|
||||
investment_debate_state = state["investment_debate_state"]
|
||||
|
|
@ -24,7 +22,7 @@ def create_bull_researcher(llm, memory):
|
|||
for i, rec in enumerate(past_memories, 1):
|
||||
past_memory_str += rec["recommendation"] + "\n\n"
|
||||
|
||||
prompt = prompts["researchers"]["bull_researcher"] \
|
||||
prompt = get_prompts("researchers", "bull_researcher") \
|
||||
.replace("{market_research_report}", market_research_report) \
|
||||
.replace("{sentiment_report}", sentiment_report) \
|
||||
.replace("{news_report}", news_report) \
|
||||
|
|
|
|||
|
|
@ -2,8 +2,6 @@ import time
|
|||
import json
|
||||
from tradingagents.i18n import get_prompts
|
||||
|
||||
prompts = get_prompts()
|
||||
|
||||
def create_risky_debator(llm):
|
||||
def risky_node(state) -> dict:
|
||||
risk_debate_state = state["risk_debate_state"]
|
||||
|
|
@ -20,7 +18,7 @@ def create_risky_debator(llm):
|
|||
|
||||
trader_decision = state["trader_investment_plan"]
|
||||
|
||||
prompt = prompts["risk_mgmt"]["aggressive_debator"] \
|
||||
prompt = get_prompts("risk_mgmt", "aggressive_debator") \
|
||||
.replace("{trader_decision}", trader_decision) \
|
||||
.replace("{market_research_report}", market_research_report) \
|
||||
.replace("{sentiment_report}", sentiment_report) \
|
||||
|
|
|
|||
|
|
@ -3,8 +3,6 @@ import time
|
|||
import json
|
||||
from tradingagents.i18n import get_prompts
|
||||
|
||||
prompts = get_prompts()
|
||||
|
||||
def create_safe_debator(llm):
|
||||
def safe_node(state) -> dict:
|
||||
risk_debate_state = state["risk_debate_state"]
|
||||
|
|
@ -21,7 +19,7 @@ def create_safe_debator(llm):
|
|||
|
||||
trader_decision = state["trader_investment_plan"]
|
||||
|
||||
prompt = prompts["risk_mgmt"]["conservative_debator"] \
|
||||
prompt = get_prompts("risk_mgmt", "conservative_debator") \
|
||||
.replace("{trader_decision}", trader_decision) \
|
||||
.replace("{market_research_report}", market_research_report) \
|
||||
.replace("{sentiment_report}", sentiment_report) \
|
||||
|
|
|
|||
|
|
@ -2,8 +2,6 @@ import time
|
|||
import json
|
||||
from tradingagents.i18n import get_prompts
|
||||
|
||||
prompts = get_prompts()
|
||||
|
||||
def create_neutral_debator(llm):
|
||||
def neutral_node(state) -> dict:
|
||||
risk_debate_state = state["risk_debate_state"]
|
||||
|
|
@ -20,7 +18,7 @@ def create_neutral_debator(llm):
|
|||
|
||||
trader_decision = state["trader_investment_plan"]
|
||||
|
||||
prompt = prompts["risk_mgmt"]["neutral_debator"] \
|
||||
prompt = get_prompts("risk_mgmt", "neutral_debator") \
|
||||
.replace("{trader_decision}", trader_decision) \
|
||||
.replace("{market_research_report}", market_research_report) \
|
||||
.replace("{sentiment_report}", sentiment_report) \
|
||||
|
|
|
|||
|
|
@ -3,8 +3,6 @@ import time
|
|||
import json
|
||||
from tradingagents.i18n import get_prompts
|
||||
|
||||
prompts = get_prompts()
|
||||
|
||||
def create_trader(llm, memory):
|
||||
def trader_node(state, name):
|
||||
company_name = state["company_of_interest"]
|
||||
|
|
@ -23,7 +21,7 @@ def create_trader(llm, memory):
|
|||
|
||||
context = {
|
||||
"role": "user",
|
||||
"content": prompts["trader"]["user_message"] \
|
||||
"content": get_prompts("trader", "user_message") \
|
||||
.replace("{company_name}", company_name) \
|
||||
.replace("{investment_plan}", investment_plan)
|
||||
}
|
||||
|
|
@ -31,7 +29,7 @@ def create_trader(llm, memory):
|
|||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": prompts["trader"]["system_message"] \
|
||||
"content": get_prompts("trader", "system_message") \
|
||||
.replace("{past_memory_str}", past_memory_str),
|
||||
},
|
||||
context,
|
||||
|
|
|
|||
|
|
@ -4,8 +4,6 @@ from typing import Dict, Any
|
|||
from langchain_openai import ChatOpenAI
|
||||
from tradingagents.i18n import get_prompts
|
||||
|
||||
prompts = get_prompts()
|
||||
|
||||
class Reflector:
|
||||
"""Handles reflection on decisions and updating memory."""
|
||||
|
||||
|
|
@ -16,7 +14,7 @@ class Reflector:
|
|||
|
||||
def _get_reflection_prompt(self) -> str:
|
||||
"""Get the system prompt for reflection."""
|
||||
return prompts["reflection"]["system_message"]
|
||||
return get_prompts("reflection", "system_message")
|
||||
|
||||
def _extract_current_situation(self, current_state: Dict[str, Any]) -> str:
|
||||
"""Extract the current market situation from the state."""
|
||||
|
|
@ -35,7 +33,7 @@ class Reflector:
|
|||
("system", self.reflection_system_prompt),
|
||||
(
|
||||
"human",
|
||||
prompts["reflection"]["user_message"] \
|
||||
get_prompts("reflection", "user_message") \
|
||||
.replace("{returns_losses}", returns_losses) \
|
||||
.replace("{report}", report) \
|
||||
.replace("{situation}", situation)
|
||||
|
|
|
|||
|
|
@ -3,8 +3,6 @@
|
|||
from langchain_openai import ChatOpenAI
|
||||
from tradingagents.i18n import get_prompts
|
||||
|
||||
prompts = get_prompts()
|
||||
|
||||
class SignalProcessor:
|
||||
"""Processes trading signals to extract actionable decisions."""
|
||||
|
||||
|
|
@ -25,7 +23,7 @@ class SignalProcessor:
|
|||
messages = [
|
||||
(
|
||||
"system",
|
||||
prompts["signal_processor"]["system_message"],
|
||||
get_prompts("signal_processor", "system_message"),
|
||||
),
|
||||
("human", full_signal),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,22 +1,37 @@
|
|||
from functools import reduce
|
||||
import importlib
|
||||
from tradingagents.default_config import DEFAULT_CONFIG
|
||||
|
||||
def get_lang():
|
||||
def get_value(dictionary: dict, *keys, default=None):
|
||||
"""
|
||||
Get values from a dictionary using a list of keys.
|
||||
If a key is not found, it returns the default value.
|
||||
"""
|
||||
try:
|
||||
return reduce((lambda d, key: d[key]), keys, dictionary)
|
||||
except (KeyError, TypeError):
|
||||
return default
|
||||
|
||||
def get_lang(*keys, default="") -> str | dict:
|
||||
lang_code = DEFAULT_CONFIG.get("language", "zh")
|
||||
try:
|
||||
lang_module = importlib.import_module(f"tradingagents.i18n.{lang_code}")
|
||||
return lang_module.LANG
|
||||
if not keys:
|
||||
return lang_module.LANG
|
||||
return get_value(lang_module.LANG, *keys, default=default)
|
||||
except Exception:
|
||||
# fallback to zh
|
||||
from .interface.zh import LANG
|
||||
return LANG
|
||||
if not keys:
|
||||
return LANG
|
||||
return get_value(LANG, *keys, default=default)
|
||||
|
||||
def get_prompts():
|
||||
def get_prompts(*keys, default="") -> str:
|
||||
lang_code = DEFAULT_CONFIG.get("language", "zh")
|
||||
try:
|
||||
lang_module = importlib.import_module(f"tradingagents.i18n.{lang_code}")
|
||||
return lang_module.PROMPTS
|
||||
return get_value(lang_module.PROMPTS, *keys, default=default)
|
||||
except Exception:
|
||||
# fallback to zh
|
||||
from .prompts.zh import PROMPTS
|
||||
return PROMPTS
|
||||
return get_value(PROMPTS, *keys, default=default)
|
||||
Loading…
Reference in New Issue