merge: resolve conflicts with origin/main (PR #85 merged)
- cli/main.py: keep module-level rich.progress imports + result.elapsed_seconds (our review fixes); take main's extract_content_string (no ast.literal_eval) - y_finance.py: take main's vectorized _get_stock_stats_bulk (better perf); keep our logger.warning() fix in the fallback path - macro_bridge.py: keep our elapsed_seconds assignments (2 paths) Co-authored-by: aguzererler <6199053+aguzererler@users.noreply.github.com> Agent-Logs-Url: https://github.com/aguzererler/TradingAgents/sessions/6e4151b2-17e3-473b-bf24-872a2656cd3f
This commit is contained in:
parent
9ff531f293
commit
a8b909e2ca
11
cli/main.py
11
cli/main.py
|
|
@ -900,8 +900,6 @@ def extract_content_string(content):
|
||||||
"""Extract string content from various message formats.
|
"""Extract string content from various message formats.
|
||||||
Returns None if no meaningful text content is found.
|
Returns None if no meaningful text content is found.
|
||||||
"""
|
"""
|
||||||
import ast
|
|
||||||
|
|
||||||
def is_empty(val):
|
def is_empty(val):
|
||||||
"""Check if value is empty using Python's truthiness."""
|
"""Check if value is empty using Python's truthiness."""
|
||||||
if val is None or val == "":
|
if val is None or val == "":
|
||||||
|
|
@ -910,10 +908,11 @@ def extract_content_string(content):
|
||||||
s = val.strip()
|
s = val.strip()
|
||||||
if not s:
|
if not s:
|
||||||
return True
|
return True
|
||||||
try:
|
# Check for common string representations of "empty" values
|
||||||
return not bool(ast.literal_eval(s))
|
# to avoid using unsafe ast.literal_eval
|
||||||
except (ValueError, SyntaxError):
|
if s.lower() in ("[]", "{}", "()", "none", "false", "0", "0.0", '""', "''"):
|
||||||
return False # Can't parse = real text
|
return True
|
||||||
|
return False
|
||||||
return not bool(val)
|
return not bool(val)
|
||||||
|
|
||||||
if is_empty(content):
|
if is_empty(content):
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,108 @@
|
||||||
|
import threading
|
||||||
|
import pytest
|
||||||
|
from cli.stats_handler import StatsCallbackHandler
|
||||||
|
from langchain_core.outputs import LLMResult, Generation
|
||||||
|
from langchain_core.messages import AIMessage
|
||||||
|
|
||||||
|
def test_stats_handler_initial_state():
|
||||||
|
handler = StatsCallbackHandler()
|
||||||
|
stats = handler.get_stats()
|
||||||
|
assert stats == {
|
||||||
|
"llm_calls": 0,
|
||||||
|
"tool_calls": 0,
|
||||||
|
"tokens_in": 0,
|
||||||
|
"tokens_out": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_stats_handler_on_llm_start():
|
||||||
|
handler = StatsCallbackHandler()
|
||||||
|
handler.on_llm_start(serialized={}, prompts=["test"])
|
||||||
|
assert handler.llm_calls == 1
|
||||||
|
assert handler.get_stats()["llm_calls"] == 1
|
||||||
|
|
||||||
|
def test_stats_handler_on_chat_model_start():
|
||||||
|
handler = StatsCallbackHandler()
|
||||||
|
handler.on_chat_model_start(serialized={}, messages=[[]])
|
||||||
|
assert handler.llm_calls == 1
|
||||||
|
assert handler.get_stats()["llm_calls"] == 1
|
||||||
|
|
||||||
|
def test_stats_handler_on_tool_start():
|
||||||
|
handler = StatsCallbackHandler()
|
||||||
|
handler.on_tool_start(serialized={}, input_str="test tool")
|
||||||
|
assert handler.tool_calls == 1
|
||||||
|
assert handler.get_stats()["tool_calls"] == 1
|
||||||
|
|
||||||
|
def test_stats_handler_on_llm_end_with_usage():
|
||||||
|
handler = StatsCallbackHandler()
|
||||||
|
|
||||||
|
# Mock usage metadata
|
||||||
|
usage_metadata = {"input_tokens": 10, "output_tokens": 20}
|
||||||
|
message = AIMessage(content="test response")
|
||||||
|
message.usage_metadata = usage_metadata
|
||||||
|
generation = Generation(message=message, text="test response")
|
||||||
|
response = LLMResult(generations=[[generation]])
|
||||||
|
|
||||||
|
handler.on_llm_end(response)
|
||||||
|
|
||||||
|
stats = handler.get_stats()
|
||||||
|
assert stats["tokens_in"] == 10
|
||||||
|
assert stats["tokens_out"] == 20
|
||||||
|
|
||||||
|
def test_stats_handler_on_llm_end_no_usage():
|
||||||
|
handler = StatsCallbackHandler()
|
||||||
|
|
||||||
|
# Generation without message/usage_metadata
|
||||||
|
generation = Generation(text="test response")
|
||||||
|
response = LLMResult(generations=[[generation]])
|
||||||
|
|
||||||
|
handler.on_llm_end(response)
|
||||||
|
|
||||||
|
stats = handler.get_stats()
|
||||||
|
assert stats["tokens_in"] == 0
|
||||||
|
assert stats["tokens_out"] == 0
|
||||||
|
|
||||||
|
def test_stats_handler_on_llm_end_empty_generations():
|
||||||
|
handler = StatsCallbackHandler()
|
||||||
|
response = LLMResult(generations=[[]])
|
||||||
|
handler.on_llm_end(response)
|
||||||
|
|
||||||
|
response_none = LLMResult(generations=[])
|
||||||
|
# on_llm_end does try response.generations[0][0], so generations=[] will trigger IndexError which is handled.
|
||||||
|
handler.on_llm_end(response_none)
|
||||||
|
|
||||||
|
assert handler.tokens_in == 0
|
||||||
|
assert handler.tokens_out == 0
|
||||||
|
|
||||||
|
def test_stats_handler_thread_safety():
|
||||||
|
handler = StatsCallbackHandler()
|
||||||
|
num_threads = 10
|
||||||
|
increments_per_thread = 100
|
||||||
|
|
||||||
|
def worker():
|
||||||
|
for _ in range(increments_per_thread):
|
||||||
|
handler.on_llm_start({}, [])
|
||||||
|
handler.on_tool_start({}, "")
|
||||||
|
|
||||||
|
# Mock usage metadata for on_llm_end
|
||||||
|
usage_metadata = {"input_tokens": 1, "output_tokens": 1}
|
||||||
|
message = AIMessage(content="x")
|
||||||
|
message.usage_metadata = usage_metadata
|
||||||
|
generation = Generation(message=message, text="x")
|
||||||
|
response = LLMResult(generations=[[generation]])
|
||||||
|
handler.on_llm_end(response)
|
||||||
|
|
||||||
|
threads = []
|
||||||
|
for _ in range(num_threads):
|
||||||
|
t = threading.Thread(target=worker)
|
||||||
|
threads.append(t)
|
||||||
|
t.start()
|
||||||
|
|
||||||
|
for t in threads:
|
||||||
|
t.join()
|
||||||
|
|
||||||
|
stats = handler.get_stats()
|
||||||
|
expected_calls = num_threads * increments_per_thread
|
||||||
|
assert stats["llm_calls"] == expected_calls
|
||||||
|
assert stats["tool_calls"] == expected_calls
|
||||||
|
assert stats["tokens_in"] == expected_calls
|
||||||
|
assert stats["tokens_out"] == expected_calls
|
||||||
|
|
@ -0,0 +1,35 @@
|
||||||
|
"""Unit tests for utility functions in finnhub_scanner.py."""
|
||||||
|
|
||||||
|
from tradingagents.dataflows.finnhub_scanner import _safe_fmt
|
||||||
|
|
||||||
|
def test_safe_fmt_none_returns_default_fallback():
|
||||||
|
assert _safe_fmt(None) == "N/A"
|
||||||
|
|
||||||
|
def test_safe_fmt_none_returns_custom_fallback():
|
||||||
|
assert _safe_fmt(None, fallback="Missing") == "Missing"
|
||||||
|
|
||||||
|
def test_safe_fmt_valid_float_returns_default_format():
|
||||||
|
assert _safe_fmt(123.456) == "$123.46"
|
||||||
|
|
||||||
|
def test_safe_fmt_valid_int_returns_default_format():
|
||||||
|
assert _safe_fmt(100) == "$100.00"
|
||||||
|
|
||||||
|
def test_safe_fmt_numeric_string_returns_default_format():
|
||||||
|
assert _safe_fmt("45.678") == "$45.68"
|
||||||
|
|
||||||
|
def test_safe_fmt_custom_format():
|
||||||
|
assert _safe_fmt(123.456, fmt="{:.3f}") == "123.456"
|
||||||
|
|
||||||
|
def test_safe_fmt_non_numeric_string_returns_original_string():
|
||||||
|
# float("abc") raises ValueError, should return "abc"
|
||||||
|
assert _safe_fmt("abc") == "abc"
|
||||||
|
|
||||||
|
def test_safe_fmt_unsupported_type_returns_str_representation():
|
||||||
|
# float([]) raises TypeError, should return "[]"
|
||||||
|
assert _safe_fmt([]) == "[]"
|
||||||
|
|
||||||
|
def test_safe_fmt_zero_returns_formatted_zero():
|
||||||
|
assert _safe_fmt(0) == "$0.00"
|
||||||
|
|
||||||
|
def test_safe_fmt_negative_number():
|
||||||
|
assert _safe_fmt(-1.23) == "$-1.23"
|
||||||
|
|
@ -61,18 +61,26 @@ def test_sync_performs_delete_then_add(mock_nlm_path):
|
||||||
# Check list call
|
# Check list call
|
||||||
args, kwargs = mock_run.call_args_list[0]
|
args, kwargs = mock_run.call_args_list[0]
|
||||||
assert "list" in args[0]
|
assert "list" in args[0]
|
||||||
|
assert "--json" in args[0]
|
||||||
|
assert "--" in args[0]
|
||||||
assert notebook_id in args[0]
|
assert notebook_id in args[0]
|
||||||
|
|
||||||
# Check delete call
|
# Check delete call
|
||||||
args, kwargs = mock_run.call_args_list[1]
|
args, kwargs = mock_run.call_args_list[1]
|
||||||
assert "delete" in args[0]
|
assert "delete" in args[0]
|
||||||
|
assert "-y" in args[0]
|
||||||
|
assert "--" in args[0]
|
||||||
|
assert notebook_id in args[0]
|
||||||
assert source_id in args[0]
|
assert source_id in args[0]
|
||||||
|
|
||||||
# Check add call
|
# Check add call
|
||||||
args, kwargs = mock_run.call_args_list[2]
|
args, kwargs = mock_run.call_args_list[2]
|
||||||
assert "add" in args[0]
|
assert "add" in args[0]
|
||||||
assert "--text" in args[0]
|
assert "--file" in args[0]
|
||||||
assert content in args[0]
|
assert str(digest_path) in args[0]
|
||||||
|
assert "--wait" in args[0]
|
||||||
|
assert "--" in args[0]
|
||||||
|
assert notebook_id in args[0]
|
||||||
|
|
||||||
def test_sync_adds_directly_when_none_exists(mock_nlm_path):
|
def test_sync_adds_directly_when_none_exists(mock_nlm_path):
|
||||||
"""Should add new source directly if no existing one is found."""
|
"""Should add new source directly if no existing one is found."""
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,107 @@
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from tradingagents.notebook_sync import sync_to_notebooklm
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_nlm_path(tmp_path):
|
||||||
|
nlm = tmp_path / "nlm"
|
||||||
|
nlm.touch(mode=0o755)
|
||||||
|
return str(nlm)
|
||||||
|
|
||||||
|
def test_security_argument_injection(mock_nlm_path, tmp_path):
|
||||||
|
"""
|
||||||
|
Test that positional arguments starting with a hyphen are handled safely
|
||||||
|
and that content is passed via file to avoid ARG_MAX issues and injection.
|
||||||
|
"""
|
||||||
|
# Malicious notebook_id that looks like a flag
|
||||||
|
notebook_id = "--some-flag"
|
||||||
|
digest_path = tmp_path / "malicious.md"
|
||||||
|
digest_path.write_text("Some content")
|
||||||
|
date = "2026-03-19"
|
||||||
|
|
||||||
|
with patch.dict(os.environ, {"NOTEBOOKLM_ID": notebook_id}):
|
||||||
|
with patch("shutil.which", return_value=mock_nlm_path):
|
||||||
|
with patch("subprocess.run") as mock_run:
|
||||||
|
# Mock 'source list'
|
||||||
|
list_result = MagicMock()
|
||||||
|
list_result.returncode = 0
|
||||||
|
list_result.stdout = "[]"
|
||||||
|
|
||||||
|
# Mock 'source add'
|
||||||
|
add_result = MagicMock()
|
||||||
|
add_result.returncode = 0
|
||||||
|
|
||||||
|
mock_run.side_effect = [list_result, add_result]
|
||||||
|
|
||||||
|
sync_to_notebooklm(digest_path, date)
|
||||||
|
|
||||||
|
# 1. Check 'source list' call
|
||||||
|
# Expected: [nlm, "source", "list", "--json", "--", notebook_id]
|
||||||
|
list_args = mock_run.call_args_list[0][0][0]
|
||||||
|
assert list_args[0] == mock_nlm_path
|
||||||
|
assert list_args[1:3] == ["source", "list"]
|
||||||
|
assert "--json" in list_args
|
||||||
|
assert "--" in list_args
|
||||||
|
# "--" should be before the notebook_id
|
||||||
|
dash_idx = list_args.index("--")
|
||||||
|
id_idx = list_args.index(notebook_id)
|
||||||
|
assert dash_idx < id_idx
|
||||||
|
|
||||||
|
# 2. Check 'source add' call
|
||||||
|
# Expected: [nlm, "source", "add", "--title", title, "--file", str(digest_path), "--wait", "--", notebook_id]
|
||||||
|
add_args = mock_run.call_args_list[1][0][0]
|
||||||
|
assert add_args[0] == mock_nlm_path
|
||||||
|
assert add_args[1:3] == ["source", "add"]
|
||||||
|
assert "--title" in add_args
|
||||||
|
assert "--file" in add_args
|
||||||
|
assert str(digest_path) in add_args
|
||||||
|
assert "--text" not in add_args # Vulnerable --text should be gone
|
||||||
|
assert "--wait" in add_args
|
||||||
|
assert "--" in add_args
|
||||||
|
|
||||||
|
dash_idx = add_args.index("--")
|
||||||
|
id_idx = add_args.index(notebook_id)
|
||||||
|
assert dash_idx < id_idx
|
||||||
|
|
||||||
|
def test_security_delete_injection(mock_nlm_path):
|
||||||
|
"""Test that source_id in delete is also handled safely with --."""
|
||||||
|
notebook_id = "normal-id"
|
||||||
|
source_id = "--delete-everything"
|
||||||
|
|
||||||
|
with patch.dict(os.environ, {"NOTEBOOKLM_ID": notebook_id}):
|
||||||
|
with patch("shutil.which", return_value=mock_nlm_path):
|
||||||
|
with patch("subprocess.run") as mock_run:
|
||||||
|
# Mock 'source list' finding the malicious source_id
|
||||||
|
list_result = MagicMock()
|
||||||
|
list_result.returncode = 0
|
||||||
|
list_result.stdout = json.dumps([{"id": source_id, "title": "Daily Trading Digest (2026-03-19)"}])
|
||||||
|
|
||||||
|
# Mock 'source delete'
|
||||||
|
delete_result = MagicMock()
|
||||||
|
delete_result.returncode = 0
|
||||||
|
|
||||||
|
# Mock 'source add'
|
||||||
|
add_result = MagicMock()
|
||||||
|
add_result.returncode = 0
|
||||||
|
|
||||||
|
mock_run.side_effect = [list_result, delete_result, add_result]
|
||||||
|
|
||||||
|
sync_to_notebooklm(Path("test.md"), "2026-03-19")
|
||||||
|
|
||||||
|
# Check 'source delete' call
|
||||||
|
# Expected: [nlm, "source", "delete", "-y", "--", notebook_id, source_id]
|
||||||
|
delete_args = mock_run.call_args_list[1][0][0]
|
||||||
|
assert delete_args[1:3] == ["source", "delete"]
|
||||||
|
assert "-y" in delete_args
|
||||||
|
assert "--" in delete_args
|
||||||
|
|
||||||
|
dash_idx = delete_args.index("--")
|
||||||
|
id_idx = delete_args.index(notebook_id)
|
||||||
|
sid_idx = delete_args.index(source_id)
|
||||||
|
assert dash_idx < id_idx
|
||||||
|
assert dash_idx < sid_idx
|
||||||
|
|
@ -1,6 +1,5 @@
|
||||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||||
import time
|
import time
|
||||||
import json
|
|
||||||
from tradingagents.agents.utils.core_stock_tools import get_stock_data
|
from tradingagents.agents.utils.core_stock_tools import get_stock_data
|
||||||
from tradingagents.agents.utils.technical_indicators_tools import get_indicators
|
from tradingagents.agents.utils.technical_indicators_tools import get_indicators
|
||||||
from tradingagents.agents.utils.fundamental_data_tools import get_macro_regime
|
from tradingagents.agents.utils.fundamental_data_tools import get_macro_regime
|
||||||
|
|
|
||||||
|
|
@ -217,18 +217,10 @@ def _get_stock_stats_bulk(
|
||||||
df[indicator] # This triggers stockstats to calculate the indicator
|
df[indicator] # This triggers stockstats to calculate the indicator
|
||||||
|
|
||||||
# Create a dictionary mapping date strings to indicator values
|
# Create a dictionary mapping date strings to indicator values
|
||||||
result_dict = {}
|
# Optimized: vectorized operations for performance using correct DatetimeIndex
|
||||||
date_index_strs = df.index.strftime("%Y-%m-%d")
|
series = df[indicator].copy()
|
||||||
for date_str, (_, row) in zip(date_index_strs, df.iterrows()):
|
series.index = series.index.strftime("%Y-%m-%d")
|
||||||
indicator_value = row[indicator]
|
return series.fillna("N/A").astype(str).to_dict()
|
||||||
|
|
||||||
# Handle NaN/None values
|
|
||||||
if pd.isna(indicator_value):
|
|
||||||
result_dict[date_str] = "N/A"
|
|
||||||
else:
|
|
||||||
result_dict[date_str] = str(indicator_value)
|
|
||||||
|
|
||||||
return result_dict
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -51,7 +51,6 @@ def sync_to_notebooklm(digest_path: Path, date: str, notebook_id: str | None = N
|
||||||
console.print("[yellow]Warning: nlm CLI not found — skipping NotebookLM sync[/yellow]")
|
console.print("[yellow]Warning: nlm CLI not found — skipping NotebookLM sync[/yellow]")
|
||||||
return
|
return
|
||||||
|
|
||||||
content = digest_path.read_text()
|
|
||||||
title = f"Daily Trading Digest ({date})"
|
title = f"Daily Trading Digest ({date})"
|
||||||
|
|
||||||
# Find and delete existing source with the same title
|
# Find and delete existing source with the same title
|
||||||
|
|
@ -60,14 +59,15 @@ def sync_to_notebooklm(digest_path: Path, date: str, notebook_id: str | None = N
|
||||||
_delete_source(nlm, notebook_id, existing_source_id)
|
_delete_source(nlm, notebook_id, existing_source_id)
|
||||||
|
|
||||||
# Add as a new source
|
# Add as a new source
|
||||||
_add_source(nlm, notebook_id, content, title)
|
_add_source(nlm, notebook_id, digest_path, title)
|
||||||
|
|
||||||
|
|
||||||
def _find_source(nlm: str, notebook_id: str, title: str) -> str | None:
|
def _find_source(nlm: str, notebook_id: str, title: str) -> str | None:
|
||||||
"""Return the source ID for the daily digest, or None if not found."""
|
"""Return the source ID for the daily digest, or None if not found."""
|
||||||
try:
|
try:
|
||||||
|
# Use -- to separate options from positional arguments
|
||||||
result = subprocess.run(
|
result = subprocess.run(
|
||||||
[nlm, "source", "list", notebook_id, "--json"],
|
[nlm, "source", "list", "--json", "--", notebook_id],
|
||||||
capture_output=True,
|
capture_output=True,
|
||||||
text=True,
|
text=True,
|
||||||
)
|
)
|
||||||
|
|
@ -85,8 +85,9 @@ def _find_source(nlm: str, notebook_id: str, title: str) -> str | None:
|
||||||
def _delete_source(nlm: str, notebook_id: str, source_id: str) -> None:
|
def _delete_source(nlm: str, notebook_id: str, source_id: str) -> None:
|
||||||
"""Delete an existing source."""
|
"""Delete an existing source."""
|
||||||
try:
|
try:
|
||||||
|
# Use -- to separate options from positional arguments
|
||||||
subprocess.run(
|
subprocess.run(
|
||||||
[nlm, "source", "delete", notebook_id, source_id, "-y"],
|
[nlm, "source", "delete", "-y", "--", notebook_id, source_id],
|
||||||
capture_output=True,
|
capture_output=True,
|
||||||
text=True,
|
text=True,
|
||||||
check=False, # Ignore non-zero exit since nlm sometimes fails even on success
|
check=False, # Ignore non-zero exit since nlm sometimes fails even on success
|
||||||
|
|
@ -95,11 +96,13 @@ def _delete_source(nlm: str, notebook_id: str, source_id: str) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def _add_source(nlm: str, notebook_id: str, content: str, title: str) -> None:
|
def _add_source(nlm: str, notebook_id: str, digest_path: Path, title: str) -> None:
|
||||||
"""Add content as a new source."""
|
"""Add content as a new source."""
|
||||||
try:
|
try:
|
||||||
|
# Use --file instead of --text to avoid ARG_MAX issues and shell injection.
|
||||||
|
# Use -- to separate options from positional arguments.
|
||||||
result = subprocess.run(
|
result = subprocess.run(
|
||||||
[nlm, "source", "add", notebook_id, "--title", title, "--text", content, "--wait"],
|
[nlm, "source", "add", "--title", title, "--file", str(digest_path), "--wait", "--", notebook_id],
|
||||||
capture_output=True,
|
capture_output=True,
|
||||||
text=True,
|
text=True,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue