Merge c510a8721a into 13b826a31d
This commit is contained in:
commit
42d8ce3f9f
45
cli/main.py
45
cli/main.py
|
|
@ -1,5 +1,6 @@
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
import datetime
|
import datetime
|
||||||
|
import os
|
||||||
import typer
|
import typer
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
|
@ -398,7 +399,7 @@ def update_display(layout, spinner_text=None):
|
||||||
def get_user_selections():
|
def get_user_selections():
|
||||||
"""Get all user selections before starting the analysis display."""
|
"""Get all user selections before starting the analysis display."""
|
||||||
# Display ASCII art welcome message
|
# Display ASCII art welcome message
|
||||||
with open("./cli/static/welcome.txt", "r") as f:
|
with open("./cli/static/welcome.txt", "r", encoding="utf-8") as f:
|
||||||
welcome_ascii = f.read()
|
welcome_ascii = f.read()
|
||||||
|
|
||||||
# Create welcome box content
|
# Create welcome box content
|
||||||
|
|
@ -748,11 +749,6 @@ def run_analysis():
|
||||||
config["backend_url"] = selections["backend_url"]
|
config["backend_url"] = selections["backend_url"]
|
||||||
config["llm_provider"] = selections["llm_provider"].lower()
|
config["llm_provider"] = selections["llm_provider"].lower()
|
||||||
|
|
||||||
# Initialize the graph
|
|
||||||
graph = TradingAgentsGraph(
|
|
||||||
[analyst.value for analyst in selections["analysts"]], config=config, debug=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create result directory
|
# Create result directory
|
||||||
results_dir = Path(config["results_dir"]) / selections["ticker"] / selections["analysis_date"]
|
results_dir = Path(config["results_dir"]) / selections["ticker"] / selections["analysis_date"]
|
||||||
results_dir.mkdir(parents=True, exist_ok=True)
|
results_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
@ -761,6 +757,18 @@ def run_analysis():
|
||||||
log_file = results_dir / "message_tool.log"
|
log_file = results_dir / "message_tool.log"
|
||||||
log_file.touch(exist_ok=True)
|
log_file.touch(exist_ok=True)
|
||||||
|
|
||||||
|
# ACE skillbook
|
||||||
|
ace_skillbook_path = os.path.join(config.get("results_dir", "./results"), "ace_skillbook.json")
|
||||||
|
|
||||||
|
# Initialize the graph with ACE enabled
|
||||||
|
graph = TradingAgentsGraph(
|
||||||
|
[analyst.value for analyst in selections["analysts"]],
|
||||||
|
config=config,
|
||||||
|
debug=True,
|
||||||
|
ace_enabled=True,
|
||||||
|
ace_skillbook_path=ace_skillbook_path,
|
||||||
|
)
|
||||||
|
|
||||||
def save_message_decorator(obj, func_name):
|
def save_message_decorator(obj, func_name):
|
||||||
func = getattr(obj, func_name)
|
func = getattr(obj, func_name)
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
|
|
@ -768,7 +776,7 @@ def run_analysis():
|
||||||
func(*args, **kwargs)
|
func(*args, **kwargs)
|
||||||
timestamp, message_type, content = obj.messages[-1]
|
timestamp, message_type, content = obj.messages[-1]
|
||||||
content = content.replace("\n", " ") # Replace newlines with spaces
|
content = content.replace("\n", " ") # Replace newlines with spaces
|
||||||
with open(log_file, "a") as f:
|
with open(log_file, "a", encoding="utf-8") as f:
|
||||||
f.write(f"{timestamp} [{message_type}] {content}\n")
|
f.write(f"{timestamp} [{message_type}] {content}\n")
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
@ -779,7 +787,7 @@ def run_analysis():
|
||||||
func(*args, **kwargs)
|
func(*args, **kwargs)
|
||||||
timestamp, tool_name, args = obj.tool_calls[-1]
|
timestamp, tool_name, args = obj.tool_calls[-1]
|
||||||
args_str = ", ".join(f"{k}={v}" for k, v in args.items())
|
args_str = ", ".join(f"{k}={v}" for k, v in args.items())
|
||||||
with open(log_file, "a") as f:
|
with open(log_file, "a", encoding="utf-8") as f:
|
||||||
f.write(f"{timestamp} [Tool Call] {tool_name}({args_str})\n")
|
f.write(f"{timestamp} [Tool Call] {tool_name}({args_str})\n")
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
@ -792,7 +800,7 @@ def run_analysis():
|
||||||
content = obj.report_sections[section_name]
|
content = obj.report_sections[section_name]
|
||||||
if content:
|
if content:
|
||||||
file_name = f"{section_name}.md"
|
file_name = f"{section_name}.md"
|
||||||
with open(report_dir / file_name, "w") as f:
|
with open(report_dir / file_name, "w", encoding="utf-8") as f:
|
||||||
f.write(content)
|
f.write(content)
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
@ -1079,6 +1087,17 @@ def run_analysis():
|
||||||
|
|
||||||
# Get final state and decision
|
# Get final state and decision
|
||||||
final_state = trace[-1]
|
final_state = trace[-1]
|
||||||
|
graph.curr_state = final_state # Ensure curr_state is set for ACE
|
||||||
|
|
||||||
|
# Trigger ACE learning from analysis (price-based)
|
||||||
|
if graph.ace_enabled:
|
||||||
|
message_buffer.add_message("ACE", "Learning from analysis...")
|
||||||
|
try:
|
||||||
|
graph._ace_learn_from_analysis()
|
||||||
|
message_buffer.add_message("ACE", "Learning cycle completed.")
|
||||||
|
except Exception as e:
|
||||||
|
message_buffer.add_message("ACE", f"Learning failed: {e}")
|
||||||
|
|
||||||
decision = graph.process_signal(final_state["final_trade_decision"])
|
decision = graph.process_signal(final_state["final_trade_decision"])
|
||||||
|
|
||||||
# Update all agent statuses to completed
|
# Update all agent statuses to completed
|
||||||
|
|
@ -1089,6 +1108,14 @@ def run_analysis():
|
||||||
"Analysis", f"Completed analysis for {selections['analysis_date']}"
|
"Analysis", f"Completed analysis for {selections['analysis_date']}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Save ACE skillbook
|
||||||
|
if graph.ace_engine:
|
||||||
|
try:
|
||||||
|
saved_path = graph.save_ace_skillbook()
|
||||||
|
message_buffer.add_message("ACE", f"Skillbook saved to {saved_path}")
|
||||||
|
except Exception as e:
|
||||||
|
message_buffer.add_message("ACE", f"Failed to save skillbook: {e}")
|
||||||
|
|
||||||
# Update final report sections
|
# Update final report sections
|
||||||
for section in message_buffer.report_sections.keys():
|
for section in message_buffer.report_sections.keys():
|
||||||
if section in final_state:
|
if section in final_state:
|
||||||
|
|
|
||||||
|
|
@ -24,3 +24,5 @@ rich
|
||||||
questionary
|
questionary
|
||||||
langchain_anthropic
|
langchain_anthropic
|
||||||
langchain-google-genai
|
langchain-google-genai
|
||||||
|
python-dotenv
|
||||||
|
ace-framework
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,36 @@
|
||||||
|
"""
|
||||||
|
Agentic Context Engineering (ACE) implementation for TradingAgents.
|
||||||
|
|
||||||
|
Uses the official Kayba ACE framework (pip install ace-framework).
|
||||||
|
Based on the ACE paper (arXiv:2510.04618) - enables agents to improve through
|
||||||
|
in-context learning instead of fine-tuning.
|
||||||
|
|
||||||
|
Core pattern:
|
||||||
|
1. INJECT: Add learned strategies to agent prompts
|
||||||
|
2. EXECUTE: Agent performs task using accumulated knowledge
|
||||||
|
3. LEARN: Reflector analyzes results, SkillManager updates skillbook
|
||||||
|
"""
|
||||||
|
|
||||||
|
from ace import (
|
||||||
|
ACELiteLLM,
|
||||||
|
Skillbook,
|
||||||
|
Skill,
|
||||||
|
Reflector,
|
||||||
|
SkillManager,
|
||||||
|
UpdateOperation,
|
||||||
|
UpdateBatch,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .kayba_ace import TradingACE, create_trading_ace
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ACELiteLLM",
|
||||||
|
"Skillbook",
|
||||||
|
"Skill",
|
||||||
|
"Reflector",
|
||||||
|
"SkillManager",
|
||||||
|
"UpdateOperation",
|
||||||
|
"UpdateBatch",
|
||||||
|
"TradingACE",
|
||||||
|
"create_trading_ace",
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,223 @@
|
||||||
|
"""
|
||||||
|
Kayba ACE Integration for TradingAgents.
|
||||||
|
|
||||||
|
Uses the official ace-framework from Kayba (pip install ace-framework)
|
||||||
|
for self-improving trading agents.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from ace import (
|
||||||
|
OnlineACE,
|
||||||
|
Agent,
|
||||||
|
Reflector,
|
||||||
|
SkillManager,
|
||||||
|
LiteLLMClient,
|
||||||
|
Sample,
|
||||||
|
TaskEnvironment,
|
||||||
|
EnvironmentResult,
|
||||||
|
Skillbook,
|
||||||
|
Skill,
|
||||||
|
AgentOutput
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TradingEnvironment(TaskEnvironment):
|
||||||
|
"""
|
||||||
|
Environment for evaluating the quality and consistency of trading analysis.
|
||||||
|
|
||||||
|
Instead of just looking at price, it evaluates the logical flow between
|
||||||
|
market data, sentiment, news, and the final decision.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def evaluate(self, sample: Sample, agent_output) -> EnvironmentResult:
|
||||||
|
"""
|
||||||
|
Evaluate the analytical rigor of the trading decision.
|
||||||
|
"""
|
||||||
|
# We provide a high-level goal for the reflector
|
||||||
|
feedback = (
|
||||||
|
"Evaluate the logical consistency of this analysis. "
|
||||||
|
"Check if the final decision is truly supported by the market report, "
|
||||||
|
"sentiment analysis, and news. Identify any contradictions or missed "
|
||||||
|
"signals that could lead to a sub-optimal trade, regardless of the price outcome."
|
||||||
|
)
|
||||||
|
|
||||||
|
return EnvironmentResult(
|
||||||
|
feedback=feedback,
|
||||||
|
ground_truth="A logically sound, consistent, and well-supported investment thesis."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TradingACE:
|
||||||
|
"""
|
||||||
|
Self-improving trading agent using Kayba's ACE framework.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: str = "gpt-4o-mini",
|
||||||
|
skillbook_path: Optional[str] = None,
|
||||||
|
api_base: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize TradingACE.
|
||||||
|
"""
|
||||||
|
self.model = model
|
||||||
|
self.skillbook_path = skillbook_path
|
||||||
|
|
||||||
|
# Initialize LLM client
|
||||||
|
self.client = LiteLLMClient(model=model, api_base=api_base)
|
||||||
|
|
||||||
|
# Create ACE components
|
||||||
|
self.agent = Agent(self.client)
|
||||||
|
self.reflector = Reflector(self.client)
|
||||||
|
self.skill_manager = SkillManager(self.client)
|
||||||
|
|
||||||
|
# Load or create skillbook
|
||||||
|
if skillbook_path and Path(skillbook_path).exists():
|
||||||
|
self.skillbook = Skillbook.load_from_file(skillbook_path)
|
||||||
|
else:
|
||||||
|
self.skillbook = Skillbook()
|
||||||
|
|
||||||
|
# Create OnlineACE adapter
|
||||||
|
self.ace = OnlineACE(
|
||||||
|
skillbook=self.skillbook,
|
||||||
|
agent=self.agent,
|
||||||
|
reflector=self.reflector,
|
||||||
|
skill_manager=self.skill_manager
|
||||||
|
)
|
||||||
|
|
||||||
|
self.environment = TradingEnvironment()
|
||||||
|
|
||||||
|
def learn_from_analysis(self, reports: Dict[str, str], decision: str):
|
||||||
|
"""
|
||||||
|
Learn from a trading analysis by reflecting on the consistency of all reports.
|
||||||
|
"""
|
||||||
|
ticker = reports.get("ticker", "Unknown")
|
||||||
|
date = reports.get("date", "Unknown")
|
||||||
|
|
||||||
|
print(f"ACE: Learning from analytical consistency for {ticker} on {date}")
|
||||||
|
|
||||||
|
# Combine all reports into a single context for the reflector
|
||||||
|
full_context = "\n\n".join([
|
||||||
|
f"MARKET REPORT:\n{reports.get('market', 'N/A')}",
|
||||||
|
f"SENTIMENT REPORT:\n{reports.get('sentiment', 'N/A')}",
|
||||||
|
f"NEWS REPORT:\n{reports.get('news', 'N/A')}",
|
||||||
|
f"FUNDAMENTALS REPORT:\n{reports.get('fundamentals', 'N/A')}",
|
||||||
|
f"INVESTMENT PLAN:\n{reports.get('plan', 'N/A')}"
|
||||||
|
])
|
||||||
|
|
||||||
|
sample = Sample(
|
||||||
|
question=(
|
||||||
|
f"Analyze the following multi-agent trading reports for {ticker} and "
|
||||||
|
"determine if the final decision is logically consistent with all data points."
|
||||||
|
),
|
||||||
|
context=full_context,
|
||||||
|
ground_truth="A perfectly consistent and well-reasoned investment thesis."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a proper AgentOutput instance representing the whole analysis
|
||||||
|
actual_output = AgentOutput(
|
||||||
|
reasoning=f"Full analysis process for {ticker}.",
|
||||||
|
final_answer=decision,
|
||||||
|
raw={"decision": decision}
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 1. Evaluate the analytical rigor
|
||||||
|
eval_result = self.environment.evaluate(sample, actual_output)
|
||||||
|
print(f"ACE: Evaluation focus: {eval_result.feedback}")
|
||||||
|
|
||||||
|
# 2. Reflect on the consistency and quality
|
||||||
|
print("ACE: Reflecting on analytical quality...")
|
||||||
|
reflector_output = self.reflector.reflect(
|
||||||
|
question=sample.question,
|
||||||
|
agent_output=actual_output,
|
||||||
|
skillbook=self.skillbook,
|
||||||
|
ground_truth=eval_result.ground_truth,
|
||||||
|
feedback=eval_result.feedback
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"ACE: Reflection generated (reasoning length: {len(reflector_output.reasoning)})")
|
||||||
|
|
||||||
|
# 3. Update the skillbook with the new reflection
|
||||||
|
print("ACE: Updating skillbook...")
|
||||||
|
sm_output = self.skill_manager.update_skills(
|
||||||
|
skillbook=self.skillbook,
|
||||||
|
reflection=reflector_output,
|
||||||
|
question_context=sample.context,
|
||||||
|
progress=eval_result.feedback
|
||||||
|
)
|
||||||
|
|
||||||
|
if sm_output.update:
|
||||||
|
print("ACE: Applying update to skillbook...")
|
||||||
|
self.skillbook.apply_update(sm_output.update)
|
||||||
|
|
||||||
|
# Force save
|
||||||
|
self.save_skillbook()
|
||||||
|
print(f"ACE: Skillbook saved to {self.skillbook_path}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error in ACE learning: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
def learn_from_trade(self, context: str, decision: str, result: str, market_data: str):
|
||||||
|
"""
|
||||||
|
Compatibility method for TradingAgentsGraph.
|
||||||
|
"""
|
||||||
|
reports = {
|
||||||
|
"market": market_data,
|
||||||
|
"ticker": context.split(" on ")[0],
|
||||||
|
"date": context.split(" on ")[1] if " on " in context else "Unknown"
|
||||||
|
}
|
||||||
|
self.learn_from_analysis(reports, decision)
|
||||||
|
|
||||||
|
def get_skills_context(self) -> str:
|
||||||
|
"""Get formatted skills for injection into prompts."""
|
||||||
|
if not self.skillbook.skills():
|
||||||
|
return ""
|
||||||
|
return self.skillbook.as_prompt()
|
||||||
|
|
||||||
|
def save_skillbook(self, path: Optional[str] = None) -> str:
|
||||||
|
"""Save the skillbook to file."""
|
||||||
|
save_path = path or self.skillbook_path or "ace_skillbook.json"
|
||||||
|
self.skillbook.save_to_file(save_path)
|
||||||
|
return save_path
|
||||||
|
|
||||||
|
def get_stats(self) -> Dict[str, Any]:
|
||||||
|
"""Get ACE statistics."""
|
||||||
|
try:
|
||||||
|
skills = self.skillbook.skills()
|
||||||
|
count = len(skills)
|
||||||
|
except:
|
||||||
|
try:
|
||||||
|
count = self.skillbook.stats().get('skills', 0)
|
||||||
|
except:
|
||||||
|
count = 0
|
||||||
|
|
||||||
|
return {
|
||||||
|
"skills_count": count,
|
||||||
|
"model": self.model
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def create_trading_ace(
|
||||||
|
config: Dict[str, Any],
|
||||||
|
skillbook_path: Optional[str] = None,
|
||||||
|
) -> TradingACE:
|
||||||
|
"""
|
||||||
|
Factory function to create TradingACE.
|
||||||
|
"""
|
||||||
|
model = config.get("quick_think_llm", "gpt-4o-mini")
|
||||||
|
api_base = config.get("backend_url")
|
||||||
|
|
||||||
|
return TradingACE(
|
||||||
|
model=model,
|
||||||
|
skillbook_path=skillbook_path,
|
||||||
|
api_base=api_base
|
||||||
|
)
|
||||||
|
|
@ -3,7 +3,7 @@ import time
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
|
||||||
def create_trader(llm, memory):
|
def create_trader(llm, memory, ace_context=""):
|
||||||
def trader_node(state, name):
|
def trader_node(state, name):
|
||||||
company_name = state["company_of_interest"]
|
company_name = state["company_of_interest"]
|
||||||
investment_plan = state["investment_plan"]
|
investment_plan = state["investment_plan"]
|
||||||
|
|
@ -22,9 +22,14 @@ def create_trader(llm, memory):
|
||||||
else:
|
else:
|
||||||
past_memory_str = "No past memories found."
|
past_memory_str = "No past memories found."
|
||||||
|
|
||||||
|
# Add ACE context if available
|
||||||
|
ace_str = ""
|
||||||
|
if ace_context:
|
||||||
|
ace_str = f"\n\nLearned Trading Strategies (ACE):\n{ace_context}"
|
||||||
|
|
||||||
context = {
|
context = {
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": f"Based on a comprehensive analysis by a team of analysts, here is an investment plan tailored for {company_name}. This plan incorporates insights from current technical market trends, macroeconomic indicators, and social media sentiment. Use this plan as a foundation for evaluating your next trading decision.\n\nProposed Investment Plan: {investment_plan}\n\nLeverage these insights to make an informed and strategic decision.",
|
"content": f"Based on a comprehensive analysis by a team of analysts, here is an investment plan tailored for {company_name}. This plan incorporates insights from current technical market trends, macroeconomic indicators, and social media sentiment. Use this plan as a foundation for evaluating your next trading decision.\n\nProposed Investment Plan: {investment_plan}\n\nLeverage these insights to make an informed and strategic decision.{ace_str}",
|
||||||
}
|
}
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,9 @@ DEFAULT_CONFIG = {
|
||||||
"max_debate_rounds": 1,
|
"max_debate_rounds": 1,
|
||||||
"max_risk_discuss_rounds": 1,
|
"max_risk_discuss_rounds": 1,
|
||||||
"max_recur_limit": 100,
|
"max_recur_limit": 100,
|
||||||
|
# ACE (Agentic Context Engineering) settings
|
||||||
|
"ace_enabled": os.getenv("ACE_ENABLED", "true").lower() in {"true", "1", "yes", "on"},
|
||||||
|
"ace_skillbook_path": os.getenv("ACE_SKILLBOOK_PATH", "results/ace_skillbook.json"),
|
||||||
# Data vendor configuration
|
# Data vendor configuration
|
||||||
# Category-level configuration (default for all tools in category)
|
# Category-level configuration (default for all tools in category)
|
||||||
"data_vendors": {
|
"data_vendors": {
|
||||||
|
|
|
||||||
|
|
@ -25,6 +25,7 @@ class GraphSetup:
|
||||||
invest_judge_memory,
|
invest_judge_memory,
|
||||||
risk_manager_memory,
|
risk_manager_memory,
|
||||||
conditional_logic: ConditionalLogic,
|
conditional_logic: ConditionalLogic,
|
||||||
|
ace_context: str = "",
|
||||||
):
|
):
|
||||||
"""Initialize with required components."""
|
"""Initialize with required components."""
|
||||||
self.quick_thinking_llm = quick_thinking_llm
|
self.quick_thinking_llm = quick_thinking_llm
|
||||||
|
|
@ -36,6 +37,7 @@ class GraphSetup:
|
||||||
self.invest_judge_memory = invest_judge_memory
|
self.invest_judge_memory = invest_judge_memory
|
||||||
self.risk_manager_memory = risk_manager_memory
|
self.risk_manager_memory = risk_manager_memory
|
||||||
self.conditional_logic = conditional_logic
|
self.conditional_logic = conditional_logic
|
||||||
|
self.ace_context = ace_context
|
||||||
|
|
||||||
def setup_graph(
|
def setup_graph(
|
||||||
self, selected_analysts=["market", "social", "news", "fundamentals"]
|
self, selected_analysts=["market", "social", "news", "fundamentals"]
|
||||||
|
|
@ -95,7 +97,7 @@ class GraphSetup:
|
||||||
research_manager_node = create_research_manager(
|
research_manager_node = create_research_manager(
|
||||||
self.deep_thinking_llm, self.invest_judge_memory
|
self.deep_thinking_llm, self.invest_judge_memory
|
||||||
)
|
)
|
||||||
trader_node = create_trader(self.quick_thinking_llm, self.trader_memory)
|
trader_node = create_trader(self.quick_thinking_llm, self.trader_memory, self.ace_context)
|
||||||
|
|
||||||
# Create risk analysis nodes
|
# Create risk analysis nodes
|
||||||
risky_analyst = create_risky_debator(self.quick_thinking_llm)
|
risky_analyst = create_risky_debator(self.quick_thinking_llm)
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
# TradingAgents/graph/trading_graph.py
|
# TradingAgents/graph/trading_graph.py
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import json
|
import json
|
||||||
from datetime import date
|
from datetime import date
|
||||||
|
|
@ -21,6 +22,7 @@ from tradingagents.agents.utils.agent_states import (
|
||||||
RiskDebateState,
|
RiskDebateState,
|
||||||
)
|
)
|
||||||
from tradingagents.dataflows.config import set_config
|
from tradingagents.dataflows.config import set_config
|
||||||
|
from tradingagents.ace import TradingACE, create_trading_ace
|
||||||
|
|
||||||
# Import the new abstract tool methods from agent_utils
|
# Import the new abstract tool methods from agent_utils
|
||||||
from tradingagents.agents.utils.agent_utils import (
|
from tradingagents.agents.utils.agent_utils import (
|
||||||
|
|
@ -42,6 +44,8 @@ from .propagation import Propagator
|
||||||
from .reflection import Reflector
|
from .reflection import Reflector
|
||||||
from .signal_processing import SignalProcessor
|
from .signal_processing import SignalProcessor
|
||||||
|
|
||||||
|
from tradingagents.ace import TradingACE, create_trading_ace
|
||||||
|
|
||||||
|
|
||||||
class TradingAgentsGraph:
|
class TradingAgentsGraph:
|
||||||
"""Main class that orchestrates the trading agents framework."""
|
"""Main class that orchestrates the trading agents framework."""
|
||||||
|
|
@ -51,6 +55,8 @@ class TradingAgentsGraph:
|
||||||
selected_analysts=["market", "social", "news", "fundamentals"],
|
selected_analysts=["market", "social", "news", "fundamentals"],
|
||||||
debug=False,
|
debug=False,
|
||||||
config: Dict[str, Any] = None,
|
config: Dict[str, Any] = None,
|
||||||
|
ace_enabled: bool = True,
|
||||||
|
ace_skillbook_path: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""Initialize the trading agents graph and components.
|
"""Initialize the trading agents graph and components.
|
||||||
|
|
||||||
|
|
@ -58,6 +64,8 @@ class TradingAgentsGraph:
|
||||||
selected_analysts: List of analyst types to include
|
selected_analysts: List of analyst types to include
|
||||||
debug: Whether to run in debug mode
|
debug: Whether to run in debug mode
|
||||||
config: Configuration dictionary. If None, uses default config
|
config: Configuration dictionary. If None, uses default config
|
||||||
|
ace_enabled: Whether to enable ACE learning (default: True)
|
||||||
|
ace_skillbook_path: Path to load/save ACE skillbook (optional)
|
||||||
"""
|
"""
|
||||||
self.debug = debug
|
self.debug = debug
|
||||||
self.config = config or DEFAULT_CONFIG
|
self.config = config or DEFAULT_CONFIG
|
||||||
|
|
@ -91,6 +99,29 @@ class TradingAgentsGraph:
|
||||||
self.invest_judge_memory = FinancialSituationMemory("invest_judge_memory", self.config)
|
self.invest_judge_memory = FinancialSituationMemory("invest_judge_memory", self.config)
|
||||||
self.risk_manager_memory = FinancialSituationMemory("risk_manager_memory", self.config)
|
self.risk_manager_memory = FinancialSituationMemory("risk_manager_memory", self.config)
|
||||||
|
|
||||||
|
# Initialize ACE Engine
|
||||||
|
self.ace_enabled = ace_enabled
|
||||||
|
self.ace_skillbook_path = ace_skillbook_path or self.config.get(
|
||||||
|
"ace_skillbook_path",
|
||||||
|
os.path.join(self.config.get("results_dir", "./results"), "ace_skillbook.json")
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.ace_enabled:
|
||||||
|
self.ace_engine = create_trading_ace(
|
||||||
|
config=self.config,
|
||||||
|
skillbook_path=self.ace_skillbook_path,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.ace_engine = None
|
||||||
|
|
||||||
|
if self.ace_enabled:
|
||||||
|
self.ace_engine = create_trading_ace(
|
||||||
|
config=self.config,
|
||||||
|
skillbook_path=self.ace_skillbook_path,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.ace_engine = None
|
||||||
|
|
||||||
# Create tool nodes
|
# Create tool nodes
|
||||||
self.tool_nodes = self._create_tool_nodes()
|
self.tool_nodes = self._create_tool_nodes()
|
||||||
|
|
||||||
|
|
@ -106,6 +137,7 @@ class TradingAgentsGraph:
|
||||||
self.invest_judge_memory,
|
self.invest_judge_memory,
|
||||||
self.risk_manager_memory,
|
self.risk_manager_memory,
|
||||||
self.conditional_logic,
|
self.conditional_logic,
|
||||||
|
ace_context=self.get_ace_context(),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.propagator = Propagator()
|
self.propagator = Propagator()
|
||||||
|
|
@ -231,6 +263,7 @@ class TradingAgentsGraph:
|
||||||
with open(
|
with open(
|
||||||
f"eval_results/{self.ticker}/TradingAgentsStrategy_logs/full_states_log_{trade_date}.json",
|
f"eval_results/{self.ticker}/TradingAgentsStrategy_logs/full_states_log_{trade_date}.json",
|
||||||
"w",
|
"w",
|
||||||
|
encoding="utf-8"
|
||||||
) as f:
|
) as f:
|
||||||
json.dump(self.log_states_dict, f, indent=4)
|
json.dump(self.log_states_dict, f, indent=4)
|
||||||
|
|
||||||
|
|
@ -252,6 +285,125 @@ class TradingAgentsGraph:
|
||||||
self.curr_state, returns_losses, self.risk_manager_memory
|
self.curr_state, returns_losses, self.risk_manager_memory
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# ACE Learning: Learn from trading execution
|
||||||
|
if self.ace_enabled and self.ace_engine and self.curr_state:
|
||||||
|
self._ace_learn(returns_losses)
|
||||||
|
|
||||||
|
def _ace_learn_from_analysis(self):
|
||||||
|
"""
|
||||||
|
Trigger ACE learning based on the analytical consistency of all reports.
|
||||||
|
"""
|
||||||
|
if not self.ace_enabled or not self.ace_engine or not self.curr_state:
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"DEBUG: ACE analytical reflection triggered for {self.curr_state.get('company_of_interest')}")
|
||||||
|
|
||||||
|
reports = {
|
||||||
|
"ticker": self.curr_state.get("company_of_interest", "Unknown"),
|
||||||
|
"date": str(self.curr_state.get("trade_date", "Unknown")),
|
||||||
|
"market": self.curr_state.get("market_report", ""),
|
||||||
|
"sentiment": self.curr_state.get("sentiment_report", ""),
|
||||||
|
"news": self.curr_state.get("news_report", ""),
|
||||||
|
"fundamentals": self.curr_state.get("fundamentals_report", ""),
|
||||||
|
"plan": self.curr_state.get("investment_plan", ""),
|
||||||
|
}
|
||||||
|
|
||||||
|
decision = self.curr_state.get("final_trade_decision", "")
|
||||||
|
|
||||||
|
decision = re.sub(r"\[ACE_METADATA: .*\]", "", decision).strip()
|
||||||
|
|
||||||
|
self.ace_engine.learn_from_analysis(
|
||||||
|
reports=reports,
|
||||||
|
decision=decision
|
||||||
|
)
|
||||||
|
|
||||||
|
def _ace_learn(self, returns_losses):
|
||||||
|
"""Apply ACE learning from the current trading state using Kayba ACE."""
|
||||||
|
if not self.curr_state:
|
||||||
|
return
|
||||||
|
|
||||||
|
context = f"{self.curr_state['company_of_interest']} on {self.curr_state['trade_date']}"
|
||||||
|
|
||||||
|
market_data = "\n\n".join([
|
||||||
|
f"Market Report:\n{self.curr_state.get('market_report', '')}",
|
||||||
|
f"Sentiment Report:\n{self.curr_state.get('sentiment_report', '')}",
|
||||||
|
f"News Report:\n{self.curr_state.get('news_report', '')}",
|
||||||
|
f"Fundamentals Report:\n{self.curr_state.get('fundamentals_report', '')}",
|
||||||
|
])
|
||||||
|
|
||||||
|
decision = self.curr_state.get("final_trade_decision", "")
|
||||||
|
|
||||||
|
self.ace_engine.learn_from_trade(
|
||||||
|
context=context,
|
||||||
|
decision=decision,
|
||||||
|
result=str(returns_losses),
|
||||||
|
market_data=market_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_ace_context(self) -> str:
|
||||||
|
"""Get ACE strategies context for injection into agent prompts."""
|
||||||
|
if self.ace_enabled and self.ace_engine:
|
||||||
|
return self.ace_engine.get_skills_context()
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def save_ace_skillbook(self, path: Optional[str] = None) -> str:
|
||||||
|
"""Save the ACE skillbook to a file."""
|
||||||
|
if self.ace_engine:
|
||||||
|
return self.ace_engine.save_skillbook(path)
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def get_ace_stats(self) -> Dict[str, Any]:
|
||||||
|
"""Get ACE learning statistics."""
|
||||||
|
if self.ace_engine:
|
||||||
|
return self.ace_engine.get_stats()
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def _ace_learn_from_analysis(self):
|
||||||
|
"""
|
||||||
|
Trigger ACE learning based on the analytical consistency of all reports.
|
||||||
|
"""
|
||||||
|
if not self.ace_enabled or not self.ace_engine or not self.curr_state:
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"DEBUG: ACE analytical reflection triggered for {self.curr_state.get('company_of_interest')}")
|
||||||
|
|
||||||
|
reports = {
|
||||||
|
"ticker": self.curr_state.get("company_of_interest", "Unknown"),
|
||||||
|
"date": str(self.curr_state.get("trade_date", "Unknown")),
|
||||||
|
"market": self.curr_state.get("market_report", ""),
|
||||||
|
"sentiment": self.curr_state.get("sentiment_report", ""),
|
||||||
|
"news": self.curr_state.get("news_report", ""),
|
||||||
|
"fundamentals": self.curr_state.get("fundamentals_report", ""),
|
||||||
|
"plan": self.curr_state.get("investment_plan", ""),
|
||||||
|
}
|
||||||
|
|
||||||
|
decision = self.curr_state.get("final_trade_decision", "")
|
||||||
|
# Clean metadata tag if present
|
||||||
|
decision = re.sub(r"\[ACE_METADATA: .*\]", "", decision).strip()
|
||||||
|
|
||||||
|
self.ace_engine.learn_from_analysis(
|
||||||
|
reports=reports,
|
||||||
|
decision=decision
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_ace_context(self) -> str:
|
||||||
|
"""Get ACE strategies context for injection into agent prompts."""
|
||||||
|
if self.ace_enabled and self.ace_engine:
|
||||||
|
return self.ace_engine.get_skills_context()
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def save_ace_skillbook(self, path: Optional[str] = None) -> str:
|
||||||
|
"""Save the ACE skillbook to a file."""
|
||||||
|
if self.ace_engine:
|
||||||
|
return self.ace_engine.save_skillbook(path)
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def get_ace_stats(self) -> Dict[str, Any]:
|
||||||
|
"""Get ACE learning statistics."""
|
||||||
|
if self.ace_engine:
|
||||||
|
return self.ace_engine.get_stats()
|
||||||
|
return {}
|
||||||
|
|
||||||
def process_signal(self, full_signal):
|
def process_signal(self, full_signal):
|
||||||
"""Process a signal to extract the core decision."""
|
"""Process a signal to extract the core decision."""
|
||||||
return self.signal_processor.process_signal(full_signal)
|
return self.signal_processor.process_signal(full_signal)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue