Merge f4519bcb84 into fa4d01c23a
This commit is contained in:
commit
859a5e5502
27
cli/main.py
27
cli/main.py
|
|
@ -25,6 +25,12 @@ from rich.align import Align
|
||||||
from rich.rule import Rule
|
from rich.rule import Rule
|
||||||
|
|
||||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||||
|
from tradingagents.graph.analyst_execution import (
|
||||||
|
AnalystWallTimeTracker,
|
||||||
|
build_analyst_execution_plan,
|
||||||
|
get_initial_analyst_node,
|
||||||
|
sync_analyst_tracker_from_chunk,
|
||||||
|
)
|
||||||
from tradingagents.default_config import DEFAULT_CONFIG
|
from tradingagents.default_config import DEFAULT_CONFIG
|
||||||
from cli.models import AnalystType
|
from cli.models import AnalystType
|
||||||
from cli.utils import *
|
from cli.utils import *
|
||||||
|
|
@ -810,7 +816,7 @@ ANALYST_REPORT_MAP = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def update_analyst_statuses(message_buffer, chunk):
|
def update_analyst_statuses(message_buffer, chunk, wall_time_tracker=None):
|
||||||
"""Update analyst statuses based on accumulated report state.
|
"""Update analyst statuses based on accumulated report state.
|
||||||
|
|
||||||
Logic:
|
Logic:
|
||||||
|
|
@ -824,6 +830,9 @@ def update_analyst_statuses(message_buffer, chunk):
|
||||||
selected = message_buffer.selected_analysts
|
selected = message_buffer.selected_analysts
|
||||||
found_active = False
|
found_active = False
|
||||||
|
|
||||||
|
if wall_time_tracker is not None:
|
||||||
|
sync_analyst_tracker_from_chunk(wall_time_tracker, chunk)
|
||||||
|
|
||||||
for analyst_key in ANALYST_ORDER:
|
for analyst_key in ANALYST_ORDER:
|
||||||
if analyst_key not in selected:
|
if analyst_key not in selected:
|
||||||
continue
|
continue
|
||||||
|
|
@ -950,6 +959,11 @@ def run_analysis():
|
||||||
# Normalize analyst selection to predefined order (selection is a 'set', order is fixed)
|
# Normalize analyst selection to predefined order (selection is a 'set', order is fixed)
|
||||||
selected_set = {analyst.value for analyst in selections["analysts"]}
|
selected_set = {analyst.value for analyst in selections["analysts"]}
|
||||||
selected_analyst_keys = [a for a in ANALYST_ORDER if a in selected_set]
|
selected_analyst_keys = [a for a in ANALYST_ORDER if a in selected_set]
|
||||||
|
analyst_execution_plan = build_analyst_execution_plan(
|
||||||
|
selected_analyst_keys,
|
||||||
|
concurrency_limit=config["analyst_concurrency_limit"],
|
||||||
|
)
|
||||||
|
analyst_wall_time_tracker = AnalystWallTimeTracker(analyst_execution_plan)
|
||||||
|
|
||||||
# Initialize the graph with callbacks bound to LLMs
|
# Initialize the graph with callbacks bound to LLMs
|
||||||
graph = TradingAgentsGraph(
|
graph = TradingAgentsGraph(
|
||||||
|
|
@ -1032,8 +1046,9 @@ def run_analysis():
|
||||||
update_display(layout, stats_handler=stats_handler, start_time=start_time)
|
update_display(layout, stats_handler=stats_handler, start_time=start_time)
|
||||||
|
|
||||||
# Update agent status to in_progress for the first analyst
|
# Update agent status to in_progress for the first analyst
|
||||||
first_analyst = f"{selections['analysts'][0].value.capitalize()} Analyst"
|
first_analyst = get_initial_analyst_node(analyst_execution_plan)
|
||||||
message_buffer.update_agent_status(first_analyst, "in_progress")
|
message_buffer.update_agent_status(first_analyst, "in_progress")
|
||||||
|
analyst_wall_time_tracker.mark_started(selected_analyst_keys[0])
|
||||||
update_display(layout, stats_handler=stats_handler, start_time=start_time)
|
update_display(layout, stats_handler=stats_handler, start_time=start_time)
|
||||||
|
|
||||||
# Create spinner text
|
# Create spinner text
|
||||||
|
|
@ -1073,7 +1088,11 @@ def run_analysis():
|
||||||
message_buffer.add_tool_call(tool_call.name, tool_call.args)
|
message_buffer.add_tool_call(tool_call.name, tool_call.args)
|
||||||
|
|
||||||
# Update analyst statuses based on report state (runs on every chunk)
|
# Update analyst statuses based on report state (runs on every chunk)
|
||||||
update_analyst_statuses(message_buffer, chunk)
|
update_analyst_statuses(
|
||||||
|
message_buffer,
|
||||||
|
chunk,
|
||||||
|
wall_time_tracker=analyst_wall_time_tracker,
|
||||||
|
)
|
||||||
|
|
||||||
# Research Team - Handle Investment Debate State
|
# Research Team - Handle Investment Debate State
|
||||||
if chunk.get("investment_debate_state"):
|
if chunk.get("investment_debate_state"):
|
||||||
|
|
@ -1162,6 +1181,7 @@ def run_analysis():
|
||||||
message_buffer.add_message(
|
message_buffer.add_message(
|
||||||
"System", f"Completed analysis for {selections['analysis_date']}"
|
"System", f"Completed analysis for {selections['analysis_date']}"
|
||||||
)
|
)
|
||||||
|
message_buffer.add_message("System", analyst_wall_time_tracker.format_summary())
|
||||||
|
|
||||||
# Update final report sections
|
# Update final report sections
|
||||||
for section in message_buffer.report_sections.keys():
|
for section in message_buffer.report_sections.keys():
|
||||||
|
|
@ -1172,6 +1192,7 @@ def run_analysis():
|
||||||
|
|
||||||
# Post-analysis prompts (outside Live context for clean interaction)
|
# Post-analysis prompts (outside Live context for clean interaction)
|
||||||
console.print("\n[bold cyan]Analysis Complete![/bold cyan]\n")
|
console.print("\n[bold cyan]Analysis Complete![/bold cyan]\n")
|
||||||
|
console.print(f"[dim]{analyst_wall_time_tracker.format_summary()}[/dim]")
|
||||||
|
|
||||||
# Prompt to save report
|
# Prompt to save report
|
||||||
save_choice = typer.prompt("Save report?", default="Y").strip().upper()
|
save_choice = typer.prompt("Save report?", default="Y").strip().upper()
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,84 @@
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from tradingagents.graph.analyst_execution import (
|
||||||
|
AnalystWallTimeTracker,
|
||||||
|
build_analyst_execution_plan,
|
||||||
|
get_initial_analyst_node,
|
||||||
|
sync_analyst_tracker_from_chunk,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AnalystExecutionPlanTests(unittest.TestCase):
|
||||||
|
def test_build_plan_preserves_selected_order(self):
|
||||||
|
plan = build_analyst_execution_plan(["news", "market"], concurrency_limit=2)
|
||||||
|
|
||||||
|
self.assertEqual([spec.key for spec in plan.specs], ["news", "market"])
|
||||||
|
self.assertEqual(plan.concurrency_limit, 2)
|
||||||
|
self.assertEqual(plan.specs[0].agent_node, "News Analyst")
|
||||||
|
self.assertEqual(plan.specs[0].tool_node, "tools_news")
|
||||||
|
self.assertEqual(plan.specs[0].clear_node, "Msg Clear News")
|
||||||
|
|
||||||
|
def test_rejects_unknown_analyst_keys(self):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
build_analyst_execution_plan(["market", "macro"])
|
||||||
|
|
||||||
|
def test_requires_positive_concurrency_limit(self):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
build_analyst_execution_plan(["market"], concurrency_limit=0)
|
||||||
|
|
||||||
|
def test_get_initial_analyst_node_uses_plan_metadata(self):
|
||||||
|
plan = build_analyst_execution_plan(["fundamentals", "news"])
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
get_initial_analyst_node(plan),
|
||||||
|
"Fundamentals Analyst",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AnalystWallTimeTrackerTests(unittest.TestCase):
|
||||||
|
def test_records_wall_time_when_analyst_completes(self):
|
||||||
|
plan = build_analyst_execution_plan(["market", "news"])
|
||||||
|
tracker = AnalystWallTimeTracker(plan)
|
||||||
|
|
||||||
|
tracker.mark_started("market", started_at=10.0)
|
||||||
|
tracker.mark_completed("market", completed_at=13.5)
|
||||||
|
|
||||||
|
self.assertEqual(tracker.get_wall_times(), {"market": 3.5})
|
||||||
|
|
||||||
|
def test_formats_summary_in_plan_order(self):
|
||||||
|
plan = build_analyst_execution_plan(["news", "market"])
|
||||||
|
tracker = AnalystWallTimeTracker(plan)
|
||||||
|
|
||||||
|
tracker.mark_started("market", started_at=20.0)
|
||||||
|
tracker.mark_completed("market", completed_at=22.25)
|
||||||
|
tracker.mark_started("news", started_at=10.0)
|
||||||
|
tracker.mark_completed("news", completed_at=14.0)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
tracker.format_summary(),
|
||||||
|
"Analyst wall time: News 4.00s | Market 2.25s",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_syncs_wall_time_from_sequential_chunks(self):
|
||||||
|
plan = build_analyst_execution_plan(["market", "news"])
|
||||||
|
tracker = AnalystWallTimeTracker(plan)
|
||||||
|
|
||||||
|
sync_analyst_tracker_from_chunk(tracker, {}, now=10.0)
|
||||||
|
self.assertEqual(tracker.get_wall_times(), {})
|
||||||
|
|
||||||
|
sync_analyst_tracker_from_chunk(
|
||||||
|
tracker,
|
||||||
|
{"market_report": "done"},
|
||||||
|
now=13.0,
|
||||||
|
)
|
||||||
|
self.assertEqual(tracker.get_wall_times(), {"market": 3.0})
|
||||||
|
|
||||||
|
sync_analyst_tracker_from_chunk(
|
||||||
|
tracker,
|
||||||
|
{"market_report": "done", "news_report": "done"},
|
||||||
|
now=18.0,
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
tracker.get_wall_times(),
|
||||||
|
{"market": 3.0, "news": 5.0},
|
||||||
|
)
|
||||||
|
|
@ -22,6 +22,7 @@ DEFAULT_CONFIG = {
|
||||||
"max_debate_rounds": 1,
|
"max_debate_rounds": 1,
|
||||||
"max_risk_discuss_rounds": 1,
|
"max_risk_discuss_rounds": 1,
|
||||||
"max_recur_limit": 100,
|
"max_recur_limit": 100,
|
||||||
|
"analyst_concurrency_limit": 1,
|
||||||
# Data vendor configuration
|
# Data vendor configuration
|
||||||
# Category-level configuration (default for all tools in category)
|
# Category-level configuration (default for all tools in category)
|
||||||
"data_vendors": {
|
"data_vendors": {
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,136 @@
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from time import monotonic
|
||||||
|
from typing import Dict, Iterable, List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class AnalystNodeSpec:
|
||||||
|
key: str
|
||||||
|
agent_node: str
|
||||||
|
clear_node: str
|
||||||
|
tool_node: str
|
||||||
|
report_key: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class AnalystExecutionPlan:
|
||||||
|
specs: List[AnalystNodeSpec]
|
||||||
|
concurrency_limit: int
|
||||||
|
|
||||||
|
|
||||||
|
ANALYST_NODE_SPECS: Dict[str, AnalystNodeSpec] = {
|
||||||
|
"market": AnalystNodeSpec(
|
||||||
|
key="market",
|
||||||
|
agent_node="Market Analyst",
|
||||||
|
clear_node="Msg Clear Market",
|
||||||
|
tool_node="tools_market",
|
||||||
|
report_key="market_report",
|
||||||
|
),
|
||||||
|
"social": AnalystNodeSpec(
|
||||||
|
key="social",
|
||||||
|
agent_node="Social Analyst",
|
||||||
|
clear_node="Msg Clear Social",
|
||||||
|
tool_node="tools_social",
|
||||||
|
report_key="sentiment_report",
|
||||||
|
),
|
||||||
|
"news": AnalystNodeSpec(
|
||||||
|
key="news",
|
||||||
|
agent_node="News Analyst",
|
||||||
|
clear_node="Msg Clear News",
|
||||||
|
tool_node="tools_news",
|
||||||
|
report_key="news_report",
|
||||||
|
),
|
||||||
|
"fundamentals": AnalystNodeSpec(
|
||||||
|
key="fundamentals",
|
||||||
|
agent_node="Fundamentals Analyst",
|
||||||
|
clear_node="Msg Clear Fundamentals",
|
||||||
|
tool_node="tools_fundamentals",
|
||||||
|
report_key="fundamentals_report",
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def build_analyst_execution_plan(
|
||||||
|
selected_analysts: Iterable[str],
|
||||||
|
concurrency_limit: int = 1,
|
||||||
|
) -> AnalystExecutionPlan:
|
||||||
|
if concurrency_limit < 1:
|
||||||
|
raise ValueError("analyst concurrency limit must be >= 1")
|
||||||
|
|
||||||
|
specs: List[AnalystNodeSpec] = []
|
||||||
|
for analyst_key in selected_analysts:
|
||||||
|
spec = ANALYST_NODE_SPECS.get(analyst_key)
|
||||||
|
if spec is None:
|
||||||
|
raise ValueError(f"unknown analyst key: {analyst_key}")
|
||||||
|
specs.append(spec)
|
||||||
|
|
||||||
|
if not specs:
|
||||||
|
raise ValueError("at least one analyst must be selected")
|
||||||
|
|
||||||
|
return AnalystExecutionPlan(specs=specs, concurrency_limit=concurrency_limit)
|
||||||
|
|
||||||
|
|
||||||
|
def get_initial_analyst_node(plan: AnalystExecutionPlan) -> str:
|
||||||
|
return plan.specs[0].agent_node
|
||||||
|
|
||||||
|
|
||||||
|
class AnalystWallTimeTracker:
|
||||||
|
def __init__(self, plan: AnalystExecutionPlan):
|
||||||
|
self.plan = plan
|
||||||
|
self._started_at: Dict[str, float] = {}
|
||||||
|
self._wall_times: Dict[str, float] = {}
|
||||||
|
|
||||||
|
def mark_started(self, analyst_key: str, started_at: Optional[float] = None) -> None:
|
||||||
|
if analyst_key not in ANALYST_NODE_SPECS:
|
||||||
|
raise ValueError(f"unknown analyst key: {analyst_key}")
|
||||||
|
self._started_at.setdefault(analyst_key, monotonic() if started_at is None else started_at)
|
||||||
|
|
||||||
|
def mark_completed(
|
||||||
|
self,
|
||||||
|
analyst_key: str,
|
||||||
|
completed_at: Optional[float] = None,
|
||||||
|
) -> None:
|
||||||
|
if analyst_key not in ANALYST_NODE_SPECS:
|
||||||
|
raise ValueError(f"unknown analyst key: {analyst_key}")
|
||||||
|
if analyst_key in self._wall_times:
|
||||||
|
return
|
||||||
|
started_at = self._started_at.get(analyst_key)
|
||||||
|
if started_at is None:
|
||||||
|
return
|
||||||
|
finished_at = monotonic() if completed_at is None else completed_at
|
||||||
|
self._wall_times[analyst_key] = max(0.0, finished_at - started_at)
|
||||||
|
|
||||||
|
def get_wall_times(self) -> Dict[str, float]:
|
||||||
|
return dict(self._wall_times)
|
||||||
|
|
||||||
|
def format_summary(self) -> str:
|
||||||
|
parts = []
|
||||||
|
for spec in self.plan.specs:
|
||||||
|
duration = self._wall_times.get(spec.key)
|
||||||
|
if duration is not None:
|
||||||
|
label = spec.agent_node.removesuffix(" Analyst")
|
||||||
|
parts.append(f"{label} {duration:.2f}s")
|
||||||
|
if not parts:
|
||||||
|
return "Analyst wall time: pending"
|
||||||
|
return "Analyst wall time: " + " | ".join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
def sync_analyst_tracker_from_chunk(
|
||||||
|
tracker: AnalystWallTimeTracker,
|
||||||
|
chunk: Dict[str, str],
|
||||||
|
now: Optional[float] = None,
|
||||||
|
) -> None:
|
||||||
|
current_time = monotonic() if now is None else now
|
||||||
|
active_found = False
|
||||||
|
|
||||||
|
for spec in tracker.plan.specs:
|
||||||
|
has_report = bool(chunk.get(spec.report_key))
|
||||||
|
|
||||||
|
if has_report:
|
||||||
|
tracker.mark_started(spec.key, started_at=current_time)
|
||||||
|
tracker.mark_completed(spec.key, completed_at=current_time)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not active_found:
|
||||||
|
tracker.mark_started(spec.key, started_at=current_time)
|
||||||
|
active_found = True
|
||||||
|
|
@ -7,6 +7,7 @@ from langgraph.prebuilt import ToolNode
|
||||||
from tradingagents.agents import *
|
from tradingagents.agents import *
|
||||||
from tradingagents.agents.utils.agent_states import AgentState
|
from tradingagents.agents.utils.agent_states import AgentState
|
||||||
|
|
||||||
|
from .analyst_execution import build_analyst_execution_plan
|
||||||
from .conditional_logic import ConditionalLogic
|
from .conditional_logic import ConditionalLogic
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -24,6 +25,7 @@ class GraphSetup:
|
||||||
invest_judge_memory,
|
invest_judge_memory,
|
||||||
portfolio_manager_memory,
|
portfolio_manager_memory,
|
||||||
conditional_logic: ConditionalLogic,
|
conditional_logic: ConditionalLogic,
|
||||||
|
analyst_concurrency_limit: int = 1,
|
||||||
):
|
):
|
||||||
"""Initialize with required components."""
|
"""Initialize with required components."""
|
||||||
self.quick_thinking_llm = quick_thinking_llm
|
self.quick_thinking_llm = quick_thinking_llm
|
||||||
|
|
@ -35,6 +37,7 @@ class GraphSetup:
|
||||||
self.invest_judge_memory = invest_judge_memory
|
self.invest_judge_memory = invest_judge_memory
|
||||||
self.portfolio_manager_memory = portfolio_manager_memory
|
self.portfolio_manager_memory = portfolio_manager_memory
|
||||||
self.conditional_logic = conditional_logic
|
self.conditional_logic = conditional_logic
|
||||||
|
self.analyst_concurrency_limit = analyst_concurrency_limit
|
||||||
|
|
||||||
def setup_graph(
|
def setup_graph(
|
||||||
self, selected_analysts=["market", "social", "news", "fundamentals"]
|
self, selected_analysts=["market", "social", "news", "fundamentals"]
|
||||||
|
|
@ -48,41 +51,17 @@ class GraphSetup:
|
||||||
- "news": News analyst
|
- "news": News analyst
|
||||||
- "fundamentals": Fundamentals analyst
|
- "fundamentals": Fundamentals analyst
|
||||||
"""
|
"""
|
||||||
if len(selected_analysts) == 0:
|
plan = build_analyst_execution_plan(
|
||||||
raise ValueError("Trading Agents Graph Setup Error: no analysts selected!")
|
selected_analysts,
|
||||||
|
concurrency_limit=self.analyst_concurrency_limit,
|
||||||
# Create analyst nodes
|
|
||||||
analyst_nodes = {}
|
|
||||||
delete_nodes = {}
|
|
||||||
tool_nodes = {}
|
|
||||||
|
|
||||||
if "market" in selected_analysts:
|
|
||||||
analyst_nodes["market"] = create_market_analyst(
|
|
||||||
self.quick_thinking_llm
|
|
||||||
)
|
)
|
||||||
delete_nodes["market"] = create_msg_delete()
|
|
||||||
tool_nodes["market"] = self.tool_nodes["market"]
|
|
||||||
|
|
||||||
if "social" in selected_analysts:
|
analyst_factories = {
|
||||||
analyst_nodes["social"] = create_social_media_analyst(
|
"market": lambda: create_market_analyst(self.quick_thinking_llm),
|
||||||
self.quick_thinking_llm
|
"social": lambda: create_social_media_analyst(self.quick_thinking_llm),
|
||||||
)
|
"news": lambda: create_news_analyst(self.quick_thinking_llm),
|
||||||
delete_nodes["social"] = create_msg_delete()
|
"fundamentals": lambda: create_fundamentals_analyst(self.quick_thinking_llm),
|
||||||
tool_nodes["social"] = self.tool_nodes["social"]
|
}
|
||||||
|
|
||||||
if "news" in selected_analysts:
|
|
||||||
analyst_nodes["news"] = create_news_analyst(
|
|
||||||
self.quick_thinking_llm
|
|
||||||
)
|
|
||||||
delete_nodes["news"] = create_msg_delete()
|
|
||||||
tool_nodes["news"] = self.tool_nodes["news"]
|
|
||||||
|
|
||||||
if "fundamentals" in selected_analysts:
|
|
||||||
analyst_nodes["fundamentals"] = create_fundamentals_analyst(
|
|
||||||
self.quick_thinking_llm
|
|
||||||
)
|
|
||||||
delete_nodes["fundamentals"] = create_msg_delete()
|
|
||||||
tool_nodes["fundamentals"] = self.tool_nodes["fundamentals"]
|
|
||||||
|
|
||||||
# Create researcher and manager nodes
|
# Create researcher and manager nodes
|
||||||
bull_researcher_node = create_bull_researcher(
|
bull_researcher_node = create_bull_researcher(
|
||||||
|
|
@ -108,12 +87,10 @@ class GraphSetup:
|
||||||
workflow = StateGraph(AgentState)
|
workflow = StateGraph(AgentState)
|
||||||
|
|
||||||
# Add analyst nodes to the graph
|
# Add analyst nodes to the graph
|
||||||
for analyst_type, node in analyst_nodes.items():
|
for spec in plan.specs:
|
||||||
workflow.add_node(f"{analyst_type.capitalize()} Analyst", node)
|
workflow.add_node(spec.agent_node, analyst_factories[spec.key]())
|
||||||
workflow.add_node(
|
workflow.add_node(spec.clear_node, create_msg_delete())
|
||||||
f"Msg Clear {analyst_type.capitalize()}", delete_nodes[analyst_type]
|
workflow.add_node(spec.tool_node, self.tool_nodes[spec.key])
|
||||||
)
|
|
||||||
workflow.add_node(f"tools_{analyst_type}", tool_nodes[analyst_type])
|
|
||||||
|
|
||||||
# Add other nodes
|
# Add other nodes
|
||||||
workflow.add_node("Bull Researcher", bull_researcher_node)
|
workflow.add_node("Bull Researcher", bull_researcher_node)
|
||||||
|
|
@ -127,27 +104,25 @@ class GraphSetup:
|
||||||
|
|
||||||
# Define edges
|
# Define edges
|
||||||
# Start with the first analyst
|
# Start with the first analyst
|
||||||
first_analyst = selected_analysts[0]
|
workflow.add_edge(START, plan.specs[0].agent_node)
|
||||||
workflow.add_edge(START, f"{first_analyst.capitalize()} Analyst")
|
|
||||||
|
|
||||||
# Connect analysts in sequence
|
# Connect analysts in sequence
|
||||||
for i, analyst_type in enumerate(selected_analysts):
|
for i, spec in enumerate(plan.specs):
|
||||||
current_analyst = f"{analyst_type.capitalize()} Analyst"
|
current_analyst = spec.agent_node
|
||||||
current_tools = f"tools_{analyst_type}"
|
current_tools = spec.tool_node
|
||||||
current_clear = f"Msg Clear {analyst_type.capitalize()}"
|
current_clear = spec.clear_node
|
||||||
|
|
||||||
# Add conditional edges for current analyst
|
# Add conditional edges for current analyst
|
||||||
workflow.add_conditional_edges(
|
workflow.add_conditional_edges(
|
||||||
current_analyst,
|
current_analyst,
|
||||||
getattr(self.conditional_logic, f"should_continue_{analyst_type}"),
|
getattr(self.conditional_logic, f"should_continue_{spec.key}"),
|
||||||
[current_tools, current_clear],
|
[current_tools, current_clear],
|
||||||
)
|
)
|
||||||
workflow.add_edge(current_tools, current_analyst)
|
workflow.add_edge(current_tools, current_analyst)
|
||||||
|
|
||||||
# Connect to next analyst or to Bull Researcher if this is the last analyst
|
# Connect to next analyst or to Bull Researcher if this is the last analyst
|
||||||
if i < len(selected_analysts) - 1:
|
if i < len(plan.specs) - 1:
|
||||||
next_analyst = f"{selected_analysts[i+1].capitalize()} Analyst"
|
workflow.add_edge(current_clear, plan.specs[i + 1].agent_node)
|
||||||
workflow.add_edge(current_clear, next_analyst)
|
|
||||||
else:
|
else:
|
||||||
workflow.add_edge(current_clear, "Bull Researcher")
|
workflow.add_edge(current_clear, "Bull Researcher")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -117,6 +117,7 @@ class TradingAgentsGraph:
|
||||||
self.invest_judge_memory,
|
self.invest_judge_memory,
|
||||||
self.portfolio_manager_memory,
|
self.portfolio_manager_memory,
|
||||||
self.conditional_logic,
|
self.conditional_logic,
|
||||||
|
analyst_concurrency_limit=self.config.get("analyst_concurrency_limit", 1),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.propagator = Propagator()
|
self.propagator = Propagator()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue