diff --git a/cli/main.py b/cli/main.py index b53fefa9..1647ea01 100644 --- a/cli/main.py +++ b/cli/main.py @@ -900,8 +900,6 @@ def extract_content_string(content): """Extract string content from various message formats. Returns None if no meaningful text content is found. """ - import ast - def is_empty(val): """Check if value is empty using Python's truthiness.""" if val is None or val == "": @@ -910,10 +908,11 @@ def extract_content_string(content): s = val.strip() if not s: return True - try: - return not bool(ast.literal_eval(s)) - except (ValueError, SyntaxError): - return False # Can't parse = real text + # Check for common string representations of "empty" values + # to avoid using unsafe ast.literal_eval + if s.lower() in ("[]", "{}", "()", "none", "false", "0", "0.0", '""', "''"): + return True + return False return not bool(val) if is_empty(content): diff --git a/tests/cli/test_stats_handler.py b/tests/cli/test_stats_handler.py new file mode 100644 index 00000000..8bad63fb --- /dev/null +++ b/tests/cli/test_stats_handler.py @@ -0,0 +1,108 @@ +import threading +import pytest +from cli.stats_handler import StatsCallbackHandler +from langchain_core.outputs import LLMResult, Generation +from langchain_core.messages import AIMessage + +def test_stats_handler_initial_state(): + handler = StatsCallbackHandler() + stats = handler.get_stats() + assert stats == { + "llm_calls": 0, + "tool_calls": 0, + "tokens_in": 0, + "tokens_out": 0, + } + +def test_stats_handler_on_llm_start(): + handler = StatsCallbackHandler() + handler.on_llm_start(serialized={}, prompts=["test"]) + assert handler.llm_calls == 1 + assert handler.get_stats()["llm_calls"] == 1 + +def test_stats_handler_on_chat_model_start(): + handler = StatsCallbackHandler() + handler.on_chat_model_start(serialized={}, messages=[[]]) + assert handler.llm_calls == 1 + assert handler.get_stats()["llm_calls"] == 1 + +def test_stats_handler_on_tool_start(): + handler = StatsCallbackHandler() + handler.on_tool_start(serialized={}, input_str="test tool") + assert handler.tool_calls == 1 + assert handler.get_stats()["tool_calls"] == 1 + +def test_stats_handler_on_llm_end_with_usage(): + handler = StatsCallbackHandler() + + # Mock usage metadata + usage_metadata = {"input_tokens": 10, "output_tokens": 20} + message = AIMessage(content="test response") + message.usage_metadata = usage_metadata + generation = Generation(message=message, text="test response") + response = LLMResult(generations=[[generation]]) + + handler.on_llm_end(response) + + stats = handler.get_stats() + assert stats["tokens_in"] == 10 + assert stats["tokens_out"] == 20 + +def test_stats_handler_on_llm_end_no_usage(): + handler = StatsCallbackHandler() + + # Generation without message/usage_metadata + generation = Generation(text="test response") + response = LLMResult(generations=[[generation]]) + + handler.on_llm_end(response) + + stats = handler.get_stats() + assert stats["tokens_in"] == 0 + assert stats["tokens_out"] == 0 + +def test_stats_handler_on_llm_end_empty_generations(): + handler = StatsCallbackHandler() + response = LLMResult(generations=[[]]) + handler.on_llm_end(response) + + response_none = LLMResult(generations=[]) + # on_llm_end does try response.generations[0][0], so generations=[] will trigger IndexError which is handled. + handler.on_llm_end(response_none) + + assert handler.tokens_in == 0 + assert handler.tokens_out == 0 + +def test_stats_handler_thread_safety(): + handler = StatsCallbackHandler() + num_threads = 10 + increments_per_thread = 100 + + def worker(): + for _ in range(increments_per_thread): + handler.on_llm_start({}, []) + handler.on_tool_start({}, "") + + # Mock usage metadata for on_llm_end + usage_metadata = {"input_tokens": 1, "output_tokens": 1} + message = AIMessage(content="x") + message.usage_metadata = usage_metadata + generation = Generation(message=message, text="x") + response = LLMResult(generations=[[generation]]) + handler.on_llm_end(response) + + threads = [] + for _ in range(num_threads): + t = threading.Thread(target=worker) + threads.append(t) + t.start() + + for t in threads: + t.join() + + stats = handler.get_stats() + expected_calls = num_threads * increments_per_thread + assert stats["llm_calls"] == expected_calls + assert stats["tool_calls"] == expected_calls + assert stats["tokens_in"] == expected_calls + assert stats["tokens_out"] == expected_calls diff --git a/tests/unit/test_finnhub_scanner_utils.py b/tests/unit/test_finnhub_scanner_utils.py new file mode 100644 index 00000000..d248c5e7 --- /dev/null +++ b/tests/unit/test_finnhub_scanner_utils.py @@ -0,0 +1,35 @@ +"""Unit tests for utility functions in finnhub_scanner.py.""" + +from tradingagents.dataflows.finnhub_scanner import _safe_fmt + +def test_safe_fmt_none_returns_default_fallback(): + assert _safe_fmt(None) == "N/A" + +def test_safe_fmt_none_returns_custom_fallback(): + assert _safe_fmt(None, fallback="Missing") == "Missing" + +def test_safe_fmt_valid_float_returns_default_format(): + assert _safe_fmt(123.456) == "$123.46" + +def test_safe_fmt_valid_int_returns_default_format(): + assert _safe_fmt(100) == "$100.00" + +def test_safe_fmt_numeric_string_returns_default_format(): + assert _safe_fmt("45.678") == "$45.68" + +def test_safe_fmt_custom_format(): + assert _safe_fmt(123.456, fmt="{:.3f}") == "123.456" + +def test_safe_fmt_non_numeric_string_returns_original_string(): + # float("abc") raises ValueError, should return "abc" + assert _safe_fmt("abc") == "abc" + +def test_safe_fmt_unsupported_type_returns_str_representation(): + # float([]) raises TypeError, should return "[]" + assert _safe_fmt([]) == "[]" + +def test_safe_fmt_zero_returns_formatted_zero(): + assert _safe_fmt(0) == "$0.00" + +def test_safe_fmt_negative_number(): + assert _safe_fmt(-1.23) == "$-1.23" diff --git a/tests/unit/test_notebook_sync.py b/tests/unit/test_notebook_sync.py index 1ecfc049..aaa29ce8 100644 --- a/tests/unit/test_notebook_sync.py +++ b/tests/unit/test_notebook_sync.py @@ -61,18 +61,26 @@ def test_sync_performs_delete_then_add(mock_nlm_path): # Check list call args, kwargs = mock_run.call_args_list[0] assert "list" in args[0] + assert "--json" in args[0] + assert "--" in args[0] assert notebook_id in args[0] # Check delete call args, kwargs = mock_run.call_args_list[1] assert "delete" in args[0] + assert "-y" in args[0] + assert "--" in args[0] + assert notebook_id in args[0] assert source_id in args[0] # Check add call args, kwargs = mock_run.call_args_list[2] assert "add" in args[0] - assert "--text" in args[0] - assert content in args[0] + assert "--file" in args[0] + assert str(digest_path) in args[0] + assert "--wait" in args[0] + assert "--" in args[0] + assert notebook_id in args[0] def test_sync_adds_directly_when_none_exists(mock_nlm_path): """Should add new source directly if no existing one is found.""" diff --git a/tests/unit/test_security_notebook_sync.py b/tests/unit/test_security_notebook_sync.py new file mode 100644 index 00000000..5403ac21 --- /dev/null +++ b/tests/unit/test_security_notebook_sync.py @@ -0,0 +1,107 @@ +import json +import os +import subprocess +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest +from tradingagents.notebook_sync import sync_to_notebooklm + +@pytest.fixture +def mock_nlm_path(tmp_path): + nlm = tmp_path / "nlm" + nlm.touch(mode=0o755) + return str(nlm) + +def test_security_argument_injection(mock_nlm_path, tmp_path): + """ + Test that positional arguments starting with a hyphen are handled safely + and that content is passed via file to avoid ARG_MAX issues and injection. + """ + # Malicious notebook_id that looks like a flag + notebook_id = "--some-flag" + digest_path = tmp_path / "malicious.md" + digest_path.write_text("Some content") + date = "2026-03-19" + + with patch.dict(os.environ, {"NOTEBOOKLM_ID": notebook_id}): + with patch("shutil.which", return_value=mock_nlm_path): + with patch("subprocess.run") as mock_run: + # Mock 'source list' + list_result = MagicMock() + list_result.returncode = 0 + list_result.stdout = "[]" + + # Mock 'source add' + add_result = MagicMock() + add_result.returncode = 0 + + mock_run.side_effect = [list_result, add_result] + + sync_to_notebooklm(digest_path, date) + + # 1. Check 'source list' call + # Expected: [nlm, "source", "list", "--json", "--", notebook_id] + list_args = mock_run.call_args_list[0][0][0] + assert list_args[0] == mock_nlm_path + assert list_args[1:3] == ["source", "list"] + assert "--json" in list_args + assert "--" in list_args + # "--" should be before the notebook_id + dash_idx = list_args.index("--") + id_idx = list_args.index(notebook_id) + assert dash_idx < id_idx + + # 2. Check 'source add' call + # Expected: [nlm, "source", "add", "--title", title, "--file", str(digest_path), "--wait", "--", notebook_id] + add_args = mock_run.call_args_list[1][0][0] + assert add_args[0] == mock_nlm_path + assert add_args[1:3] == ["source", "add"] + assert "--title" in add_args + assert "--file" in add_args + assert str(digest_path) in add_args + assert "--text" not in add_args # Vulnerable --text should be gone + assert "--wait" in add_args + assert "--" in add_args + + dash_idx = add_args.index("--") + id_idx = add_args.index(notebook_id) + assert dash_idx < id_idx + +def test_security_delete_injection(mock_nlm_path): + """Test that source_id in delete is also handled safely with --.""" + notebook_id = "normal-id" + source_id = "--delete-everything" + + with patch.dict(os.environ, {"NOTEBOOKLM_ID": notebook_id}): + with patch("shutil.which", return_value=mock_nlm_path): + with patch("subprocess.run") as mock_run: + # Mock 'source list' finding the malicious source_id + list_result = MagicMock() + list_result.returncode = 0 + list_result.stdout = json.dumps([{"id": source_id, "title": "Daily Trading Digest (2026-03-19)"}]) + + # Mock 'source delete' + delete_result = MagicMock() + delete_result.returncode = 0 + + # Mock 'source add' + add_result = MagicMock() + add_result.returncode = 0 + + mock_run.side_effect = [list_result, delete_result, add_result] + + sync_to_notebooklm(Path("test.md"), "2026-03-19") + + # Check 'source delete' call + # Expected: [nlm, "source", "delete", "-y", "--", notebook_id, source_id] + delete_args = mock_run.call_args_list[1][0][0] + assert delete_args[1:3] == ["source", "delete"] + assert "-y" in delete_args + assert "--" in delete_args + + dash_idx = delete_args.index("--") + id_idx = delete_args.index(notebook_id) + sid_idx = delete_args.index(source_id) + assert dash_idx < id_idx + assert dash_idx < sid_idx diff --git a/tradingagents/agents/analysts/market_analyst.py b/tradingagents/agents/analysts/market_analyst.py index 65abd9ea..31c90093 100644 --- a/tradingagents/agents/analysts/market_analyst.py +++ b/tradingagents/agents/analysts/market_analyst.py @@ -1,6 +1,5 @@ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder import time -import json from tradingagents.agents.utils.core_stock_tools import get_stock_data from tradingagents.agents.utils.technical_indicators_tools import get_indicators from tradingagents.agents.utils.fundamental_data_tools import get_macro_regime diff --git a/tradingagents/dataflows/y_finance.py b/tradingagents/dataflows/y_finance.py index 2c393659..a49cd12d 100644 --- a/tradingagents/dataflows/y_finance.py +++ b/tradingagents/dataflows/y_finance.py @@ -217,18 +217,10 @@ def _get_stock_stats_bulk( df[indicator] # This triggers stockstats to calculate the indicator # Create a dictionary mapping date strings to indicator values - result_dict = {} - date_index_strs = df.index.strftime("%Y-%m-%d") - for date_str, (_, row) in zip(date_index_strs, df.iterrows()): - indicator_value = row[indicator] - - # Handle NaN/None values - if pd.isna(indicator_value): - result_dict[date_str] = "N/A" - else: - result_dict[date_str] = str(indicator_value) - - return result_dict + # Optimized: vectorized operations for performance using correct DatetimeIndex + series = df[indicator].copy() + series.index = series.index.strftime("%Y-%m-%d") + return series.fillna("N/A").astype(str).to_dict() diff --git a/tradingagents/notebook_sync.py b/tradingagents/notebook_sync.py index 7610d5af..77674c19 100644 --- a/tradingagents/notebook_sync.py +++ b/tradingagents/notebook_sync.py @@ -51,7 +51,6 @@ def sync_to_notebooklm(digest_path: Path, date: str, notebook_id: str | None = N console.print("[yellow]Warning: nlm CLI not found — skipping NotebookLM sync[/yellow]") return - content = digest_path.read_text() title = f"Daily Trading Digest ({date})" # Find and delete existing source with the same title @@ -60,14 +59,15 @@ def sync_to_notebooklm(digest_path: Path, date: str, notebook_id: str | None = N _delete_source(nlm, notebook_id, existing_source_id) # Add as a new source - _add_source(nlm, notebook_id, content, title) + _add_source(nlm, notebook_id, digest_path, title) def _find_source(nlm: str, notebook_id: str, title: str) -> str | None: """Return the source ID for the daily digest, or None if not found.""" try: + # Use -- to separate options from positional arguments result = subprocess.run( - [nlm, "source", "list", notebook_id, "--json"], + [nlm, "source", "list", "--json", "--", notebook_id], capture_output=True, text=True, ) @@ -85,8 +85,9 @@ def _find_source(nlm: str, notebook_id: str, title: str) -> str | None: def _delete_source(nlm: str, notebook_id: str, source_id: str) -> None: """Delete an existing source.""" try: + # Use -- to separate options from positional arguments subprocess.run( - [nlm, "source", "delete", notebook_id, source_id, "-y"], + [nlm, "source", "delete", "-y", "--", notebook_id, source_id], capture_output=True, text=True, check=False, # Ignore non-zero exit since nlm sometimes fails even on success @@ -95,11 +96,13 @@ def _delete_source(nlm: str, notebook_id: str, source_id: str) -> None: pass -def _add_source(nlm: str, notebook_id: str, content: str, title: str) -> None: +def _add_source(nlm: str, notebook_id: str, digest_path: Path, title: str) -> None: """Add content as a new source.""" try: + # Use --file instead of --text to avoid ARG_MAX issues and shell injection. + # Use -- to separate options from positional arguments. result = subprocess.run( - [nlm, "source", "add", notebook_id, "--title", title, "--text", content, "--wait"], + [nlm, "source", "add", "--title", title, "--file", str(digest_path), "--wait", "--", notebook_id], capture_output=True, text=True, )