From 23dd1b694dac4c286c299b9a14bff2b103257309 Mon Sep 17 00:00:00 2001 From: Zhigong Liu Date: Mon, 20 Apr 2026 16:34:25 -0400 Subject: [PATCH] feat: add real sentiment analyst with Reddit and Fear & Greed dataflows Co-Authored-By: Claude Sonnet 4.6 --- tests/test_sentiment_tools.py | 215 ++++++++++++++++++ .../agents/analysts/social_media_analyst.py | 23 +- tradingagents/agents/utils/sentiment_tools.py | 35 +++ tradingagents/dataflows/fear_greed.py | 56 +++++ tradingagents/dataflows/interface.py | 16 ++ tradingagents/dataflows/reddit_sentiment.py | 211 +++++++++++++++++ tradingagents/default_config.py | 1 + tradingagents/graph/setup.py | 19 +- tradingagents/graph/trading_graph.py | 5 +- 9 files changed, 569 insertions(+), 12 deletions(-) create mode 100644 tests/test_sentiment_tools.py create mode 100644 tradingagents/agents/utils/sentiment_tools.py create mode 100644 tradingagents/dataflows/fear_greed.py create mode 100644 tradingagents/dataflows/reddit_sentiment.py diff --git a/tests/test_sentiment_tools.py b/tests/test_sentiment_tools.py new file mode 100644 index 00000000..1d56fdab --- /dev/null +++ b/tests/test_sentiment_tools.py @@ -0,0 +1,215 @@ +"""Mock-based unit tests for Reddit sentiment and Fear & Greed dataflows.""" + +import pytest +import requests +from unittest.mock import MagicMock, patch + +from tradingagents.dataflows.reddit_sentiment import get_reddit_sentiment +from tradingagents.dataflows.fear_greed import get_fear_greed + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_post(post_id, title, score=100, num_comments=50, upvote_ratio=0.9, + flair=None, created_utc=9_999_999_999): + return {"kind": "t3", "data": { + "id": post_id, + "title": title, + "score": score, + "num_comments": num_comments, + "upvote_ratio": upvote_ratio, + "link_flair_text": flair, + "created_utc": created_utc, + }} + + +def _search_response(posts): + return {"data": {"children": posts}} + + +def _comment_response(comments): + comment_items = [ + {"kind": "t1", "data": {"author": "user1", "body": c}} + for c in comments + ] + return [ + {"data": {"children": []}}, # post listing (unused) + {"data": {"children": comment_items}}, + ] + + +# --------------------------------------------------------------------------- +# Reddit — get_reddit_sentiment +# --------------------------------------------------------------------------- + +class TestRedditSentiment: + + def _patch_search(self, posts_by_subreddit): + """Return a mock requests.get that returns given posts per subreddit.""" + def fake_get(url, params=None, headers=None, timeout=None): + resp = MagicMock() + resp.ok = True + resp.status_code = 200 + resp.encoding = "utf-8" + subreddit = url.split("/r/")[1].split("/")[0] + posts = posts_by_subreddit.get(subreddit, []) + resp.json.return_value = _search_response(posts) + return resp + return fake_get + + def test_happy_path_returns_formatted_post(self): + posts = {"wallstreetbets": [ + _make_post("abc1", "NVDA calls printing today", score=500, num_comments=80, upvote_ratio=0.92) + ], "stocks": [], "options": []} + + with patch("tradingagents.dataflows.reddit_sentiment.requests.get", side_effect=self._patch_search(posts)), \ + patch("tradingagents.dataflows.reddit_sentiment._get_company_name", return_value=""): + result = get_reddit_sentiment("NVDA", days=7) + + assert "NVDA" in result + assert "NVDA calls printing today" in result + assert "Score: 500" in result + assert "Comments: 80" in result + assert "92%" in result + + def test_no_posts_returns_informative_message(self): + empty = {"wallstreetbets": [], "stocks": [], "options": []} + + with patch("tradingagents.dataflows.reddit_sentiment.requests.get", side_effect=self._patch_search(empty)), \ + patch("tradingagents.dataflows.reddit_sentiment._get_company_name", return_value=""): + result = get_reddit_sentiment("XYZQ", days=7) + + assert "No Reddit posts found" in result + assert "XYZQ" in result + + def test_429_skips_subreddit_and_returns_no_posts_message(self): + """429 from all subreddits → no posts collected → informative message returned.""" + def rate_limited(*args, **kwargs): + resp = MagicMock() + resp.ok = False + resp.status_code = 429 + return resp + + with patch("tradingagents.dataflows.reddit_sentiment.requests.get", side_effect=rate_limited), \ + patch("tradingagents.dataflows.reddit_sentiment._get_company_name", return_value=""): + result = get_reddit_sentiment("NVDA", days=7) + + assert "No Reddit posts found" in result + assert "NVDA" in result + + def test_network_error_skips_subreddit_and_returns_no_posts_message(self): + """Network failure on all subreddits → no posts collected → informative message returned.""" + with patch("tradingagents.dataflows.reddit_sentiment.requests.get", + side_effect=requests.RequestException("connection reset")), \ + patch("tradingagents.dataflows.reddit_sentiment._get_company_name", return_value=""): + result = get_reddit_sentiment("NVDA", days=7) + + assert "No Reddit posts found" in result + assert "NVDA" in result + + def test_title_filter_removes_off_topic_posts(self): + """Posts whose title doesn't contain ticker or company name are dropped.""" + posts = {"wallstreetbets": [ + _make_post("abc1", "SanDisk joins QQQ today", score=900), # off-topic + _make_post("abc2", "NVDA calls printing today", score=100), # on-topic + ], "stocks": [], "options": []} + + with patch("tradingagents.dataflows.reddit_sentiment.requests.get", side_effect=self._patch_search(posts)), \ + patch("tradingagents.dataflows.reddit_sentiment._get_company_name", return_value=""): + result = get_reddit_sentiment("NVDA", days=7) + + assert "SanDisk" not in result + assert "NVDA calls printing today" in result + + def test_company_name_keyword_matches_title(self): + """Posts containing company name but not ticker are included.""" + posts = {"wallstreetbets": [ + _make_post("abc1", "Nvidia GPU demand surging", score=200), + ], "stocks": [], "options": []} + + with patch("tradingagents.dataflows.reddit_sentiment.requests.get", side_effect=self._patch_search(posts)), \ + patch("tradingagents.dataflows.reddit_sentiment._get_company_name", return_value="Nvidia Corp"): + result = get_reddit_sentiment("NVDA", days=7) + + assert "Nvidia GPU demand surging" in result + + def test_deduplication_across_subreddits(self): + """Same post appearing in multiple subreddit results is only shown once.""" + same_post = _make_post("dup1", "NVDA bull case", score=50) + posts = { + "wallstreetbets": [same_post], + "stocks": [same_post], + "options": [], + } + + with patch("tradingagents.dataflows.reddit_sentiment.requests.get", side_effect=self._patch_search(posts)), \ + patch("tradingagents.dataflows.reddit_sentiment._get_company_name", return_value=""): + result = get_reddit_sentiment("NVDA", days=7) + + assert result.count("NVDA bull case") == 1 + + +# --------------------------------------------------------------------------- +# Fear & Greed — get_fear_greed +# --------------------------------------------------------------------------- + +class TestFearGreed: + + def _fng_response(self, days): + import time + data = [] + for i in range(days): + ts = int(time.time()) - i * 86400 + data.append({ + "value": str(30 + i), + "value_classification": "Fear", + "timestamp": str(ts), + }) + return {"data": data} + + def test_happy_path_returns_n_entries(self): + resp = MagicMock() + resp.ok = True + resp.status_code = 200 + resp.encoding = "utf-8" + resp.json.return_value = self._fng_response(7) + + with patch("tradingagents.dataflows.fear_greed.requests.get", return_value=resp): + result = get_fear_greed(7) + + lines = [l for l in result.splitlines() if "Score:" in l] + assert len(lines) == 7 + assert "Fear" in result + assert "/100" in result + + def test_single_day_returns_one_entry(self): + resp = MagicMock() + resp.ok = True + resp.status_code = 200 + resp.encoding = "utf-8" + resp.json.return_value = self._fng_response(1) + + with patch("tradingagents.dataflows.fear_greed.requests.get", return_value=resp): + result = get_fear_greed(1) + + lines = [l for l in result.splitlines() if "Score:" in l] + assert len(lines) == 1 + + def test_api_failure_returns_empty_string(self): + resp = MagicMock() + resp.ok = False + resp.status_code = 500 + + with patch("tradingagents.dataflows.fear_greed.requests.get", return_value=resp): + result = get_fear_greed(7) + + assert result == "" + + def test_network_error_returns_empty_string(self): + with patch("tradingagents.dataflows.fear_greed.requests.get", + side_effect=requests.RequestException("timeout")): + result = get_fear_greed(7) + + assert result == "" diff --git a/tradingagents/agents/analysts/social_media_analyst.py b/tradingagents/agents/analysts/social_media_analyst.py index 34a53c46..93367cb4 100644 --- a/tradingagents/agents/analysts/social_media_analyst.py +++ b/tradingagents/agents/analysts/social_media_analyst.py @@ -1,5 +1,6 @@ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder -from tradingagents.agents.utils.agent_utils import build_instrument_context, get_language_instruction, get_news +from tradingagents.agents.utils.agent_utils import build_instrument_context, get_language_instruction +from tradingagents.agents.utils.sentiment_tools import get_reddit_sentiment, get_market_fear_greed from tradingagents.dataflows.config import get_config @@ -9,12 +10,26 @@ def create_social_media_analyst(llm): instrument_context = build_instrument_context(state["company_of_interest"]) tools = [ - get_news, + get_reddit_sentiment, + get_market_fear_greed, ] system_message = ( - "You are a social media and company specific news researcher/analyst tasked with analyzing social media posts, recent company news, and public sentiment for a specific company over the past week. You will be given a company's name your objective is to write a comprehensive long report detailing your analysis, insights, and implications for traders and investors on this company's current state after looking at social media and what people are saying about that company, analyzing sentiment data of what people feel each day about the company, and looking at recent company news. Use the get_news(query, start_date, end_date) tool to search for company-specific news and social media discussions. Try to look at all sources possible from social media to sentiment to news. Provide specific, actionable insights with supporting evidence to help traders make informed decisions." - + """ Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read.""" + "You are a Sentiment Analyst. Your job is to gauge retail investor sentiment " + "and macro market mood for a specific stock.\n\n" + "Use get_reddit_sentiment(ticker, days) to fetch recent Reddit posts from " + "r/wallstreetbets, r/stocks, and r/options. Analyse titles, scores, comment " + "counts, upvote ratios, and the actual comments to assess retail mood.\n\n" + "Use get_market_fear_greed(days) to fetch the CNN Fear & Greed Index — a " + "market-wide macro signal. Use it to contextualise retail sentiment: bullish " + "Reddit posts carry more weight in a Greed market; bearish posts in an Extreme " + "Fear market may signal capitulation.\n\n" + "If Reddit returns no posts (obscure or small-cap ticker), state that clearly — " + "absence of retail coverage is itself a signal.\n\n" + "Write a comprehensive sentiment report covering: overall retail mood (bullish / " + "bearish / mixed), engagement level, notable narratives, current Fear & Greed " + "reading, and implications for short-term trader sentiment. " + "Append a Markdown table at the end summarising key data points." + get_language_instruction() ) diff --git a/tradingagents/agents/utils/sentiment_tools.py b/tradingagents/agents/utils/sentiment_tools.py new file mode 100644 index 00000000..b257bf3c --- /dev/null +++ b/tradingagents/agents/utils/sentiment_tools.py @@ -0,0 +1,35 @@ +from langchain_core.tools import tool +from typing import Annotated +from tradingagents.dataflows.interface import route_to_vendor + + +@tool +def get_market_fear_greed( + days: Annotated[int, "Number of past days to fetch (default 7)"] = 7, +) -> str: + """ + Fetch the CNN Fear & Greed Index time series from alternative.me. + Returns one entry per day with a numeric score (0–100) and classification + label (Extreme Fear / Fear / Neutral / Greed / Extreme Greed). + This is a market-wide macro signal — not ticker-specific. Use it to + contextualise retail sentiment against the broader market mood. + """ + return route_to_vendor("get_market_fear_greed", days) + + +@tool +def get_reddit_sentiment( + ticker: Annotated[str, "Stock ticker symbol, e.g. NVDA or AAPL"], + days: Annotated[int, "Number of past days to search for posts (default 3)"] = 3, +) -> str: + """ + Fetch recent Reddit posts mentioning a stock ticker from investing subreddits + (r/wallstreetbets, r/stocks, r/options). Returns post titles, upvote scores, + comment counts, and upvote ratios as a formatted string. This is a + ticker-specific retail sentiment signal, not a news source. + Automatically searches by both ticker symbol and company name (e.g. NVDA OR + Nvidia) so posts using the company name are not missed. + Returns an empty string on API failure; returns a 'no posts found' message + for obscure tickers with no Reddit coverage. + """ + return route_to_vendor("get_reddit_sentiment", ticker, days) diff --git a/tradingagents/dataflows/fear_greed.py b/tradingagents/dataflows/fear_greed.py new file mode 100644 index 00000000..b8169b15 --- /dev/null +++ b/tradingagents/dataflows/fear_greed.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +"""CNN Fear & Greed Index fetching via alternative.me public API (no auth required).""" + +import logging +import requests +from datetime import datetime, timezone + +logger = logging.getLogger(__name__) + +_URL = "https://api.alternative.me/fng/" +_TIMEOUT = 10 + + +def get_fear_greed(days: int = 7) -> str: + """ + Fetch the CNN Fear & Greed Index time series from alternative.me. + + Returns one entry per day with a numeric score (0–100) and classification + label (Extreme Fear / Fear / Neutral / Greed / Extreme Greed). This is a + market-wide macro signal — not ticker-specific. + + Args: + days: Number of past days to fetch (default 7) + + Returns: + Formatted string of daily entries, or empty string on API failure. + """ + try: + r = requests.get(_URL, params={"limit": days}, timeout=_TIMEOUT) + except requests.RequestException as e: + logger.warning("Fear & Greed API request failed: %s", e) + return "" + + if not r.ok: + logger.warning("Fear & Greed API returned HTTP %s", r.status_code) + return "" + + r.encoding = "utf-8" + try: + data = r.json().get("data", []) + except ValueError: + logger.warning("Fear & Greed API returned invalid JSON") + return "" + + if not data: + return "" + + lines = [f"CNN Fear & Greed Index (last {days} days, most recent first):\n"] + for entry in data: + ts = int(entry.get("timestamp", 0)) + date = datetime.fromtimestamp(ts, tz=timezone.utc).strftime("%Y-%m-%d") + score = entry.get("value", "?") + label = entry.get("value_classification", "?") + lines.append(f"{date} | Score: {score}/100 | {label}") + + return "\n".join(lines) diff --git a/tradingagents/dataflows/interface.py b/tradingagents/dataflows/interface.py index 0caf4b68..5d5b130f 100644 --- a/tradingagents/dataflows/interface.py +++ b/tradingagents/dataflows/interface.py @@ -23,6 +23,8 @@ from .alpha_vantage import ( get_global_news as get_alpha_vantage_global_news, ) from .alpha_vantage_common import AlphaVantageRateLimitError +from .reddit_sentiment import get_reddit_sentiment as get_reddit_sentiment_impl +from .fear_greed import get_fear_greed as get_fear_greed_impl # Configuration and routing logic from .config import get_config @@ -57,6 +59,13 @@ TOOLS_CATEGORIES = { "get_global_news", "get_insider_transactions", ] + }, + "sentiment_data": { + "description": "Retail sentiment and market mood data", + "tools": [ + "get_reddit_sentiment", + "get_market_fear_greed", + ] } } @@ -107,6 +116,13 @@ VENDOR_METHODS = { "alpha_vantage": get_alpha_vantage_insider_transactions, "yfinance": get_yfinance_insider_transactions, }, + # sentiment_data + "get_reddit_sentiment": { + "default": get_reddit_sentiment_impl, + }, + "get_market_fear_greed": { + "default": get_fear_greed_impl, + }, } def get_category_for_method(method: str) -> str: diff --git a/tradingagents/dataflows/reddit_sentiment.py b/tradingagents/dataflows/reddit_sentiment.py new file mode 100644 index 00000000..bd86a4e1 --- /dev/null +++ b/tradingagents/dataflows/reddit_sentiment.py @@ -0,0 +1,211 @@ +# -*- coding: utf-8 -*- +"""Reddit-based retail sentiment fetching via public JSON API (no auth required).""" + +import logging +import requests +import yfinance as yf +from datetime import datetime, timedelta, timezone + +logger = logging.getLogger(__name__) + +_HEADERS = {"User-Agent": "TradingAgentsBot/0.1"} +_SUBREDDITS = ["wallstreetbets", "stocks", "options"] +_TIMEOUT = 10 +_COMMENT_PREVIEW = 200 # max chars per comment shown to LLM +_TOP_POSTS_WITH_COMMENTS = 3 # fetch comments for this many top posts only + + +def _get_company_name(ticker: str) -> str: + """Look up the short company name for a ticker via yfinance.""" + try: + info = yf.Ticker(ticker).info + return info.get("shortName") or info.get("longName") or "" + except Exception: + return "" + + +def _search_subreddit(subreddit: str, query: str) -> list: + """Fetch up to 25 posts from a subreddit matching query. Returns raw post dicts.""" + params = { + "q": query, + "restrict_sr": 1, + "sort": "relevance", + "limit": 25, + "t": "week", + } + try: + r = requests.get( + f"https://www.reddit.com/r/{subreddit}/search.json", + params=params, + headers=_HEADERS, + timeout=_TIMEOUT, + ) + except requests.RequestException as e: + logger.warning("Reddit API request failed for r/%s: %s", subreddit, e) + return [] + + if r.status_code == 429: + logger.warning("Reddit API rate limit hit (429) for r/%s", subreddit) + return [] + if not r.ok: + logger.warning("Reddit API returned HTTP %s for r/%s", r.status_code, subreddit) + return [] + + r.encoding = "utf-8" + try: + return r.json().get("data", {}).get("children", []) + except ValueError: + logger.warning("Reddit API returned invalid JSON for r/%s", subreddit) + return [] + + +def _fetch_top_comments(subreddit: str, post_id: str, limit: int = 20) -> list[str]: + """ + Fetch top-level comments for a post, sorted by score. + Returns a list of comment body strings (truncated to _COMMENT_PREVIEW chars). + """ + try: + r = requests.get( + f"https://www.reddit.com/r/{subreddit}/comments/{post_id}.json", + params={"sort": "top", "limit": limit, "depth": 1}, + headers=_HEADERS, + timeout=_TIMEOUT, + ) + except requests.RequestException as e: + logger.warning("Reddit comment fetch failed for %s/%s: %s", subreddit, post_id, e) + return [] + + if r.status_code == 429: + logger.warning("Reddit API rate limit hit (429) fetching comments for %s", post_id) + return [] + if not r.ok: + logger.warning("Reddit comment API returned HTTP %s for %s", r.status_code, post_id) + return [] + + r.encoding = "utf-8" + try: + data = r.json() + except ValueError: + return [] + + # Response is [post_listing, comment_listing] + if len(data) < 2: + return [] + + _BOT_AUTHORS = {"automoderator", "visualmod"} + + comments = [] + for item in data[1].get("data", {}).get("children", []): + cdata = item.get("data", {}) + author = cdata.get("author", "").lower() + if author in _BOT_AUTHORS: + continue + body = cdata.get("body", "") + if not body or body == "[deleted]" or body == "[removed]": + continue + # Truncate and clean whitespace + body = " ".join(body.split()) + comments.append(body[:_COMMENT_PREVIEW]) + + return comments + + +def get_reddit_sentiment(ticker: str, days: int = 3) -> str: + """ + Fetch recent Reddit posts mentioning a ticker from investing subreddits. + + Searches r/wallstreetbets, r/stocks, and r/options via Reddit's public + JSON API (no authentication required). Runs separate queries for the ticker + symbol and company name so posts using either form are captured. Only keeps + posts whose title contains the ticker or company name. Fetches top comments + for the three highest-scoring posts to capture actual retail discussion. + + Args: + ticker: Stock ticker symbol (e.g., "NVDA") + days: Number of days to look back (default 3) + + Returns: + Formatted string of matching posts with top comments, or empty string + on API failure. + """ + cutoff = datetime.now(tz=timezone.utc) - timedelta(days=days) + + company_name = _get_company_name(ticker) + # Use the first meaningful word of the company name as a separate search term + # e.g. "NVIDIA Corporation" → "NVIDIA", "Apple Inc." → "Apple" + name_keyword = "" + if company_name: + first_word = company_name.split()[0] + if len(first_word) > 3 and first_word.upper() != ticker.upper(): + name_keyword = first_word + + search_terms = [ticker] + if name_keyword: + search_terms.append(name_keyword) + + seen_ids = set() + all_posts = [] + + for subreddit in _SUBREDDITS: + for term in search_terms: + for item in _search_subreddit(subreddit, term): + post = item.get("data", {}) + post_id = post.get("id", "") + + if post_id in seen_ids: + continue + + # Only keep posts whose title mentions ticker or company name + title_lower = post.get("title", "").lower() + if ticker.lower() not in title_lower and ( + not name_keyword or name_keyword.lower() not in title_lower + ): + continue + + created_utc = post.get("created_utc", 0) + post_time = datetime.fromtimestamp(created_utc, tz=timezone.utc) + if post_time < cutoff: + continue + + seen_ids.add(post_id) + all_posts.append({ + "id": post_id, + "subreddit": subreddit, + "title": post.get("title", ""), + "score": post.get("score", 0), + "num_comments": post.get("num_comments", 0), + "upvote_ratio": post.get("upvote_ratio", 0.0), + "flair": post.get("link_flair_text") or "", + }) + + if not all_posts: + return ( + f"No Reddit posts found mentioning {ticker} in the last {days} days " + f"across r/wallstreetbets, r/stocks, r/options." + ) + + # Sort by score descending — highest engagement first + all_posts.sort(key=lambda p: p["score"], reverse=True) + + # Fetch comments for top N posts only to keep API calls bounded + for post in all_posts[:_TOP_POSTS_WITH_COMMENTS]: + post["comments"] = _fetch_top_comments(post["subreddit"], post["id"]) + + label = f"{ticker}" + (f" / {name_keyword}" if name_keyword else "") + lines = [f"Reddit posts mentioning {label} (last {days} days, sorted by upvotes):\n"] + + for i, p in enumerate(all_posts): + flair = f" [{p['flair']}]" if p["flair"] else "" + lines.append( + f"r/{p['subreddit']}{flair} | " + f"Score: {p['score']} | " + f"Comments: {p['num_comments']} | " + f"Upvote ratio: {p['upvote_ratio']:.0%} | " + f"{p['title']}" + ) + comments = p.get("comments", []) + if comments: + for c in comments: + lines.append(f" > {c}") + + return "\n".join(lines) diff --git a/tradingagents/default_config.py b/tradingagents/default_config.py index a9b75e4b..fa2ae046 100644 --- a/tradingagents/default_config.py +++ b/tradingagents/default_config.py @@ -29,6 +29,7 @@ DEFAULT_CONFIG = { "technical_indicators": "yfinance", # Options: alpha_vantage, yfinance "fundamental_data": "yfinance", # Options: alpha_vantage, yfinance "news_data": "yfinance", # Options: alpha_vantage, yfinance + "sentiment_data": "default", # Reddit, Fear&Greed, Discord UW (no vendor alt) }, # Tool-level configuration (takes precedence over category-level) "tool_vendors": { diff --git a/tradingagents/graph/setup.py b/tradingagents/graph/setup.py index ae90489c..061e5021 100644 --- a/tradingagents/graph/setup.py +++ b/tradingagents/graph/setup.py @@ -104,14 +104,21 @@ class GraphSetup: self.deep_thinking_llm, self.portfolio_manager_memory ) + # Maps analyst selector keys to their display names in the graph. + # Add entries here when a key and its label should differ. + _labels = {"social": "Sentiment"} + + def _label(analyst_type: str) -> str: + return _labels.get(analyst_type, analyst_type.capitalize()) + # Create workflow workflow = StateGraph(AgentState) # Add analyst nodes to the graph for analyst_type, node in analyst_nodes.items(): - workflow.add_node(f"{analyst_type.capitalize()} Analyst", node) + workflow.add_node(f"{_label(analyst_type)} Analyst", node) workflow.add_node( - f"Msg Clear {analyst_type.capitalize()}", delete_nodes[analyst_type] + f"Msg Clear {_label(analyst_type)}", delete_nodes[analyst_type] ) workflow.add_node(f"tools_{analyst_type}", tool_nodes[analyst_type]) @@ -128,13 +135,13 @@ class GraphSetup: # Define edges # Start with the first analyst first_analyst = selected_analysts[0] - workflow.add_edge(START, f"{first_analyst.capitalize()} Analyst") + workflow.add_edge(START, f"{_label(first_analyst)} Analyst") # Connect analysts in sequence for i, analyst_type in enumerate(selected_analysts): - current_analyst = f"{analyst_type.capitalize()} Analyst" + current_analyst = f"{_label(analyst_type)} Analyst" current_tools = f"tools_{analyst_type}" - current_clear = f"Msg Clear {analyst_type.capitalize()}" + current_clear = f"Msg Clear {_label(analyst_type)}" # Add conditional edges for current analyst workflow.add_conditional_edges( @@ -146,7 +153,7 @@ class GraphSetup: # Connect to next analyst or to Bull Researcher if this is the last analyst if i < len(selected_analysts) - 1: - next_analyst = f"{selected_analysts[i+1].capitalize()} Analyst" + next_analyst = f"{_label(selected_analysts[i+1])} Analyst" workflow.add_edge(current_clear, next_analyst) else: workflow.add_edge(current_clear, "Bull Researcher") diff --git a/tradingagents/graph/trading_graph.py b/tradingagents/graph/trading_graph.py index 78bc13e5..2f15b026 100644 --- a/tradingagents/graph/trading_graph.py +++ b/tradingagents/graph/trading_graph.py @@ -32,6 +32,7 @@ from tradingagents.agents.utils.agent_utils import ( get_insider_transactions, get_global_news ) +from tradingagents.agents.utils.sentiment_tools import get_reddit_sentiment, get_market_fear_greed from .conditional_logic import ConditionalLogic from .setup import GraphSetup @@ -166,8 +167,8 @@ class TradingAgentsGraph: ), "social": ToolNode( [ - # News tools for social media analysis - get_news, + get_reddit_sentiment, + get_market_fear_greed, ] ), "news": ToolNode(