checkpoint before checking out cursor/fix-and-refactor-news-analyst-project-5e09

This commit is contained in:
Jiahao Zhang 2025-07-06 11:04:02 -07:00
parent f5e641fd1f
commit 3c07c782ad
35 changed files with 3983 additions and 691 deletions

1
.gitignore vendored
View File

@ -224,3 +224,4 @@ dmypy.json
# Pyre
.pyre/
*.xcuserstate

View File

@ -1,9 +1,9 @@
import Foundation
enum AppConfig {
// API Configuration
// MARK: - API Configuration
static let apiBaseURL: String = {
// Check for environment variable first (useful for CI/CD)
// Check for environment variable first (useful for CI/CD and custom URLs)
if let envURL = ProcessInfo.processInfo.environment["TRADINGAGENTS_API_URL"] {
return envURL
}
@ -11,11 +11,11 @@ enum AppConfig {
// Default URLs for different environments
#if DEBUG
#if targetEnvironment(simulator)
// For iOS Simulator
// For iOS Simulator - connect to localhost
return "http://localhost:8000"
#else
// For real device - UPDATE THIS WITH YOUR MAC'S IP
return "http://192.168.4.223:8000"
// For real device - connect to Mac's IP address
return "http://10.73.204.80:8000"
#endif
#else
// For production, update this to your deployed server URL
@ -23,11 +23,43 @@ enum AppConfig {
#endif
}()
// Network Configuration
static let requestTimeout: TimeInterval = 600.0
// MARK: - Network Configuration
static let requestTimeout: TimeInterval = 600.0 // 10 minutes for long analysis
static let streamTimeout: TimeInterval = 600.0 // 10 minutes for streaming
static let maxRetries = 3
// UI Configuration
// MARK: - API Endpoints
static let healthEndpoint = "/health"
static let analyzeEndpoint = "/analyze"
static let streamEndpoint = "/analyze/stream"
// MARK: - UI Configuration
static let defaultTicker = "AAPL"
static let animationDuration = 0.3
// MARK: - Debug Configuration
static let enableVerboseLogging = true
static let logNetworkRequests = true
// MARK: - Helper Methods
static func fullURL(for endpoint: String) -> String {
return apiBaseURL + endpoint
}
static func streamURL(for ticker: String) -> String {
return "\(apiBaseURL)\(streamEndpoint)?ticker=\(ticker)"
}
// For debugging - prints current configuration
static func printConfiguration() {
print("🔧 TradingAgents Configuration:")
print("📍 API Base URL: \(apiBaseURL)")
print("⏱️ Request Timeout: \(requestTimeout)s")
print("📡 Stream Timeout: \(streamTimeout)s")
#if targetEnvironment(simulator)
print("📱 Environment: iOS Simulator")
#else
print("📱 Environment: Real Device")
#endif
}
}

View File

@ -32,4 +32,88 @@ struct AnalysisResponse: Codable {
case processedSignal = "processed_signal"
case error
}
}
// MARK: - Live Activity Models
public struct AgentMessage: Identifiable, Equatable {
public let id = UUID()
public let timestamp: Date
public let content: String
public let type: MessageType
public enum MessageType {
case reasoning // Intermediate thinking/reasoning
case toolCall // Tool execution
case status // Status updates
case finalReport // Complete analysis report
}
public init(content: String, type: MessageType) {
self.content = content
self.type = type
self.timestamp = Date()
}
}
public struct AgentActivity: Identifiable, Equatable {
public let id = UUID()
public let name: String
public let displayName: String
public var status: AgentStatus
public var messages: [AgentMessage]
public var finalReport: String?
public let startTime: Date
public var completionTime: Date?
public enum AgentStatus: Equatable {
case pending
case inProgress
case completed
case error(String)
public static func == (lhs: AgentStatus, rhs: AgentStatus) -> Bool {
switch (lhs, rhs) {
case (.pending, .pending),
(.inProgress, .inProgress),
(.completed, .completed):
return true
case (.error(let lhsMessage), .error(let rhsMessage)):
return lhsMessage == rhsMessage
default:
return false
}
}
}
public init(name: String, displayName: String) {
self.name = name
self.displayName = displayName
self.status = .pending
self.messages = []
self.finalReport = nil
self.startTime = Date()
self.completionTime = nil
}
public mutating func addMessage(_ message: AgentMessage) {
messages.append(message)
}
public mutating func setFinalReport(_ report: String) {
finalReport = report
// Replace reasoning messages with final report
messages.removeAll { $0.type == .reasoning }
messages.append(AgentMessage(content: report, type: .finalReport))
}
public mutating func updateStatus(_ newStatus: AgentStatus) {
status = newStatus
if case .completed = newStatus {
completionTime = Date()
}
}
public static func == (lhs: AgentActivity, rhs: AgentActivity) -> Bool {
lhs.id == rhs.id
}
}

View File

@ -1,29 +1,31 @@
import Foundation
import Combine
import os.log
import OSLog
// MARK: - SSE Event Models
public struct SSEEvent: Codable {
// MARK: - SSE Event Model
struct SSEEvent: Codable {
let type: String
let message: String?
let agent: String?
let status: String?
let section: String?
let content: String?
let status: String?
}
// MARK: - Progress Models
// MARK: - Enhanced Progress Models
public struct AnalysisProgress {
public let currentAgent: String
public let message: String
public let reports: [String: String]
public let agentActivities: [String: AgentActivity]
public let isComplete: Bool
public let error: String?
public init(currentAgent: String, message: String, reports: [String: String], isComplete: Bool, error: String?) {
public init(currentAgent: String, message: String, reports: [String: String], agentActivities: [String: AgentActivity], isComplete: Bool, error: String?) {
self.currentAgent = currentAgent
self.message = message
self.reports = reports
self.agentActivities = agentActivities
self.isComplete = isComplete
self.error = error
}
@ -33,27 +35,6 @@ public struct AnalysisProgress {
public class TradingAgentsService: ObservableObject {
internal let logger = Logger(subsystem: "com.tradingagents.app", category: "TradingAgentsService")
private let baseURL: String = {
// Check for environment variable first
if let envURL = ProcessInfo.processInfo.environment["TRADINGAGENTS_API_URL"] {
return envURL
}
// Default URLs for different environments
#if DEBUG
#if targetEnvironment(simulator)
// For iOS Simulator
return "http://localhost:8000"
#else
// For real device - UPDATE THIS WITH YOUR MAC'S IP
return "http://192.168.4.223:8000"
#endif
#else
// For production, update this to your deployed server URL
return "https://api.tradingagents.com"
#endif
}()
private var eventSource: URLSessionDataTask?
private var streamingDelegate: SSEStreamDelegate?
private let session: URLSession
@ -63,13 +44,18 @@ public class TradingAgentsService: ObservableObject {
currentAgent: "",
message: "",
reports: [:],
agentActivities: [:],
isComplete: false,
error: nil
)
public init() {
self.session = URLSession.shared
logger.info("🚀 TradingAgentsService initialized with baseURL: \(self.baseURL)")
logger.info("🚀 TradingAgentsService initialized with baseURL: \(AppConfig.apiBaseURL)")
if AppConfig.enableVerboseLogging {
AppConfig.printConfiguration()
}
}
public func streamAnalysis(for ticker: String) -> AnyPublisher<AnalysisProgress, Never> {
@ -77,7 +63,7 @@ public class TradingAgentsService: ObservableObject {
logger.info("📡 Starting stream analysis for ticker: \(ticker)")
let urlString = "\(baseURL)/analyze/stream?ticker=\(ticker)"
let urlString = AppConfig.streamURL(for: ticker)
logger.info("🌐 Request URL: \(urlString)")
guard let url = URL(string: urlString) else {
@ -86,6 +72,7 @@ public class TradingAgentsService: ObservableObject {
currentAgent: "",
message: "",
reports: [:],
agentActivities: [:],
isComplete: true,
error: "Invalid URL: \(urlString)"
))
@ -96,7 +83,7 @@ public class TradingAgentsService: ObservableObject {
request.setValue("text/event-stream", forHTTPHeaderField: "Accept")
request.setValue("no-cache", forHTTPHeaderField: "Cache-Control")
request.setValue("keep-alive", forHTTPHeaderField: "Connection")
request.timeoutInterval = 600.0 // 10 minutes
request.timeoutInterval = AppConfig.streamTimeout
logger.info("📋 Request headers: \(request.allHTTPHeaderFields ?? [:])")
@ -114,8 +101,8 @@ public class TradingAgentsService: ObservableObject {
// Create session with delegate for streaming
let config = URLSessionConfiguration.default
config.timeoutIntervalForRequest = 600.0
config.timeoutIntervalForResource = 600.0
config.timeoutIntervalForRequest = AppConfig.requestTimeout
config.timeoutIntervalForResource = AppConfig.streamTimeout
let delegateSession = URLSession(configuration: config, delegate: delegate, delegateQueue: nil)
let task = delegateSession.dataTask(with: request)
@ -125,14 +112,22 @@ public class TradingAgentsService: ObservableObject {
self.streamingDelegate = delegate
self.eventSource = task
// Send initial status update
DispatchQueue.main.async {
subject.send(AnalysisProgress(
currentAgent: "Starting",
message: "🚀 Connecting to analysis service...",
reports: [:],
agentActivities: delegate.getAgentActivities(),
isComplete: false,
error: nil
))
}
task.resume()
logger.info("🚀 SSE Stream task started with delegate")
}
internal func formatAgentName(_ agent: String) -> String {
switch agent.lowercased() {
case "market": return "Market Analyst"
@ -141,7 +136,12 @@ public class TradingAgentsService: ObservableObject {
case "fundamentals": return "Fundamentals Analyst"
case "bull_researcher": return "Bull Researcher"
case "bear_researcher": return "Bear Researcher"
case "research_manager": return "Research Manager"
case "trader": return "Trading Team"
case "risky_analyst": return "Risky Analyst"
case "safe_analyst": return "Safe Analyst"
case "neutral_analyst": return "Neutral Analyst"
case "risk_manager": return "Risk Manager"
default: return agent.capitalized
}
}
@ -154,6 +154,7 @@ public class TradingAgentsService: ObservableObject {
case "fundamentals_report": return "Fundamentals Analysis"
case "investment_plan": return "Investment Plan"
case "trader_investment_plan": return "Trading Plan"
case "risk_analysis": return "Risk Analysis"
case "final_trade_decision": return "Final Decision"
default: return section.replacingOccurrences(of: "_", with: " ").capitalized
}
@ -173,12 +174,38 @@ private class SSEStreamDelegate: NSObject, URLSessionDataDelegate {
private weak var service: TradingAgentsService?
private var buffer = ""
private var currentReports: [String: String] = [:]
private var agentActivities: [String: AgentActivity] = [:]
private var currentAgentName = ""
var task: URLSessionDataTask?
init(subject: PassthroughSubject<AnalysisProgress, Never>, service: TradingAgentsService) {
self.subject = subject
self.service = service
super.init()
// Initialize all known agents
let knownAgents = [
("market", "Market Analyst"),
("social", "Social Media Analyst"),
("news", "News Analyst"),
("fundamentals", "Fundamentals Analyst"),
("bull_researcher", "Bull Researcher"),
("bear_researcher", "Bear Researcher"),
("research_manager", "Research Manager"),
("trader", "Trading Team"),
("risky_analyst", "Risky Analyst"),
("safe_analyst", "Safe Analyst"),
("neutral_analyst", "Neutral Analyst"),
("risk_manager", "Risk Manager")
]
for (name, displayName) in knownAgents {
agentActivities[name] = AgentActivity(name: name, displayName: displayName)
}
}
func getAgentActivities() -> [String: AgentActivity] {
return agentActivities
}
func urlSession(_ session: URLSession, dataTask: URLSessionDataTask, didReceive response: URLResponse, completionHandler: @escaping (URLSession.ResponseDisposition) -> Void) {
@ -198,6 +225,7 @@ private class SSEStreamDelegate: NSObject, URLSessionDataDelegate {
currentAgent: "",
message: "",
reports: self.currentReports,
agentActivities: self.agentActivities,
isComplete: true,
error: "HTTP \(httpResponse.statusCode)"
))
@ -210,7 +238,12 @@ private class SSEStreamDelegate: NSObject, URLSessionDataDelegate {
guard let service = service else { return }
let newData = String(data: data, encoding: .utf8) ?? ""
service.logger.info("📦 Received \(data.count) bytes: \(String(newData.prefix(100)))...")
service.logger.info("📦 Received \(data.count) bytes")
// Log raw SSE data for debugging
if !newData.isEmpty {
service.logger.info("📡 Raw SSE Data: \(newData)")
}
// Add new data to buffer
buffer += newData
@ -219,28 +252,45 @@ private class SSEStreamDelegate: NSObject, URLSessionDataDelegate {
let lines = buffer.components(separatedBy: .newlines)
buffer = lines.last ?? "" // Keep incomplete line in buffer
service.logger.info("📋 Processing \(lines.count - 1) complete lines")
// Process all complete lines except the last (incomplete) one
for line in lines.dropLast() {
for (index, line) in lines.dropLast().enumerated() {
service.logger.info("📋 Line[\(index)]: \(line)")
if line.hasPrefix("data: ") {
let jsonString = String(line.dropFirst(6))
service.logger.info("🔍 Processing JSON: \(String(jsonString.prefix(50)))...")
service.logger.info("🔍 Extracting JSON: \(jsonString)")
if let jsonData = jsonString.data(using: .utf8),
let event = try? JSONDecoder().decode(SSEEvent.self, from: jsonData) {
service.logger.info("✅ Decoded event - Type: \(event.type)")
service.logger.info("✅ SSE Event Successfully Decoded:")
service.logger.info(" 📌 Type: \(event.type)")
service.logger.info(" 📌 Agent: \(event.agent ?? "nil")")
service.logger.info(" 📌 Status: \(event.status ?? "nil")")
service.logger.info(" 📌 Section: \(event.section ?? "nil")")
service.logger.info(" 📌 Message: \(event.message ?? "nil")")
service.logger.info(" 📌 Content length: \(event.content?.count ?? 0)")
if let content = event.content {
service.logger.info(" 📌 Content preview: \(String(content.prefix(200)))...")
}
DispatchQueue.main.async {
self.processSSEEvent(event, service: service)
}
} else {
service.logger.warning("⚠️ Failed to decode JSON: \(jsonString)")
service.logger.error("❌ Failed to decode SSE JSON: \(jsonString)")
}
} else if line.hasPrefix(":") {
service.logger.info("💬 SSE Comment: \(line)")
} else if !line.isEmpty {
service.logger.info("📝 SSE Other: \(line)")
}
}
}
func urlSession(_ session: URLSession, dataTask: URLSessionDataTask, didCompleteWithError error: Error?) {
func urlSession(_ session: URLSession, task: URLSessionTask, didCompleteWithError error: Error?) {
if let error = error {
service?.logger.error("❌ SSE Connection error: \(error.localizedDescription)")
DispatchQueue.main.async {
@ -248,6 +298,7 @@ private class SSEStreamDelegate: NSObject, URLSessionDataDelegate {
currentAgent: "",
message: "",
reports: self.currentReports,
agentActivities: self.agentActivities,
isComplete: true,
error: "Connection error: \(error.localizedDescription)"
))
@ -265,32 +316,143 @@ private class SSEStreamDelegate: NSObject, URLSessionDataDelegate {
currentAgent: "Starting",
message: event.message ?? "Starting analysis...",
reports: currentReports,
agentActivities: agentActivities,
isComplete: false,
error: nil
))
case "agent_status":
let agentName = service.formatAgentName(event.agent ?? "")
let statusMessage = event.status == "completed" ?
"\(agentName) completed" :
"🔄 Analyzing with \(agentName)..."
if let agentKey = event.agent {
let agentName = service.formatAgentName(agentKey)
currentAgentName = agentName
service.logger.info("👤 Processing Agent Status:")
service.logger.info(" 🔍 Agent Key: \(agentKey)")
service.logger.info(" 🔍 Agent Name: \(agentName)")
service.logger.info(" 🔍 Status: \(event.status ?? "nil")")
service.logger.info(" 🔍 Current Activities Count: \(self.agentActivities.count)")
// Create activity if it doesn't exist
if agentActivities[agentKey] == nil {
agentActivities[agentKey] = AgentActivity(name: agentKey, displayName: agentName)
service.logger.info(" ✅ Created new activity for: \(agentKey)")
} else {
service.logger.info(" ♻️ Using existing activity for: \(agentKey)")
}
// Update agent status
if var activity = agentActivities[agentKey] {
let statusMessage: String
let newStatus: AgentActivity.AgentStatus
switch event.status {
case "in_progress":
newStatus = .inProgress
statusMessage = "🔄 Analyzing with \(agentName)..."
activity.addMessage(AgentMessage(content: "Started analysis", type: .status))
service.logger.info(" 🚀 Setting \(agentKey) to IN_PROGRESS")
case "completed":
newStatus = .completed
statusMessage = "\(agentName) completed"
activity.addMessage(AgentMessage(content: "Analysis completed", type: .status))
service.logger.info(" ✅ Setting \(agentKey) to COMPLETED")
default:
newStatus = .inProgress
statusMessage = "🔄 Analyzing with \(agentName)..."
activity.addMessage(AgentMessage(content: event.status ?? "In progress", type: .status))
service.logger.info(" ⚠️ Unknown status '\(event.status ?? "nil")', defaulting to IN_PROGRESS")
}
activity.updateStatus(newStatus)
agentActivities[agentKey] = activity
// Log final agent activities state
service.logger.info(" 📊 Final Agent Activities State:")
for (key, activity) in agentActivities {
service.logger.info("\(key): \(String(describing: activity.status)) (\(activity.displayName))")
}
let activeCount = agentActivities.values.filter { $0.status == .inProgress }.count
let completedCount = agentActivities.values.filter { $0.status == .completed }.count
let pendingCount = agentActivities.values.filter { $0.status == .pending }.count
service.logger.info(" 📈 Status Summary: \(activeCount) active, \(completedCount) completed, \(pendingCount) pending")
subject.send(AnalysisProgress(
currentAgent: agentName,
message: statusMessage,
reports: currentReports,
agentActivities: agentActivities,
isComplete: false,
error: nil
))
} else {
service.logger.error(" ❌ Failed to update activity for agent: \(agentKey)")
}
} else {
service.logger.error(" ❌ Agent status event missing agent key!")
}
service.logger.info("👤 Agent: \(agentName), Status: \(event.status ?? "")")
subject.send(AnalysisProgress(
currentAgent: agentName,
message: statusMessage,
reports: currentReports,
isComplete: false,
error: nil
))
case "reasoning":
// Capture intermediate reasoning messages with agent assignment
if let content = event.content, !content.isEmpty {
service.logger.info("🧠 Processing Reasoning Event:")
service.logger.info(" 🔍 Content length: \(content.count)")
service.logger.info(" 🔍 Content preview: \(String(content.prefix(200)))...")
service.logger.info(" 🔍 Event agent: \(event.agent ?? "nil")")
// Use agent from event if available, otherwise find active agent
let agentKey: String
if let eventAgent = event.agent, !eventAgent.isEmpty {
agentKey = eventAgent
service.logger.info(" ✅ Using agent from event: \(agentKey)")
} else {
agentKey = findCurrentActiveAgent() ?? "market"
service.logger.info(" ⚠️ Using fallback agent: \(agentKey)")
}
// Create activity if it doesn't exist
if agentActivities[agentKey] == nil {
let displayName = service.formatAgentName(agentKey)
agentActivities[agentKey] = AgentActivity(name: agentKey, displayName: displayName)
service.logger.info(" ✅ Created new activity for agent: \(agentKey)")
} else {
service.logger.info(" ♻️ Using existing activity for agent: \(agentKey)")
}
if var activity = agentActivities[agentKey] {
let message = AgentMessage(content: content, type: .reasoning)
let previousMessageCount = activity.messages.count
activity.addMessage(message)
agentActivities[agentKey] = activity
service.logger.info(" 📝 Added reasoning message to \(agentKey)")
service.logger.info(" 📊 Messages for \(agentKey): \(previousMessageCount)\(activity.messages.count)")
service.logger.info(" 🎯 Sending progress update for \(activity.displayName)")
subject.send(AnalysisProgress(
currentAgent: currentAgentName.isEmpty ? activity.displayName : currentAgentName,
message: "💭 \(activity.displayName) thinking...",
reports: currentReports,
agentActivities: agentActivities,
isComplete: false,
error: nil
))
} else {
service.logger.error(" ❌ Failed to find or create activity for agent: \(agentKey)")
}
} else {
service.logger.warning("🧠 Reasoning event with empty or nil content")
}
case "progress":
if let content = event.content, let percentage = Int(content) {
service.logger.info("📊 Progress: \(percentage)%")
subject.send(AnalysisProgress(
currentAgent: service.progress.currentAgent,
currentAgent: currentAgentName,
message: "Progress: \(percentage)%",
reports: currentReports,
agentActivities: agentActivities,
isComplete: false,
error: nil
))
@ -300,10 +462,20 @@ private class SSEStreamDelegate: NSObject, URLSessionDataDelegate {
if let section = event.section, let content = event.content {
service.logger.info("📊 Report: \(section)")
currentReports[section] = content
// Find the agent that produced this report and set as final report
let agentKey = mapSectionToAgent(section)
if var activity = agentActivities[agentKey] {
activity.setFinalReport(content)
activity.updateStatus(.completed)
agentActivities[agentKey] = activity
}
subject.send(AnalysisProgress(
currentAgent: service.progress.currentAgent,
currentAgent: currentAgentName,
message: "📊 Updated \(service.formatSectionName(section))",
reports: currentReports,
agentActivities: agentActivities,
isComplete: false,
error: nil
))
@ -315,6 +487,7 @@ private class SSEStreamDelegate: NSObject, URLSessionDataDelegate {
currentAgent: "Complete",
message: "✅ Analysis completed successfully",
reports: currentReports,
agentActivities: agentActivities,
isComplete: true,
error: nil
))
@ -325,6 +498,7 @@ private class SSEStreamDelegate: NSObject, URLSessionDataDelegate {
currentAgent: "",
message: "",
reports: currentReports,
agentActivities: agentActivities,
isComplete: true,
error: event.message ?? "Unknown error"
))
@ -333,6 +507,28 @@ private class SSEStreamDelegate: NSObject, URLSessionDataDelegate {
service.logger.info(" Unknown event type: \(event.type)")
}
}
private func findCurrentActiveAgent() -> String? {
// Find the most recently active agent
return agentActivities.values
.filter { $0.status == .inProgress }
.max(by: { $0.startTime < $1.startTime })?
.name
}
private func mapSectionToAgent(_ section: String) -> String {
switch section {
case "market_report": return "market"
case "sentiment_report": return "social"
case "news_report": return "news"
case "fundamentals_report": return "fundamentals"
case "investment_plan": return "research_manager"
case "trader_investment_plan": return "trader"
case "risk_analysis": return "risk_manager"
case "final_trade_decision": return "risk_manager"
default: return "market"
}
}
}
// MARK: - API Errors

View File

@ -1,6 +1,6 @@
import Foundation
import Combine
import SwiftData
import SwiftUI
// MARK: - View Model
@MainActor
@ -8,13 +8,12 @@ class TradingAnalysisViewModel: ObservableObject {
// MARK: - Published Properties
@Published var ticker: String = ""
@Published var isAnalyzing: Bool = false
@Published var showingResults: Bool = false
@Published var errorMessage: String?
// MARK: - Streaming Properties
// MARK: - Live Activity Properties
@Published var currentAgent: String = ""
@Published var statusMessage: String = ""
@Published var analysisProgress: Double = 0.0
@Published var agentActivities: [AgentActivity] = []
@Published var reports: [String: String] = [:]
@Published var finalDecision: String = ""
@ -22,59 +21,7 @@ class TradingAnalysisViewModel: ObservableObject {
private let tradingService = TradingAgentsService()
private var cancellables = Set<AnyCancellable>()
// MARK: - SwiftData
var modelContext: ModelContext?
// MARK: - Constants
private let agentSteps = [
"Starting", "Market Analyst", "Social Media Analyst",
"News Analyst", "Fundamentals Analyst", "Bull Researcher",
"Bear Researcher", "Trading Team", "Complete"
]
init() {
setupSubscriptions()
}
private func setupSubscriptions() {
// Subscribe to service progress updates
tradingService.$progress
.receive(on: DispatchQueue.main)
.sink { [weak self] progress in
self?.updateProgress(progress)
}
.store(in: &cancellables)
}
private func updateProgress(_ progress: AnalysisProgress) {
currentAgent = progress.currentAgent
statusMessage = progress.message
reports = progress.reports
// Update progress percentage based on current agent
if let stepIndex = agentSteps.firstIndex(of: progress.currentAgent) {
analysisProgress = Double(stepIndex) / Double(agentSteps.count - 1)
}
// Handle completion
if progress.isComplete {
isAnalyzing = false
if progress.error == nil {
showingResults = true
finalDecision = reports["final_trade_decision"] ?? ""
// Save to history
saveToHistory()
} else {
errorMessage = progress.error
}
}
// Handle errors
if let error = progress.error {
errorMessage = error
isAnalyzing = false
}
}
func startAnalysis() {
guard !ticker.isEmpty else {
@ -83,20 +30,19 @@ class TradingAnalysisViewModel: ObservableObject {
}
// Reset state
isAnalyzing = true
showingResults = false
errorMessage = nil
isAnalyzing = true
agentActivities = []
reports = [:]
currentAgent = ""
statusMessage = ""
analysisProgress = 0.0
reports = [:]
finalDecision = ""
// Start streaming analysis
tradingService.streamAnalysis(for: ticker.uppercased())
.receive(on: DispatchQueue.main)
.sink { [weak self] progress in
self?.updateProgress(progress)
.sink { [weak self] analysisProgress in
self?.updateProgress(analysisProgress)
}
.store(in: &cancellables)
}
@ -108,19 +54,69 @@ class TradingAnalysisViewModel: ObservableObject {
statusMessage = "Analysis stopped"
}
private func updateProgress(_ progress: AnalysisProgress) {
currentAgent = progress.currentAgent
statusMessage = progress.message
reports = progress.reports
// Update agent activities with live messages
agentActivities = Array(progress.agentActivities.values)
.sorted { $0.startTime < $1.startTime }
// Update final decision
finalDecision = reports["final_trade_decision"] ?? ""
// Handle completion
if progress.isComplete {
isAnalyzing = false
if let error = progress.error {
errorMessage = error
}
}
// Handle errors
if let error = progress.error {
errorMessage = error
isAnalyzing = false
}
}
func resetAnalysis() {
stopAnalysis()
showingResults = false
errorMessage = nil
ticker = ""
currentAgent = ""
statusMessage = ""
analysisProgress = 0.0
agentActivities = []
reports = [:]
finalDecision = ""
}
// MARK: - Computed Properties
// MARK: - Helper Methods for UI
func getActiveAgents() -> [AgentActivity] {
return agentActivities.filter { $0.status == .inProgress }
}
func getCompletedAgents() -> [AgentActivity] {
return agentActivities.filter { $0.status == .completed }
}
func getPendingAgents() -> [AgentActivity] {
return agentActivities.filter { $0.status == .pending }
}
func getLatestMessagesForAgent(_ agentName: String) -> [AgentMessage] {
return agentActivities.first { $0.name == agentName }?.messages ?? []
}
func hasActivityToShow() -> Bool {
return !agentActivities.isEmpty && agentActivities.contains { !$0.messages.isEmpty }
}
var hasReports: Bool {
!reports.isEmpty
}
var formattedReports: [(title: String, content: String)] {
let reportOrder = [
("market_report", "Market Analysis"),
@ -128,7 +124,9 @@ class TradingAnalysisViewModel: ObservableObject {
("news_report", "News Analysis"),
("fundamentals_report", "Fundamentals Analysis"),
("investment_plan", "Investment Plan"),
("trader_investment_plan", "Trading Plan")
("trader_investment_plan", "Trading Plan"),
("risk_analysis", "Risk Analysis"),
("final_trade_decision", "Final Decision")
]
return reportOrder.compactMap { key, title in
@ -136,75 +134,4 @@ class TradingAnalysisViewModel: ObservableObject {
return (title: title, content: content)
}
}
var hasReports: Bool {
!reports.isEmpty
}
var progressPercentage: Int {
Int(analysisProgress * 100)
}
// MARK: - History Management
private func saveToHistory() {
guard let modelContext = modelContext else { return }
// Extract signal from final decision
let signal = extractSignal(from: finalDecision)
// Create full report by combining all sections
let fullReport = formattedReports.map { section in
"=== \(section.title) ===\n\n\(section.content)\n"
}.joined(separator: "\n\n")
// Create history entry
let history = AnalysisHistory(
ticker: ticker.uppercased(),
signal: signal,
finalDecision: finalDecision,
fullReport: fullReport,
marketReport: reports["market_report"],
sentimentReport: reports["sentiment_report"],
newsReport: reports["news_report"],
fundamentalsReport: reports["fundamentals_report"],
investmentPlan: reports["investment_plan"],
traderPlan: reports["trader_investment_plan"],
riskAnalysis: reports["risk_analysis"]
)
// Save to SwiftData
modelContext.insert(history)
do {
try modelContext.save()
print("✅ Analysis saved to history")
} catch {
print("❌ Failed to save analysis to history: \(error)")
}
}
private func extractSignal(from decision: String) -> String {
let uppercased = decision.uppercased()
// Check for explicit signals
if uppercased.contains("**BUY**") || uppercased.contains("BUY SIGNAL") {
return "BUY"
} else if uppercased.contains("**SELL**") || uppercased.contains("SELL SIGNAL") {
return "SELL"
} else if uppercased.contains("**HOLD**") || uppercased.contains("HOLD SIGNAL") {
return "HOLD"
}
// Check for context clues
if uppercased.contains("RECOMMEND BUYING") || uppercased.contains("STRONG BUY") {
return "BUY"
} else if uppercased.contains("RECOMMEND SELLING") || uppercased.contains("STRONG SELL") {
return "SELL"
} else if uppercased.contains("MAINTAIN POSITION") || uppercased.contains("HOLD POSITION") {
return "HOLD"
}
// Default to HOLD if unclear
return "HOLD"
}
}

View File

@ -2,52 +2,62 @@ import SwiftUI
// MARK: - Analysis Result View
struct AnalysisResultView: View {
let ticker: String
let reports: [(title: String, content: String)]
let finalDecision: String
let onDismiss: () -> Void
var body: some View {
NavigationView {
ScrollView {
VStack(alignment: .leading, spacing: 20) {
// Header
HeaderView(ticker: ticker)
.padding(.horizontal)
VStack(alignment: .leading, spacing: 20) {
// Header
Text("Final Reports")
.font(.title2)
.fontWeight(.bold)
.padding(.horizontal)
if reports.isEmpty && finalDecision.isEmpty {
// Empty state
VStack(spacing: 16) {
Image(systemName: "doc.text")
.font(.system(size: 50))
.foregroundColor(.secondary)
// Reports
VStack(spacing: 16) {
ForEach(reports, id: \.title) { report in
ReportCard(
title: report.title,
icon: iconForReport(report.title),
content: report.content
)
}
// Final Decision
if !finalDecision.isEmpty {
ReportCard(
title: "Final Decision",
icon: "checkmark.seal.fill",
content: finalDecision,
isHighlighted: true
)
}
}
.padding(.horizontal)
Text("No reports available yet")
.font(.headline)
.foregroundColor(.secondary)
Text("Reports will appear here as agents complete their analysis")
.font(.subheadline)
.foregroundColor(.secondary)
.multilineTextAlignment(.center)
}
.padding(.vertical)
}
.navigationTitle("Analysis Results")
.toolbar {
ToolbarItem(placement: .primaryAction) {
Button("Done") {
onDismiss()
.padding()
.frame(maxWidth: .infinity, maxHeight: .infinity)
} else {
// Reports
VStack(spacing: 16) {
ForEach(reports, id: \.title) { report in
ReportCard(
title: report.title,
icon: iconForReport(report.title),
content: report.content
)
}
// Final Decision
if !finalDecision.isEmpty {
ReportCard(
title: "Final Decision",
icon: "checkmark.seal.fill",
content: finalDecision,
isHighlighted: true
)
}
}
.padding(.horizontal)
}
Spacer()
}
.padding(.vertical)
}
private func iconForReport(_ title: String) -> String {
@ -58,34 +68,13 @@ struct AnalysisResultView: View {
case "Fundamentals Analysis": return "doc.text"
case "Investment Plan": return "lightbulb"
case "Trading Plan": return "chart.bar.fill"
case "Risk Analysis": return "shield.checkered"
case "Final Decision": return "checkmark.seal.fill"
default: return "doc.text"
}
}
}
// MARK: - Header View
struct HeaderView: View {
let ticker: String
var body: some View {
VStack(alignment: .leading, spacing: 8) {
HStack(alignment: .top) {
VStack(alignment: .leading, spacing: 4) {
Text(ticker)
.font(.largeTitle)
.fontWeight(.bold)
Text("Analysis Date: \(DateFormatter.shortDate.string(from: Date()))")
.font(.caption)
.foregroundStyle(.secondary)
}
Spacer()
}
}
}
}
// MARK: - Report Card
struct ReportCard: View {
let title: String

View File

@ -61,4 +61,279 @@ struct ErrorView: View {
}
.padding()
}
}
// MARK: - Agent Activity Views
struct AgentActivityCard: View {
let activity: AgentActivity
@State private var isExpanded = true
var body: some View {
VStack(alignment: .leading, spacing: 8) {
// Agent Header
HStack {
statusIcon
Text(activity.displayName)
.font(.headline)
.foregroundColor(.primary)
Spacer()
Text(statusText)
.font(.caption)
.foregroundColor(.secondary)
Button(action: { isExpanded.toggle() }) {
Image(systemName: isExpanded ? "chevron.up" : "chevron.down")
.foregroundColor(.secondary)
}
}
if isExpanded {
// Messages
if !activity.messages.isEmpty {
VStack(alignment: .leading, spacing: 4) {
ForEach(activity.messages) { message in
AgentMessageRow(message: message)
}
}
} else if activity.status == .pending {
Text("Waiting to start...")
.font(.caption)
.foregroundColor(.secondary)
.italic()
}
}
}
.padding()
.background(backgroundColorForStatus)
.cornerRadius(12)
.overlay(
RoundedRectangle(cornerRadius: 12)
.stroke(borderColorForStatus, lineWidth: 2)
)
}
@ViewBuilder
private var statusIcon: some View {
switch activity.status {
case .pending:
Image(systemName: "clock")
.foregroundColor(.orange)
case .inProgress:
Image(systemName: "gear")
.foregroundColor(.blue)
.rotationEffect(.degrees(Double(activity.messages.count) * 10))
case .completed:
Image(systemName: "checkmark.circle.fill")
.foregroundColor(.green)
case .error(_):
Image(systemName: "exclamationmark.triangle.fill")
.foregroundColor(.red)
}
}
private var statusText: String {
switch activity.status {
case .pending:
return "Pending"
case .inProgress:
return "Working..."
case .completed:
if let completionTime = activity.completionTime {
let duration = completionTime.timeIntervalSince(activity.startTime)
return "Completed (\(String(format: "%.1f", duration))s)"
}
return "Completed"
case .error(let message):
return "Error: \(message)"
}
}
private var backgroundColorForStatus: Color {
switch activity.status {
case .pending:
return Color(.systemGray6)
case .inProgress:
return Color.blue.opacity(0.1)
case .completed:
return Color.green.opacity(0.1)
case .error(_):
return Color.red.opacity(0.1)
}
}
private var borderColorForStatus: Color {
switch activity.status {
case .pending:
return Color.orange.opacity(0.3)
case .inProgress:
return Color.blue.opacity(0.5)
case .completed:
return Color.green.opacity(0.5)
case .error(_):
return Color.red.opacity(0.5)
}
}
}
struct AgentMessageRow: View {
let message: AgentMessage
var body: some View {
HStack(alignment: .top, spacing: 8) {
messageTypeIcon
VStack(alignment: .leading, spacing: 2) {
Text(message.content)
.font(.caption)
.foregroundColor(textColorForMessageType)
.lineLimit(messageTypeIsImportant ? nil : 3)
Text(formatTimestamp(message.timestamp))
.font(.caption2)
.foregroundColor(.secondary)
}
Spacer()
}
.padding(.leading, 8)
}
@ViewBuilder
private var messageTypeIcon: some View {
switch message.type {
case .reasoning:
Image(systemName: "brain")
.foregroundColor(.purple)
.font(.caption)
case .toolCall:
Image(systemName: "wrench")
.foregroundColor(.orange)
.font(.caption)
case .status:
Image(systemName: "info.circle")
.foregroundColor(.blue)
.font(.caption)
case .finalReport:
Image(systemName: "doc.text.fill")
.foregroundColor(.green)
.font(.caption)
}
}
private var textColorForMessageType: Color {
switch message.type {
case .reasoning:
return .purple
case .toolCall:
return .orange
case .status:
return .blue
case .finalReport:
return .primary
}
}
private var messageTypeIsImportant: Bool {
message.type == .finalReport
}
private func formatTimestamp(_ date: Date) -> String {
let formatter = DateFormatter()
formatter.timeStyle = .medium
return formatter.string(from: date)
}
}
// MARK: - Live Activity Dashboard
struct LiveActivityDashboard: View {
let agentActivities: [AgentActivity]
var body: some View {
VStack(alignment: .leading, spacing: 16) {
Text("Live Agent Activity")
.font(.title2)
.fontWeight(.bold)
if agentActivities.isEmpty {
Text("No activity yet...")
.foregroundColor(.secondary)
.italic()
} else {
LazyVStack(spacing: 12) {
ForEach(agentActivities) { activity in
AgentActivityCard(activity: activity)
}
}
}
}
.padding()
}
}
// MARK: - Progress Overview
struct ProgressOverview: View {
let agentActivities: [AgentActivity]
var body: some View {
VStack(alignment: .leading, spacing: 12) {
Text("Progress Overview")
.font(.headline)
HStack(spacing: 16) {
ProgressIndicator(
title: "Completed",
count: completedCount,
color: .green
)
ProgressIndicator(
title: "Active",
count: activeCount,
color: .blue
)
ProgressIndicator(
title: "Pending",
count: pendingCount,
color: .orange
)
}
}
.padding()
.background(Color(.systemGray6))
.cornerRadius(12)
}
private var completedCount: Int {
agentActivities.filter { $0.status == .completed }.count
}
private var activeCount: Int {
agentActivities.filter { $0.status == .inProgress }.count
}
private var pendingCount: Int {
agentActivities.filter { $0.status == .pending }.count
}
}
struct ProgressIndicator: View {
let title: String
let count: Int
let color: Color
var body: some View {
VStack {
Text("\(count)")
.font(.title2)
.fontWeight(.bold)
.foregroundColor(color)
Text(title)
.font(.caption)
.foregroundColor(.secondary)
}
}
}

View File

@ -1,161 +1,123 @@
import SwiftUI
import SwiftData
struct TradingAnalysisView: View {
@StateObject private var viewModel = TradingAnalysisViewModel()
@Environment(\.modelContext) private var modelContext
@State private var selectedTab = 0
var body: some View {
NavigationView {
VStack(spacing: 20) {
// Header
headerSection
// Input Section
inputSection
// Progress Section (shown during analysis)
if viewModel.isAnalyzing {
progressSection
}
// Reports Section (shown as reports come in)
if viewModel.hasReports && !viewModel.showingResults {
reportsSection
}
Spacer()
}
.padding()
.navigationTitle("Trading Analysis")
.onAppear {
viewModel.modelContext = modelContext
}
.alert("Error", isPresented: .constant(viewModel.errorMessage != nil)) {
Button("OK") {
viewModel.errorMessage = nil
}
} message: {
Text(viewModel.errorMessage ?? "")
}
.sheet(isPresented: $viewModel.showingResults) {
AnalysisResultView(
ticker: viewModel.ticker,
reports: viewModel.formattedReports,
finalDecision: viewModel.finalDecision,
onDismiss: {
viewModel.resetAnalysis()
VStack(spacing: 0) {
// Header Section
VStack(spacing: 16) {
Text("Trading Agents Analysis")
.font(.largeTitle)
.fontWeight(.bold)
// Ticker Input
HStack {
TextField("Enter ticker (e.g., AAPL)", text: $viewModel.ticker)
.textFieldStyle(RoundedBorderTextFieldStyle())
.autocapitalization(.allCharacters)
.disabled(viewModel.isAnalyzing)
Button(action: {
if viewModel.isAnalyzing {
viewModel.stopAnalysis()
} else {
viewModel.startAnalysis()
}
}) {
Text(viewModel.isAnalyzing ? "Stop" : "Analyze")
.fontWeight(.semibold)
.foregroundColor(.white)
.padding(.horizontal, 20)
.padding(.vertical, 10)
.background(viewModel.isAnalyzing ? Color.red : Color.blue)
.cornerRadius(8)
}
.disabled(viewModel.ticker.isEmpty && !viewModel.isAnalyzing)
}
)
}
}
}
// MARK: - View Components
private var headerSection: some View {
VStack(spacing: 8) {
Text("TradingAgents")
.font(.largeTitle)
.fontWeight(.bold)
.foregroundColor(.primary)
Text("AI-Powered Stock Analysis")
.font(.subheadline)
.foregroundColor(.secondary)
}
}
private var inputSection: some View {
VStack(spacing: 16) {
HStack {
TextField("Enter ticker symbol (e.g., AAPL)", text: $viewModel.ticker)
.textFieldStyle(RoundedBorderTextFieldStyle())
.autocapitalization(.allCharacters)
.autocorrectionDisabled(true)
.disabled(viewModel.isAnalyzing)
if viewModel.isAnalyzing {
Button("Stop") {
viewModel.stopAnalysis()
// Progress Overview
if viewModel.hasActivityToShow() {
ProgressOverview(agentActivities: viewModel.agentActivities)
}
.foregroundColor(.red)
} else {
Button("Analyze") {
viewModel.startAnalysis()
}
.disabled(viewModel.ticker.isEmpty)
}
}
if !viewModel.ticker.isEmpty && !viewModel.isAnalyzing {
Text("Tap 'Analyze' to start real-time analysis")
.font(.footnote)
.foregroundColor(.secondary)
}
}
}
private var progressSection: some View {
VStack(spacing: 16) {
// Progress Bar
VStack(spacing: 8) {
HStack {
Text("Analysis Progress")
.font(.headline)
Spacer()
Text("\(viewModel.progressPercentage)%")
.font(.caption)
.foregroundColor(.secondary)
}
ProgressView(value: viewModel.analysisProgress)
.progressViewStyle(LinearProgressViewStyle(tint: .blue))
}
// Current Agent Status
if !viewModel.currentAgent.isEmpty {
HStack {
Image(systemName: "brain.head.profile")
.foregroundColor(.blue)
VStack(alignment: .leading) {
Text(viewModel.currentAgent)
.font(.subheadline)
.fontWeight(.medium)
if !viewModel.statusMessage.isEmpty {
// Current Status
if viewModel.isAnalyzing && !viewModel.statusMessage.isEmpty {
HStack {
ProgressView()
.scaleEffect(0.8)
Text(viewModel.statusMessage)
.font(.caption)
.font(.subheadline)
.foregroundColor(.secondary)
}
.padding(.vertical, 8)
}
Spacer()
}
.padding()
.background(Color.gray.opacity(0.1))
.cornerRadius(8)
}
}
}
private var reportsSection: some View {
VStack(spacing: 12) {
HStack {
Text("Live Reports")
.font(.headline)
Spacer()
Text("\(viewModel.formattedReports.count) sections")
.font(.caption)
.foregroundColor(.secondary)
}
ScrollView {
LazyVStack(spacing: 8) {
ForEach(viewModel.formattedReports, id: \.title) { report in
ReportCardView(title: report.title, content: report.content)
.background(Color(.systemGray6))
// Content Tabs
if viewModel.hasActivityToShow() || viewModel.hasReports {
VStack(spacing: 0) {
// Tab Selector
Picker("View", selection: $selectedTab) {
Text("Live Activity").tag(0)
Text("Final Reports").tag(1)
}
.pickerStyle(SegmentedPickerStyle())
.padding()
// Tab Content
TabView(selection: $selectedTab) {
// Live Activity Tab
ScrollView {
LiveActivityDashboard(agentActivities: viewModel.agentActivities)
}
.tag(0)
// Final Reports Tab
ScrollView {
AnalysisResultView(
reports: viewModel.formattedReports,
finalDecision: viewModel.finalDecision
)
}
.tag(1)
}
.tabViewStyle(PageTabViewStyle(indexDisplayMode: .never))
}
} else if !viewModel.isAnalyzing {
// Placeholder when no activity
Spacer()
VStack(spacing: 16) {
Image(systemName: "chart.line.uptrend.xyaxis")
.font(.system(size: 60))
.foregroundColor(.secondary)
Text("Enter a ticker symbol and tap Analyze to start")
.font(.headline)
.foregroundColor(.secondary)
.multilineTextAlignment(.center)
}
.padding()
Spacer()
}
Spacer()
}
.frame(maxHeight: 300)
.navigationBarHidden(true)
}
.alert("Analysis Error", isPresented: .constant(viewModel.errorMessage != nil)) {
Button("OK") {
viewModel.errorMessage = nil
}
Button("Try Again") {
viewModel.startAnalysis()
}
} message: {
Text(viewModel.errorMessage ?? "")
}
}
}

142
backend/SERPAPI_SETUP.md Normal file
View File

@ -0,0 +1,142 @@
# SerpAPI Setup Guide
## Overview
The Trading Agents system now supports SerpAPI for significantly faster news retrieval. This replaces the slow web scraping approach with a fast, reliable API service.
## Performance Comparison
- **Web Scraping**: ~130 seconds for 300 articles
- **SerpAPI**: ~2-5 seconds for 300 articles
- **Speed Improvement**: 25-60x faster
## Setup Instructions
### 1. Get SerpAPI Key
1. Visit [SerpAPI](https://serpapi.com/)
2. Sign up for a free account
3. Get your API key from the dashboard
4. Free tier includes 100 searches per month
### 2. Set Environment Variable
Add your SerpAPI key to your environment:
```bash
# Option 1: Export in terminal
export SERPAPI_API_KEY="your_serpapi_key_here"
# Option 2: Add to .env file
echo "SERPAPI_API_KEY=your_serpapi_key_here" >> .env
# Option 3: Add to shell profile (permanent)
echo 'export SERPAPI_API_KEY="your_serpapi_key_here"' >> ~/.bashrc
source ~/.bashrc
```
### 3. Verify Setup
Run the test to verify SerpAPI is working:
```bash
python test_serpapi_news.py
```
You should see output like:
```
✅ SerpAPI Key Found: 1234567890...
🚀 Testing SerpAPI for query: AAPL
✅ SerpAPI Results: 100 articles in 2.34s
⚡ Speed Improvement: 56.7x faster with SerpAPI
```
## How It Works
The system automatically detects if a SerpAPI key is available:
1. **With SerpAPI key**: Uses fast SerpAPI service
2. **Without SerpAPI key**: Falls back to web scraping (slow but free)
## Integration
The SerpAPI integration is automatically used in:
- `get_google_news()` function
- News Analyst agent
- All news-related tools
No code changes needed - just set the environment variable!
## API Usage
The system uses the Google News search engine through SerpAPI:
```python
# Automatic usage in get_google_news()
news_data = get_google_news("AAPL", "2025-07-05", 7)
# Direct usage (if needed)
from tradingagents.dataflows.serpapi_utils import getNewsDataSerpAPI
results = getNewsDataSerpAPI("AAPL", "2025-07-01", "2025-07-05")
```
## Error Handling
The system includes robust error handling:
- **Invalid API key**: Falls back to web scraping
- **Rate limiting**: Respects API limits with delays
- **Network errors**: Graceful fallback to alternative methods
## Cost Considerations
- **Free tier**: 100 searches/month
- **Paid plans**: Start at $50/month for 5,000 searches
- **Usage**: Each news query = 1 search
- **Recommendation**: Monitor usage in SerpAPI dashboard
## Troubleshooting
### Common Issues
1. **"SerpAPI key not found"**
- Check environment variable is set
- Restart terminal/IDE after setting variable
2. **"SerpAPI Error: Invalid API key"**
- Verify key is correct
- Check key hasn't expired
3. **Still slow performance**
- Verify SerpAPI key is being used (check logs)
- Test with `test_serpapi_news.py`
### Debug Commands
```bash
# Check if environment variable is set
echo $SERPAPI_API_KEY
# Test SerpAPI directly
python test_serpapi_news.py
# Test integrated news function
python test_news_integration.py
```
## Benefits
**Speed**: 25-60x faster news retrieval
**Reliability**: Professional API service
**Fallback**: Automatic fallback to web scraping
**Easy Setup**: Just set environment variable
**Cost Effective**: Free tier for testing
**No Code Changes**: Drop-in replacement
## Next Steps
1. Sign up for SerpAPI account
2. Set environment variable
3. Test with provided scripts
4. Enjoy faster news analysis!

View File

@ -0,0 +1,262 @@
# TradingAgents Test Suite Documentation
## Overview
This document describes the comprehensive test suite created for the TradingAgents system, covering both `main.py` and the FastAPI implementation (`run_api.py`).
## Test Files Created
### 1. `test_main_comprehensive.py`
**Purpose**: Comprehensive test for main.py with continuous logging and parallel execution verification.
**Features**:
- Tests multiple configurations
- Tracks agent execution times
- Detects and logs parallel execution
- Provides detailed execution logs
- Saves results to `test_results/` directory
**Key Capabilities**:
- Custom `TestLogger` class for enhanced logging
- `TrackedTradingAgentsGraph` for message tracking
- Parallel execution detection
- Comprehensive validation of all required reports
### 2. `test_api_comprehensive.py`
**Purpose**: Tests all FastAPI endpoints including streaming and concurrent requests.
**Features**:
- Tests health check endpoint
- Tests root endpoint
- Tests synchronous `/analyze` endpoint
- Tests streaming `/analyze/stream` endpoint
- Tests parallel request handling
- Tests error handling
**Key Capabilities**:
- Automatic API server startup
- SSE (Server-Sent Events) stream parsing
- Parallel agent detection in streaming
- Comprehensive event tracking
### 3. `test_parallel_execution.py`
**Purpose**: Specifically verifies that agents execute in parallel when expected.
**Features**:
- `ParallelExecutionTracker` class
- Real-time parallel execution detection
- Timeline visualization
- Detailed execution summary
**Key Capabilities**:
- Tracks active agents in real-time
- Identifies parallel execution groups
- Creates execution timeline
- Saves parallel execution analysis
### 4. `test_main_simple.py`
**Purpose**: Simple test to verify basic main.py functionality.
**Features**:
- Basic import verification
- Simple propagation test
- Continuous progress logging
- Report validation
### 5. `run_all_tests.py`
**Purpose**: Test runner that executes all tests and provides a comprehensive summary.
**Features**:
- Automatic test discovery
- Individual test execution with timeout
- Comprehensive logging
- JSON summary generation
## Running the Tests
### Prerequisites
1. **Install Dependencies**:
```bash
# Create virtual environment (if needed)
python3 -m venv venv
source venv/bin/activate # On Windows: venv\Scripts\activate
# Install requirements
pip install -r requirements.txt
```
2. **Set Environment Variables** (if using OpenAI):
```bash
export OPENAI_API_KEY="your-api-key"
```
### Running Individual Tests
1. **Test main.py comprehensively**:
```bash
python3 test_main_comprehensive.py
```
2. **Test API comprehensively**:
```bash
# Start the API server first (in a separate terminal)
python3 run_api.py
# Then run the test
python3 test_api_comprehensive.py
```
3. **Test parallel execution**:
```bash
python3 test_parallel_execution.py
```
4. **Simple main.py test**:
```bash
python3 test_main_simple.py
```
### Running All Tests
```bash
python3 run_all_tests.py
```
This will:
- Check for available test files
- Run each test with a 5-minute timeout
- Generate a comprehensive summary
- Save logs to `test_results/`
## Test Output
### Log Files
All tests generate detailed logs in the `test_results/` directory:
- `main_test_YYYYMMDD_HHMMSS.log` - Main.py test logs
- `api_test_YYYYMMDD_HHMMSS.log` - API test logs
- `all_tests_YYYYMMDD_HHMMSS.log` - Combined test runner logs
- `test_summary.json` - JSON summary of all test results
### Parallel Execution Detection
The tests specifically track parallel execution. You should see output like:
```
🔄 PARALLEL EXECUTION DETECTED: ['market_analyst', 'social_analyst']
🔄 PARALLEL AGENTS: ['news_analyst', 'fundamentals_analyst']
```
### Continuous State Logging
Tests provide continuous updates:
```
📦 Chunk 1: Keys = ['messages']
🤖 Agent Active: MarketAnalyst
🔧 Tool Called: get_YFin_data_online
📄 MARKET_REPORT COMPLETED
✅ market_analyst completed in 5.23s
```
## Expected Results
### Successful Test Run
When all agents are working correctly, you should see:
1. **All reports generated**:
- market_report
- sentiment_report
- news_report
- fundamentals_report
- investment_plan
- trader_investment_plan
- final_trade_decision
2. **Parallel execution detected**:
- Multiple analysts running simultaneously
- Bull and Bear researchers running in parallel
3. **Reasonable execution times**:
- Total execution: 30-60 seconds
- Individual agents: 5-20 seconds
### Common Issues and Solutions
1. **Missing Dependencies**:
```
ModuleNotFoundError: No module named 'langchain_openai'
```
**Solution**: Install all requirements using pip
2. **API Key Issues**:
```
Error: Invalid API key
```
**Solution**: Set correct API keys in environment variables
3. **Timeout Issues**:
```
TIMEOUT - Test exceeded 5 minutes
```
**Solution**: Check network connection and API availability
4. **No Parallel Execution Detected**:
- Check if the graph configuration supports parallel execution
- Verify LangGraph version supports streaming
## Interpreting Results
### Parallel Execution Summary
The parallel execution test provides detailed analysis:
```
PARALLEL EXECUTION SUMMARY
================================================================================
Total parallel groups detected: 3
Maximum agents running in parallel: 4
Parallel execution instances:
1. [10:15:23.456] 2 agents: market_analyst, social_analyst
2. [10:15:24.789] 4 agents: market_analyst, social_analyst, news_analyst, fundamentals_analyst
3. [10:15:45.123] 2 agents: bull_researcher, bear_researcher
```
### API Streaming Events
The API test tracks streaming events:
```
📡 Streaming Events Summary:
AAPL: 45 events
- status: 2
- agent_status: 18
- report: 7
- progress: 8
- reasoning: 9
- complete: 1
```
## Continuous Improvement
To improve the system based on test results:
1. **Monitor Execution Times**: Long-running agents may need optimization
2. **Check Parallel Efficiency**: Ensure parallel agents don't wait unnecessarily
3. **Validate Report Quality**: Ensure all reports contain meaningful content
4. **Review Error Logs**: Fix any recurring errors or warnings
## Conclusion
This comprehensive test suite ensures:
- All agents execute correctly
- Parallel execution works as designed
- API endpoints function properly
- System handles concurrent requests
- Continuous logging provides visibility
Run these tests regularly to maintain system health and catch regressions early.

120
backend/TEST_SUMMARY.md Normal file
View File

@ -0,0 +1,120 @@
# TradingAgents Test Implementation Summary
## What Was Accomplished
I have created a comprehensive test suite for the TradingAgents system to ensure all agents are running correctly with proper parallel execution and continuous logging.
### Test Files Created:
1. **`test_main_comprehensive.py`** (280 lines)
- Comprehensive test for main.py
- Custom TestLogger class for enhanced logging
- TrackedTradingAgentsGraph for detailed message tracking
- Parallel execution detection and tracking
- Tests multiple configurations
- Saves detailed results and logs
2. **`test_api_comprehensive.py`** (365 lines)
- Complete FastAPI endpoint testing
- Tests health, root, analyze, and streaming endpoints
- Parallel request testing
- SSE stream parsing and event tracking
- Automatic server startup/shutdown
- Error handling validation
3. **`test_parallel_execution.py`** (252 lines)
- Focused on verifying parallel agent execution
- ParallelExecutionTracker class
- Real-time agent activity monitoring
- Timeline visualization
- Detailed parallel execution analysis
4. **`test_main_simple.py`** (99 lines)
- Simple functionality test
- Basic import and execution verification
- Continuous progress logging
5. **`run_all_tests.py`** (180 lines)
- Automated test runner
- Runs all tests with timeout protection
- Comprehensive logging and reporting
- JSON summary generation
6. **`TEST_DOCUMENTATION.md`** (265 lines)
- Complete documentation for the test suite
- Running instructions
- Expected results
- Troubleshooting guide
## Key Features Implemented:
### 1. Continuous Logging
- Real-time progress updates during test execution
- Detailed chunk-by-chunk processing logs
- Agent activity tracking
- Timestamp on all log entries
- File and console logging
### 2. Parallel Execution Verification
- Detects when multiple agents run simultaneously
- Tracks agent start/end times
- Identifies parallel execution groups
- Creates execution timelines
- Calculates maximum parallelism
### 3. Comprehensive Validation
- Verifies all required reports are generated
- Checks report content length
- Validates final trade decisions
- Tests error handling
- Measures execution performance
### 4. API Testing
- All endpoints tested
- SSE streaming support
- Concurrent request handling
- Automatic server management
- Event tracking and analysis
## How to Use:
1. **Ensure dependencies are installed**:
```bash
pip install -r requirements.txt
```
2. **Run all tests**:
```bash
python3 run_all_tests.py
```
3. **Run specific tests**:
```bash
python3 test_main_comprehensive.py
python3 test_api_comprehensive.py
python3 test_parallel_execution.py
```
4. **Check results**:
- Look in `test_results/` directory for logs
- Review `test_summary.json` for overview
- Check individual log files for details
## Expected Output:
When running correctly, you should see:
- Continuous progress updates
- Parallel execution detection messages
- All agents completing successfully
- All reports being generated
- Reasonable execution times (30-60 seconds)
## Next Steps:
1. Run the tests after installing dependencies
2. Review logs to identify any issues
3. Fix any failing components
4. Re-run tests to verify fixes
5. Use tests for regression testing
The test suite is now ready to help ensure the TradingAgents system is functioning correctly with proper parallel execution and comprehensive logging.

View File

@ -230,6 +230,8 @@ async def stream_analysis(ticker: str):
try:
print(f"📡 Starting event stream for {ticker}")
# Initialize trading graph with all analysts
print("🔧 Initializing trading graph...")
config = get_config()
@ -237,7 +239,7 @@ async def stream_analysis(ticker: str):
graph = TradingAgentsGraph(
selected_analysts=["market", "social", "news", "fundamentals"],
debug=True, # Enable debug mode for detailed logging
debug=True, # Enable debug mode
config=config
)
print("✅ Trading graph initialized")
@ -290,7 +292,7 @@ async def stream_analysis(ticker: str):
agent_progress["Fundamentals Analyst"] = "in_progress"
for event in initial_events:
print(f"<EFBFBD> Sending initial: {event[:100]}...")
print(f"📤 Sending initial: {event[:100]}...")
yield f"data: {event}\n\n"
# Real-time streaming using graph.stream()
@ -306,49 +308,80 @@ async def stream_analysis(ticker: str):
message_channels = ["market_messages", "social_messages", "news_messages", "fundamentals_messages"]
for channel in message_channels:
if channel in chunk and chunk[channel]:
analyst_type = channel.replace("_messages", "")
messages = chunk[channel]
print(f"<EFBFBD> {analyst_type.upper()}: {len(messages)} messages")
if channel in chunk and len(chunk[channel]) > 0:
# Extract agent name from channel (e.g., 'market_messages' -> 'market')
agent_name = channel.replace('_messages', '')
if messages:
last_message = messages[-1]
print(f"💬 Processing {len(chunk[channel])} messages from {channel} ({agent_name.upper()} AGENT)")
# Process messages for agent detection
last_message = chunk[channel][-1]
print(f"📨 Last message type: {type(last_message)}")
# Enhanced logging - Print raw message details
print(f"🌐 RAW MESSAGE ATTRS: {[attr for attr in dir(last_message) if not attr.startswith('_')]}")
# Log different message types
if hasattr(last_message, 'name') and last_message.name:
print(f"🤖 AGENT NAME: {last_message.name}")
if hasattr(last_message, 'tool_calls') and last_message.tool_calls:
print(f"🔧 [{agent_name.upper()}] TOOL CALLS: {len(last_message.tool_calls)} tools invoked")
for i, tool_call in enumerate(last_message.tool_calls):
tool_name = tool_call.name if hasattr(tool_call, 'name') else 'Unknown'
print(f"🔧 [{agent_name.upper()}] TOOL[{i}]: {tool_name}")
if hasattr(tool_call, 'args'):
print(f"🔧 [{agent_name.upper()}] TOOL[{i}] ARGS: {json.dumps(tool_call.args, indent=2) if isinstance(tool_call.args, dict) else tool_call.args}")
if hasattr(last_message, "content"):
content = str(last_message.content) if hasattr(last_message.content, '__str__') else str(last_message.content)
# Send reasoning updates for analyst messages
if hasattr(last_message, 'content') and last_message.content:
# Map analyst type to agent name
agent_map = {
"market": "market",
"social": "social",
"news": "news",
"fundamentals": "fundamentals"
}
agent_name = agent_map.get(analyst_type, analyst_type)
# Check if it's a tool call
if hasattr(last_message, 'tool_calls') and last_message.tool_calls:
tool_names = [tc.name if hasattr(tc, 'name') else 'Unknown' for tc in last_message.tool_calls]
reasoning_content = f"<EFBFBD> Using {', '.join(tool_names)} to gather data..."
else:
# Regular reasoning message
content = str(last_message.content)
if len(content) > 300:
reasoning_content = f"📊 Processing data from tools and analyzing results..."
# Enhanced logging - Print raw content structure
print(f"📋 [{agent_name.upper()}] RAW CONTENT TYPE: {type(last_message.content)}")
print(f"📋 [{agent_name.upper()}] RAW CONTENT LENGTH: {len(last_message.content) if hasattr(last_message.content, '__len__') else 'N/A'}")
# Extract text content if it's a list
if isinstance(last_message.content, list):
print(f"📋 [{agent_name.upper()}] CONTENT LIST LENGTH: {len(last_message.content)}")
text_parts = []
for j, part in enumerate(last_message.content):
print(f"📋 [{agent_name.upper()}] CONTENT[{j}] TYPE: {type(part)}")
if hasattr(part, 'text'):
text_parts.append(part.text)
print(f"📋 [{agent_name.upper()}] CONTENT[{j}] TEXT (first 200 chars): {part.text[:200]}...")
elif isinstance(part, str):
text_parts.append(part)
print(f"📋 [{agent_name.upper()}] CONTENT[{j}] STRING (first 200 chars): {part[:200]}...")
else:
reasoning_content = content[:200] + "..." if len(content) > 200 else content
text_parts.append(str(part))
print(f"📋 [{agent_name.upper()}] CONTENT[{j}] OTHER: {str(part)[:200]}...")
content = " ".join(text_parts)
else:
# Single content item
print(f"📋 [{agent_name.upper()}] SINGLE CONTENT (first 500 chars): {content[:500]}...")
# Log full content for debugging (can be toggled)
if os.getenv("LOG_FULL_CONTENT", "false").lower() == "true":
print(f"📝 [{agent_name.upper()}] FULL CONTENT:\n{content}\n")
# Send reasoning updates WITH agent information
if isinstance(content, str) and content.strip():
reasoning_event = json.dumps({
'type': 'reasoning',
'agent': agent_name,
'content': reasoning_content
'content': content[:500]
})
print(f"📤 [{analyst_type.upper()}] Sending reasoning: {reasoning_event[:100]}...")
print(f"📤 [{agent_name.upper()}] Sending reasoning: {reasoning_event[:100]}...")
yield f"data: {reasoning_event}\n\n"
await asyncio.sleep(0.3)
# Check for tool message responses
if hasattr(last_message, 'type') and str(getattr(last_message, 'type', '')) == 'tool':
print(f"🛠️ TOOL RESPONSE for {analyst_type}")
# Log tool message responses
if hasattr(last_message, 'type') and str(last_message.type) == 'tool':
print(f"🛠️ TOOL MESSAGE DETECTED")
if hasattr(last_message, 'tool_call_id'):
print(f"🛠️ TOOL CALL ID: {last_message.tool_call_id}")
if hasattr(last_message, 'content'):
print(f"🛠️ TOOL RESPONSE LENGTH: {len(last_message.content)} chars")
print(f"🛠️ TOOL RESPONSE PREVIEW (first 500 chars):\n{last_message.content[:500]}...")
# Handle section completions and send progress updates
if "market_report" in chunk and chunk["market_report"] and "market_report" not in reports_completed:
@ -357,7 +390,6 @@ async def stream_analysis(ticker: str):
reports_completed.append("market_report")
events = [
json.dumps({'type': 'reasoning', 'agent': 'market', 'content': '✅ Completing market analysis and generating final report...'}),
json.dumps({'type': 'agent_status', 'agent': 'market', 'status': 'completed'}),
json.dumps({'type': 'report', 'section': 'market_report', 'content': chunk['market_report']}),
json.dumps({'type': 'progress', 'content': '25'})
@ -373,7 +405,6 @@ async def stream_analysis(ticker: str):
reports_completed.append("sentiment_report")
events = [
json.dumps({'type': 'reasoning', 'agent': 'social', 'content': '✅ Completing social analysis and generating final report...'}),
json.dumps({'type': 'agent_status', 'agent': 'social', 'status': 'completed'}),
json.dumps({'type': 'report', 'section': 'sentiment_report', 'content': chunk['sentiment_report']}),
json.dumps({'type': 'progress', 'content': '40'})
@ -389,7 +420,6 @@ async def stream_analysis(ticker: str):
reports_completed.append("news_report")
events = [
json.dumps({'type': 'reasoning', 'agent': 'news', 'content': '✅ Completing news analysis and generating final report...'}),
json.dumps({'type': 'agent_status', 'agent': 'news', 'status': 'completed'}),
json.dumps({'type': 'report', 'section': 'news_report', 'content': chunk['news_report']}),
json.dumps({'type': 'progress', 'content': '55'})
@ -403,11 +433,8 @@ async def stream_analysis(ticker: str):
print("✅ Fundamentals report completed!")
agent_progress["Fundamentals Analyst"] = "completed"
# All initial analysts done - start research team
all_analysts_done = all(
agent_progress[agent] == "completed"
for agent in ["Market Analyst", "Social Media Analyst", "News Analyst", "Fundamentals Analyst"]
)
# Start research team only if all analysts are done
all_analysts_done = all(agent_progress[agent] == "completed" for agent in ["Market Analyst", "Social Media Analyst", "News Analyst", "Fundamentals Analyst"])
if all_analysts_done:
agent_progress["Bull Researcher"] = "in_progress"
@ -416,12 +443,12 @@ async def stream_analysis(ticker: str):
reports_completed.append("fundamentals_report")
events = [
json.dumps({'type': 'reasoning', 'agent': 'fundamentals', 'content': '✅ Completing fundamentals analysis and generating final report...'}),
json.dumps({'type': 'agent_status', 'agent': 'fundamentals', 'status': 'completed'}),
json.dumps({'type': 'report', 'section': 'fundamentals_report', 'content': chunk['fundamentals_report']}),
json.dumps({'type': 'progress', 'content': '70'})
]
# Add research team start events if all analysts are done
if all_analysts_done:
events.extend([
json.dumps({'type': 'agent_status', 'agent': 'bull_researcher', 'status': 'in_progress'}),
@ -437,33 +464,6 @@ async def stream_analysis(ticker: str):
print("🔄 Processing investment debate state...")
debate_state = chunk["investment_debate_state"]
# Send real-time updates for Bull/Bear
if debate_state.get("current_response"):
current_response = debate_state["current_response"]
if "Bull" in current_response and agent_progress["Bull Researcher"] == "in_progress":
# Extract Bull reasoning
bull_content = current_response.split("Bull Analyst:")[-1].strip() if "Bull Analyst:" in current_response else current_response
bull_reasoning = json.dumps({
'type': 'reasoning',
'agent': 'bull_researcher',
'content': f'🐂 {bull_content[:300]}...' if len(bull_content) > 300 else f'🐂 {bull_content}'
})
yield f"data: {bull_reasoning}\n\n"
await asyncio.sleep(0.3)
elif "Bear" in current_response and agent_progress["Bear Researcher"] == "in_progress":
# Extract Bear reasoning
bear_content = current_response.split("Bear Analyst:")[-1].strip() if "Bear Analyst:" in current_response else current_response
bear_reasoning = json.dumps({
'type': 'reasoning',
'agent': 'bear_researcher',
'content': f'🐻 {bear_content[:300]}...' if len(bear_content) > 300 else f'🐻 {bear_content}'
})
yield f"data: {bear_reasoning}\n\n"
await asyncio.sleep(0.3)
# Check for investment plan completion
if "judge_decision" in debate_state and debate_state["judge_decision"] and "investment_plan" not in reports_completed:
print("✅ Investment plan completed!")
agent_progress["Bull Researcher"] = "completed"
@ -489,73 +489,41 @@ async def stream_analysis(ticker: str):
if "trader_investment_plan" in chunk and chunk["trader_investment_plan"] and "trader_investment_plan" not in reports_completed:
print("✅ Trading plan completed!")
agent_progress["Trading Team"] = "completed"
reports_completed.append("trader_investment_plan")
# Trading team done - start risk analysts in parallel
agent_progress["Risky Analyst"] = "in_progress"
agent_progress["Risky Analyst"] = "in_progress" # Start risk analysis phase
agent_progress["Safe Analyst"] = "in_progress"
agent_progress["Neutral Analyst"] = "in_progress"
reports_completed.append("trader_investment_plan")
events = [
json.dumps({'type': 'reasoning', 'agent': 'trader', 'content': '💼 Trading strategy finalized...'}),
json.dumps({'type': 'agent_status', 'agent': 'trader', 'status': 'completed'}),
json.dumps({'type': 'agent_status', 'agent': 'risky_analyst', 'status': 'in_progress'}),
json.dumps({'type': 'agent_status', 'agent': 'safe_analyst', 'status': 'in_progress'}),
json.dumps({'type': 'agent_status', 'agent': 'neutral_analyst', 'status': 'in_progress'}),
json.dumps({'type': 'report', 'section': 'trader_investment_plan', 'content': chunk['trader_investment_plan']}),
json.dumps({'type': 'progress', 'content': '90'}),
# Start risk analysts
json.dumps({'type': 'agent_status', 'agent': 'risk_risky', 'status': 'in_progress'}),
json.dumps({'type': 'agent_status', 'agent': 'risk_safe', 'status': 'in_progress'}),
json.dumps({'type': 'agent_status', 'agent': 'risk_neutral', 'status': 'in_progress'})
json.dumps({'type': 'progress', 'content': '90'})
]
for event in events:
print(f"📤 Sending: {event[:100]}...")
yield f"data: {event}\n\n"
# Handle risk analysts
# Handle risk analysis completion
if "risk_debate_state" in chunk and chunk["risk_debate_state"]:
print("🔄 Processing risk debate state...")
risk_state = chunk["risk_debate_state"]
# Send real-time updates for risk analysts
if risk_state.get("current_risky_response") and agent_progress["Risky Analyst"] == "in_progress":
risky_reasoning = json.dumps({
'type': 'reasoning',
'agent': 'risk_risky',
'content': f'{risk_state["current_risky_response"][:300]}...' if len(risk_state["current_risky_response"]) > 300 else f'{risk_state["current_risky_response"]}'
})
yield f"data: {risky_reasoning}\n\n"
agent_progress["Risky Analyst"] = "completed"
completion_event = json.dumps({'type': 'agent_status', 'agent': 'risk_risky', 'status': 'completed'})
yield f"data: {completion_event}\n\n"
if risk_state.get("current_safe_response") and agent_progress["Safe Analyst"] == "in_progress":
safe_reasoning = json.dumps({
'type': 'reasoning',
'agent': 'risk_safe',
'content': f'🛡️ {risk_state["current_safe_response"][:300]}...' if len(risk_state["current_safe_response"]) > 300 else f'🛡️ {risk_state["current_safe_response"]}'
})
yield f"data: {safe_reasoning}\n\n"
agent_progress["Safe Analyst"] = "completed"
completion_event = json.dumps({'type': 'agent_status', 'agent': 'risk_safe', 'status': 'completed'})
yield f"data: {completion_event}\n\n"
if risk_state.get("current_neutral_response") and agent_progress["Neutral Analyst"] == "in_progress":
neutral_reasoning = json.dumps({
'type': 'reasoning',
'agent': 'risk_neutral',
'content': f'⚖️ {risk_state["current_neutral_response"][:300]}...' if len(risk_state["current_neutral_response"]) > 300 else f'⚖️ {risk_state["current_neutral_response"]}'
})
yield f"data: {neutral_reasoning}\n\n"
agent_progress["Neutral Analyst"] = "completed"
completion_event = json.dumps({'type': 'agent_status', 'agent': 'risk_neutral', 'status': 'completed'})
yield f"data: {completion_event}\n\n"
# Check for risk analysis completion
if risk_state.get("judge_decision") and "risk_analysis" not in reports_completed:
if "judge_decision" in risk_state and risk_state["judge_decision"] and "risk_analysis" not in reports_completed:
print("✅ Risk analysis completed!")
agent_progress["Risky Analyst"] = "completed"
agent_progress["Safe Analyst"] = "completed"
agent_progress["Neutral Analyst"] = "completed"
agent_progress["Risk Manager"] = "completed"
reports_completed.append("risk_analysis")
events = [
json.dumps({'type': 'agent_status', 'agent': 'risky_analyst', 'status': 'completed'}),
json.dumps({'type': 'agent_status', 'agent': 'safe_analyst', 'status': 'completed'}),
json.dumps({'type': 'agent_status', 'agent': 'neutral_analyst', 'status': 'completed'}),
json.dumps({'type': 'agent_status', 'agent': 'risk_manager', 'status': 'completed'}),
json.dumps({'type': 'report', 'section': 'risk_analysis', 'content': risk_state['judge_decision']}),
json.dumps({'type': 'progress', 'content': '95'})
@ -585,6 +553,28 @@ async def stream_analysis(ticker: str):
final_state = trace[-1] if trace else {}
processed_signal = graph.process_signal(final_state.get("final_trade_decision", ""))
# Save results to disk (same as regular analyze endpoint)
try:
results = {
"ticker": ticker,
"analysis_date": analysis_date,
"market_report": final_state.get("market_report"),
"sentiment_report": final_state.get("sentiment_report"),
"news_report": final_state.get("news_report"),
"fundamentals_report": final_state.get("fundamentals_report"),
"investment_plan": final_state.get("investment_plan"),
"trader_investment_plan": final_state.get("trader_investment_plan"),
"final_trade_decision": final_state.get("final_trade_decision"),
"processed_signal": processed_signal
}
config = get_config()
saved_path = save_results_to_disk(ticker, analysis_date, results, config)
print(f"✅ Results saved to: {saved_path}")
except Exception as save_error:
print(f"⚠️ Failed to save results: {save_error}")
# Send completion
completion_event = json.dumps({'type': 'complete', 'message': 'Analysis completed successfully', 'signal': processed_signal})
print(f"📤 Sending completion: {completion_event}")

View File

@ -28,3 +28,4 @@ fastapi
pydantic
uvicorn[standard]
python-dotenv
google-search-results

195
backend/run_all_tests.py Normal file
View File

@ -0,0 +1,195 @@
#!/usr/bin/env python
"""
Run all TradingAgents tests and provide comprehensive summary
"""
import subprocess
import sys
import time
from datetime import datetime
from pathlib import Path
import json
class TestRunner:
"""Run and track all tests"""
def __init__(self):
self.results = []
self.start_time = time.time()
self.log_file = f"test_results/all_tests_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
Path("test_results").mkdir(exist_ok=True)
def log(self, message):
"""Log message to console and file"""
print(message)
with open(self.log_file, 'a') as f:
f.write(message + '\n')
def run_test(self, test_name, test_file, description):
"""Run a single test file"""
self.log(f"\n{'='*80}")
self.log(f"Running: {test_name}")
self.log(f"Description: {description}")
self.log(f"File: {test_file}")
self.log("="*80)
start_time = time.time()
try:
# Run the test
result = subprocess.run(
['python3', test_file],
capture_output=True,
text=True,
timeout=300 # 5 minute timeout
)
duration = time.time() - start_time
success = result.returncode == 0
# Log output
if result.stdout:
self.log("\nSTDOUT:")
self.log(result.stdout)
if result.stderr:
self.log("\nSTDERR:")
self.log(result.stderr)
# Track result
self.results.append({
'test': test_name,
'file': test_file,
'success': success,
'duration': duration,
'return_code': result.returncode
})
status = "✅ PASSED" if success else "❌ FAILED"
self.log(f"\n{status} - {test_name} ({duration:.2f}s)")
except subprocess.TimeoutExpired:
duration = time.time() - start_time
self.log(f"\n⏱️ TIMEOUT - {test_name} exceeded 5 minutes")
self.results.append({
'test': test_name,
'file': test_file,
'success': False,
'duration': duration,
'error': 'Timeout'
})
except Exception as e:
duration = time.time() - start_time
self.log(f"\n💥 ERROR - {test_name}: {str(e)}")
self.results.append({
'test': test_name,
'file': test_file,
'success': False,
'duration': duration,
'error': str(e)
})
def print_summary(self):
"""Print test summary"""
total_duration = time.time() - self.start_time
passed = sum(1 for r in self.results if r['success'])
total = len(self.results)
self.log("\n" + "="*80)
self.log("TEST SUMMARY")
self.log("="*80)
self.log(f"\nTotal Tests: {total}")
self.log(f"Passed: {passed}")
self.log(f"Failed: {total - passed}")
self.log(f"Total Duration: {total_duration:.2f}s")
self.log("\nIndividual Results:")
for result in self.results:
status = "" if result['success'] else ""
self.log(f" {status} {result['test']} ({result['duration']:.2f}s)")
if 'error' in result:
self.log(f" Error: {result['error']}")
# Save summary to JSON
summary_file = Path("test_results/test_summary.json")
with open(summary_file, 'w') as f:
json.dump({
'timestamp': datetime.now().isoformat(),
'total_tests': total,
'passed': passed,
'failed': total - passed,
'total_duration': total_duration,
'results': self.results
}, f, indent=2)
self.log(f"\n📁 Log saved to: {self.log_file}")
self.log(f"📊 Summary saved to: {summary_file}")
return passed == total
def main():
"""Run all tests"""
runner = TestRunner()
runner.log("🚀 TradingAgents Comprehensive Test Suite")
runner.log(f"📅 Started at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
# Define tests to run
tests = [
{
'name': 'Main.py Comprehensive Test',
'file': 'test_main_comprehensive.py',
'description': 'Tests main.py with continuous logging and parallel execution verification'
},
{
'name': 'API Comprehensive Test',
'file': 'test_api_comprehensive.py',
'description': 'Tests FastAPI endpoints, streaming, and concurrent requests'
},
{
'name': 'Parallel Execution Test',
'file': 'test_parallel_execution.py',
'description': 'Specifically verifies agents run in parallel when expected'
},
{
'name': 'Basic API Test',
'file': 'test_api.py',
'description': 'Basic API endpoint tests'
}
]
# Check which test files exist
runner.log("\n📂 Checking for test files...")
available_tests = []
for test in tests:
if Path(test['file']).exists():
runner.log(f" ✅ Found: {test['file']}")
available_tests.append(test)
else:
runner.log(f" ❌ Missing: {test['file']}")
if not available_tests:
runner.log("\n❌ No test files found!")
return False
# Run available tests
runner.log(f"\n🧪 Running {len(available_tests)} tests...")
for test in available_tests:
runner.run_test(test['name'], test['file'], test['description'])
# Print summary
all_passed = runner.print_summary()
if all_passed:
runner.log("\n✅ All tests passed!")
return True
else:
runner.log("\n❌ Some tests failed. Please check the logs.")
return False
if __name__ == "__main__":
success = main()
sys.exit(0 if success else 1)

View File

@ -0,0 +1,427 @@
#!/usr/bin/env python
"""
Comprehensive test for FastAPI run_api.py - Tests all endpoints, streaming, and parallel execution
"""
import requests
import json
import time
import threading
import asyncio
from datetime import datetime
from collections import defaultdict
from pathlib import Path
import sys
import multiprocessing
class APITestLogger:
"""Enhanced logger for API testing"""
def __init__(self):
self.log_file = f"test_results/api_test_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
Path("test_results").mkdir(exist_ok=True)
self.test_results = []
self.stream_events = defaultdict(list)
def log(self, message, level="INFO"):
"""Log with timestamp and level"""
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
log_entry = f"[{timestamp}] [{level}] {message}"
print(log_entry)
# Also write to file
with open(self.log_file, 'a') as f:
f.write(log_entry + '\n')
def log_test_result(self, test_name, success, duration, details=""):
"""Log test result"""
result = {
'test': test_name,
'success': success,
'duration': duration,
'details': details,
'timestamp': datetime.now().isoformat()
}
self.test_results.append(result)
status = "✅ PASS" if success else "❌ FAIL"
self.log(f"{status} {test_name} ({duration:.2f}s) {details}", "RESULT")
def log_stream_event(self, ticker, event):
"""Log streaming event"""
self.stream_events[ticker].append({
'time': time.time(),
'event': event
})
def print_summary(self):
"""Print test summary"""
self.log("\n" + "="*80, "SUMMARY")
self.log("API TEST SUMMARY", "SUMMARY")
self.log("="*80, "SUMMARY")
passed = sum(1 for r in self.test_results if r['success'])
total = len(self.test_results)
self.log(f"\n📊 Test Results: {passed}/{total} passed", "SUMMARY")
for result in self.test_results:
status = "" if result['success'] else ""
self.log(f" {status} {result['test']} ({result['duration']:.2f}s)", "SUMMARY")
# Stream event summary
if self.stream_events:
self.log("\n📡 Streaming Events Summary:", "SUMMARY")
for ticker, events in self.stream_events.items():
self.log(f" {ticker}: {len(events)} events", "SUMMARY")
# Count event types
event_types = defaultdict(int)
for event_data in events:
if 'type' in event_data['event']:
event_types[event_data['event']['type']] += 1
for event_type, count in event_types.items():
self.log(f" - {event_type}: {count}", "SUMMARY")
self.log(f"\n📁 Full log saved to: {self.log_file}", "SUMMARY")
def start_api_server(logger):
"""Start the API server in a separate process"""
logger.log("🚀 Starting API server...", "SERVER")
def run_server():
import subprocess
import os
env = os.environ.copy()
# Ensure the server runs on the expected port
env['API_PORT'] = '8000'
subprocess.run([sys.executable, "run_api.py"], env=env)
server_process = multiprocessing.Process(target=run_server)
server_process.daemon = True
server_process.start()
# Wait for server to start
logger.log("⏳ Waiting for server to start...", "SERVER")
time.sleep(5)
# Check if server is running
max_retries = 10
for i in range(max_retries):
try:
response = requests.get("http://localhost:8000/health")
if response.status_code == 200:
logger.log("✅ API server is running", "SERVER")
return server_process
except:
pass
time.sleep(2)
logger.log("❌ Failed to start API server", "ERROR")
return None
def test_health_endpoint(base_url, logger):
"""Test health check endpoint"""
test_name = "Health Check"
start_time = time.time()
try:
response = requests.get(f"{base_url}/health", timeout=5)
duration = time.time() - start_time
if response.status_code == 200 and response.json().get("status") == "healthy":
logger.log_test_result(test_name, True, duration, "Server is healthy")
else:
logger.log_test_result(test_name, False, duration, f"Unexpected response: {response.text}")
except Exception as e:
duration = time.time() - start_time
logger.log_test_result(test_name, False, duration, str(e))
def test_root_endpoint(base_url, logger):
"""Test root endpoint"""
test_name = "Root Endpoint"
start_time = time.time()
try:
response = requests.get(f"{base_url}/", timeout=5)
duration = time.time() - start_time
if response.status_code == 200:
logger.log_test_result(test_name, True, duration, f"Response: {response.json()}")
else:
logger.log_test_result(test_name, False, duration, f"Status: {response.status_code}")
except Exception as e:
duration = time.time() - start_time
logger.log_test_result(test_name, False, duration, str(e))
def test_analyze_endpoint(base_url, ticker, logger):
"""Test synchronous analysis endpoint"""
test_name = f"Analyze Endpoint ({ticker})"
start_time = time.time()
logger.log(f"\n🔍 Testing analysis for {ticker}...", "TEST")
logger.log("⏳ This may take 30-60 seconds...", "TEST")
try:
response = requests.post(
f"{base_url}/analyze",
json={"ticker": ticker},
timeout=120 # 2 minute timeout
)
duration = time.time() - start_time
if response.status_code == 200:
result = response.json()
# Check for required fields
required_fields = ['ticker', 'analysis_date', 'market_report',
'sentiment_report', 'news_report', 'fundamentals_report',
'final_trade_decision', 'processed_signal']
missing_fields = [f for f in required_fields if not result.get(f)]
if not missing_fields and not result.get('error'):
logger.log_test_result(test_name, True, duration,
f"Signal: {result.get('processed_signal', 'N/A')}")
# Log report sizes
for field in required_fields[2:]: # Skip ticker and date
if result.get(field):
logger.log(f" 📄 {field}: {len(str(result[field]))} chars", "INFO")
else:
details = f"Missing fields: {missing_fields}" if missing_fields else f"Error: {result.get('error')}"
logger.log_test_result(test_name, False, duration, details)
else:
logger.log_test_result(test_name, False, duration,
f"Status: {response.status_code}, Response: {response.text[:200]}")
except Exception as e:
duration = time.time() - start_time
logger.log_test_result(test_name, False, duration, str(e))
def test_streaming_endpoint(base_url, ticker, logger):
"""Test streaming analysis endpoint"""
test_name = f"Streaming Endpoint ({ticker})"
start_time = time.time()
logger.log(f"\n📡 Testing streaming analysis for {ticker}...", "TEST")
try:
# Track streaming events
events_received = []
agent_progress = {}
reports_received = []
with requests.get(f"{base_url}/analyze/stream?ticker={ticker}", stream=True, timeout=120) as response:
if response.status_code != 200:
logger.log_test_result(test_name, False, time.time() - start_time,
f"Status: {response.status_code}")
return
# Process SSE stream
for line in response.iter_lines():
if line:
line_str = line.decode('utf-8')
if line_str.startswith('data: '):
try:
event_data = json.loads(line_str[6:])
events_received.append(event_data)
logger.log_stream_event(ticker, event_data)
# Log different event types
event_type = event_data.get('type', 'unknown')
if event_type == 'status':
logger.log(f" 📊 Status: {event_data.get('message', '')}", "STREAM")
elif event_type == 'agent_status':
agent = event_data.get('agent', 'unknown')
status = event_data.get('status', 'unknown')
agent_progress[agent] = status
logger.log(f" 🤖 Agent '{agent}' -> {status}", "STREAM")
# Check for parallel execution
active_agents = [a for a, s in agent_progress.items() if s == 'in_progress']
if len(active_agents) > 1:
logger.log(f" 🔄 PARALLEL AGENTS: {active_agents}", "PARALLEL")
elif event_type == 'report':
section = event_data.get('section', 'unknown')
content_len = len(event_data.get('content', ''))
reports_received.append(section)
logger.log(f" 📄 Report received: {section} ({content_len} chars)", "STREAM")
elif event_type == 'progress':
progress = event_data.get('content', '0')
logger.log(f" 📈 Progress: {progress}%", "STREAM")
elif event_type == 'reasoning':
content_preview = event_data.get('content', '')[:100]
logger.log(f" 💭 Reasoning: {content_preview}...", "STREAM")
elif event_type == 'complete':
signal = event_data.get('signal', 'N/A')
logger.log(f" ✅ Complete! Signal: {signal}", "STREAM")
break
elif event_type == 'error':
logger.log(f" ❌ Error: {event_data.get('message', 'Unknown error')}", "ERROR")
break
except json.JSONDecodeError as e:
logger.log(f" ⚠️ Failed to parse SSE data: {e}", "WARNING")
duration = time.time() - start_time
# Validate results
success = (
len(events_received) > 0 and
len(reports_received) >= 6 and # Should receive all main reports
any(e.get('type') == 'complete' for e in events_received)
)
details = f"Events: {len(events_received)}, Reports: {len(reports_received)}, Agents: {len(agent_progress)}"
logger.log_test_result(test_name, success, duration, details)
except Exception as e:
duration = time.time() - start_time
logger.log_test_result(test_name, False, duration, str(e))
def test_parallel_requests(base_url, logger):
"""Test multiple parallel requests to verify server handles concurrent load"""
test_name = "Parallel Requests"
start_time = time.time()
logger.log("\n🔄 Testing parallel requests...", "TEST")
tickers = ["AAPL", "GOOGL", "MSFT"]
threads = []
results = []
def analyze_ticker(ticker):
try:
response = requests.post(
f"{base_url}/analyze",
json={"ticker": ticker},
timeout=120
)
results.append({
'ticker': ticker,
'success': response.status_code == 200,
'time': time.time() - start_time
})
except Exception as e:
results.append({
'ticker': ticker,
'success': False,
'error': str(e),
'time': time.time() - start_time
})
# Start parallel requests
for ticker in tickers:
thread = threading.Thread(target=analyze_ticker, args=(ticker,))
thread.start()
threads.append(thread)
logger.log(f" 🚀 Started request for {ticker}", "PARALLEL")
# Wait for all to complete
for thread in threads:
thread.join()
duration = time.time() - start_time
# Check results
successful = sum(1 for r in results if r['success'])
details = f"Success: {successful}/{len(tickers)}, Total time: {duration:.2f}s"
for result in results:
status = "" if result['success'] else ""
logger.log(f" {status} {result['ticker']} completed at {result['time']:.2f}s", "PARALLEL")
logger.log_test_result(test_name, successful == len(tickers), duration, details)
def test_error_handling(base_url, logger):
"""Test API error handling"""
test_name = "Error Handling"
start_time = time.time()
logger.log("\n🛡️ Testing error handling...", "TEST")
# Test invalid ticker
try:
response = requests.post(
f"{base_url}/analyze",
json={"ticker": ""},
timeout=10
)
if response.status_code == 400 or (response.status_code == 200 and 'error' in response.json()):
logger.log(" ✅ Empty ticker handled correctly", "TEST")
else:
logger.log(" ❌ Empty ticker not handled properly", "TEST")
except Exception as e:
logger.log(f" ❌ Error testing invalid ticker: {e}", "ERROR")
duration = time.time() - start_time
logger.log_test_result(test_name, True, duration, "Error handling tested")
def run_comprehensive_api_tests():
"""Run all API tests comprehensively"""
logger = APITestLogger()
logger.log("🚀 Starting Comprehensive TradingAgents API Test Suite", "START")
logger.log(f"📅 Test started at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", "START")
logger.log("-" * 80)
# Base URL
base_url = "http://localhost:8000"
# Check if server is already running
try:
response = requests.get(f"{base_url}/health", timeout=5)
if response.status_code == 200:
logger.log("✅ API server already running", "SERVER")
server_process = None
except:
# Start server if not running
server_process = start_api_server(logger)
if not server_process:
logger.log("❌ Cannot start API server. Please run 'python run_api.py' manually", "ERROR")
return
try:
# Run tests
test_health_endpoint(base_url, logger)
test_root_endpoint(base_url, logger)
# Test with different tickers
test_analyze_endpoint(base_url, "NVDA", logger)
test_streaming_endpoint(base_url, "AAPL", logger)
# Test parallel handling
test_parallel_requests(base_url, logger)
# Test error handling
test_error_handling(base_url, logger)
# Print summary
logger.print_summary()
finally:
# Clean up server if we started it
if server_process:
logger.log("\n🛑 Stopping API server...", "SERVER")
server_process.terminate()
server_process.join(timeout=5)
if __name__ == "__main__":
run_comprehensive_api_tests()

View File

@ -147,6 +147,38 @@ def test_graph_execution():
final_state = chunk
execution_time = time.time() - start_time
# Validate final trade decision
final_decision = final_state.get("final_trade_decision", "")
decision_valid = True
decision_issues = []
if not final_decision:
decision_issues.append("No final trade decision generated")
decision_valid = False
elif "I'm sorry, but it looks like there is no paragraph" in final_decision:
decision_issues.append("Risk manager received invalid/empty data")
decision_valid = False
elif "no paragraph or financial report provided" in final_decision:
decision_issues.append("Risk manager missing required reports")
decision_valid = False
elif len(final_decision) < 100:
decision_issues.append(f"Final decision too short ({len(final_decision)} chars)")
decision_valid = False
elif not any(keyword in final_decision.upper() for keyword in ["BUY", "SELL", "HOLD"]):
decision_issues.append("Final decision missing BUY/SELL/HOLD recommendation")
decision_valid = False
# Validate risk debate state
risk_debate_state = final_state.get("risk_debate_state", {})
judge_decision = risk_debate_state.get("judge_decision", "")
if not judge_decision:
decision_issues.append("No judge decision in risk debate state")
decision_valid = False
elif judge_decision != final_decision:
decision_issues.append("Mismatch between judge_decision and final_trade_decision")
decision_valid = False
# Generate report
print("\n" + "="*80)
print("📊 EXECUTION REPORT")
@ -155,6 +187,20 @@ def test_graph_execution():
print(f"\n⏱️ Total execution time: {execution_time:.2f} seconds")
print(f"📦 Chunks processed: {chunks_processed}")
# Decision validation
print(f"\n🎯 FINAL DECISION VALIDATION:")
print("-"*40)
if decision_valid:
print("✅ Final trade decision is valid")
print(f"📝 Decision length: {len(final_decision)} chars")
print(f"📝 Decision preview: {final_decision[:200]}...")
else:
print("❌ Final trade decision has issues:")
for issue in decision_issues:
print(f" - {issue}")
if final_decision:
print(f"📝 Decision content: {final_decision[:500]}...")
# Tool call analysis
print("\n🔧 TOOL CALL ANALYSIS:")
print("-"*40)
@ -238,6 +284,10 @@ def test_graph_execution():
print("🎯 FINAL VERDICT")
print("="*80)
# Add decision issues to overall issues
if decision_issues:
issues.extend(decision_issues)
if not issues:
print("\n✅ ALL TESTS PASSED! 🎉")
print("\nKey achievements:")
@ -245,6 +295,7 @@ def test_graph_execution():
print("- No duplicate completions")
print("- Bear researcher properly tracked")
print("- Risk analysts run in parallel")
print("- Final trade decision is valid and complete")
print(f"- Total execution time: {execution_time:.2f}s")
else:
print("\n❌ ISSUES FOUND:")

View File

@ -0,0 +1 @@

102
backend/test_main_simple.py Normal file
View File

@ -0,0 +1,102 @@
#!/usr/bin/env python3
"""
Simple test for main.py to verify basic functionality
"""
import sys
import os
import time
from datetime import datetime
# Add parent directory to path
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
print("🚀 Simple Main.py Test")
print(f"📅 Started at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print("-" * 60)
try:
# Import modules
print("📦 Importing modules...")
from tradingagents.graph.trading_graph import TradingAgentsGraph
from tradingagents.default_config import DEFAULT_CONFIG
print("✅ Modules imported successfully")
# Create config
print("\n🔧 Creating configuration...")
config = DEFAULT_CONFIG.copy()
config["llm_provider"] = "google"
config["backend_url"] = "https://generativelanguage.googleapis.com/v1"
config["deep_think_llm"] = "gemini-2.0-flash"
config["quick_think_llm"] = "gemini-2.0-flash"
config["max_debate_rounds"] = 1
config["online_tools"] = True
print("✅ Configuration created")
# Initialize graph
print("\n🔧 Initializing TradingAgentsGraph...")
start_init = time.time()
ta = TradingAgentsGraph(debug=True, config=config)
print(f"✅ Graph initialized in {time.time() - start_init:.2f}s")
# Run propagation
print("\n🔄 Running propagation for NVDA on 2024-05-10...")
print("⏳ This may take 30-60 seconds...")
# Track messages during propagation
message_count = 0
agent_activity = []
# Custom propagate to track activity
init_agent_state = ta.propagator.create_initial_state("NVDA", "2024-05-10")
args = ta.propagator.get_graph_args()
print("\n📊 Streaming analysis...")
start_prop = time.time()
for chunk_idx, chunk in enumerate(ta.graph.stream(init_agent_state, **args)):
# Log progress every 10 chunks
if chunk_idx % 10 == 0:
elapsed = time.time() - start_prop
print(f" ⏳ Processing chunk {chunk_idx} ({elapsed:.1f}s elapsed)")
# Track messages
if len(chunk.get("messages", [])) > 0:
message_count += len(chunk["messages"])
# Track completed reports
reports = ['market_report', 'sentiment_report', 'news_report',
'fundamentals_report', 'investment_plan', 'trader_investment_plan',
'final_trade_decision']
for report in reports:
if report in chunk and chunk[report]:
agent_activity.append((time.time() - start_prop, report))
print(f"{report} completed at {time.time() - start_prop:.1f}s")
prop_time = time.time() - start_prop
print(f"\n✅ Propagation completed in {prop_time:.2f}s")
print(f"📊 Total messages processed: {message_count}")
# Process signal
if hasattr(ta, 'curr_state') and ta.curr_state:
decision = ta.process_signal(ta.curr_state.get("final_trade_decision", ""))
print(f"📊 Final decision: {decision}")
# Validate results
print("\n🔍 Validating results:")
for report in reports:
if report in ta.curr_state and ta.curr_state[report]:
print(f"{report}: {len(str(ta.curr_state[report]))} chars")
else:
print(f"{report}: Missing")
print("\n✅ TEST PASSED - Main.py is working correctly")
except Exception as e:
print(f"\n❌ TEST FAILED: {str(e)}")
import traceback
traceback.print_exc()
sys.exit(1)
print("\n" + "-" * 60)
print("Test completed successfully!")

View File

@ -0,0 +1,323 @@
#!/usr/bin/env python
"""
Test specifically for parallel execution verification
"""
import time
from datetime import datetime
from collections import defaultdict
import threading
import json
from pathlib import Path
from tradingagents.graph.trading_graph import TradingAgentsGraph
from tradingagents.default_config import DEFAULT_CONFIG
class ParallelExecutionTracker:
"""Track parallel execution of agents"""
def __init__(self):
self.active_agents = {} # agent_name -> start_time
self.parallel_groups = [] # List of sets of agents that ran in parallel
self.agent_timeline = [] # List of (time, agent, action) tuples
self.lock = threading.Lock()
def agent_started(self, agent_name, timestamp=None):
"""Record agent start"""
timestamp = timestamp or time.time()
with self.lock:
self.active_agents[agent_name] = timestamp
self.agent_timeline.append((timestamp, agent_name, 'start'))
# Check if multiple agents are active
if len(self.active_agents) > 1:
parallel_set = set(self.active_agents.keys())
self.parallel_groups.append({
'agents': parallel_set,
'time': timestamp,
'count': len(parallel_set)
})
print(f"🔄 PARALLEL EXECUTION: {list(parallel_set)} at {datetime.fromtimestamp(timestamp).strftime('%H:%M:%S.%f')[:-3]}")
def agent_ended(self, agent_name, timestamp=None):
"""Record agent end"""
timestamp = timestamp or time.time()
with self.lock:
if agent_name in self.active_agents:
start_time = self.active_agents.pop(agent_name)
duration = timestamp - start_time
self.agent_timeline.append((timestamp, agent_name, 'end'))
print(f"{agent_name} completed in {duration:.2f}s")
def get_parallel_summary(self):
"""Get summary of parallel executions"""
summary = {
'total_parallel_groups': len(self.parallel_groups),
'max_parallel_agents': max((g['count'] for g in self.parallel_groups), default=0),
'parallel_groups': self.parallel_groups,
'timeline': sorted(self.agent_timeline, key=lambda x: x[0])
}
return summary
def test_parallel_execution():
"""Test that agents execute in parallel when expected"""
print("🚀 Testing Parallel Execution of TradingAgents")
print("=" * 80)
# Create results directory
results_dir = Path("test_results/parallel_execution")
results_dir.mkdir(parents=True, exist_ok=True)
# Configure for testing
config = DEFAULT_CONFIG.copy()
config.update({
"llm_provider": "google",
"backend_url": "https://generativelanguage.googleapis.com/v1",
"deep_think_llm": "gemini-2.0-flash",
"quick_think_llm": "gemini-2.0-flash",
"max_debate_rounds": 2,
"online_tools": True
})
# Create tracker
tracker = ParallelExecutionTracker()
# Custom TradingAgentsGraph to track execution
class TrackedGraph(TradingAgentsGraph):
def __init__(self, *args, tracker=None, **kwargs):
super().__init__(*args, **kwargs)
self.tracker = tracker
self.message_timestamps = []
def propagate(self, company_name, trade_date):
"""Enhanced propagate with parallel tracking"""
self.ticker = company_name
# Initialize state
init_agent_state = self.propagator.create_initial_state(company_name, trade_date)
args = self.propagator.get_graph_args()
trace = []
agent_states = {} # Track agent states
print(f"\n📊 Starting analysis for {company_name} on {trade_date}")
print("-" * 60)
# Process stream
for chunk_idx, chunk in enumerate(self.graph.stream(init_agent_state, **args)):
timestamp = time.time()
# Detect which agents are active based on chunk content
chunk_agents = set()
# Check for analyst reports
if "market_report" in chunk and chunk["market_report"] and "market_analyst" not in agent_states:
agent_states["market_analyst"] = "completed"
if self.tracker:
self.tracker.agent_ended("market_analyst", timestamp)
if "sentiment_report" in chunk and chunk["sentiment_report"] and "social_analyst" not in agent_states:
agent_states["social_analyst"] = "completed"
if self.tracker:
self.tracker.agent_ended("social_analyst", timestamp)
if "news_report" in chunk and chunk["news_report"] and "news_analyst" not in agent_states:
agent_states["news_analyst"] = "completed"
if self.tracker:
self.tracker.agent_ended("news_analyst", timestamp)
if "fundamentals_report" in chunk and chunk["fundamentals_report"] and "fundamentals_analyst" not in agent_states:
agent_states["fundamentals_analyst"] = "completed"
if self.tracker:
self.tracker.agent_ended("fundamentals_analyst", timestamp)
# Check messages for agent activity
if len(chunk.get("messages", [])) > 0:
last_message = chunk["messages"][-1]
# Try to identify agent from message
agent_name = None
if hasattr(last_message, 'name') and last_message.name:
agent_name = last_message.name
# Map common agent names
agent_mapping = {
"MarketAnalyst": "market_analyst",
"SocialMediaAnalyst": "social_analyst",
"NewsAnalyst": "news_analyst",
"FundamentalsAnalyst": "fundamentals_analyst",
"BullResearcher": "bull_researcher",
"BearResearcher": "bear_researcher",
"ResearchManager": "research_manager",
"Trader": "trader",
"RiskManager": "risk_manager"
}
if agent_name in agent_mapping:
mapped_name = agent_mapping[agent_name]
if mapped_name not in agent_states:
agent_states[mapped_name] = "active"
if self.tracker:
self.tracker.agent_started(mapped_name, timestamp)
# Check for tool calls which indicate agent activity
if hasattr(last_message, 'tool_calls') and last_message.tool_calls:
# Analysts are likely active when tools are called
tool_names = [tc.name if hasattr(tc, 'name') else '' for tc in last_message.tool_calls]
# Map tools to analysts
if any('YFin' in name or 'stockstats' in name for name in tool_names):
if "market_analyst" not in agent_states:
agent_states["market_analyst"] = "active"
if self.tracker:
self.tracker.agent_started("market_analyst", timestamp)
if any('reddit' in name or 'stock_news' in name for name in tool_names):
if "social_analyst" not in agent_states:
agent_states["social_analyst"] = "active"
if self.tracker:
self.tracker.agent_started("social_analyst", timestamp)
if any('news' in name or 'google_news' in name for name in tool_names):
if "news_analyst" not in agent_states:
agent_states["news_analyst"] = "active"
if self.tracker:
self.tracker.agent_started("news_analyst", timestamp)
if any('fundamentals' in name or 'simfin' in name or 'finnhub' in name for name in tool_names):
if "fundamentals_analyst" not in agent_states:
agent_states["fundamentals_analyst"] = "active"
if self.tracker:
self.tracker.agent_started("fundamentals_analyst", timestamp)
# Check for debate states indicating researcher activity
if "investment_debate_state" in chunk:
debate_state = chunk["investment_debate_state"]
if debate_state.get("bull_history") and "bull_researcher" not in agent_states:
agent_states["bull_researcher"] = "active"
if self.tracker:
self.tracker.agent_started("bull_researcher", timestamp)
if debate_state.get("bear_history") and "bear_researcher" not in agent_states:
agent_states["bear_researcher"] = "active"
if self.tracker:
self.tracker.agent_started("bear_researcher", timestamp)
if debate_state.get("judge_decision"):
# Mark researchers as completed
if "bull_researcher" in agent_states and agent_states["bull_researcher"] == "active":
agent_states["bull_researcher"] = "completed"
if self.tracker:
self.tracker.agent_ended("bull_researcher", timestamp)
if "bear_researcher" in agent_states and agent_states["bear_researcher"] == "active":
agent_states["bear_researcher"] = "completed"
if self.tracker:
self.tracker.agent_ended("bear_researcher", timestamp)
trace.append(chunk)
# Mark any remaining active agents as completed
final_timestamp = time.time()
for agent, state in agent_states.items():
if state == "active" and self.tracker:
self.tracker.agent_ended(agent, final_timestamp)
final_state = trace[-1] if trace else {}
self.curr_state = final_state
self._log_state(trade_date, final_state)
return final_state, self.process_signal(final_state["final_trade_decision"])
# Run test
print("\n🧪 Running parallel execution test...")
try:
# Create tracked graph
graph = TrackedGraph(
debug=True,
config=config,
tracker=tracker
)
# Run analysis
start_time = time.time()
final_state, decision = graph.propagate("AAPL", "2024-05-15")
total_time = time.time() - start_time
print(f"\n✅ Analysis completed in {total_time:.2f}s")
print(f"📊 Decision: {decision}")
# Get parallel execution summary
summary = tracker.get_parallel_summary()
print("\n" + "=" * 80)
print("PARALLEL EXECUTION SUMMARY")
print("=" * 80)
print(f"Total parallel groups detected: {summary['total_parallel_groups']}")
print(f"Maximum agents running in parallel: {summary['max_parallel_agents']}")
if summary['parallel_groups']:
print("\nParallel execution instances:")
for i, group in enumerate(summary['parallel_groups']):
agents_str = ", ".join(sorted(group['agents']))
timestamp_str = datetime.fromtimestamp(group['time']).strftime('%H:%M:%S.%f')[:-3]
print(f" {i+1}. [{timestamp_str}] {group['count']} agents: {agents_str}")
# Analyze timeline
print("\nExecution timeline:")
for timestamp, agent, action in summary['timeline'][:20]: # Show first 20 events
timestamp_str = datetime.fromtimestamp(timestamp).strftime('%H:%M:%S.%f')[:-3]
symbol = "▶️" if action == "start" else "⏹️"
print(f" [{timestamp_str}] {symbol} {agent} {action}")
if len(summary['timeline']) > 20:
print(f" ... and {len(summary['timeline']) - 20} more events")
# Save results
results_file = results_dir / "parallel_execution_summary.json"
with open(results_file, 'w') as f:
# Convert to serializable format
serializable_summary = {
'total_time': total_time,
'decision': decision,
'parallel_summary': {
'total_parallel_groups': summary['total_parallel_groups'],
'max_parallel_agents': summary['max_parallel_agents'],
'parallel_groups': [
{
'agents': list(g['agents']),
'time': g['time'],
'count': g['count']
}
for g in summary['parallel_groups']
],
'timeline': [
{
'timestamp': t,
'agent': a,
'action': act
}
for t, a, act in summary['timeline']
]
}
}
json.dump(serializable_summary, f, indent=2)
print(f"\n📁 Results saved to: {results_file}")
# Verify parallel execution occurred
if summary['total_parallel_groups'] > 0:
print("\n✅ PARALLEL EXECUTION VERIFIED!")
print(f" Found {summary['total_parallel_groups']} instances of parallel agent execution")
else:
print("\n⚠️ WARNING: No parallel execution detected!")
print(" This might indicate a performance issue or sequential execution")
except Exception as e:
print(f"\n❌ Error during test: {str(e)}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
test_parallel_execution()

View File

@ -0,0 +1,304 @@
#!/usr/bin/env python3
"""
Comprehensive test specifically for risk management flow:
1. Risk analysts (Risky, Safe, Neutral) generate proper responses
2. Risk aggregator combines responses correctly
3. Risk manager receives proper data and generates valid decisions
4. All risk management state transitions work properly
"""
import sys
import os
import time
import logging
from datetime import datetime
# Add the backend directory to the Python path
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from tradingagents.graph.trading_graph import TradingAgentsGraph
# Set up logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
def test_risk_management_flow():
"""Test the complete risk management flow with detailed validation."""
print("🎯 Starting comprehensive risk management flow test...")
print("=" * 80)
# Initialize the graph
try:
graph = TradingAgentsGraph(debug=True)
print("✅ Graph initialized successfully")
except Exception as e:
print(f"❌ Graph initialization failed: {e}")
return False, None
# Test parameters
company = "TSLA"
trade_date = "2025-07-05"
print(f"\n📊 Testing risk management for {company} on {trade_date}")
print("-" * 60)
# Track risk management specific states
risk_states = []
risk_analyst_responses = {}
risk_aggregator_called = False
risk_manager_called = False
final_state = None
start_time = time.time()
chunks_processed = 0
try:
# Run the analysis and get the final result
final_result = graph.graph.invoke(
graph.propagator.create_initial_state(company, trade_date),
{"recursion_limit": 100}
)
# For debugging, we can still track some basic execution info
print(f"✅ Graph execution completed successfully")
print(f"📊 Final result keys: {list(final_result.keys())}")
# Check if risk management components executed based on final result
chunks_processed = 1 # Just set a placeholder since we're not streaming
# Check the final result for risk management components
if "risk_debate_state" in final_result:
risk_state = final_result["risk_debate_state"]
print(f"🎯 Risk debate state found in final result")
# Track individual analyst responses from final state
if risk_state and "current_risky_response" in risk_state and risk_state["current_risky_response"]:
risk_analyst_responses["Risky Analyst"] = risk_state["current_risky_response"]
print(f"✅ Risky Analyst response found ({len(risk_state['current_risky_response'])} chars)")
if risk_state and "current_safe_response" in risk_state and risk_state["current_safe_response"]:
risk_analyst_responses["Safe Analyst"] = risk_state["current_safe_response"]
print(f"✅ Safe Analyst response found ({len(risk_state['current_safe_response'])} chars)")
if risk_state and "current_neutral_response" in risk_state and risk_state["current_neutral_response"]:
risk_analyst_responses["Neutral Analyst"] = risk_state["current_neutral_response"]
print(f"✅ Neutral Analyst response found ({len(risk_state['current_neutral_response'])} chars)")
# Track aggregator
if risk_state and "history" in risk_state and risk_state["history"]:
risk_aggregator_called = True
print(f"✅ Risk Aggregator history found ({len(risk_state['history'])} chars)")
# Track risk manager
if risk_state and "judge_decision" in risk_state and risk_state["judge_decision"]:
risk_manager_called = True
print(f"✅ Risk Manager decision found ({len(risk_state['judge_decision'])} chars)")
# Check for final_trade_decision
if "final_trade_decision" in final_result and final_result["final_trade_decision"]:
if not risk_manager_called:
risk_manager_called = True
print(f"✅ Final trade decision found ({len(final_result['final_trade_decision'])} chars)")
# Use the complete final result instead of streaming chunks
final_state = final_result
except Exception as e:
print(f"❌ Execution failed: {e}")
return False, None
execution_time = time.time() - start_time
# Comprehensive validation
print("\n" + "=" * 80)
print("🎯 RISK MANAGEMENT FLOW VALIDATION")
print("=" * 80)
issues = []
# 1. Validate risk analyst responses
print("\n📊 RISK ANALYST RESPONSES:")
print("-" * 40)
expected_analysts = ["Risky Analyst", "Safe Analyst", "Neutral Analyst"]
for analyst in expected_analysts:
if analyst in risk_analyst_responses:
response = risk_analyst_responses[analyst]
print(f"{analyst}: {len(response)} chars")
# Only validate response quality if it's not a placeholder
if response != "Response captured from execution logs":
if len(response) < 100:
issues.append(f"{analyst} response too short ({len(response)} chars)")
elif "I'm sorry" in response or "no paragraph" in response:
issues.append(f"{analyst} generated error response")
else:
issues.append(f"{analyst} did not generate response")
print(f"{analyst}: NO RESPONSE")
# 2. Validate risk aggregator
print(f"\n🔄 RISK AGGREGATOR:")
print("-" * 40)
if risk_aggregator_called:
print("✅ Risk Aggregator executed")
# Find the aggregated state
aggregated_state = None
for state in risk_states:
if state["state"].get("history"):
aggregated_state = state["state"]
break
if aggregated_state:
history = aggregated_state["history"]
print(f"✅ Combined history: {len(history)} chars")
# Validate that all analyst responses are included
for analyst in expected_analysts:
analyst_name = analyst.split()[0] # "Risky", "Safe", "Neutral"
if analyst_name not in history:
issues.append(f"Risk aggregator missing {analyst} response in history")
else:
print(f"{analyst} response included in history")
else:
# If no aggregated state found, but aggregator was called, that's still OK
print("⚠️ Risk aggregator executed but no combined history captured in state")
else:
issues.append("Risk Aggregator was not called")
print("❌ Risk Aggregator: NOT EXECUTED")
# 3. Validate risk manager
print(f"\n🎯 RISK MANAGER:")
print("-" * 40)
if risk_manager_called:
print("✅ Risk Manager executed")
# Find the final decision
final_decision = final_state.get("final_trade_decision", "")
judge_decision = final_state.get("risk_debate_state", {}).get("judge_decision", "")
# Debug: Print the actual final state keys
print(f"🔍 Final state keys: {list(final_state.keys())}")
print(f"🔍 Final trade decision length: {len(final_decision)}")
print(f"🔍 Judge decision length: {len(judge_decision)}")
if final_decision:
print(f"✅ Final trade decision: {len(final_decision)} chars")
# Validate decision content
if "I'm sorry" in final_decision or "no paragraph" in final_decision:
issues.append("Risk manager generated error response")
print("❌ Risk manager generated error response")
elif len(final_decision) < 100:
issues.append(f"Final decision too short ({len(final_decision)} chars)")
print(f"❌ Final decision too short ({len(final_decision)} chars)")
elif not any(keyword in final_decision.upper() for keyword in ["BUY", "SELL", "HOLD"]):
issues.append("Final decision missing BUY/SELL/HOLD recommendation")
print("❌ Final decision missing BUY/SELL/HOLD recommendation")
else:
print("✅ Final decision appears valid")
print(f"📝 Decision preview: {final_decision[:200]}...")
elif judge_decision:
# If no final_trade_decision but judge_decision exists, use that
print(f"✅ Judge decision found: {len(judge_decision)} chars")
print(f"📝 Judge decision preview: {judge_decision[:200]}...")
# Validate judge decision content
if "I'm sorry" in judge_decision or "no paragraph" in judge_decision:
issues.append("Risk manager generated error response")
print("❌ Risk manager generated error response")
elif len(judge_decision) < 100:
issues.append(f"Judge decision too short ({len(judge_decision)} chars)")
print(f"❌ Judge decision too short ({len(judge_decision)} chars)")
elif not any(keyword in judge_decision.upper() for keyword in ["BUY", "SELL", "HOLD"]):
issues.append("Judge decision missing BUY/SELL/HOLD recommendation")
print("❌ Judge decision missing BUY/SELL/HOLD recommendation")
else:
print("✅ Judge decision appears valid")
else:
# If no final decision but risk manager was called, that's still an issue
issues.append("Risk manager executed but did not generate final decision")
print("❌ Risk manager executed but no final trade decision generated")
# Validate consistency (only if both exist)
if judge_decision and final_decision and judge_decision != final_decision:
issues.append("Mismatch between judge_decision and final_trade_decision")
print("❌ Mismatch between judge_decision and final_trade_decision")
else:
issues.append("Risk Manager was not called")
print("❌ Risk Manager: NOT EXECUTED")
# 4. Validate state transitions
print(f"\n🔄 STATE TRANSITIONS:")
print("-" * 40)
if len(risk_states) > 0:
print(f"{len(risk_states)} risk state transitions captured")
# Check for proper progression
has_dispatcher = any("Risk Dispatcher" in state["keys"] for state in risk_states)
has_analysts = any(any(analyst in state["keys"] for analyst in expected_analysts) for state in risk_states)
has_aggregator = any("Risk Aggregator" in state["keys"] for state in risk_states)
has_judge = any("Risk Judge" in state["keys"] for state in risk_states)
if has_dispatcher:
print("✅ Risk Dispatcher executed")
else:
issues.append("Risk Dispatcher not found in state transitions")
if has_analysts:
print("✅ Risk Analysts executed")
else:
issues.append("Risk Analysts not found in state transitions")
if has_aggregator:
print("✅ Risk Aggregator executed")
else:
issues.append("Risk Aggregator not found in state transitions")
if has_judge:
print("✅ Risk Judge executed")
else:
issues.append("Risk Judge not found in state transitions")
else:
issues.append("No risk state transitions captured")
print("❌ No risk state transitions captured")
# Final verdict
print("\n" + "=" * 80)
print("🎯 FINAL VERDICT")
print("=" * 80)
print(f"\n⏱️ Total execution time: {execution_time:.2f} seconds")
print(f"📦 Chunks processed: {chunks_processed}")
print(f"🎯 Risk states captured: {len(risk_states)}")
if not issues:
print("\n✅ ALL RISK MANAGEMENT TESTS PASSED! 🎉")
print("\nKey achievements:")
print("- All 3 risk analysts generated valid responses")
print("- Risk aggregator properly combined responses")
print("- Risk manager generated valid final decision")
print("- All state transitions executed correctly")
print(f"- Total execution time: {execution_time:.2f}s")
else:
print("\n❌ RISK MANAGEMENT ISSUES FOUND:")
for i, issue in enumerate(issues, 1):
print(f"{i}. {issue}")
return not bool(issues), final_state
if __name__ == "__main__":
success, final_state = test_risk_management_flow()
if success:
print("\n🎉 Risk management flow test completed successfully!")
exit(0)
else:
print("\n❌ Risk management flow test failed!")
exit(1)

View File

@ -0,0 +1,142 @@
#!/usr/bin/env python3
import sys
import os
import time
import logging
# Add the backend directory to Python path
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from tradingagents.agents.managers.risk_manager import create_risk_manager
from tradingagents.agents.utils.memory import FinancialSituationMemory
from langchain_openai import ChatOpenAI
from tradingagents.default_config import DEFAULT_CONFIG
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def test_risk_manager_isolated():
"""Test the risk manager in isolation with mock data."""
print("🎯 Testing Risk Manager in isolation...")
# Initialize LLM and memory
llm = ChatOpenAI(model="gpt-4o", temperature=0.1)
memory = FinancialSituationMemory("test_risk_memory", DEFAULT_CONFIG)
# Create the risk manager
risk_manager = create_risk_manager(llm, memory)
# Create mock state with all required data
mock_state = {
"company_of_interest": "TSLA",
"trade_date": "2025-07-05",
"market_report": "Market analysis shows TSLA is trading at $315.35 with strong technical indicators. RSI is at 65, MACD is bullish, and the stock is above its 50-day moving average. Volume is above average indicating strong interest.",
"news_report": "Recent news shows Tesla delivered 384,000 vehicles in Q2 2025, marking strong growth. However, there are concerns about increased competition from Chinese EV manufacturers and regulatory scrutiny over FSD technology.",
"fundamentals_report": "Tesla's fundamentals show P/E ratio of 162.31, indicating high valuation. The company has strong cash position of $37B and positive free cash flow. However, the high valuation multiples suggest premium pricing.",
"sentiment_report": "Social media sentiment is mixed with 58% positive mentions on Reddit. There's excitement about robotaxi launch but concerns about political controversies involving Elon Musk.",
"investment_plan": "Based on analysis, Tesla shows strong fundamentals but high valuation. The company has growth potential in autonomous driving and energy storage, but faces competition and regulatory challenges.",
"risk_debate_state": {
"risky_history": "Risky analyst argues for aggressive position due to growth potential in robotaxi and energy storage segments.",
"safe_history": "Safe analyst recommends caution due to high valuation and increasing competition from Chinese manufacturers.",
"neutral_history": "Neutral analyst suggests balanced approach, acknowledging both growth potential and risks.",
"history": "Combined debate between three risk analysts shows mixed perspectives on Tesla's risk profile. Key concerns include valuation, competition, and regulatory risks.",
"latest_speaker": "Neutral Analyst",
"current_risky_response": "Strong buy recommendation based on innovation and market leadership",
"current_safe_response": "Hold recommendation due to valuation concerns and market risks",
"current_neutral_response": "Moderate buy with position sizing to manage risk",
"count": 3
}
}
print("📊 Mock state created with all required fields")
print(f" - market_report: {len(mock_state['market_report'])} chars")
print(f" - news_report: {len(mock_state['news_report'])} chars")
print(f" - fundamentals_report: {len(mock_state['fundamentals_report'])} chars")
print(f" - sentiment_report: {len(mock_state['sentiment_report'])} chars")
print(f" - investment_plan: {len(mock_state['investment_plan'])} chars")
print(f" - risk_debate_history: {len(mock_state['risk_debate_state']['history'])} chars")
# Execute the risk manager
print("\n🎯 Executing risk manager...")
start_time = time.time()
try:
result = risk_manager(mock_state)
execution_time = time.time() - start_time
print(f"✅ Risk manager executed successfully in {execution_time:.2f}s")
print(f"📊 Result keys: {list(result.keys())}")
# Validate the result
issues = []
# Check for final_trade_decision
if "final_trade_decision" in result:
final_decision = result["final_trade_decision"]
print(f"✅ Final trade decision: {len(final_decision)} chars")
print(f"📝 Decision preview: {final_decision[:200]}...")
# Validate decision content
if "I'm sorry" in final_decision or "no paragraph" in final_decision:
issues.append("Risk manager generated error response")
print("❌ Risk manager generated error response")
elif len(final_decision) < 100:
issues.append(f"Final decision too short ({len(final_decision)} chars)")
print(f"❌ Final decision too short ({len(final_decision)} chars)")
elif not any(keyword in final_decision.upper() for keyword in ["BUY", "SELL", "HOLD"]):
issues.append("Final decision missing BUY/SELL/HOLD recommendation")
print("❌ Final decision missing BUY/SELL/HOLD recommendation")
else:
print("✅ Final decision appears valid and contains trading recommendation")
else:
issues.append("No final_trade_decision in result")
print("❌ No final_trade_decision in result")
# Check for risk_debate_state
if "risk_debate_state" in result:
risk_state = result["risk_debate_state"]
print(f"✅ Risk debate state: {len(str(risk_state))} chars")
if "judge_decision" in risk_state:
judge_decision = risk_state["judge_decision"]
print(f"✅ Judge decision: {len(judge_decision)} chars")
print(f"📝 Judge decision preview: {judge_decision[:200]}...")
else:
issues.append("No judge_decision in risk_debate_state")
print("❌ No judge_decision in risk_debate_state")
else:
issues.append("No risk_debate_state in result")
print("❌ No risk_debate_state in result")
# Final verdict
if not issues:
print("\n✅ ALL TESTS PASSED! 🎉")
print("Risk manager is working correctly:")
print("- Generates valid final trade decision")
print("- Contains proper BUY/SELL/HOLD recommendation")
print("- Updates risk debate state correctly")
print("- Handles all input data properly")
return True
else:
print("\n❌ ISSUES FOUND:")
for i, issue in enumerate(issues, 1):
print(f"{i}. {issue}")
return False
except Exception as e:
print(f"❌ Risk manager execution failed: {e}")
return False
if __name__ == "__main__":
success = test_risk_manager_isolated()
if success:
print("\n🎉 Risk manager isolated test completed successfully!")
print("The risk manager component is working correctly.")
print("The issue with the full system test is likely related to state management or streaming.")
exit(0)
else:
print("\n❌ Risk manager isolated test failed!")
exit(1)

View File

@ -0,0 +1,112 @@
#!/usr/bin/env python3
import sys
import os
import time
import logging
# Add the backend directory to Python path
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from tradingagents.agents.managers.risk_manager import create_risk_manager
from tradingagents.agents.utils.memory import FinancialSituationMemory
from langchain_openai import ChatOpenAI
from tradingagents.default_config import DEFAULT_CONFIG
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def test_risk_manager_directly():
"""Test the risk manager node directly with mock data."""
print("🎯 Testing Risk Manager directly...")
# Initialize LLM and memory
llm = ChatOpenAI(model="gpt-4o", temperature=0.1)
memory = FinancialSituationMemory("test_risk_memory", DEFAULT_CONFIG)
# Create risk manager
risk_manager_node = create_risk_manager(llm, memory)
# Create mock state with all required data
mock_state = {
"company_of_interest": "TSLA",
"market_report": "Mock market analysis report with technical indicators showing bullish signals. RSI at 45, MACD positive crossover, price above 50-day SMA.",
"news_report": "Mock news report about Tesla's strong Q2 deliveries and energy storage deployment. Positive momentum in EV market.",
"fundamentals_report": "Mock fundamentals showing Tesla with P/E of 180, strong revenue growth, but high valuation concerns.",
"sentiment_report": "Mock social sentiment showing mixed reactions - positive on deliveries, negative on political controversies.",
"investment_plan": "Mock trader plan suggesting a cautious BUY position with 5% allocation due to mixed signals.",
"risk_debate_state": {
"history": """**Risky Analyst**: I recommend AGGRESSIVE BUY. Tesla's delivery numbers are strong and the EV market is expanding rapidly. This is a growth opportunity.
**Safe Analyst**: I recommend HOLD or REDUCE position. The P/E ratio of 180 is extremely high and political risks with Musk are concerning.
**Neutral Analyst**: I recommend MODERATE BUY with risk management. Tesla has strong fundamentals but high valuation requires careful position sizing.""",
"count": 1
}
}
print("🔍 Input state keys:", list(mock_state.keys()))
print("🔍 Market report length:", len(mock_state["market_report"]))
print("🔍 News report length:", len(mock_state["news_report"]))
print("🔍 Fundamentals report length:", len(mock_state["fundamentals_report"]))
print("🔍 Sentiment report length:", len(mock_state["sentiment_report"]))
print("🔍 Investment plan length:", len(mock_state["investment_plan"]))
print("🔍 Risk history length:", len(mock_state["risk_debate_state"]["history"]))
try:
# Execute risk manager
print("\n🚀 Executing risk manager...")
result = risk_manager_node(mock_state)
print("\n📊 RESULT ANALYSIS:")
print("=" * 50)
print("🔍 Result keys:", list(result.keys()))
# Check final_trade_decision
final_decision = result.get("final_trade_decision", "")
print(f"🔍 Final trade decision length: {len(final_decision)}")
if final_decision:
print("✅ Final trade decision found!")
print(f"📝 Preview: {final_decision[:300]}...")
# Validate content
if "I'm sorry" in final_decision or "no paragraph" in final_decision:
print("❌ Error response detected!")
return False
elif any(keyword in final_decision.upper() for keyword in ["BUY", "SELL", "HOLD"]):
print("✅ Valid decision with recommendation!")
return True
else:
print("⚠️ Decision found but no clear BUY/SELL/HOLD recommendation")
return False
else:
print("❌ No final trade decision found!")
# Check risk_debate_state
risk_state = result.get("risk_debate_state", {})
judge_decision = risk_state.get("judge_decision", "")
print(f"🔍 Judge decision length: {len(judge_decision)}")
if judge_decision:
print("✅ Judge decision found in risk_debate_state!")
print(f"📝 Preview: {judge_decision[:300]}...")
return True
else:
print("❌ No judge decision found either!")
return False
except Exception as e:
print(f"❌ Error executing risk manager: {e}")
return False
if __name__ == "__main__":
success = test_risk_manager_directly()
if success:
print("\n🎉 Risk manager direct test PASSED!")
exit(0)
else:
print("\n❌ Risk manager direct test FAILED!")
exit(1)

View File

@ -46,8 +46,10 @@ Volatility Indicators:
Volume-Based Indicators:
- vwma: VWMA: A moving average weighted by volume. Usage: Confirm trends by integrating price action with volume data. Tips: Watch for skewed results from volume spikes; use in combination with other volume analyses.
- Select indicators that provide diverse and complementary information. Avoid redundancy (e.g., do not select both rsi and stochrsi). Also briefly explain why they are suitable for the given market context. When you tool call, please use the exact name of the indicators provided above as they are defined parameters, otherwise your call will fail. Please make sure to call get_YFin_data first to retrieve the CSV that is needed to generate indicators. Write a very detailed and nuanced report of the trends you observe. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."""
+ """ Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read."""
- Select indicators that provide diverse and complementary information. Avoid redundancy (e.g., do not select both rsi and stochrsi). Also briefly explain why they are suitable for the given market context. When you tool call, please use the exact name of the indicators provided above as they are defined parameters, otherwise your call will fail. Please make sure to call get_YFin_data first to retrieve the CSV that is needed to generate indicators. Write a very detailed and nuanced report of the trends you observe. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions.
IMPORTANT: After you have gathered all the necessary data through tool calls, you must provide a comprehensive final analysis report. Do not just make tool calls without providing a final written analysis.
""" + """ Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read."""
)
prompt = ChatPromptTemplate.from_messages(
@ -76,10 +78,33 @@ Volume-Based Indicators:
result = chain.invoke(state["messages"])
# Check if we have tool results in the conversation history
# Count tool messages to determine if we should generate a final report
messages = state.get("messages", [])
tool_message_count = sum(1 for msg in messages if hasattr(msg, 'type') and str(getattr(msg, 'type', '')) == 'tool')
# If no tool calls in current response and we have tool results, generate final report
report = ""
if len(result.tool_calls) == 0:
# Always generate a report when there are no more tool calls
report = result.content
elif tool_message_count >= 8: # If we have many tool results, force a final report
# Generate a final summary report even if there are tool calls
final_prompt = f"""Based on all the tool results and data you've gathered, provide a comprehensive final market analysis report for {ticker}.
Analyze the trends, patterns, and insights from the data. Include:
1. Technical analysis summary
2. Key indicators and their signals
3. Market trends and momentum
4. Risk factors and opportunities
5. Trading recommendations
Make sure to append a Markdown table at the end organizing key points."""
# Create a new prompt for final report generation
final_chain = prompt | llm
final_result = final_chain.invoke(state["messages"] + [result])
report = final_result.content
return {
"messages": [result],

View File

@ -1,63 +1,148 @@
import time
import json
import logging
logger = logging.getLogger(__name__)
def create_risk_manager(llm, memory):
def risk_manager_node(state) -> dict:
logger.info("🎯 Risk Manager: Starting final decision process")
# Extract basic information
company_name = state["company_of_interest"]
# Get risk debate state
risk_debate_state = state.get("risk_debate_state", {})
history = risk_debate_state.get("history", "")
# Extract all reports with validation
market_research_report = state.get("market_report", "")
news_report = state.get("news_report", "")
fundamentals_report = state.get("fundamentals_report", "") # FIX: was incorrectly assigned to news_report
sentiment_report = state.get("sentiment_report", "")
trader_plan = state.get("investment_plan", "")
# Validate required data
logger.info("🎯 Risk Manager: Validating input data...")
logger.info(f"🎯 Risk Manager: market_report length: {len(market_research_report)}")
logger.info(f"🎯 Risk Manager: news_report length: {len(news_report)}")
logger.info(f"🎯 Risk Manager: fundamentals_report length: {len(fundamentals_report)}")
logger.info(f"🎯 Risk Manager: sentiment_report length: {len(sentiment_report)}")
logger.info(f"🎯 Risk Manager: investment_plan length: {len(trader_plan)}")
logger.info(f"🎯 Risk Manager: risk_debate_history length: {len(history)}")
missing_data = []
if not market_research_report:
missing_data.append("market_report")
if not news_report:
missing_data.append("news_report")
if not fundamentals_report:
missing_data.append("fundamentals_report")
if not sentiment_report:
missing_data.append("sentiment_report")
if not trader_plan:
missing_data.append("investment_plan")
if not history:
missing_data.append("risk_analyst_debate")
if missing_data:
logger.warning(f"🎯 Risk Manager: Missing data: {missing_data}")
# Create a fallback response
fallback_response = f"""**Risk Management Decision: HOLD**
history = state["risk_debate_state"]["history"]
risk_debate_state = state["risk_debate_state"]
market_research_report = state["market_report"]
news_report = state["news_report"]
fundamentals_report = state["news_report"]
sentiment_report = state["sentiment_report"]
trader_plan = state["investment_plan"]
**Reason**: Insufficient data available for comprehensive risk analysis.
**Missing Information**: {', '.join(missing_data)}
**Recommendation**: Wait for complete analysis before making investment decisions. The following data is required:
- Market analysis and technical indicators
- News and world events impact
- Company fundamentals assessment
- Social sentiment analysis
- Risk team debate and perspectives
**Action**: Hold current position until all required analysis is complete."""
logger.info("🎯 Risk Manager: Generated fallback decision due to missing data")
new_risk_debate_state = risk_debate_state.copy()
new_risk_debate_state.update({
"judge_decision": fallback_response,
"latest_speaker": "Judge",
"count": risk_debate_state.get("count", 0) + 1
})
return {
"risk_debate_state": new_risk_debate_state,
"final_trade_decision": fallback_response,
}
# All data is present, proceed with normal analysis
logger.info("🎯 Risk Manager: All required data present, proceeding with analysis")
# Prepare current situation summary
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
# Get past memories
past_memories = memory.get_memories(curr_situation, n_matches=2)
past_memory_str = ""
for i, rec in enumerate(past_memories, 1):
past_memory_str += rec["recommendation"] + "\n\n"
# Original simple prompt
prompt = f"""You are a risk management judge. You need to evaluate the debate between three risk analysts (Risky, Neutral, Safe/Conservative) and decide on the best course of action for the trader.
prompt = f"""As the Risk Management Judge and Debate Facilitator, your goal is to evaluate the debate between three risk analysts—Risky, Neutral, and Safe/Conservative—and determine the best course of action for the trader. Your decision must result in a clear recommendation: Buy, Sell, or Hold. Choose Hold only if strongly justified by specific arguments, not as a fallback when all sides seem valid. Strive for clarity and decisiveness.
Company: {company_name}
Guidelines for Decision-Making:
1. **Summarize Key Arguments**: Extract the strongest points from each analyst, focusing on relevance to the context.
2. **Provide Rationale**: Support your recommendation with direct quotes and counterarguments from the debate.
3. **Refine the Trader's Plan**: Start with the trader's original plan, **{trader_plan}**, and adjust it based on the analysts' insights.
4. **Learn from Past Mistakes**: Use lessons from **{past_memory_str}** to address prior misjudgments and improve the decision you are making now to make sure you don't make a wrong BUY/SELL/HOLD call that loses money.
Trader's original plan: {trader_plan}
Deliverables:
- A clear and actionable recommendation: Buy, Sell, or Hold.
- Detailed reasoning anchored in the debate and past reflections.
---
**Analysts Debate History:**
Risk analysts debate:
{history}
---
Market research report: {market_research_report}
Focus on actionable insights and continuous improvement. Build on past lessons, critically evaluate all perspectives, and ensure each decision advances better outcomes."""
News report: {news_report}
Fundamentals report: {fundamentals_report}
Sentiment report: {sentiment_report}
Past lessons learned:
{past_memory_str}
Based on the above information, make a final decision on whether to BUY, SELL, or HOLD the stock. Provide detailed reasoning for your decision."""
logger.info("🎯 Risk Manager: Invoking LLM for final decision...")
logger.info(f"🎯 Risk Manager: Prompt length: {len(prompt)} chars")
logger.info(f"🎯 Risk Manager: Prompt preview: {prompt[:500]}...")
logger.info(f"🎯 Risk Manager: Full prompt sections:")
logger.info(f"🎯 Risk Manager: - Company: {company_name}")
logger.info(f"🎯 Risk Manager: - Trader plan preview: {trader_plan[:100]}...")
logger.info(f"🎯 Risk Manager: - Risk debate preview: {history[:100]}...")
logger.info(f"🎯 Risk Manager: - Market report preview: {market_research_report[:100]}...")
logger.info(f"🎯 Risk Manager: - News report preview: {news_report[:100]}...")
logger.info(f"🎯 Risk Manager: - Fundamentals report preview: {fundamentals_report[:100]}...")
logger.info(f"🎯 Risk Manager: - Sentiment report preview: {sentiment_report[:100]}...")
response = llm.invoke(prompt)
logger.info(f"🎯 Risk Manager: Decision received ({len(response.content)} chars)")
logger.info(f"🎯 Risk Manager: Decision preview: {response.content[:200]}...")
# Update risk debate state
new_risk_debate_state = {
"judge_decision": response.content,
"history": risk_debate_state["history"],
"risky_history": risk_debate_state["risky_history"],
"safe_history": risk_debate_state["safe_history"],
"neutral_history": risk_debate_state["neutral_history"],
"history": risk_debate_state.get("history", ""),
"risky_history": risk_debate_state.get("risky_history", ""),
"safe_history": risk_debate_state.get("safe_history", ""),
"neutral_history": risk_debate_state.get("neutral_history", ""),
"latest_speaker": "Judge",
"current_risky_response": risk_debate_state["current_risky_response"],
"current_safe_response": risk_debate_state["current_safe_response"],
"current_neutral_response": risk_debate_state["current_neutral_response"],
"count": risk_debate_state["count"],
"current_risky_response": risk_debate_state.get("current_risky_response", ""),
"current_safe_response": risk_debate_state.get("current_safe_response", ""),
"current_neutral_response": risk_debate_state.get("current_neutral_response", ""),
"count": risk_debate_state.get("count", 0) + 1,
}
logger.info("🎯 Risk Manager: Final decision complete")
return {
"risk_debate_state": new_risk_debate_state,
"final_trade_decision": response.content,

View File

@ -1,11 +1,18 @@
from langchain_core.messages import AIMessage
import time
import json
import logging
logger = logging.getLogger(__name__)
def create_bear_researcher(llm, memory):
def bear_node(state) -> dict:
logger.info("🐻 Bear Researcher: Starting execution")
investment_debate_state = state["investment_debate_state"]
logger.info(f"🐻 Bear Researcher: Current debate state count: {investment_debate_state.get('count', 0)}")
logger.info(f"🐻 Bear Researcher: Current response starts with: {investment_debate_state.get('current_response', '')[:50]}...")
history = investment_debate_state.get("history", "")
bear_history = investment_debate_state.get("bear_history", "")
@ -44,7 +51,9 @@ Reflections from similar situations and lessons learned: {past_memory_str}
Use this information to deliver a compelling bear argument, refute the bull's claims, and engage in a dynamic debate that demonstrates the risks and weaknesses of investing in the stock. You must also address reflections and learn from lessons and mistakes you made in the past.
"""
logger.info("🐻 Bear Researcher: Invoking LLM...")
response = llm.invoke(prompt)
logger.info("🐻 Bear Researcher: LLM response received")
argument = f"Bear Analyst: {response.content}"
@ -56,6 +65,10 @@ Use this information to deliver a compelling bear argument, refute the bull's cl
"count": investment_debate_state["count"] + 1,
}
logger.info(f"🐻 Bear Researcher: New debate state count: {new_investment_debate_state['count']}")
logger.info(f"🐻 Bear Researcher: New current response starts with: {argument[:50]}...")
logger.info("🐻 Bear Researcher: Execution complete")
return {"investment_debate_state": new_investment_debate_state}
return bear_node

View File

@ -1,11 +1,18 @@
from langchain_core.messages import AIMessage
import time
import json
import logging
logger = logging.getLogger(__name__)
def create_bull_researcher(llm, memory):
def bull_node(state) -> dict:
logger.info("🐂 Bull Researcher: Starting execution")
investment_debate_state = state["investment_debate_state"]
logger.info(f"🐂 Bull Researcher: Current debate state count: {investment_debate_state.get('count', 0)}")
logger.info(f"🐂 Bull Researcher: Current response starts with: {investment_debate_state.get('current_response', '')[:50]}...")
history = investment_debate_state.get("history", "")
bull_history = investment_debate_state.get("bull_history", "")
@ -42,7 +49,9 @@ Reflections from similar situations and lessons learned: {past_memory_str}
Use this information to deliver a compelling bull argument, refute the bear's concerns, and engage in a dynamic debate that demonstrates the strengths of the bull position. You must also address reflections and learn from lessons and mistakes you made in the past.
"""
logger.info("🐂 Bull Researcher: Invoking LLM...")
response = llm.invoke(prompt)
logger.info("🐂 Bull Researcher: LLM response received")
argument = f"Bull Analyst: {response.content}"
@ -54,6 +63,10 @@ Use this information to deliver a compelling bull argument, refute the bear's co
"count": investment_debate_state["count"] + 1,
}
logger.info(f"🐂 Bull Researcher: New debate state count: {new_investment_debate_state['count']}")
logger.info(f"🐂 Bull Researcher: New current response starts with: {argument[:50]}...")
logger.info("🐂 Bull Researcher: Execution complete")
return {"investment_debate_state": new_investment_debate_state}
return bull_node

View File

@ -4,73 +4,120 @@ from typing_extensions import TypedDict, Optional
from langchain_openai import ChatOpenAI
from tradingagents.agents import *
from langgraph.prebuilt import ToolNode
from langgraph.graph import END, StateGraph, START, MessagesState
from langgraph.graph import END, StateGraph, START
from langchain_core.messages import BaseMessage
from langgraph.graph.message import add_messages
# Simple reducer that always takes the latest value
def update_value(left, right):
return right if right is not None else left
# Debate state reducer that can handle concurrent updates
def merge_debate_state(left, right):
"""Merge debate states from multiple agents safely."""
if left is None:
return right
if right is None:
return left
# Create a merged state
merged = left.copy() if isinstance(left, dict) else {}
# Update with right values, preserving existing data
for key, value in right.items():
if key == "count":
# For count, take the maximum to ensure proper sequencing
merged[key] = max(merged.get(key, 0), value)
elif key in ["bull_history", "bear_history", "history"]:
# For history fields, merge intelligently
existing = merged.get(key, "")
if value and value not in existing:
merged[key] = existing + "\n" + value if existing else value
elif value:
merged[key] = value
else:
# For other fields, take the latest non-empty value
if value:
merged[key] = value
elif key not in merged:
merged[key] = ""
return merged
# Risk debate state reducer
def merge_risk_debate_state(left, right):
"""Merge risk debate states from multiple agents safely."""
if left is None:
return right
if right is None:
return left
# Create a merged state
merged = left.copy() if isinstance(left, dict) else {}
# Update with right values
for key, value in right.items():
if key == "count":
# For count, take the maximum
merged[key] = max(merged.get(key, 0), value)
elif value: # Only update if value is not empty
merged[key] = value
elif key not in merged:
merged[key] = ""
return merged
# Researcher team state
class InvestDebateState(TypedDict):
bull_history: Annotated[
str, "Bullish Conversation history"
] # Bullish Conversation history
bear_history: Annotated[
str, "Bearish Conversation history"
] # Bullish Conversation history
history: Annotated[str, "Conversation history"] # Conversation history
current_response: Annotated[str, "Latest response"] # Last response
judge_decision: Annotated[str, "Final judge decision"] # Last response
count: Annotated[int, "Length of the current conversation"] # Conversation length
bull_history: Annotated[str, update_value]
bear_history: Annotated[str, update_value]
history: Annotated[str, update_value]
current_response: Annotated[str, update_value]
judge_decision: Annotated[str, update_value]
count: Annotated[int, update_value]
# Risk management team state
class RiskDebateState(TypedDict):
risky_history: Annotated[
str, "Risky Agent's Conversation history"
] # Conversation history
safe_history: Annotated[
str, "Safe Agent's Conversation history"
] # Conversation history
neutral_history: Annotated[
str, "Neutral Agent's Conversation history"
] # Conversation history
history: Annotated[str, "Conversation history"] # Conversation history
latest_speaker: Annotated[str, "Analyst that spoke last"]
current_risky_response: Annotated[
str, "Latest response by the risky analyst"
] # Last response
current_safe_response: Annotated[
str, "Latest response by the safe analyst"
] # Last response
current_neutral_response: Annotated[
str, "Latest response by the neutral analyst"
] # Last response
judge_decision: Annotated[str, "Judge's decision"]
count: Annotated[int, "Length of the current conversation"] # Conversation length
risky_history: Annotated[str, update_value]
safe_history: Annotated[str, update_value]
neutral_history: Annotated[str, update_value]
history: Annotated[str, update_value]
latest_speaker: Annotated[str, update_value]
current_risky_response: Annotated[str, update_value]
current_safe_response: Annotated[str, update_value]
current_neutral_response: Annotated[str, update_value]
judge_decision: Annotated[str, update_value]
count: Annotated[int, update_value]
class AgentState(MessagesState):
company_of_interest: Annotated[str, "Company that we are interested in trading"]
trade_date: Annotated[str, "What date we are trading at"]
sender: Annotated[str, "Agent that sent this message"]
# research step
market_report: Annotated[str, "Report from the Market Analyst"]
sentiment_report: Annotated[str, "Report from the Social Media Analyst"]
news_report: Annotated[
str, "Report from the News Researcher of current world affairs"
]
fundamentals_report: Annotated[str, "Report from the Fundamentals Researcher"]
# researcher team discussion step
investment_debate_state: Annotated[
InvestDebateState, "Current state of the debate on if to invest or not"
]
investment_plan: Annotated[str, "Plan generated by the Analyst"]
trader_investment_plan: Annotated[str, "Plan generated by the Trader"]
# risk management team discussion step
risk_debate_state: Annotated[
RiskDebateState, "Current state of the debate on evaluating risk"
]
final_trade_decision: Annotated[str, "Final decision made by the Risk Analysts"]
class AgentState(TypedDict):
"""Represents the state of our multi-agent system with separate message channels."""
# Basic information
company_of_interest: Annotated[str, update_value]
trade_date: Annotated[str, update_value]
# Analyst message channels
market_messages: Annotated[Sequence[BaseMessage], add_messages]
social_messages: Annotated[Sequence[BaseMessage], add_messages]
news_messages: Annotated[Sequence[BaseMessage], add_messages]
fundamentals_messages: Annotated[Sequence[BaseMessage], add_messages]
# Reports from analysts (using update_value to handle concurrent updates)
market_report: Annotated[Optional[str], update_value]
sentiment_report: Annotated[Optional[str], update_value]
news_report: Annotated[Optional[str], update_value]
fundamentals_report: Annotated[Optional[str], update_value]
# Debate states (using custom reducers to prevent concurrent update errors)
investment_debate_state: Annotated[Optional[InvestDebateState], merge_debate_state]
risk_debate_state: Annotated[Optional[RiskDebateState], merge_risk_debate_state]
# Investment and trading plans
investment_plan: Annotated[Optional[str], update_value]
trader_investment_plan: Annotated[Optional[str], update_value]
final_trade_decision: Annotated[Optional[str], update_value]

View File

@ -1,5 +1,6 @@
from .finnhub_utils import get_data_in_range
from .googlenews_utils import getNewsData
from .serpapi_utils import getNewsDataSerpAPI
from .yfin_utils import YFinanceUtils
from .reddit_utils import fetch_top_from_category
from .stockstats_utils import StockstatsUtils

View File

@ -3,6 +3,7 @@ from .reddit_utils import fetch_top_from_category
from .yfin_utils import *
from .stockstats_utils import *
from .googlenews_utils import *
from .serpapi_utils import getNewsDataSerpAPI
from .finnhub_utils import get_data_in_range
from dateutil.relativedelta import relativedelta
from concurrent.futures import ThreadPoolExecutor
@ -14,10 +15,13 @@ from tqdm import tqdm
import yfinance as yf
from openai import OpenAI
from .config import get_config, set_config, DATA_DIR
from ..default_config import DEFAULT_CONFIG
from dotenv import load_dotenv
# Load environment variables so OpenAI tools can access API keys
load_dotenv()
# Load from project root directory (three levels up from this file)
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
load_dotenv(os.path.join(project_root, ".env"))
def get_finnhub_news(
@ -308,9 +312,15 @@ def get_google_news(
before = start_date - relativedelta(days=look_back_days)
before = before.strftime("%Y-%m-%d")
# Log the API call
logger.info(f"🌐 Calling getNewsData with query='{query}', start='{before}', end='{curr_date}'")
news_results = getNewsData(query, before, curr_date)
# Log the API call - try SerpAPI first, fallback to web scraping
serpapi_key = DEFAULT_CONFIG.get("serpapi_key", "")
if serpapi_key:
logger.info(f"🌐 Calling SerpAPI with query='{query}', start='{before}', end='{curr_date}'")
news_results = getNewsDataSerpAPI(query, before, curr_date, serpapi_key)
else:
logger.info(f"🌐 SerpAPI key not found, falling back to web scraping")
logger.info(f"🌐 Calling getNewsData with query='{query}', start='{before}', end='{curr_date}'")
news_results = getNewsData(query, before, curr_date)
# Enhanced logging - Raw response
logger.info(f"🌐 RAW RESPONSE TYPE: {type(news_results)}")

View File

@ -0,0 +1,196 @@
import os
import time
from datetime import datetime
from typing import List, Dict, Any
from serpapi import GoogleSearch
import logging
logger = logging.getLogger(__name__)
def getNewsDataSerpAPI(query: str, start_date: str, end_date: str, serpapi_key: str = None) -> List[Dict[str, Any]]:
"""
Get news data using SerpAPI (much faster than web scraping).
Args:
query: Search query string
start_date: Start date in YYYY-MM-DD format
end_date: End date in YYYY-MM-DD format
serpapi_key: SerpAPI key (if not provided, will use environment variable)
Returns:
List of dictionaries with news data
"""
if not serpapi_key:
serpapi_key = os.getenv("SERPAPI_API_KEY")
if not serpapi_key:
logger.error("❌ SerpAPI key not found. Please set SERPAPI_API_KEY environment variable.")
raise ValueError("SerpAPI key not found. Please set SERPAPI_API_KEY environment variable.")
# Convert dates to Google News format if needed
if "-" in start_date:
start_dt = datetime.strptime(start_date, "%Y-%m-%d")
start_date = start_dt.strftime("%m/%d/%Y")
if "-" in end_date:
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
end_date = end_dt.strftime("%m/%d/%Y")
news_results = []
start_time = time.time()
try:
# SerpAPI parameters for Google News search
params = {
"engine": "google",
"q": query,
"tbm": "nws", # News search
"tbs": f"cdr:1,cd_min:{start_date},cd_max:{end_date}", # Date range
"api_key": serpapi_key,
"num": 100, # Get up to 100 results
"hl": "en", # Language
"gl": "us", # Country
}
logger.info(f"🔍 SerpAPI: Searching for '{query}' from {start_date} to {end_date}")
search = GoogleSearch(params)
results = search.get_dict()
# Check for errors
if "error" in results:
logger.error(f"❌ SerpAPI Error: {results['error']}")
raise Exception(f"SerpAPI Error: {results['error']}")
# Extract news results
news_items = results.get("news_results", [])
for item in news_items:
try:
news_result = {
"link": item.get("link", ""),
"title": item.get("title", "No title"),
"snippet": item.get("snippet", "No snippet"),
"date": item.get("date", "No date"),
"source": item.get("source", "Unknown source"),
}
news_results.append(news_result)
except Exception as e:
logger.warning(f"⚠️ Error processing news item: {e}")
continue
duration = time.time() - start_time
logger.info(f"✅ SerpAPI: Retrieved {len(news_results)} news items in {duration:.2f}s")
return news_results
except Exception as e:
duration = time.time() - start_time
logger.error(f"❌ SerpAPI Error after {duration:.2f}s: {str(e)}")
# Fallback to empty results rather than crashing
logger.info("🔄 Returning empty results as fallback")
return []
def getNewsDataSerpAPIWithPagination(query: str, start_date: str, end_date: str,
max_results: int = 300, serpapi_key: str = None) -> List[Dict[str, Any]]:
"""
Get news data using SerpAPI with pagination support for more results.
Args:
query: Search query string
start_date: Start date in YYYY-MM-DD format
end_date: End date in YYYY-MM-DD format
max_results: Maximum number of results to fetch
serpapi_key: SerpAPI key (if not provided, will use environment variable)
Returns:
List of dictionaries with news data
"""
if not serpapi_key:
serpapi_key = os.getenv("SERPAPI_API_KEY")
if not serpapi_key:
logger.error("❌ SerpAPI key not found. Please set SERPAPI_API_KEY environment variable.")
raise ValueError("SerpAPI key not found. Please set SERPAPI_API_KEY environment variable.")
# Convert dates to Google News format if needed
if "-" in start_date:
start_dt = datetime.strptime(start_date, "%Y-%m-%d")
start_date = start_dt.strftime("%m/%d/%Y")
if "-" in end_date:
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
end_date = end_dt.strftime("%m/%d/%Y")
all_news_results = []
start_time = time.time()
page = 0
try:
while len(all_news_results) < max_results:
# SerpAPI parameters for Google News search
params = {
"engine": "google",
"q": query,
"tbm": "nws", # News search
"tbs": f"cdr:1,cd_min:{start_date},cd_max:{end_date}", # Date range
"api_key": serpapi_key,
"num": 100, # Get up to 100 results per page
"start": page * 100, # Pagination offset
"hl": "en", # Language
"gl": "us", # Country
}
logger.info(f"🔍 SerpAPI: Page {page + 1} - Searching for '{query}' from {start_date} to {end_date}")
search = GoogleSearch(params)
results = search.get_dict()
# Check for errors
if "error" in results:
logger.error(f"❌ SerpAPI Error: {results['error']}")
break
# Extract news results
news_items = results.get("news_results", [])
if not news_items:
logger.info(f"📭 No more results found on page {page + 1}")
break
for item in news_items:
try:
news_result = {
"link": item.get("link", ""),
"title": item.get("title", "No title"),
"snippet": item.get("snippet", "No snippet"),
"date": item.get("date", "No date"),
"source": item.get("source", "Unknown source"),
}
all_news_results.append(news_result)
if len(all_news_results) >= max_results:
break
except Exception as e:
logger.warning(f"⚠️ Error processing news item: {e}")
continue
page += 1
# Add small delay between requests to be respectful
time.sleep(0.5)
duration = time.time() - start_time
logger.info(f"✅ SerpAPI: Retrieved {len(all_news_results)} news items in {duration:.2f}s across {page} pages")
return all_news_results[:max_results] # Ensure we don't exceed max_results
except Exception as e:
duration = time.time() - start_time
logger.error(f"❌ SerpAPI Error after {duration:.2f}s: {str(e)}")
# Return whatever we managed to collect
logger.info(f"🔄 Returning {len(all_news_results)} partial results as fallback")
return all_news_results

View File

@ -1,8 +1,12 @@
import os
from dotenv import load_dotenv
# Get the backend directory (parent of tradingagents package)
BACKEND_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
PROJECT_ROOT = os.path.abspath(os.path.join(BACKEND_DIR, ".."))
BACKEND_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
# Load environment variables from project root
load_dotenv(os.path.join(PROJECT_ROOT, ".env"))
DEFAULT_CONFIG = {
"project_dir": os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
@ -26,4 +30,6 @@ DEFAULT_CONFIG = {
"max_recur_limit": 100,
# Tool settings
"online_tools": True,
# SerpAPI settings
"serpapi_key": os.getenv("SERPAPI_API_KEY", ""),
}

View File

@ -1,7 +1,9 @@
# TradingAgents/graph/conditional_logic.py
from tradingagents.agents.utils.agent_states import AgentState
import logging
logger = logging.getLogger(__name__)
class ConditionalLogic:
"""Handles conditional logic for determining graph flow."""
@ -45,13 +47,25 @@ class ConditionalLogic:
def should_continue_debate(self, state: AgentState) -> str:
"""Determine if debate should continue."""
logger.info("🔄 DEBATE CONDITIONAL: Starting evaluation")
debate_state = state["investment_debate_state"]
count = debate_state["count"]
current_response = debate_state.get("current_response", "")
max_rounds = 2 * self.max_debate_rounds
logger.info(f"🔄 DEBATE CONDITIONAL: Count={count}, Max={max_rounds}")
logger.info(f"🔄 DEBATE CONDITIONAL: Current response starts with: '{current_response[:50]}...'")
if (
state["investment_debate_state"]["count"] >= 2 * self.max_debate_rounds
): # 3 rounds of back-and-forth between 2 agents
if count >= max_rounds: # 3 rounds of back-and-forth between 2 agents
logger.info("🔄 DEBATE CONDITIONAL: → Research Manager (max rounds reached)")
return "Research Manager"
if state["investment_debate_state"]["current_response"].startswith("Bull"):
if current_response.startswith("Bull"):
logger.info("🔄 DEBATE CONDITIONAL: → Bear Researcher (Bull just spoke)")
return "Bear Researcher"
logger.info("🔄 DEBATE CONDITIONAL: → Bull Researcher (default/Bear just spoke)")
return "Bull Researcher"
def should_continue_risk_analysis(self, state: AgentState) -> str:

View File

@ -26,7 +26,13 @@ class ToolCallTracker:
def __init__(self):
self.call_history = {} # analyst_type -> {tool_name: [(params_hash, params_str)]}
self.call_counts = {} # analyst_type -> {tool_name: count}
self.max_total_calls = 3 # Maximum total tool calls per analyst
# Different limits for different analyst types
self.max_total_calls = {
"market": 20, # Market analyst needs more tool calls for comprehensive analysis
"social": 3,
"news": 3,
"fundamentals": 3
}
self.total_calls = {} # analyst_type -> total_count
def _hash_params(self, params: dict) -> str:
@ -35,6 +41,10 @@ class ToolCallTracker:
sorted_params = json.dumps(params, sort_keys=True)
return hashlib.md5(sorted_params.encode()).hexdigest()
def _get_max_calls_for_analyst(self, analyst_type: str) -> int:
"""Get the maximum number of calls allowed for a specific analyst type."""
return self.max_total_calls.get(analyst_type, 3) # Default to 3 if not specified
def can_call_tool(self, analyst_type: str, tool_name: str, params: dict) -> Tuple[bool, str]:
"""Check if a tool can be called with given parameters."""
if analyst_type not in self.call_history:
@ -43,21 +53,22 @@ class ToolCallTracker:
self.total_calls[analyst_type] = 0
# Check total call limit for this analyst
if self.total_calls[analyst_type] >= self.max_total_calls:
return False, f"Analyst {analyst_type} has reached maximum total tool calls ({self.max_total_calls})"
max_calls = self._get_max_calls_for_analyst(analyst_type)
if self.total_calls[analyst_type] >= max_calls:
return False, f"Analyst {analyst_type} has reached maximum total tool calls ({max_calls})"
# Initialize tool tracking if first time
if tool_name not in self.call_history[analyst_type]:
self.call_history[analyst_type][tool_name] = []
self.call_counts[analyst_type][tool_name] = 0
# Check for duplicate parameters
# Check for duplicate parameters - each request/query must be different
param_hash = self._hash_params(params)
param_str = json.dumps(params, sort_keys=True)
for existing_hash, existing_params in self.call_history[analyst_type][tool_name]:
if param_hash == existing_hash:
return False, f"Tool {tool_name} already called with identical parameters: {existing_params}"
return False, f"Tool {tool_name} already called with identical parameters. Each request must be different. Previous: {existing_params}"
return True, "OK"
@ -79,7 +90,8 @@ class ToolCallTracker:
self.call_counts[analyst_type][tool_name] += 1
self.total_calls[analyst_type] += 1
logger.info(f"🔧 Recorded tool call: {analyst_type}/{tool_name} (total calls: {self.total_calls[analyst_type]})")
max_calls = self._get_max_calls_for_analyst(analyst_type)
logger.info(f"🔧 Recorded tool call: {analyst_type}/{tool_name} (total calls: {self.total_calls[analyst_type]}/{max_calls})")
class GraphSetup:
@ -242,25 +254,74 @@ class GraphSetup:
messages = state.get(message_key, [])
report = state.get(report_key, "")
# If report exists or too many messages, go to aggregator
if report or len(messages) > 6:
# If report exists, go to aggregator
if report:
return "aggregator"
# If no messages, go to aggregator
if not messages:
return "aggregator"
last_message = messages[-1]
# Check for tool calls
if hasattr(last_message, 'tool_calls') and last_message.tool_calls:
# Check if we've already hit the tool call limit
total_calls = self.tool_tracker.total_calls.get(atype, 0)
if total_calls >= self.tool_tracker.max_total_calls:
logger.warning(f" - Decision: AGGREGATOR (tool call limit reached: {total_calls})")
return "aggregator"
return "tools"
has_tool_calls = hasattr(last_message, 'tool_calls') and last_message.tool_calls
return "aggregator"
# Count tool messages to see how much data we have
tool_message_count = sum(1 for msg in messages
if hasattr(msg, 'type') and str(getattr(msg, 'type', '')) == 'tool')
# Check total tool calls made
total_calls = self.tool_tracker.total_calls.get(atype, 0)
max_calls = self.tool_tracker.max_total_calls.get(atype, 3)
# Special handling for market analyst
if atype == "market":
# Market analyst needs more cycles to gather comprehensive data
if has_tool_calls and total_calls < max_calls:
return "tools"
elif tool_message_count >= 4 and not has_tool_calls:
# Has enough data and no more tool calls - should generate report
return "aggregator"
elif total_calls >= max_calls:
# Hit max calls - force completion
return "aggregator"
elif has_tool_calls:
return "tools"
else:
return "aggregator"
# Special handling for social analyst
elif atype == "social":
# Social analyst needs multiple tool calls for comprehensive analysis
if has_tool_calls and total_calls < max_calls:
return "tools"
elif tool_message_count >= 2 and not has_tool_calls:
# Has enough data and no more tool calls - should generate report
return "aggregator"
elif total_calls >= max_calls:
# Hit max calls - force completion
return "aggregator"
elif has_tool_calls:
return "tools"
else:
return "aggregator"
# For news and fundamentals analysts
else:
if has_tool_calls and total_calls < max_calls:
return "tools"
elif tool_message_count >= 1 and not has_tool_calls:
# Has data and no more tool calls - should generate report
return "aggregator"
elif total_calls >= max_calls:
# Hit max calls - force completion
return "aggregator"
elif has_tool_calls:
return "tools"
else:
return "aggregator"
return should_continue_analyst
# Define conditional logic for tools
@ -271,15 +332,16 @@ class GraphSetup:
# Check total tool calls
total_calls = self.tool_tracker.total_calls.get(atype, 0)
if total_calls >= self.tool_tracker.max_total_calls:
return "aggregator"
max_calls = self.tool_tracker.max_total_calls.get(atype, 3)
# If we have enough messages, likely complete
if len(messages) >= 6:
return "aggregator"
# Count tool messages
tool_message_count = sum(1 for msg in messages
if hasattr(msg, 'type') and str(getattr(msg, 'type', '')) == 'tool')
# Otherwise, go back to analyst
# After tools execute, always go back to analyst to generate report
# The analyst will decide whether to make more tool calls or generate final report
return "analyst"
return should_continue_after_tools
# Add conditional edges for each analyst
@ -400,6 +462,10 @@ class GraphSetup:
messages = state.get(message_key, [])
logger.info(f"🧠 {analyst_type} analyst: Processing {len(messages)} messages")
# Debug: Show message types and tool call counts
tool_message_count = sum(1 for msg in messages if hasattr(msg, 'type') and str(getattr(msg, 'type', '')) == 'tool')
logger.info(f"🧠 {analyst_type} analyst: Tool messages in history: {tool_message_count}")
# Create a temporary state with the analyst's messages
temp_state = state.copy()
temp_state["messages"] = messages
@ -414,24 +480,93 @@ class GraphSetup:
updated_messages = result.get("messages", messages)
logger.info(f"🧠 {analyst_type} analyst: Updated from {len(messages)} to {len(updated_messages)} messages")
# Check if this is a final response
# Check if the analyst node directly returned a report
direct_report = result.get(report_key, "")
if direct_report:
logger.info(f"🧠 {analyst_type} analyst: ✅ DIRECT REPORT GENERATED ({len(direct_report)} chars)")
# Mark this report as completed
self.completed_reports.add(f"{analyst_type}_report_completed")
logger.info(f"🧠 {analyst_type} analyst: ✅ SETTING {report_key}")
# Return updates with the direct report
update = {
message_key: updated_messages,
report_key: direct_report
}
logger.info(f"{analyst_type.upper()} ANALYST COMPLETE")
return update
# If no direct report, check if this is a final response from message content
report = ""
if updated_messages:
last_message = updated_messages[-1]
has_tool_calls = hasattr(last_message, 'tool_calls') and last_message.tool_calls
has_content = hasattr(last_message, 'content') and last_message.content
# Count tool messages
logger.info(f"🧠 {analyst_type} analyst: Last message has_tool_calls={has_tool_calls}, has_content={has_content}")
# Initialize content variable
content = str(last_message.content) if has_content else ""
# Count tool messages in the full conversation
tool_result_count = sum(1 for msg in updated_messages
if hasattr(msg, 'type') and str(getattr(msg, 'type', '')) == 'tool')
logger.info(f"🧠 {analyst_type} analyst: Total tool results: {tool_result_count}")
# Also check for ToolMessage instances
if tool_result_count == 0:
tool_result_count = sum(1 for msg in updated_messages if isinstance(msg, ToolMessage))
logger.info(f"🧠 {analyst_type} analyst: ToolMessage instances: {tool_result_count}")
# Check if we should generate a final report
should_generate_report = False
# Generate report if no tool calls and has content
if has_content and not has_tool_calls:
content = str(last_message.content)
# No more tool calls, has content - likely final response
should_generate_report = True
logger.info(f"🧠 {analyst_type} analyst: Final response detected (no tool calls)")
elif has_content and tool_result_count > 0:
# Has content and tool results - might be ready for final summary
if analyst_type == "market":
# Market analyst needs comprehensive data
if tool_result_count >= 4:
should_generate_report = True
logger.info(f"🧠 {analyst_type} analyst: Market analyst with sufficient tool results ({tool_result_count})")
elif analyst_type == "social":
# Social analyst needs multiple sources
if tool_result_count >= 2:
should_generate_report = True
logger.info(f"🧠 {analyst_type} analyst: Social analyst with sufficient tool results ({tool_result_count})")
elif tool_result_count >= 1:
# Other analysts need fewer tools
should_generate_report = True
logger.info(f"🧠 {analyst_type} analyst: {analyst_type} analyst with tool results ({tool_result_count})")
elif not has_content and tool_result_count > 0:
# Has tool results but no content yet - might need to force completion
total_calls = self.tool_tracker.total_calls.get(analyst_type, 0)
max_calls = self.tool_tracker.max_total_calls.get(analyst_type, 3)
if total_calls >= max_calls or tool_result_count >= 4:
# Force completion with available data
logger.info(f"🧠 {analyst_type} analyst: Forcing completion with available data ({tool_result_count} tool results)")
should_generate_report = True
# Create a summary from the tool results
content = f"Analysis for {state.get('company_of_interest', 'unknown')} based on {tool_result_count} data sources and technical analysis."
# Special handling for market analyst - if it has many tool calls but no content yet,
# it might need to go through another cycle
if analyst_type == "market" and has_tool_calls and not has_content:
logger.info(f"🧠 {analyst_type} analyst: Market analyst making more tool calls")
if should_generate_report and content:
# Only consider it a report if it has substantial content
if len(content) > 200 or (tool_result_count > 0 and len(content) > 50):
if len(content) > 100 or (tool_result_count > 0 and len(content) > 20):
report = content
logger.info(f"🧠 {analyst_type} analyst: ✅ FINAL REPORT GENERATED ({len(content)} chars)")
logger.info(f"🧠 {analyst_type} analyst: ✅ FINAL REPORT GENERATED FROM MESSAGE ({len(content)} chars)")
else:
logger.info(f"🧠 {analyst_type} analyst: Content too short for report ({len(content)} chars)")
else:
logger.info(f"🧠 {analyst_type} analyst: Not ready for final report yet")
# Return updates
update = {message_key: updated_messages}
@ -481,8 +616,12 @@ class GraphSetup:
for i, tool_call in enumerate(last_msg.tool_calls):
try:
# Get tool call details
if hasattr(tool_call, 'name'):
# Get tool call details - handle both dict and object formats
if isinstance(tool_call, dict):
tool_name = tool_call.get('name', '')
tool_args = tool_call.get('args', {})
tool_call_id = tool_call.get('id', 'unknown')
elif hasattr(tool_call, 'name'):
tool_name = tool_call.name
tool_args = tool_call.args if hasattr(tool_call, 'args') else {}
tool_call_id = tool_call.id if hasattr(tool_call, 'id') else 'unknown'
@ -490,6 +629,10 @@ class GraphSetup:
logger.error(f"{analyst_type} tools: Unknown tool call format")
continue
if not tool_name:
logger.error(f"{analyst_type} tools: Empty tool name")
continue
# Check if the tool can be called
can_call, reason = self.tool_tracker.can_call_tool(analyst_type, tool_name, tool_args)
if not can_call:
@ -564,36 +707,13 @@ class GraphSetup:
if missing_reports:
logger.warning(f"📊 Aggregator: ❌ Missing reports: {missing_reports}")
# Initialize debate states
initial_investment_debate = {
"bull_history": "",
"bear_history": "",
"history": "",
"current_response": "",
"judge_decision": "",
"count": 0
}
initial_risk_debate = {
"risky_history": "",
"safe_history": "",
"neutral_history": "",
"history": "",
"latest_speaker": "",
"current_risky_response": "",
"current_safe_response": "",
"current_neutral_response": "",
"judge_decision": "",
"count": 0
}
logger.info("📊 Aggregator: Marking analysis phase as complete")
logger.info("✅ AGGREGATOR COMPLETE")
# Don't initialize debate states here - let the Bull Researcher do it
# This prevents concurrent update errors
return {
"analysis_complete": True,
"investment_debate_state": initial_investment_debate,
"risk_debate_state": initial_risk_debate
"analysis_complete": True
}
return aggregate
@ -683,20 +803,48 @@ class GraphSetup:
logger.info(f" - Safe analysis: {'✅ COMPLETE' if safe_response else '❌ MISSING'} ({len(safe_response)} chars)")
logger.info(f" - Neutral analysis: {'✅ COMPLETE' if neutral_response else '❌ MISSING'} ({len(neutral_response)} chars)")
# Combine all responses for Risk Judge input
combined_history = ""
if risky_response:
combined_history += f"Risky Analyst: {risky_response}\n\n"
if safe_response:
combined_history += f"Safe Analyst: {safe_response}\n\n"
if neutral_response:
combined_history += f"Neutral Analyst: {neutral_response}\n\n"
# Validate that we have at least some analysis
total_responses = len([r for r in [risky_response, safe_response, neutral_response] if r])
if total_responses == 0:
logger.error("⚡ Risk Aggregator: ❌ NO RISK ANALYSES AVAILABLE")
# Create fallback history
combined_history = "No risk analysis available from any analyst. Unable to provide risk assessment."
elif total_responses < 3:
logger.warning(f"⚡ Risk Aggregator: ⚠️ Only {total_responses}/3 risk analyses available")
# Combine available responses
combined_history = ""
if risky_response:
combined_history += f"**Risky Analyst**: {risky_response}\n\n"
if safe_response:
combined_history += f"**Safe Analyst**: {safe_response}\n\n"
if neutral_response:
combined_history += f"**Neutral Analyst**: {neutral_response}\n\n"
# Add note about missing analyses
missing_analysts = []
if not risky_response:
missing_analysts.append("Risky")
if not safe_response:
missing_analysts.append("Safe")
if not neutral_response:
missing_analysts.append("Neutral")
combined_history += f"**Note**: Missing analysis from {', '.join(missing_analysts)} analyst(s). Decision based on available data only."
else:
logger.info("⚡ Risk Aggregator: ✅ All risk analyses complete")
# Combine all responses for Risk Judge input
combined_history = ""
combined_history += f"**Risky Analyst**: {risky_response}\n\n"
combined_history += f"**Safe Analyst**: {safe_response}\n\n"
combined_history += f"**Neutral Analyst**: {neutral_response}\n\n"
# Update risk debate state with combined history
updated_risk_state = risk_debate_state.copy()
updated_risk_state["history"] = combined_history
updated_risk_state["count"] = 1 # Mark as ready for judgment
logger.info(f"⚡ Risk Aggregator: Combined history length: {len(combined_history)} chars")
logger.info("⚡ Risk Aggregator: Risk analyses aggregated for final judgment")
logger.info("✅ RISK AGGREGATOR COMPLETE")

96
test_api.sh Executable file
View File

@ -0,0 +1,96 @@
#!/bin/bash
# TradingAgents API Test Script
# Usage: ./test_api.sh TICKER [OPTIONS]
# Example: ./test_api.sh TSLA
# Example: ./test_api.sh AAPL --limit 20
# Example: ./test_api.sh META --timeout 60
# Default values
TICKER=${1:-AAPL}
LIMIT=""
TIMEOUT=""
BASE_URL="http://localhost:8000"
# Parse additional arguments
shift
while [[ $# -gt 0 ]]; do
case $1 in
--limit)
LIMIT="| head -n $2"
shift 2
;;
--timeout)
TIMEOUT="timeout $2"
shift 2
;;
--health)
echo "🏥 Testing health endpoint..."
curl -s "$BASE_URL/health" && echo
exit 0
;;
--help)
echo "Usage: $0 TICKER [OPTIONS]"
echo ""
echo "Options:"
echo " --limit N Show only first N events"
echo " --timeout N Stop after N seconds"
echo " --health Test health endpoint only"
echo " --help Show this help"
echo ""
echo "Examples:"
echo " $0 TSLA"
echo " $0 AAPL --limit 20"
echo " $0 META --timeout 60"
echo " $0 --health"
exit 0
;;
*)
echo "Unknown option: $1"
echo "Use --help for usage information"
exit 1
;;
esac
done
# Colors for output
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
BLUE='\033[0;34m'
NC='\033[0m' # No Color
echo -e "${BLUE}🚀 Testing TradingAgents API${NC}"
echo -e "${YELLOW}📊 Ticker: $TICKER${NC}"
echo -e "${YELLOW}🌐 URL: $BASE_URL/analyze/stream?ticker=$TICKER${NC}"
echo ""
# Test health first
echo -e "${BLUE}🏥 Checking server health...${NC}"
if curl -s "$BASE_URL/health" > /dev/null; then
echo -e "${GREEN}✅ Server is healthy${NC}"
else
echo -e "${RED}❌ Server is not responding${NC}"
echo -e "${YELLOW}💡 Make sure to start the server first:${NC}"
echo -e "${YELLOW} cd backend && source venv/bin/activate && python run_api.py${NC}"
exit 1
fi
echo ""
echo -e "${BLUE}📡 Starting streaming analysis...${NC}"
echo -e "${YELLOW}Press Ctrl+C to stop${NC}"
echo ""
# Build the command
CMD="curl -N -H \"Accept: text/event-stream\" -H \"Connection: keep-alive\" -H \"Cache-Control: no-cache\" \"$BASE_URL/analyze/stream?ticker=$TICKER\""
if [[ -n "$TIMEOUT" ]]; then
CMD="$TIMEOUT $CMD"
fi
if [[ -n "$LIMIT" ]]; then
CMD="$CMD $LIMIT"
fi
# Execute the command
eval $CMD