diff --git a/.gitignore b/.gitignore index da020cac..af770502 100644 --- a/.gitignore +++ b/.gitignore @@ -8,4 +8,6 @@ eval_data/ *.egg-info/ .env -.env.local \ No newline at end of file +.env.local + +*/reports/* \ No newline at end of file diff --git a/cli/main.py b/cli/main.py index 305f2643..4559a699 100644 --- a/cli/main.py +++ b/cli/main.py @@ -1042,6 +1042,10 @@ def run_analysis(): # Display the complete final report display_complete_report(final_state) + if config["save_report"]: + reports = extract_reports_from_final_state(final_state) + save_reports(selections["ticker"], reports, config["report_dir"]) + update_display(layout) diff --git a/cli/utils.py b/cli/utils.py index ebc4143e..40147da3 100644 --- a/cli/utils.py +++ b/cli/utils.py @@ -288,3 +288,63 @@ def select_llm_provider() -> tuple[str, str]: print(f"You selected: {display_name}\tURL: {url}") return display_name, url + +def extract_reports_from_final_state(final_state): + analyst_reports = [] + if final_state.get("market_report"): + analyst_reports.append(("Market Analyst", final_state["market_report"])) + if final_state.get("sentiment_report"): + analyst_reports.append(("Sentiment Analyst", final_state["sentiment_report"])) + if final_state.get("news_report"): + analyst_reports.append(("News Analyst", final_state["news_report"])) + if final_state.get("fundamentals_report"): + analyst_reports.append(("Fundamentals Analyst", final_state["fundamentals_report"])) + if final_state.get("investment_debate_state"): + debate_state = final_state["investment_debate_state"] + if debate_state.get("bull_history"): + analyst_reports.append(("Investment Debate - Bull", debate_state["bull_history"])) + if debate_state.get("bear_history"): + analyst_reports.append(("Investment Debate - Bear", debate_state["bear_history"])) + if debate_state.get("judge_decision"): + analyst_reports.append(("Investment Debate - Judge Decision", debate_state["judge_decision"])) + if final_state.get("trader_investment_plan"): + analyst_reports.append(("Trader Investment Plan", final_state["trader_investment_plan"])) + if final_state.get("risk_debate_state"): + risk_state = final_state["risk_debate_state"] + if risk_state.get("risky_history"): + analyst_reports.append(("Risk Debate - Risky", risk_state["risky_history"])) + if risk_state.get("safe_history"): + analyst_reports.append(("Risk Debate - Safe", risk_state["safe_history"])) + if risk_state.get("neutral_history"): + analyst_reports.append(("Risk Debate - Neutral", risk_state["neutral_history"])) + if risk_state.get("judge_decision"): + analyst_reports.append(("Risk Debate - Judge Decision", risk_state["judge_decision"])) + return {report_name: report_content for report_name, report_content in analyst_reports if report_content} + +def save_reports(ticker: str, reports: Dict[str, str], output_dir: str, filename = "") -> None: + """ + Save the generated reports to the specified output directory. + Args: + ticker (str): The ticker symbol for which the reports are generated. + reports (Dict[str, str]): A dictionary where keys are report names and values are report content. + output_dir (str): The directory where the reports will be saved. + filename (str): Optional filename to save the reports as a single file. If empty, the filename will be formatted as `{ticker}_reports_{time}.md`. + """ + import os + from datetime import datetime + + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + if filename: + file_path = os.path.join(output_dir, filename) + else: + time_str = datetime.now().strftime("%Y%m%d_%H%M") + file_path = os.path.join(output_dir, f"{ticker}_reports_{time_str}.md") + + with open(file_path, "w", encoding="utf-8") as file: + file.write(f"# Reports for {ticker}\n\n") + file.write(f"Generated on: {datetime.now().strftime('%Y-%m-%d %H:%M')}\n\n") + for report_name, report_content in reports.items(): + file.write(f"## {report_name}\n\n") + file.write(report_content + "\n\n") diff --git a/tradingagents/default_config.py b/tradingagents/default_config.py index 7f013c10..25a51594 100644 --- a/tradingagents/default_config.py +++ b/tradingagents/default_config.py @@ -7,6 +7,11 @@ DEFAULT_CONFIG = { os.path.abspath(os.path.join(os.path.dirname(__file__), ".")), "dataflows/data_cache", ), + # Output settings + "save_report": True, + "report_dir": os.path.join( + os.path.abspath(os.path.join(os.path.dirname(__file__), ".")), "reports" + ), # LLM settings "llm_provider": "qwen", "deep_think_llm": "qwen-plus",