348 lines
17 KiB
Python
348 lines
17 KiB
Python
from fastapi import FastAPI, Request, Form, BackgroundTasks, HTTPException
|
|
from fastapi.responses import HTMLResponse
|
|
from fastapi.staticfiles import StaticFiles
|
|
import jinja2
|
|
import os
|
|
from typing import Dict, Any
|
|
import threading
|
|
import time
|
|
from dotenv import load_dotenv
|
|
|
|
# Load environment variables from .env file
|
|
load_dotenv()
|
|
|
|
# Check required environment variables
|
|
required_env_vars = [
|
|
'FINNHUB_API_KEY',
|
|
'OPENAI_API_KEY',
|
|
#'REDDIT_CLIENT_ID',
|
|
#'REDDIT_CLIENT_SECRET',
|
|
#'REDDIT_USER_AGENT'
|
|
]
|
|
|
|
missing_vars = [var for var in required_env_vars if not os.getenv(var)]
|
|
if missing_vars:
|
|
print(f"Error: Missing required environment variables: {', '.join(missing_vars)}")
|
|
print("Please create a .env file with these variables or set them in your environment.")
|
|
|
|
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
|
|
|
app = FastAPI()
|
|
|
|
# In-memory storage for the process state
|
|
# Using a lock for thread-safe access to app_state
|
|
app_state_lock = threading.Lock()
|
|
app_state: Dict[str, Any] = {
|
|
"process_running": False,
|
|
"company_symbol": None,
|
|
"execution_tree": [],
|
|
"overall_status": "idle", # idle, in_progress, completed, error
|
|
"overall_progress": 0 # 0-100
|
|
}
|
|
|
|
# Mount the static directory to serve CSS, JS, etc.
|
|
app.mount("/static", StaticFiles(directory="webapp/static"), name="static")
|
|
|
|
# Setup Jinja2 for templating
|
|
template_dir = os.path.join(os.path.dirname(__file__), "templates")
|
|
jinja_env = jinja2.Environment(loader=jinja2.FileSystemLoader(template_dir))
|
|
|
|
def update_execution_state(state: Dict[str, Any]):
|
|
"""Callback function to update the app_state based on LangGraph's state."""
|
|
print(f"📡 Callback received state keys: {list(state.keys())}")
|
|
|
|
with app_state_lock:
|
|
# Initialize the root node if needed
|
|
if not app_state["execution_tree"] or (
|
|
len(app_state["execution_tree"]) == 1 and
|
|
app_state["execution_tree"][0]["id"] == "initialization"
|
|
):
|
|
app_state["execution_tree"] = [{
|
|
"id": "root",
|
|
"name": f"Trading Analysis for {app_state['company_symbol']}",
|
|
"status": "in_progress",
|
|
"content": f"Analyzing {app_state['company_symbol']} using multiple trading agents",
|
|
"children": [],
|
|
"timestamp": time.time()
|
|
}]
|
|
|
|
root_node = app_state["execution_tree"][0]
|
|
|
|
# Map LangGraph node names to user-friendly display info
|
|
node_mapping = {
|
|
"Market Analyst": {"name": "📈 Market Analysis", "phase": "data_collection"},
|
|
"Social Analyst": {"name": "📱 Social Media Analysis", "phase": "data_collection"},
|
|
"News Analyst": {"name": "📰 News Analysis", "phase": "data_collection"},
|
|
"Fundamentals Analyst": {"name": "📊 Fundamental Analysis", "phase": "data_collection"},
|
|
"Bull Researcher": {"name": "🐂 Bull Case Research", "phase": "research"},
|
|
"Bear Researcher": {"name": "🐻 Bear Case Research", "phase": "research"},
|
|
"Research Manager": {"name": "🔍 Research Synthesis", "phase": "research"},
|
|
"Trade Planner": {"name": "📋 Trade Planning", "phase": "planning"},
|
|
"Trader": {"name": "⚡ Trade Execution", "phase": "execution"},
|
|
"Risky Analyst": {"name": "🚨 Risk Assessment (Aggressive)", "phase": "risk_analysis"},
|
|
"Neutral Analyst": {"name": "⚖️ Risk Assessment (Neutral)", "phase": "risk_analysis"},
|
|
"Safe Analyst": {"name": "🛡️ Risk Assessment (Conservative)", "phase": "risk_analysis"},
|
|
"Risk Judge": {"name": "⚠️ Final Risk Evaluation", "phase": "risk_analysis"}
|
|
}
|
|
|
|
phase_names = {
|
|
"data_collection": "📊 Data Collection",
|
|
"research": "🔍 Research & Analysis",
|
|
"planning": "📋 Trade Planning",
|
|
"execution": "⚡ Trade Execution",
|
|
"risk_analysis": "⚠️ Risk Management"
|
|
}
|
|
|
|
# The state dict contains the current state of all nodes
|
|
# We need to determine what has actually been executed
|
|
current_step = None
|
|
|
|
# LangGraph streams the full state each time, so we need to detect what's new
|
|
# Look for populated report fields to determine what has been completed
|
|
if state.get("market_report") and not any(child.get("id") == "data_collection_market" for phase in root_node["children"] for child in phase.get("children", [])):
|
|
current_step = "Market Analyst"
|
|
elif state.get("sentiment_report") and not any(child.get("id") == "data_collection_social" for phase in root_node["children"] for child in phase.get("children", [])):
|
|
current_step = "Social Analyst"
|
|
elif state.get("news_report") and not any(child.get("id") == "data_collection_news" for phase in root_node["children"] for child in phase.get("children", [])):
|
|
current_step = "News Analyst"
|
|
elif state.get("fundamentals_report") and not any(child.get("id") == "data_collection_fundamentals" for phase in root_node["children"] for child in phase.get("children", [])):
|
|
current_step = "Fundamentals Analyst"
|
|
elif state.get("investment_debate_state", {}).get("bull_history") and not any(child.get("id") == "research_bull" for phase in root_node["children"] for child in phase.get("children", [])):
|
|
current_step = "Bull Researcher"
|
|
elif state.get("investment_debate_state", {}).get("bear_history") and not any(child.get("id") == "research_bear" for phase in root_node["children"] for child in phase.get("children", [])):
|
|
current_step = "Bear Researcher"
|
|
elif state.get("investment_debate_state", {}).get("judge_decision") and not any(child.get("id") == "research_manager" for phase in root_node["children"] for child in phase.get("children", [])):
|
|
current_step = "Research Manager"
|
|
elif state.get("trader_investment_plan") and not any(child.get("id") == "planning_trade_planner" for phase in root_node["children"] for child in phase.get("children", [])):
|
|
current_step = "Trade Planner"
|
|
elif state.get("investment_plan") and not any(child.get("id") == "execution_trader" for phase in root_node["children"] for child in phase.get("children", [])):
|
|
current_step = "Trader"
|
|
elif state.get("risk_debate_state", {}).get("risky_history") and not any(child.get("id") == "risk_analysis_risky" for phase in root_node["children"] for child in phase.get("children", [])):
|
|
current_step = "Risky Analyst"
|
|
elif state.get("risk_debate_state", {}).get("neutral_history") and not any(child.get("id") == "risk_analysis_neutral" for phase in root_node["children"] for child in phase.get("children", [])):
|
|
current_step = "Neutral Analyst"
|
|
elif state.get("risk_debate_state", {}).get("safe_history") and not any(child.get("id") == "risk_analysis_safe" for phase in root_node["children"] for child in phase.get("children", [])):
|
|
current_step = "Safe Analyst"
|
|
elif state.get("final_trade_decision") and not any(child.get("id") == "risk_analysis_risk_judge" for phase in root_node["children"] for child in phase.get("children", [])):
|
|
current_step = "Risk Judge"
|
|
|
|
if current_step and current_step in node_mapping:
|
|
print(f"🎯 Processing step: {current_step}")
|
|
node_info = node_mapping[current_step]
|
|
phase_id = node_info["phase"]
|
|
|
|
# Find or create phase category
|
|
phase_category = None
|
|
for child in root_node["children"]:
|
|
if child["id"] == phase_id:
|
|
phase_category = child
|
|
break
|
|
|
|
if not phase_category:
|
|
phase_category = {
|
|
"id": phase_id,
|
|
"name": phase_names.get(phase_id, phase_id),
|
|
"status": "in_progress",
|
|
"content": f"Phase: {phase_names.get(phase_id, phase_id)}",
|
|
"children": [],
|
|
"timestamp": time.time()
|
|
}
|
|
root_node["children"].append(phase_category)
|
|
|
|
# Add new step
|
|
step_id = f"{phase_id}_{current_step.lower().replace(' ', '_')}"
|
|
new_step = {
|
|
"id": step_id,
|
|
"name": node_info["name"],
|
|
"status": "completed",
|
|
"content": f"✅ {node_info['name']} completed successfully",
|
|
"children": [],
|
|
"timestamp": time.time()
|
|
}
|
|
phase_category["children"].append(new_step)
|
|
|
|
# Mark phase as completed if it has steps
|
|
phase_category["status"] = "completed"
|
|
|
|
# Update overall progress
|
|
total_steps = len(node_mapping)
|
|
completed_steps = sum(len(child["children"]) for child in root_node["children"])
|
|
app_state["overall_progress"] = min(100, int((completed_steps / max(total_steps, 1)) * 100))
|
|
|
|
print(f"📊 Progress updated: {app_state['overall_progress']}% ({completed_steps}/{total_steps} steps)")
|
|
else:
|
|
print(f"⏳ No new step detected or step already processed")
|
|
|
|
def run_trading_process(company_symbol: str, config: Dict[str, Any]):
|
|
"""Runs the TradingAgentsGraph in a separate thread."""
|
|
with app_state_lock:
|
|
app_state["overall_status"] = "in_progress"
|
|
app_state["overall_progress"] = 0
|
|
|
|
try:
|
|
# Import and create custom config
|
|
from tradingagents.default_config import DEFAULT_CONFIG
|
|
|
|
# Create custom configuration with user selections
|
|
custom_config = DEFAULT_CONFIG.copy()
|
|
custom_config["llm_provider"] = config["llm_provider"]
|
|
custom_config["max_debate_rounds"] = config["max_debate_rounds"]
|
|
custom_config["cost_per_trade"] = config["cost_per_trade"]
|
|
|
|
# Set the appropriate LLM models based on provider
|
|
if config["llm_provider"] == "google":
|
|
custom_config["gemini_quick_think_llm"] = config["quick_think_llm"]
|
|
custom_config["gemini_deep_think_llm"] = config["deep_think_llm"]
|
|
else:
|
|
custom_config["quick_think_llm"] = config["quick_think_llm"]
|
|
custom_config["deep_think_llm"] = config["deep_think_llm"]
|
|
|
|
# Set backend URL based on provider
|
|
if config["llm_provider"] == "openrouter":
|
|
custom_config["backend_url"] = "https://openrouter.ai/api/v1"
|
|
elif config["llm_provider"] == "google":
|
|
custom_config["backend_url"] = "https://generativelanguage.googleapis.com/v1"
|
|
elif config["llm_provider"] == "anthropic":
|
|
custom_config["backend_url"] = "https://api.anthropic.com/"
|
|
elif config["llm_provider"] == "ollama":
|
|
custom_config["backend_url"] = f"http://{os.getenv('OLLAMA_HOST', 'localhost')}:11434/v1"
|
|
else: # openai
|
|
custom_config["backend_url"] = "https://api.openai.com/v1"
|
|
|
|
print(f"🚀 Initializing TradingAgentsGraph for {company_symbol}")
|
|
graph = TradingAgentsGraph(config=custom_config)
|
|
analysis_date = config["analysis_date"] # Use user-selected date
|
|
print(f"🔄 Starting propagation for {company_symbol} on {analysis_date}")
|
|
|
|
# The propagate method now accepts the callback and trade_date
|
|
final_state, processed_signal = graph.propagate(company_symbol, trade_date=analysis_date, on_step_callback=update_execution_state)
|
|
print(f"✅ Propagation completed for {company_symbol}")
|
|
|
|
with app_state_lock:
|
|
app_state["overall_status"] = "completed"
|
|
app_state["overall_progress"] = 100
|
|
# Update the root node status to completed
|
|
if app_state["execution_tree"]:
|
|
app_state["execution_tree"][0]["status"] = "completed"
|
|
app_state["execution_tree"][0]["content"] = f"✅ Analysis completed successfully!\n\nFinal Decision: {processed_signal}\n\nFull State: {str(final_state)}"
|
|
|
|
except Exception as e:
|
|
import traceback
|
|
error_detail = traceback.format_exc()
|
|
with app_state_lock:
|
|
app_state["overall_status"] = "error"
|
|
app_state["overall_progress"] = 100
|
|
if app_state["execution_tree"]:
|
|
app_state["execution_tree"][0]["status"] = "error"
|
|
app_state["execution_tree"][0]["content"] = f"Error during execution: {str(e)}\n\n{error_detail}"
|
|
# Add a specific error item to the tree
|
|
app_state["execution_tree"].append({
|
|
"id": "error",
|
|
"name": "Process Error",
|
|
"status": "error",
|
|
"content": f"Error during execution: {str(e)}\n\n{error_detail}",
|
|
"children": [],
|
|
"timestamp": time.time()
|
|
})
|
|
finally:
|
|
with app_state_lock:
|
|
app_state["process_running"] = False
|
|
|
|
|
|
@app.get("/", response_class=HTMLResponse)
|
|
async def read_root():
|
|
template = jinja_env.get_template("index.html")
|
|
return template.render(app_state=app_state)
|
|
|
|
@app.post("/start", response_class=HTMLResponse)
|
|
async def start_process(
|
|
background_tasks: BackgroundTasks,
|
|
company_symbol: str = Form(...),
|
|
llm_provider: str = Form(...),
|
|
quick_think_llm: str = Form(...),
|
|
deep_think_llm: str = Form(...),
|
|
max_debate_rounds: int = Form(...),
|
|
cost_per_trade: float = Form(...),
|
|
analysis_date: str = Form(...)
|
|
):
|
|
# Check if all required environment variables are set
|
|
missing_vars = [var for var in required_env_vars if not os.getenv(var)]
|
|
if missing_vars:
|
|
app_state["overall_status"] = "error"
|
|
app_state["execution_tree"] = [{
|
|
"id": "error",
|
|
"name": "Configuration Error",
|
|
"status": "error",
|
|
"content": f"Missing required environment variables: {', '.join(missing_vars)}. Please check .env.example file.",
|
|
"children": [],
|
|
"timestamp": time.time()
|
|
}]
|
|
template = jinja_env.get_template("_partials/left_panel.html")
|
|
return template.render(tree=app_state["execution_tree"], app_state=app_state)
|
|
|
|
with app_state_lock:
|
|
if app_state["process_running"]:
|
|
# Optionally, return an error or a message that a process is already running
|
|
template = jinja_env.get_template("_partials/left_panel.html")
|
|
return template.render(tree=app_state["execution_tree"], app_state=app_state)
|
|
|
|
app_state["process_running"] = True
|
|
app_state["company_symbol"] = company_symbol
|
|
app_state["overall_status"] = "in_progress"
|
|
app_state["overall_progress"] = 5 # Show initial progress
|
|
|
|
# Store all configuration parameters
|
|
app_state["config"] = {
|
|
"llm_provider": llm_provider,
|
|
"quick_think_llm": quick_think_llm,
|
|
"deep_think_llm": deep_think_llm,
|
|
"max_debate_rounds": max_debate_rounds,
|
|
"cost_per_trade": cost_per_trade,
|
|
"analysis_date": analysis_date
|
|
}
|
|
|
|
# Initialize execution tree with startup message
|
|
app_state["execution_tree"] = [{
|
|
"id": "initialization",
|
|
"name": f"🚀 Initializing Trading Analysis for {company_symbol}",
|
|
"status": "in_progress",
|
|
"content": f"Starting comprehensive trading analysis for {company_symbol}...\n\nConfiguration:\n• LLM Provider: {llm_provider}\n• Quick Think Model: {quick_think_llm}\n• Deep Think Model: {deep_think_llm}\n• Max Debate Rounds: {max_debate_rounds}\n• Cost Per Trade: ${cost_per_trade}\n• Analysis Date: {analysis_date}\n\nInitializing trading agents and preparing analysis pipeline...",
|
|
"children": [],
|
|
"timestamp": time.time()
|
|
}]
|
|
|
|
background_tasks.add_task(run_trading_process, company_symbol, app_state["config"])
|
|
|
|
template = jinja_env.get_template("_partials/left_panel.html")
|
|
return template.render(tree=app_state["execution_tree"], app_state=app_state)
|
|
|
|
@app.get("/status", response_class=HTMLResponse)
|
|
async def get_status():
|
|
with app_state_lock:
|
|
template = jinja_env.get_template("_partials/left_panel.html")
|
|
return template.render(tree=app_state["execution_tree"], app_state=app_state)
|
|
|
|
def find_item_in_tree(item_id: str, tree: list) -> Dict[str, Any] | None:
|
|
"""Recursively searches the execution tree for an item by its ID."""
|
|
for item in tree:
|
|
if item["id"] == item_id:
|
|
return item
|
|
if item["children"]:
|
|
found_child = find_item_in_tree(item_id, item["children"])
|
|
if found_child:
|
|
return found_child
|
|
return None
|
|
|
|
@app.get("/content/{item_id}", response_class=HTMLResponse)
|
|
async def get_item_content(item_id: str):
|
|
with app_state_lock:
|
|
item = find_item_in_tree(item_id, app_state["execution_tree"])
|
|
if item:
|
|
template = jinja_env.get_template("_partials/right_panel.html")
|
|
return template.render(content=item.get("content", "No content available."))
|
|
else:
|
|
return HTMLResponse(content="<p>Item not found.</p>", status_code=404)
|
|
|
|
# To run this app:
|
|
# uvicorn webapp.main:app --reload
|