diff --git a/cli/main.py b/cli/main.py index d41013eb..305f2643 100644 --- a/cli/main.py +++ b/cli/main.py @@ -32,7 +32,7 @@ lang = get_lang() app = typer.Typer( name="TradingAgents", - help=lang["welcome"] + ": " + lang["framework_subtitle"], + help=get_lang("welcome") + ": " + get_lang("framework_subtitle"), add_completion=True, # Enable shell completion ) @@ -46,22 +46,22 @@ class MessageBuffer: self.final_report = None # Store the complete final report self.agent_status = { # Analyst Team - lang["market_analyst"]: lang["pending"], - lang["social_analyst"]: lang["pending"], - lang["news_analyst"]: lang["pending"], - lang["fundamentals_analyst"]: lang["pending"], + get_lang("market_analyst"): get_lang("pending"), + get_lang("social_analyst"): get_lang("pending"), + get_lang("news_analyst"): get_lang("pending"), + get_lang("fundamentals_analyst"): get_lang("pending"), # Research Team - lang["bull_researcher"]: lang["pending"], - lang["bear_researcher"]: lang["pending"], - lang["research_manager"]: lang["pending"], + get_lang("bull_researcher"): get_lang("pending"), + get_lang("bear_researcher"): get_lang("pending"), + get_lang("research_manager"): get_lang("pending"), # Trading Team - lang["trader"]: lang["pending"], + get_lang("trader"): get_lang("pending"), # Risk Management Team - lang["risky_analyst"]: lang["pending"], - lang["neutral_analyst"]: lang["pending"], - lang["safe_analyst"]: lang["pending"], + get_lang("risky_analyst"): get_lang("pending"), + get_lang("neutral_analyst"): get_lang("pending"), + get_lang("safe_analyst"): get_lang("pending"), # Portfolio Management Team - lang["portfolio_manager"]: lang["pending"], + get_lang("portfolio_manager"): get_lang("pending"), } self.current_agent = None self.report_sections = { @@ -106,13 +106,13 @@ class MessageBuffer: if latest_section and latest_content: # Format the current section for display section_titles = { - "market_report": lang["market_report"], - "sentiment_report": lang["sentiment_report"], - "news_report": lang["news_report"], - "fundamentals_report": lang["fundamentals_report"], - "investment_plan": lang["investment_plan"], - "trader_investment_plan": lang["trader_investment_plan"], - "final_trade_decision": lang["final_trade_decision"], + "market_report": get_lang("market_report"), + "sentiment_report": get_lang("sentiment_report"), + "news_report": get_lang("news_report"), + "fundamentals_report": get_lang("fundamentals_report"), + "investment_plan": get_lang("investment_plan"), + "trader_investment_plan": get_lang("trader_investment_plan"), + "final_trade_decision": get_lang("final_trade_decision"), } self.current_report = ( f"### {section_titles[latest_section]}\n{latest_content}" @@ -134,37 +134,37 @@ class MessageBuffer: "fundamentals_report", ] ): - report_parts.append(f"## {lang['analyst_team']}") + report_parts.append(f"## {get_lang("analyst_team")}") if self.report_sections["market_report"]: report_parts.append( - f"### {lang['market_report']}\n{self.report_sections['market_report']}" + f"### {get_lang("market_report")}\n{self.report_sections['market_report']}" ) if self.report_sections["sentiment_report"]: report_parts.append( - f"### {lang['sentiment_report']}\n{self.report_sections['sentiment_report']}" + f"### {get_lang("sentiment_report")}\n{self.report_sections['sentiment_report']}" ) if self.report_sections["news_report"]: report_parts.append( - f"### {lang['news_report']}\n{self.report_sections['news_report']}" + f"### {get_lang("news_report")}\n{self.report_sections['news_report']}" ) if self.report_sections["fundamentals_report"]: report_parts.append( - f"### {lang['fundamentals_report']}\n{self.report_sections['fundamentals_report']}" + f"### {get_lang("fundamentals_report")}\n{self.report_sections['fundamentals_report']}" ) # Research Team Reports if self.report_sections["investment_plan"]: - report_parts.append(f"## {lang['research_team']}") + report_parts.append(f"## {get_lang("research_team")}") report_parts.append(f"{self.report_sections['investment_plan']}") # Trading Team Reports if self.report_sections["trader_investment_plan"]: - report_parts.append(f"## {lang['trading_team']}") + report_parts.append(f"## {get_lang("trading_team")}") report_parts.append(f"{self.report_sections['trader_investment_plan']}") # Portfolio Management Decision if self.report_sections["final_trade_decision"]: - report_parts.append(f"## {lang['portfolio_management']}") + report_parts.append(f"## {get_lang("portfolio_management")}") report_parts.append(f"{self.report_sections['final_trade_decision']}") self.final_report = "\n\n".join(report_parts) if report_parts else None @@ -193,9 +193,9 @@ def update_display(layout, spinner_text=None): # Header with welcome message layout["header"].update( Panel( - f"[bold green]{lang['welcome']}[/bold green]\n" + f"[bold green]{get_lang("welcome")}[/bold green]\n" "[dim]© [Tauric Research](https://github.com/TauricResearch)[/dim]", - title=lang["welcome"], + title=get_lang("welcome"), border_style="green", padding=(1, 2), expand=True, @@ -212,22 +212,22 @@ def update_display(layout, spinner_text=None): padding=(0, 2), # Add horizontal padding expand=True, # Make table expand to fill available space ) - progress_table.add_column(lang["team"], style="cyan", justify="center", width=20) - progress_table.add_column(lang["agent"], style="green", justify="center", width=20) - progress_table.add_column(lang["status"], style="yellow", justify="center", width=20) + progress_table.add_column(get_lang("team"), style="cyan", justify="center", width=20) + progress_table.add_column(get_lang("agent"), style="green", justify="center", width=20) + progress_table.add_column(get_lang("status"), style="yellow", justify="center", width=20) # Group agents by team teams = { - lang["analyst_team"]: [ - lang["market_analyst"], - lang["social_analyst"], - lang["news_analyst"], - lang["fundamentals_analyst"], + get_lang("analyst_team"): [ + get_lang("market_analyst"), + get_lang("social_analyst"), + get_lang("news_analyst"), + get_lang("fundamentals_analyst"), ], - lang["research_team"]: [lang["bull_researcher"], lang["bear_researcher"], lang["research_manager"]], - lang["trading_team"]: [lang["trader"]], - lang["risk_management"]: [lang["risky_analyst"], lang["neutral_analyst"], lang["safe_analyst"]], - lang["portfolio_management"]: [lang["portfolio_manager"]], + get_lang("research_team"): [get_lang("bull_researcher"), get_lang("bear_researcher"), get_lang("research_manager")], + get_lang("trading_team"): [get_lang("trader")], + get_lang("risk_management"): [get_lang("risky_analyst"), get_lang("neutral_analyst"), get_lang("safe_analyst")], + get_lang("portfolio_management"): [get_lang("portfolio_manager")], } for team, agents in teams.items(): @@ -236,14 +236,14 @@ def update_display(layout, spinner_text=None): status = message_buffer.agent_status[first_agent] if status == "in_progress": spinner = Spinner( - "dots", text=f"[blue]{lang['in_progress']}[/blue]", style="bold cyan" + "dots", text=f"[blue]{get_lang("in_progress")}[/blue]", style="bold cyan" ) status_cell = spinner else: status_color = { - lang["pending"]: "yellow", - lang["completed"]: "green", - lang["error"]: "red", + get_lang("pending"): "yellow", + get_lang("completed"): "green", + get_lang("error"): "red", }.get(status, "white") status_cell = f"[{status_color}]{status}[/{status_color}]" progress_table.add_row(team, first_agent, status_cell) @@ -253,14 +253,14 @@ def update_display(layout, spinner_text=None): status = message_buffer.agent_status[agent] if status == "in_progress": spinner = Spinner( - "dots", text=f"[blue]{lang['in_progress']}[/blue]", style="bold cyan" + "dots", text=f"[blue]{get_lang("in_progress")}[/blue]", style="bold cyan" ) status_cell = spinner else: status_color = { - lang["pending"]: "yellow", - lang["completed"]: "green", - lang["error"]: "red", + get_lang("pending"): "yellow", + get_lang("completed"): "green", + get_lang("error"): "red", }.get(status, "white") status_cell = f"[{status_color}]{status}[/{status_color}]" progress_table.add_row("", agent, status_cell) @@ -269,7 +269,7 @@ def update_display(layout, spinner_text=None): progress_table.add_row("─" * 20, "─" * 20, "─" * 20, style="dim") layout["progress"].update( - Panel(progress_table, title=lang["progress"], border_style="cyan", padding=(1, 2)) + Panel(progress_table, title=get_lang("progress"), border_style="cyan", padding=(1, 2)) ) # Messages panel showing recent messages and tool calls @@ -284,7 +284,7 @@ def update_display(layout, spinner_text=None): ) messages_table.add_column("Time", style="cyan", width=8, justify="center") messages_table.add_column("Type", style="green", width=10, justify="center") - messages_table.add_column(lang["content"] if "content" in lang else "Content", style="white", no_wrap=False, ratio=1) + messages_table.add_column(get_lang("content") if "content" in lang else "Content", style="white", no_wrap=False, ratio=1) # Combine tool calls and messages all_messages = [] @@ -346,7 +346,7 @@ def update_display(layout, spinner_text=None): ) layout["messages"].update( - Panel(messages_table, title=lang["messages_tools"], border_style="blue") + Panel(messages_table, title=get_lang("messages_tools"), border_style="blue") ) # Analysis panel showing current report @@ -354,7 +354,7 @@ def update_display(layout, spinner_text=None): layout["analysis"].update( Panel( Markdown(message_buffer.current_report), - title=lang["current_report"], + title=get_lang("current_report"), border_style="green", padding=(1, 2), ) @@ -363,7 +363,7 @@ def update_display(layout, spinner_text=None): layout["analysis"].update( Panel( "[italic]Waiting for analysis report...[/italic]", - title=lang["current_report"], + title=get_lang("current_report"), border_style="green", padding=(1, 2), ) @@ -395,9 +395,9 @@ def get_user_selections(): # Create welcome box content welcome_content = f"{welcome_ascii}\n" - welcome_content += f"[bold green]TradingAgents: {lang['framework_subtitle']} - CLI[/bold green]\n\n" - welcome_content += f"[bold]{lang['workflow_steps_title']}[/bold]\n" - welcome_content += lang['workflow_steps'] + welcome_content += f"[bold green]TradingAgents: {get_lang("framework_subtitle")} - CLI[/bold green]\n\n" + welcome_content += f"[bold]{get_lang("workflow_steps_title")}[/bold]\n" + welcome_content += get_lang("workflow_steps") welcome_content += ( "[dim]Built by [Tauric Research](https://github.com/TauricResearch)[/dim]" ) @@ -407,8 +407,8 @@ def get_user_selections(): welcome_content, border_style="green", padding=(1, 2), - title= lang["welcome"], - subtitle=lang["framework_subtitle"], + title= get_lang("welcome"), + subtitle=get_lang("framework_subtitle"), ) console.print(Align.center(welcome_box)) console.print() # Add a blank line after the welcome box @@ -424,7 +424,7 @@ def get_user_selections(): # Step 1: Ticker symbol console.print( create_question_box( - lang["step1_title"], lang["step1_prompt"], lang["default_ticker"] + get_lang("step1_title"), get_lang("step1_prompt"), get_lang("default_ticker") ) ) selected_ticker = get_ticker() @@ -433,8 +433,8 @@ def get_user_selections(): default_date = datetime.datetime.now().strftime("%Y-%m-%d") console.print( create_question_box( - lang["step2_title"], - lang["step2_prompt"], + get_lang("step2_title"), + get_lang("step2_prompt"), default_date, ) ) @@ -443,7 +443,7 @@ def get_user_selections(): # Step 3: Select analysts console.print( create_question_box( - lang["step3_title"], lang["step3_prompt"] + get_lang("step3_title"), get_lang("step3_prompt") ) ) selected_analysts = select_analysts() @@ -454,7 +454,7 @@ def get_user_selections(): # Step 4: Research depth console.print( create_question_box( - lang["step4_title"], lang["step4_prompt"] + get_lang("step4_title"), get_lang("step4_prompt") ) ) selected_research_depth = select_research_depth() @@ -462,7 +462,7 @@ def get_user_selections(): # Step 5: OpenAI backend console.print( create_question_box( - lang["step5_title"], lang["step5_prompt"] + get_lang("step5_title"), get_lang("step5_prompt") ) ) selected_llm_provider, backend_url = select_llm_provider() @@ -470,7 +470,7 @@ def get_user_selections(): # Step 6: Thinking agents console.print( create_question_box( - lang["step6_title"], lang["step6_prompt"] + get_lang("step6_title"), get_lang("step6_prompt") ) ) selected_shallow_thinker = select_shallow_thinking_agent(selected_llm_provider) diff --git a/cli/utils.py b/cli/utils.py index 0e786ad4..ebc4143e 100644 --- a/cli/utils.py +++ b/cli/utils.py @@ -17,8 +17,8 @@ ANALYST_ORDER = [ def get_ticker() -> str: """Prompt the user to enter a ticker symbol.""" ticker = questionary.text( - lang["step1_prompt"], - validate=lambda x: len(x.strip()) > 0 or lang["ticker_validate"], + get_lang("step1_prompt"), + validate=lambda x: len(x.strip()) > 0 or get_lang("ticker_validate"), style=questionary.Style( [ ("text", "fg:green"), @@ -49,8 +49,8 @@ def get_analysis_date() -> str: return False date = questionary.text( - lang["step2_prompt"], - validate=lambda x: validate_date(x.strip()) or lang["date_validate"], + get_lang("step2_prompt"), + validate=lambda x: validate_date(x.strip()) or get_lang("date_validate"), style=questionary.Style( [ ("text", "fg:green"), @@ -69,12 +69,12 @@ def get_analysis_date() -> str: def select_analysts() -> List[AnalystType]: """Select analysts using an interactive checkbox.""" choices = questionary.checkbox( - lang["step3_prompt"], + get_lang("step3_prompt"), choices=[ questionary.Choice(lang.get(display.replace(" ", "_").lower(), display), value=value) for display, value in ANALYST_ORDER ], - instruction=lang["analyst_instruction"], - validate=lambda x: len(x) > 0 or lang["analyst_validate"], + instruction=get_lang("analyst_instruction"), + validate=lambda x: len(x) > 0 or get_lang("analyst_validate"), style=questionary.Style( [ ("checkbox-selected", "fg:green"), @@ -97,17 +97,17 @@ def select_research_depth() -> int: # Define research depth options with their corresponding values DEPTH_OPTIONS = [ - (lang["depth_shallow"], 1), - (lang["depth_medium"], 3), - (lang["depth_deep"], 5), + (get_lang("depth_shallow"), 1), + (get_lang("depth_medium"), 3), + (get_lang("depth_deep"), 5), ] choice = questionary.select( - lang["step4_prompt"], + get_lang("step4_prompt"), choices=[ questionary.Choice(display, value=value) for display, value in DEPTH_OPTIONS ], - instruction=lang["depth_instruction"], + instruction=get_lang("depth_instruction"), style=questionary.Style( [ ("selected", "fg:yellow noinherit"), diff --git a/tradingagents/agents/analysts/fundamentals_analyst.py b/tradingagents/agents/analysts/fundamentals_analyst.py index dc81c304..e3ed51c2 100644 --- a/tradingagents/agents/analysts/fundamentals_analyst.py +++ b/tradingagents/agents/analysts/fundamentals_analyst.py @@ -3,8 +3,6 @@ import time import json from tradingagents.i18n import get_prompts -prompts = get_prompts() - def create_fundamentals_analyst(llm, toolkit): def fundamentals_analyst_node(state): current_date = state["trade_date"] @@ -23,14 +21,14 @@ def create_fundamentals_analyst(llm, toolkit): ] system_message = ( - prompts["analysts"]["fundamentals_analyst"]["system_message"] + get_prompts("analysts", "fundamentals_analyst", "system_message") ) prompt = ChatPromptTemplate.from_messages( [ ( "system", - prompts["analysts"]["template"] + get_prompts("analysts", "template") ), MessagesPlaceholder(variable_name="messages"), ] diff --git a/tradingagents/agents/analysts/market_analyst.py b/tradingagents/agents/analysts/market_analyst.py index 73eb82cc..9c240ea4 100644 --- a/tradingagents/agents/analysts/market_analyst.py +++ b/tradingagents/agents/analysts/market_analyst.py @@ -3,8 +3,6 @@ import time import json from tradingagents.i18n import get_prompts -prompts = get_prompts() - def create_market_analyst(llm, toolkit): def market_analyst_node(state): @@ -24,14 +22,14 @@ def create_market_analyst(llm, toolkit): ] system_message = ( - prompts["analysts"]["market_analyst"]["system_message"] + get_prompts("analysts", "market_analyst", "system_message") ) prompt = ChatPromptTemplate.from_messages( [ ( "system", - prompts["analysts"]["template"] + get_prompts("analysts", "template") ), MessagesPlaceholder(variable_name="messages"), ] diff --git a/tradingagents/agents/analysts/news_analyst.py b/tradingagents/agents/analysts/news_analyst.py index af3574be..94e71bba 100644 --- a/tradingagents/agents/analysts/news_analyst.py +++ b/tradingagents/agents/analysts/news_analyst.py @@ -3,8 +3,6 @@ import time import json from tradingagents.i18n import get_prompts -prompts = get_prompts() - def create_news_analyst(llm, toolkit): def news_analyst_node(state): current_date = state["trade_date"] @@ -20,14 +18,14 @@ def create_news_analyst(llm, toolkit): ] system_message = ( - prompts["analysts"]["news_analyst"]["system_message"] + get_prompts("analysts", "news_analyst", "system_message") ) prompt = ChatPromptTemplate.from_messages( [ ( "system", - prompts["analysts"]["template"] + get_prompts("analysts", "template") ), MessagesPlaceholder(variable_name="messages"), ] diff --git a/tradingagents/agents/analysts/social_media_analyst.py b/tradingagents/agents/analysts/social_media_analyst.py index 0419082d..b0d72c7e 100644 --- a/tradingagents/agents/analysts/social_media_analyst.py +++ b/tradingagents/agents/analysts/social_media_analyst.py @@ -3,8 +3,6 @@ import time import json from tradingagents.i18n import get_prompts -prompts = get_prompts() - def create_social_media_analyst(llm, toolkit): def social_media_analyst_node(state): current_date = state["trade_date"] @@ -19,14 +17,14 @@ def create_social_media_analyst(llm, toolkit): ] system_message = ( - prompts["analysts"]["social_media_analyst"]["system_message"] + get_prompts("analysts", "social_media_analyst", "system_message") ) prompt = ChatPromptTemplate.from_messages( [ ( "system", - prompts["analysts"]["template"] + get_prompts("analysts", "template") ), MessagesPlaceholder(variable_name="messages"), ] diff --git a/tradingagents/agents/managers/research_manager.py b/tradingagents/agents/managers/research_manager.py index 3ff27747..7f2bd409 100644 --- a/tradingagents/agents/managers/research_manager.py +++ b/tradingagents/agents/managers/research_manager.py @@ -2,8 +2,6 @@ import time import json from tradingagents.i18n import get_prompts -prompts = get_prompts() - def create_research_manager(llm, memory): def research_manager_node(state) -> dict: history = state["investment_debate_state"].get("history", "") @@ -21,7 +19,7 @@ def create_research_manager(llm, memory): for i, rec in enumerate(past_memories, 1): past_memory_str += rec["recommendation"] + "\n\n" - prompt = prompts["managers"]["research_manager"] \ + prompt = get_prompts("managers", "research_manager") \ .replace("{past_memory_str}", past_memory_str) \ .replace("{history}", history) response = llm.invoke(prompt) diff --git a/tradingagents/agents/managers/risk_manager.py b/tradingagents/agents/managers/risk_manager.py index a4a658e0..b3340eec 100644 --- a/tradingagents/agents/managers/risk_manager.py +++ b/tradingagents/agents/managers/risk_manager.py @@ -2,8 +2,6 @@ import time import json from tradingagents.i18n import get_prompts -prompts = get_prompts() - def create_risk_manager(llm, memory): def risk_manager_node(state) -> dict: @@ -24,7 +22,7 @@ def create_risk_manager(llm, memory): for i, rec in enumerate(past_memories, 1): past_memory_str += rec["recommendation"] + "\n\n" - prompt = prompts["managers"]["risk_manager"] \ + prompt = get_prompts("managers", "risk_manager") \ .replace("{trader_plan}", trader_plan) \ .replace("{past_memory_str}", past_memory_str) \ .replace("{history}", history) diff --git a/tradingagents/agents/researchers/bear_researcher.py b/tradingagents/agents/researchers/bear_researcher.py index cc8b27db..2967a9b7 100644 --- a/tradingagents/agents/researchers/bear_researcher.py +++ b/tradingagents/agents/researchers/bear_researcher.py @@ -3,8 +3,6 @@ import time import json from tradingagents.i18n import get_prompts -prompts = get_prompts() - def create_bear_researcher(llm, memory): def bear_node(state) -> dict: investment_debate_state = state["investment_debate_state"] @@ -24,7 +22,7 @@ def create_bear_researcher(llm, memory): for i, rec in enumerate(past_memories, 1): past_memory_str += rec["recommendation"] + "\n\n" - prompt = prompts["researchers"]["bear_researcher"] \ + prompt = get_prompts("researchers", "bear_researcher") \ .replace("{market_research_report}", market_research_report) \ .replace("{sentiment_report}", sentiment_report) \ .replace("{news_report}", news_report) \ diff --git a/tradingagents/agents/researchers/bull_researcher.py b/tradingagents/agents/researchers/bull_researcher.py index 37cbc7b4..e05b09b1 100644 --- a/tradingagents/agents/researchers/bull_researcher.py +++ b/tradingagents/agents/researchers/bull_researcher.py @@ -3,8 +3,6 @@ import time import json from tradingagents.i18n import get_prompts -prompts = get_prompts() - def create_bull_researcher(llm, memory): def bull_node(state) -> dict: investment_debate_state = state["investment_debate_state"] @@ -24,7 +22,7 @@ def create_bull_researcher(llm, memory): for i, rec in enumerate(past_memories, 1): past_memory_str += rec["recommendation"] + "\n\n" - prompt = prompts["researchers"]["bull_researcher"] \ + prompt = get_prompts("researchers", "bull_researcher") \ .replace("{market_research_report}", market_research_report) \ .replace("{sentiment_report}", sentiment_report) \ .replace("{news_report}", news_report) \ diff --git a/tradingagents/agents/risk_mgmt/aggresive_debator.py b/tradingagents/agents/risk_mgmt/aggresive_debator.py index 7a9ea874..9457fb0d 100644 --- a/tradingagents/agents/risk_mgmt/aggresive_debator.py +++ b/tradingagents/agents/risk_mgmt/aggresive_debator.py @@ -2,8 +2,6 @@ import time import json from tradingagents.i18n import get_prompts -prompts = get_prompts() - def create_risky_debator(llm): def risky_node(state) -> dict: risk_debate_state = state["risk_debate_state"] @@ -20,7 +18,7 @@ def create_risky_debator(llm): trader_decision = state["trader_investment_plan"] - prompt = prompts["risk_mgmt"]["aggressive_debator"] \ + prompt = get_prompts("risk_mgmt", "aggressive_debator") \ .replace("{trader_decision}", trader_decision) \ .replace("{market_research_report}", market_research_report) \ .replace("{sentiment_report}", sentiment_report) \ diff --git a/tradingagents/agents/risk_mgmt/conservative_debator.py b/tradingagents/agents/risk_mgmt/conservative_debator.py index e16b434e..e3bfd7ce 100644 --- a/tradingagents/agents/risk_mgmt/conservative_debator.py +++ b/tradingagents/agents/risk_mgmt/conservative_debator.py @@ -3,8 +3,6 @@ import time import json from tradingagents.i18n import get_prompts -prompts = get_prompts() - def create_safe_debator(llm): def safe_node(state) -> dict: risk_debate_state = state["risk_debate_state"] @@ -21,7 +19,7 @@ def create_safe_debator(llm): trader_decision = state["trader_investment_plan"] - prompt = prompts["risk_mgmt"]["conservative_debator"] \ + prompt = get_prompts("risk_mgmt", "conservative_debator") \ .replace("{trader_decision}", trader_decision) \ .replace("{market_research_report}", market_research_report) \ .replace("{sentiment_report}", sentiment_report) \ diff --git a/tradingagents/agents/risk_mgmt/neutral_debator.py b/tradingagents/agents/risk_mgmt/neutral_debator.py index 2cea3175..96dc188d 100644 --- a/tradingagents/agents/risk_mgmt/neutral_debator.py +++ b/tradingagents/agents/risk_mgmt/neutral_debator.py @@ -2,8 +2,6 @@ import time import json from tradingagents.i18n import get_prompts -prompts = get_prompts() - def create_neutral_debator(llm): def neutral_node(state) -> dict: risk_debate_state = state["risk_debate_state"] @@ -20,7 +18,7 @@ def create_neutral_debator(llm): trader_decision = state["trader_investment_plan"] - prompt = prompts["risk_mgmt"]["neutral_debator"] \ + prompt = get_prompts("risk_mgmt", "neutral_debator") \ .replace("{trader_decision}", trader_decision) \ .replace("{market_research_report}", market_research_report) \ .replace("{sentiment_report}", sentiment_report) \ diff --git a/tradingagents/agents/trader/trader.py b/tradingagents/agents/trader/trader.py index f24dd26f..3e95054a 100644 --- a/tradingagents/agents/trader/trader.py +++ b/tradingagents/agents/trader/trader.py @@ -3,8 +3,6 @@ import time import json from tradingagents.i18n import get_prompts -prompts = get_prompts() - def create_trader(llm, memory): def trader_node(state, name): company_name = state["company_of_interest"] @@ -23,7 +21,7 @@ def create_trader(llm, memory): context = { "role": "user", - "content": prompts["trader"]["user_message"] \ + "content": get_prompts("trader", "user_message") \ .replace("{company_name}", company_name) \ .replace("{investment_plan}", investment_plan) } @@ -31,7 +29,7 @@ def create_trader(llm, memory): messages = [ { "role": "system", - "content": prompts["trader"]["system_message"] \ + "content": get_prompts("trader", "system_message") \ .replace("{past_memory_str}", past_memory_str), }, context, diff --git a/tradingagents/graph/reflection.py b/tradingagents/graph/reflection.py index 456157c4..59a06f25 100644 --- a/tradingagents/graph/reflection.py +++ b/tradingagents/graph/reflection.py @@ -4,8 +4,6 @@ from typing import Dict, Any from langchain_openai import ChatOpenAI from tradingagents.i18n import get_prompts -prompts = get_prompts() - class Reflector: """Handles reflection on decisions and updating memory.""" @@ -16,7 +14,7 @@ class Reflector: def _get_reflection_prompt(self) -> str: """Get the system prompt for reflection.""" - return prompts["reflection"]["system_message"] + return get_prompts("reflection", "system_message") def _extract_current_situation(self, current_state: Dict[str, Any]) -> str: """Extract the current market situation from the state.""" @@ -35,7 +33,7 @@ class Reflector: ("system", self.reflection_system_prompt), ( "human", - prompts["reflection"]["user_message"] \ + get_prompts("reflection", "user_message") \ .replace("{returns_losses}", returns_losses) \ .replace("{report}", report) \ .replace("{situation}", situation) diff --git a/tradingagents/graph/signal_processing.py b/tradingagents/graph/signal_processing.py index b87cd630..0f2f23e5 100644 --- a/tradingagents/graph/signal_processing.py +++ b/tradingagents/graph/signal_processing.py @@ -3,8 +3,6 @@ from langchain_openai import ChatOpenAI from tradingagents.i18n import get_prompts -prompts = get_prompts() - class SignalProcessor: """Processes trading signals to extract actionable decisions.""" @@ -25,7 +23,7 @@ class SignalProcessor: messages = [ ( "system", - prompts["signal_processor"]["system_message"], + get_prompts("signal_processor", "system_message"), ), ("human", full_signal), ] diff --git a/tradingagents/i18n/__init__.py b/tradingagents/i18n/__init__.py index d00ff886..b6f8aa6d 100644 --- a/tradingagents/i18n/__init__.py +++ b/tradingagents/i18n/__init__.py @@ -1,22 +1,37 @@ +from functools import reduce import importlib from tradingagents.default_config import DEFAULT_CONFIG -def get_lang(): +def get_value(dictionary: dict, *keys, default=None): + """ + Get values from a dictionary using a list of keys. + If a key is not found, it returns the default value. + """ + try: + return reduce((lambda d, key: d[key]), keys, dictionary) + except (KeyError, TypeError): + return default + +def get_lang(*keys, default="") -> str | dict: lang_code = DEFAULT_CONFIG.get("language", "zh") try: lang_module = importlib.import_module(f"tradingagents.i18n.{lang_code}") - return lang_module.LANG + if not keys: + return lang_module.LANG + return get_value(lang_module.LANG, *keys, default=default) except Exception: # fallback to zh from .interface.zh import LANG - return LANG + if not keys: + return LANG + return get_value(LANG, *keys, default=default) -def get_prompts(): +def get_prompts(*keys, default="") -> str: lang_code = DEFAULT_CONFIG.get("language", "zh") try: lang_module = importlib.import_module(f"tradingagents.i18n.{lang_code}") - return lang_module.PROMPTS + return get_value(lang_module.PROMPTS, *keys, default=default) except Exception: # fallback to zh from .prompts.zh import PROMPTS - return PROMPTS \ No newline at end of file + return get_value(PROMPTS, *keys, default=default) \ No newline at end of file