From 80e8fb0eee78eb11d9ebb47bd29bf2e1bf2df50b Mon Sep 17 00:00:00 2001 From: Joseph O'Brien <98370624+89jobrien@users.noreply.github.com> Date: Wed, 3 Dec 2025 02:11:29 -0500 Subject: [PATCH] feat: add centralized logging module with dual output support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add tradingagents/logging.py with JSONFormatter for structured file logs and RichHandler for colored console output. Migrate 59 print statements across 13 files to proper logger calls. Configure via environment variables (TRADINGAGENTS_LOG_LEVEL, etc.) with default_config fallback. - Rotating file handler: 10MB max, 5 backups to logs/tradingagents.log - Rich console handler with tracebacks for development - Logger hierarchy follows module paths (tradingagents.dataflows.*, etc.) - 20 feature-specific tests covering core, config, migration, integration 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- tests/test_logging.py | 169 ++++++++++++++++++ tests/test_logging_config.py | 121 +++++++++++++ tests/test_logging_integration.py | 169 ++++++++++++++++++ tests/test_logging_migration.py | 124 +++++++++++++ tradingagents/agents/utils/memory.py | 52 +----- .../dataflows/alpha_vantage_common.py | 45 +---- .../dataflows/alpha_vantage_indicator.py | 32 +--- tradingagents/dataflows/alpha_vantage_news.py | 11 +- tradingagents/dataflows/brave.py | 23 +-- tradingagents/dataflows/googlenews_utils.py | 22 +-- tradingagents/dataflows/interface.py | 43 ++--- tradingagents/dataflows/local.py | 103 ++--------- tradingagents/dataflows/tavily.py | 17 +- tradingagents/dataflows/utils.py | 5 +- tradingagents/dataflows/y_finance.py | 133 +++++--------- tradingagents/dataflows/yfin_utils.py | 26 +-- tradingagents/default_config.py | 4 + tradingagents/graph/trading_graph.py | 5 +- tradingagents/logging.py | 121 +++++++++++++ 19 files changed, 857 insertions(+), 368 deletions(-) create mode 100644 tests/test_logging.py create mode 100644 tests/test_logging_config.py create mode 100644 tests/test_logging_integration.py create mode 100644 tests/test_logging_migration.py create mode 100644 tradingagents/logging.py diff --git a/tests/test_logging.py b/tests/test_logging.py new file mode 100644 index 00000000..65e09867 --- /dev/null +++ b/tests/test_logging.py @@ -0,0 +1,169 @@ +import json +import logging +import os +import tempfile +import pytest +from unittest.mock import patch + + +class TestLoggingModule: + @pytest.fixture(autouse=True) + def setup_and_teardown(self): + for handler in logging.root.handlers[:]: + logging.root.removeHandler(handler) + + tradingagents_logger = logging.getLogger("tradingagents") + for handler in tradingagents_logger.handlers[:]: + tradingagents_logger.removeHandler(handler) + tradingagents_logger.setLevel(logging.NOTSET) + + yield + + for handler in logging.root.handlers[:]: + logging.root.removeHandler(handler) + tradingagents_logger = logging.getLogger("tradingagents") + for handler in tradingagents_logger.handlers[:]: + tradingagents_logger.removeHandler(handler) + + def test_setup_logging_initializes_handlers_based_on_env_vars(self): + with tempfile.TemporaryDirectory() as tmpdir: + env_vars = { + "TRADINGAGENTS_LOG_LEVEL": "DEBUG", + "TRADINGAGENTS_LOG_DIR": tmpdir, + "TRADINGAGENTS_LOG_CONSOLE": "false", + "TRADINGAGENTS_LOG_FILE": "true", + } + with patch.dict(os.environ, env_vars, clear=False): + import importlib + import tradingagents.logging as log_module + importlib.reload(log_module) + + root_logger = log_module.setup_logging() + + assert root_logger is not None + assert root_logger.name == "tradingagents" + assert root_logger.level == logging.DEBUG + + has_file_handler = any( + hasattr(h, "baseFilename") for h in root_logger.handlers + ) + assert has_file_handler, "File handler should be present when LOG_FILE=true" + + def test_get_logger_returns_properly_configured_logger_with_hierarchy(self): + with tempfile.TemporaryDirectory() as tmpdir: + env_vars = { + "TRADINGAGENTS_LOG_LEVEL": "INFO", + "TRADINGAGENTS_LOG_DIR": tmpdir, + "TRADINGAGENTS_LOG_CONSOLE": "false", + "TRADINGAGENTS_LOG_FILE": "true", + } + with patch.dict(os.environ, env_vars, clear=False): + import importlib + import tradingagents.logging as log_module + importlib.reload(log_module) + + log_module.setup_logging() + child_logger = log_module.get_logger("tradingagents.dataflows.interface") + + assert child_logger.name == "tradingagents.dataflows.interface" + assert child_logger.parent.name == "tradingagents.dataflows" or child_logger.parent.name == "tradingagents" + + def test_json_file_handler_writes_valid_json_with_required_fields(self): + with tempfile.TemporaryDirectory() as tmpdir: + env_vars = { + "TRADINGAGENTS_LOG_LEVEL": "DEBUG", + "TRADINGAGENTS_LOG_DIR": tmpdir, + "TRADINGAGENTS_LOG_CONSOLE": "false", + "TRADINGAGENTS_LOG_FILE": "true", + } + with patch.dict(os.environ, env_vars, clear=False): + import importlib + import tradingagents.logging as log_module + importlib.reload(log_module) + + logger = log_module.setup_logging() + logger.info("Test message for JSON validation") + + for handler in logger.handlers: + handler.flush() + + log_file_path = os.path.join(tmpdir, "tradingagents.log") + assert os.path.exists(log_file_path), f"Log file should exist at {log_file_path}" + + with open(log_file_path, "r") as f: + log_content = f.read().strip() + + assert log_content, "Log file should not be empty" + + log_entry = json.loads(log_content.split("\n")[0]) + + required_fields = ["timestamp", "level", "logger", "message", "filename", "funcName", "lineno"] + for field in required_fields: + assert field in log_entry, f"JSON log should contain '{field}' field" + + assert "T" in log_entry["timestamp"], "Timestamp should be in ISO 8601 format" + assert log_entry["level"] == "INFO" + assert log_entry["message"] == "Test message for JSON validation" + + def test_log_rotation_triggers_at_configured_file_size(self): + with tempfile.TemporaryDirectory() as tmpdir: + env_vars = { + "TRADINGAGENTS_LOG_LEVEL": "DEBUG", + "TRADINGAGENTS_LOG_DIR": tmpdir, + "TRADINGAGENTS_LOG_CONSOLE": "false", + "TRADINGAGENTS_LOG_FILE": "true", + } + with patch.dict(os.environ, env_vars, clear=False): + import importlib + import tradingagents.logging as log_module + importlib.reload(log_module) + + logger = log_module.setup_logging() + + file_handler = None + for handler in logger.handlers: + if hasattr(handler, "maxBytes"): + file_handler = handler + break + + assert file_handler is not None, "RotatingFileHandler should be configured" + assert file_handler.maxBytes == 10 * 1024 * 1024, "Max file size should be 10MB" + assert file_handler.backupCount == 5, "Backup count should be 5" + + def test_console_handler_disabled_when_env_var_false(self): + with tempfile.TemporaryDirectory() as tmpdir: + env_vars = { + "TRADINGAGENTS_LOG_LEVEL": "INFO", + "TRADINGAGENTS_LOG_DIR": tmpdir, + "TRADINGAGENTS_LOG_CONSOLE": "false", + "TRADINGAGENTS_LOG_FILE": "true", + } + with patch.dict(os.environ, env_vars, clear=False): + import importlib + import tradingagents.logging as log_module + importlib.reload(log_module) + + logger = log_module.setup_logging() + + from rich.logging import RichHandler + has_rich_handler = any(isinstance(h, RichHandler) for h in logger.handlers) + assert not has_rich_handler, "RichHandler should NOT be present when LOG_CONSOLE=false" + + def test_console_handler_enabled_when_env_var_true(self): + with tempfile.TemporaryDirectory() as tmpdir: + env_vars = { + "TRADINGAGENTS_LOG_LEVEL": "INFO", + "TRADINGAGENTS_LOG_DIR": tmpdir, + "TRADINGAGENTS_LOG_CONSOLE": "true", + "TRADINGAGENTS_LOG_FILE": "false", + } + with patch.dict(os.environ, env_vars, clear=False): + import importlib + import tradingagents.logging as log_module + importlib.reload(log_module) + + logger = log_module.setup_logging() + + from rich.logging import RichHandler + has_rich_handler = any(isinstance(h, RichHandler) for h in logger.handlers) + assert has_rich_handler, "RichHandler should be present when LOG_CONSOLE=true" diff --git a/tests/test_logging_config.py b/tests/test_logging_config.py new file mode 100644 index 00000000..79c29e61 --- /dev/null +++ b/tests/test_logging_config.py @@ -0,0 +1,121 @@ +import logging +import os +import tempfile +import pytest +from unittest.mock import patch + + +class TestLoggingConfigIntegration: + @pytest.fixture(autouse=True) + def setup_and_teardown(self): + for handler in logging.root.handlers[:]: + logging.root.removeHandler(handler) + + tradingagents_logger = logging.getLogger("tradingagents") + for handler in tradingagents_logger.handlers[:]: + tradingagents_logger.removeHandler(handler) + tradingagents_logger.setLevel(logging.NOTSET) + + yield + + for handler in logging.root.handlers[:]: + logging.root.removeHandler(handler) + tradingagents_logger = logging.getLogger("tradingagents") + for handler in tradingagents_logger.handlers[:]: + tradingagents_logger.removeHandler(handler) + + def test_default_config_values_used_when_env_vars_not_set(self): + with tempfile.TemporaryDirectory() as tmpdir: + env_vars_to_remove = [ + "TRADINGAGENTS_LOG_LEVEL", + "TRADINGAGENTS_LOG_DIR", + "TRADINGAGENTS_LOG_CONSOLE", + "TRADINGAGENTS_LOG_FILE", + ] + clean_env = {k: v for k, v in os.environ.items() if k not in env_vars_to_remove} + + with patch.dict(os.environ, clean_env, clear=True): + import importlib + import tradingagents.logging as log_module + importlib.reload(log_module) + + from tradingagents.default_config import DEFAULT_CONFIG + + expected_level = getattr(logging, DEFAULT_CONFIG.get("log_level", "INFO").upper()) + + logger = log_module.setup_logging() + + assert logger.level == expected_level + + def test_env_vars_override_default_config_values(self): + with tempfile.TemporaryDirectory() as tmpdir: + env_vars = { + "TRADINGAGENTS_LOG_LEVEL": "WARNING", + "TRADINGAGENTS_LOG_DIR": tmpdir, + "TRADINGAGENTS_LOG_CONSOLE": "false", + "TRADINGAGENTS_LOG_FILE": "true", + } + with patch.dict(os.environ, env_vars, clear=False): + import importlib + import tradingagents.logging as log_module + importlib.reload(log_module) + + logger = log_module.setup_logging() + + assert logger.level == logging.WARNING + + def test_boolean_parsing_for_log_console_and_file(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_cases = [ + ("true", True), + ("false", False), + ("1", True), + ("0", False), + ("True", True), + ("False", False), + ("TRUE", True), + ("FALSE", False), + ("yes", True), + ("no", False), + ] + + for bool_str, expected in test_cases: + env_vars = { + "TRADINGAGENTS_LOG_LEVEL": "INFO", + "TRADINGAGENTS_LOG_DIR": tmpdir, + "TRADINGAGENTS_LOG_CONSOLE": bool_str, + "TRADINGAGENTS_LOG_FILE": "false", + } + + tradingagents_logger = logging.getLogger("tradingagents") + for handler in tradingagents_logger.handlers[:]: + tradingagents_logger.removeHandler(handler) + + with patch.dict(os.environ, env_vars, clear=False): + import importlib + import tradingagents.logging as log_module + importlib.reload(log_module) + + logger = log_module.setup_logging() + + from rich.logging import RichHandler + has_rich_handler = any(isinstance(h, RichHandler) for h in logger.handlers) + + assert has_rich_handler == expected, f"TRADINGAGENTS_LOG_CONSOLE={bool_str} should result in RichHandler present={expected}" + + def test_invalid_log_level_falls_back_to_info(self): + with tempfile.TemporaryDirectory() as tmpdir: + env_vars = { + "TRADINGAGENTS_LOG_LEVEL": "INVALID_LEVEL", + "TRADINGAGENTS_LOG_DIR": tmpdir, + "TRADINGAGENTS_LOG_CONSOLE": "false", + "TRADINGAGENTS_LOG_FILE": "true", + } + with patch.dict(os.environ, env_vars, clear=False): + import importlib + import tradingagents.logging as log_module + importlib.reload(log_module) + + logger = log_module.setup_logging() + + assert logger.level == logging.INFO, "Invalid log level should fall back to INFO" diff --git a/tests/test_logging_integration.py b/tests/test_logging_integration.py new file mode 100644 index 00000000..47ea66d8 --- /dev/null +++ b/tests/test_logging_integration.py @@ -0,0 +1,169 @@ +import logging +import os +import tempfile +import pytest +from unittest.mock import patch + + +class TestLoggingIntegration: + @pytest.fixture(autouse=True) + def setup_and_teardown(self): + for handler in logging.root.handlers[:]: + logging.root.removeHandler(handler) + + tradingagents_logger = logging.getLogger("tradingagents") + for handler in tradingagents_logger.handlers[:]: + tradingagents_logger.removeHandler(handler) + tradingagents_logger.setLevel(logging.NOTSET) + + yield + + for handler in logging.root.handlers[:]: + logging.root.removeHandler(handler) + tradingagents_logger = logging.getLogger("tradingagents") + for handler in tradingagents_logger.handlers[:]: + tradingagents_logger.removeHandler(handler) + + def test_logging_initialization_from_module_import(self): + with tempfile.TemporaryDirectory() as tmpdir: + env_vars = { + "TRADINGAGENTS_LOG_LEVEL": "INFO", + "TRADINGAGENTS_LOG_DIR": tmpdir, + "TRADINGAGENTS_LOG_CONSOLE": "false", + "TRADINGAGENTS_LOG_FILE": "true", + } + with patch.dict(os.environ, env_vars, clear=False): + import importlib + import tradingagents.logging as log_module + importlib.reload(log_module) + + log_module.setup_logging() + + interface_logger = log_module.get_logger("tradingagents.dataflows.interface") + + assert interface_logger is not None + assert interface_logger.name == "tradingagents.dataflows.interface" + + interface_logger.info("Test message from interface logger") + + log_file = os.path.join(tmpdir, "tradingagents.log") + assert os.path.exists(log_file) + + with open(log_file, "r") as f: + content = f.read() + assert "Test message from interface logger" in content + + def test_rich_handler_does_not_break_cli_live_displays(self): + with tempfile.TemporaryDirectory() as tmpdir: + env_vars = { + "TRADINGAGENTS_LOG_LEVEL": "INFO", + "TRADINGAGENTS_LOG_DIR": tmpdir, + "TRADINGAGENTS_LOG_CONSOLE": "true", + "TRADINGAGENTS_LOG_FILE": "false", + } + with patch.dict(os.environ, env_vars, clear=False): + import importlib + import tradingagents.logging as log_module + importlib.reload(log_module) + + logger = log_module.setup_logging() + + from rich.logging import RichHandler + rich_handlers = [h for h in logger.handlers if isinstance(h, RichHandler)] + assert len(rich_handlers) == 1 + + rich_handler = rich_handlers[0] + assert rich_handler.console is not None + assert rich_handler.console.file is not None + + def test_log_file_creation_and_format(self): + with tempfile.TemporaryDirectory() as tmpdir: + env_vars = { + "TRADINGAGENTS_LOG_LEVEL": "DEBUG", + "TRADINGAGENTS_LOG_DIR": tmpdir, + "TRADINGAGENTS_LOG_CONSOLE": "false", + "TRADINGAGENTS_LOG_FILE": "true", + } + with patch.dict(os.environ, env_vars, clear=False): + import importlib + import json + import tradingagents.logging as log_module + importlib.reload(log_module) + + logger = log_module.setup_logging() + + logger.debug("Debug message") + logger.info("Info message") + logger.warning("Warning message") + logger.error("Error message") + + for handler in logger.handlers: + handler.flush() + + log_file = os.path.join(tmpdir, "tradingagents.log") + assert os.path.exists(log_file) + + with open(log_file, "r") as f: + lines = f.readlines() + + assert len(lines) >= 4 + + for line in lines: + log_entry = json.loads(line) + assert "timestamp" in log_entry + assert "level" in log_entry + assert "logger" in log_entry + assert "message" in log_entry + + def test_logger_hierarchy_inherits_parent_configuration(self): + with tempfile.TemporaryDirectory() as tmpdir: + env_vars = { + "TRADINGAGENTS_LOG_LEVEL": "WARNING", + "TRADINGAGENTS_LOG_DIR": tmpdir, + "TRADINGAGENTS_LOG_CONSOLE": "false", + "TRADINGAGENTS_LOG_FILE": "true", + } + with patch.dict(os.environ, env_vars, clear=False): + import importlib + import tradingagents.logging as log_module + importlib.reload(log_module) + + root_logger = log_module.setup_logging() + + child_logger = log_module.get_logger("tradingagents.dataflows.interface") + grandchild_logger = log_module.get_logger("tradingagents.dataflows.interface.submodule") + + assert root_logger.level == logging.WARNING + + child_logger.info("This should not be logged") + child_logger.warning("This should be logged") + + for handler in root_logger.handlers: + handler.flush() + + log_file = os.path.join(tmpdir, "tradingagents.log") + with open(log_file, "r") as f: + content = f.read() + + assert "This should not be logged" not in content + assert "This should be logged" in content + + def test_lazy_initialization_pattern(self): + with tempfile.TemporaryDirectory() as tmpdir: + env_vars = { + "TRADINGAGENTS_LOG_LEVEL": "INFO", + "TRADINGAGENTS_LOG_DIR": tmpdir, + "TRADINGAGENTS_LOG_CONSOLE": "false", + "TRADINGAGENTS_LOG_FILE": "true", + } + with patch.dict(os.environ, env_vars, clear=False): + import importlib + import tradingagents.logging as log_module + importlib.reload(log_module) + + log_module._logging_initialized = False + + logger = log_module.get_logger("tradingagents.test") + + assert log_module._logging_initialized is True + assert logger is not None diff --git a/tests/test_logging_migration.py b/tests/test_logging_migration.py new file mode 100644 index 00000000..bbcf17d4 --- /dev/null +++ b/tests/test_logging_migration.py @@ -0,0 +1,124 @@ +import ast +import os +import pytest + + +class TestLoggingMigration: + def test_no_print_statements_in_interface_py(self): + file_path = os.path.join( + os.path.dirname(__file__), + "..", + "tradingagents", + "dataflows", + "interface.py", + ) + with open(file_path, "r") as f: + content = f.read() + + tree = ast.parse(content) + + print_calls = [] + for node in ast.walk(tree): + if isinstance(node, ast.Call): + if isinstance(node.func, ast.Name) and node.func.id == "print": + print_calls.append(node.lineno) + + assert len(print_calls) == 0, f"Found print statements at lines: {print_calls}" + + def test_no_print_statements_in_brave_py(self): + file_path = os.path.join( + os.path.dirname(__file__), + "..", + "tradingagents", + "dataflows", + "brave.py", + ) + with open(file_path, "r") as f: + content = f.read() + + tree = ast.parse(content) + + print_calls = [] + for node in ast.walk(tree): + if isinstance(node, ast.Call): + if isinstance(node.func, ast.Name) and node.func.id == "print": + print_calls.append(node.lineno) + + assert len(print_calls) == 0, f"Found print statements at lines: {print_calls}" + + def test_no_print_statements_in_tavily_py(self): + file_path = os.path.join( + os.path.dirname(__file__), + "..", + "tradingagents", + "dataflows", + "tavily.py", + ) + with open(file_path, "r") as f: + content = f.read() + + tree = ast.parse(content) + + print_calls = [] + for node in ast.walk(tree): + if isinstance(node, ast.Call): + if isinstance(node.func, ast.Name) and node.func.id == "print": + print_calls.append(node.lineno) + + assert len(print_calls) == 0, f"Found print statements at lines: {print_calls}" + + def test_no_print_statements_in_migrated_dataflow_files(self): + dataflow_files = [ + "alpha_vantage_news.py", + "y_finance.py", + "local.py", + "yfin_utils.py", + "googlenews_utils.py", + "utils.py", + "alpha_vantage_common.py", + "alpha_vantage_indicator.py", + ] + + dataflows_dir = os.path.join( + os.path.dirname(__file__), + "..", + "tradingagents", + "dataflows", + ) + + all_print_calls = {} + + for filename in dataflow_files: + file_path = os.path.join(dataflows_dir, filename) + if not os.path.exists(file_path): + continue + + with open(file_path, "r") as f: + content = f.read() + + tree = ast.parse(content) + + print_calls = [] + for node in ast.walk(tree): + if isinstance(node, ast.Call): + if isinstance(node.func, ast.Name) and node.func.id == "print": + print_calls.append(node.lineno) + + if print_calls: + all_print_calls[filename] = print_calls + + assert len(all_print_calls) == 0, f"Found print statements in: {all_print_calls}" + + def test_logger_import_exists_in_interface_py(self): + file_path = os.path.join( + os.path.dirname(__file__), + "..", + "tradingagents", + "dataflows", + "interface.py", + ) + with open(file_path, "r") as f: + content = f.read() + + assert "import logging" in content, "interface.py should import logging" + assert "logger = logging.getLogger(__name__)" in content, "interface.py should define logger" diff --git a/tradingagents/agents/utils/memory.py b/tradingagents/agents/utils/memory.py index 9a410183..892a2109 100644 --- a/tradingagents/agents/utils/memory.py +++ b/tradingagents/agents/utils/memory.py @@ -1,7 +1,10 @@ +import logging import chromadb from chromadb.config import Settings from openai import OpenAI +logger = logging.getLogger(__name__) + class FinancialSituationMemory: def __init__(self, name, config): @@ -14,15 +17,13 @@ class FinancialSituationMemory: self.situation_collection = self.chroma_client.get_or_create_collection(name=name) def get_embedding(self, text): - """Get OpenAI embedding for a text""" - + response = self.client.embeddings.create( model=self.embedding, input=text ) return response.data[0].embedding def add_situations(self, situations_and_advice): - """Add financial situations and their corresponding advice. Parameter is a list of tuples (situation, rec)""" situations = [] advice = [] @@ -45,7 +46,6 @@ class FinancialSituationMemory: ) def get_memories(self, current_situation, n_matches=1): - """Find matching recommendations using OpenAI embeddings""" query_embedding = self.get_embedding(current_situation) results = self.situation_collection.query( @@ -65,47 +65,3 @@ class FinancialSituationMemory: ) return matched_results - - -# if __name__ == "__main__": -# # Example usage -# matcher = FinancialSituationMemory() -# example_data = [ -# ( -# "High inflation rate with rising interest rates and declining consumer spending", -# "Consider defensive sectors like consumer staples and utilities. Review fixed-income portfolio duration.", -# ), -# ( -# "Tech sector showing high volatility with increasing institutional selling pressure", -# "Reduce exposure to high-growth tech stocks. Look for value opportunities in established tech companies with strong cash flows.", -# ), -# ( -# "Strong dollar affecting emerging markets with increasing forex volatility", -# "Hedge currency exposure in international positions. Consider reducing allocation to emerging market debt.", -# ), -# ( -# "Market showing signs of sector rotation with rising yields", -# "Rebalance portfolio to maintain target allocations. Consider increasing exposure to sectors benefiting from higher rates.", -# ), -# ] - -# # Add the example situations and recommendations -# matcher.add_situations(example_data) - -# # Example query -# current_situation = """ -# Market showing increased volatility in tech sector, with institutional investors -# reducing positions and rising interest rates affecting growth stock valuations -# """ - -# try: -# recommendations = matcher.get_memories(current_situation, n_matches=2) - -# for i, rec in enumerate(recommendations, 1): -# print(f"\nMatch {i}:") -# print(f"Similarity Score: {rec['similarity_score']:.2f}") -# print(f"Matched Situation: {rec['matched_situation']}") -# print(f"Recommendation: {rec['recommendation']}") - -# except Exception as e: -# print(f"Error during recommendation: {str(e)}") diff --git a/tradingagents/dataflows/alpha_vantage_common.py b/tradingagents/dataflows/alpha_vantage_common.py index 409ff29e..44585b01 100644 --- a/tradingagents/dataflows/alpha_vantage_common.py +++ b/tradingagents/dataflows/alpha_vantage_common.py @@ -1,3 +1,4 @@ +import logging import os import requests import pandas as pd @@ -5,22 +6,20 @@ import json from datetime import datetime from io import StringIO +logger = logging.getLogger(__name__) + API_BASE_URL = "https://www.alphavantage.co/query" def get_api_key() -> str: - """Retrieve the API key for Alpha Vantage from environment variables.""" api_key = os.getenv("ALPHA_VANTAGE_API_KEY") if not api_key: raise ValueError("ALPHA_VANTAGE_API_KEY environment variable is not set.") return api_key def format_datetime_for_api(date_input) -> str: - """Convert various date formats to YYYYMMDDTHHMM format required by Alpha Vantage API.""" if isinstance(date_input, str): - # If already in correct format, return as-is if len(date_input) == 13 and 'T' in date_input: return date_input - # Try to parse common date formats try: dt = datetime.strptime(date_input, "%Y-%m-%d") return dt.strftime("%Y%m%dT0000") @@ -36,48 +35,36 @@ def format_datetime_for_api(date_input) -> str: raise ValueError(f"Date must be string or datetime object, got {type(date_input)}") class AlphaVantageRateLimitError(Exception): - """Exception raised when Alpha Vantage API rate limit is exceeded.""" pass def _make_api_request(function_name: str, params: dict) -> dict | str: - """Helper function to make API requests and handle responses. - - Raises: - AlphaVantageRateLimitError: When API rate limit is exceeded - """ - # Create a copy of params to avoid modifying the original api_params = params.copy() api_params.update({ "function": function_name, "apikey": get_api_key(), "source": "trading_agents", }) - - # Handle entitlement parameter if present in params or global variable + current_entitlement = globals().get('_current_entitlement') entitlement = api_params.get("entitlement") or current_entitlement - + if entitlement: api_params["entitlement"] = entitlement elif "entitlement" in api_params: - # Remove entitlement if it's None or empty api_params.pop("entitlement", None) - + response = requests.get(API_BASE_URL, params=api_params) response.raise_for_status() response_text = response.text - - # Check if response is JSON (error responses are typically JSON) + try: response_json = json.loads(response_text) - # Check for rate limit error if "Information" in response_json: info_message = response_json["Information"] if "rate limit" in info_message.lower() or "api key" in info_message.lower(): raise AlphaVantageRateLimitError(f"Alpha Vantage rate limit exceeded: {info_message}") except json.JSONDecodeError: - # Response is not JSON (likely CSV data), which is normal pass return response_text @@ -85,38 +72,22 @@ def _make_api_request(function_name: str, params: dict) -> dict | str: def _filter_csv_by_date_range(csv_data: str, start_date: str, end_date: str) -> str: - """ - Filter CSV data to include only rows within the specified date range. - - Args: - csv_data: CSV string from Alpha Vantage API - start_date: Start date in yyyy-mm-dd format - end_date: End date in yyyy-mm-dd format - - Returns: - Filtered CSV string - """ if not csv_data or csv_data.strip() == "": return csv_data try: - # Parse CSV data df = pd.read_csv(StringIO(csv_data)) - # Assume the first column is the date column (timestamp) date_col = df.columns[0] df[date_col] = pd.to_datetime(df[date_col]) - # Filter by date range start_dt = pd.to_datetime(start_date) end_dt = pd.to_datetime(end_date) filtered_df = df[(df[date_col] >= start_dt) & (df[date_col] <= end_dt)] - # Convert back to CSV string return filtered_df.to_csv(index=False) except Exception as e: - # If filtering fails, return original data with a warning - print(f"Warning: Failed to filter CSV data by date range: {e}") + logger.warning("Failed to filter CSV data by date range: %s", e) return csv_data diff --git a/tradingagents/dataflows/alpha_vantage_indicator.py b/tradingagents/dataflows/alpha_vantage_indicator.py index 6225b9bb..913cc96c 100644 --- a/tradingagents/dataflows/alpha_vantage_indicator.py +++ b/tradingagents/dataflows/alpha_vantage_indicator.py @@ -1,5 +1,8 @@ +import logging from .alpha_vantage_common import _make_api_request +logger = logging.getLogger(__name__) + def get_indicator( symbol: str, indicator: str, @@ -9,21 +12,6 @@ def get_indicator( time_period: int = 14, series_type: str = "close" ) -> str: - """ - Returns Alpha Vantage technical indicator values over a time window. - - Args: - symbol: ticker symbol of the company - indicator: technical indicator to get the analysis and report of - curr_date: The current trading date you are trading on, YYYY-mm-dd - look_back_days: how many days to look back - interval: Time interval (daily, weekly, monthly) - time_period: Number of data points for calculation - series_type: The desired price type (close, open, high, low) - - Returns: - String containing indicator values and description - """ from datetime import datetime from dateutil.relativedelta import relativedelta @@ -65,15 +53,12 @@ def get_indicator( curr_date_dt = datetime.strptime(curr_date, "%Y-%m-%d") before = curr_date_dt - relativedelta(days=look_back_days) - # Get the full data for the period instead of making individual calls _, required_series_type = supported_indicators[indicator] - # Use the provided series_type or fall back to the required one if required_series_type: series_type = required_series_type try: - # Get indicator data for the period if indicator == "close_50_sma": data = _make_api_request("SMA", { "symbol": symbol, @@ -143,25 +128,20 @@ def get_indicator( "datatype": "csv" }) elif indicator == "vwma": - # Alpha Vantage doesn't have direct VWMA, so we'll return an informative message - # In a real implementation, this would need to be calculated from OHLCV data return f"## VWMA (Volume Weighted Moving Average) for {symbol}:\n\nVWMA calculation requires OHLCV data and is not directly available from Alpha Vantage API.\nThis indicator would need to be calculated from the raw stock data using volume-weighted price averaging.\n\n{indicator_descriptions.get('vwma', 'No description available.')}" else: return f"Error: Indicator {indicator} not implemented yet." - # Parse CSV data and extract values for the date range lines = data.strip().split('\n') if len(lines) < 2: return f"Error: No data returned for {indicator}" - # Parse header and data header = [col.strip() for col in lines[0].split(',')] try: date_col_idx = header.index('time') except ValueError: return f"Error: 'time' column not found in data for {indicator}. Available columns: {header}" - # Map internal indicator names to expected CSV column names from Alpha Vantage col_name_map = { "macd": "MACD", "macds": "MACD_Signal", "macdh": "MACD_Hist", "boll": "Real Middle Band", "boll_ub": "Real Upper Band", "boll_lb": "Real Lower Band", @@ -172,7 +152,6 @@ def get_indicator( target_col_name = col_name_map.get(indicator) if not target_col_name: - # Default to the second column if no specific mapping exists value_col_idx = 1 else: try: @@ -188,17 +167,14 @@ def get_indicator( if len(values) > value_col_idx: try: date_str = values[date_col_idx].strip() - # Parse the date date_dt = datetime.strptime(date_str, "%Y-%m-%d") - # Check if date is in our range if before <= date_dt <= curr_date_dt: value = values[value_col_idx].strip() result_data.append((date_dt, value)) except (ValueError, IndexError): continue - # Sort by date and format output result_data.sort(key=lambda x: x[0]) ind_string = "" @@ -218,5 +194,5 @@ def get_indicator( return result_str except Exception as e: - print(f"Error getting Alpha Vantage indicator data for {indicator}: {e}") + logger.error("Error getting Alpha Vantage indicator data for %s: %s", indicator, e) return f"Error retrieving {indicator} data: {str(e)}" diff --git a/tradingagents/dataflows/alpha_vantage_news.py b/tradingagents/dataflows/alpha_vantage_news.py index 8968f06d..ee941c3e 100644 --- a/tradingagents/dataflows/alpha_vantage_news.py +++ b/tradingagents/dataflows/alpha_vantage_news.py @@ -1,8 +1,11 @@ import json +import logging from datetime import datetime, timedelta from typing import List, Dict, Any from .alpha_vantage_common import _make_api_request, format_datetime_for_api +logger = logging.getLogger(__name__) + def get_news(ticker, start_date, end_date) -> dict[str, str] | str: params = { "tickers": ticker, @@ -40,19 +43,19 @@ def get_bulk_news_alpha_vantage(lookback_hours: int) -> List[Dict[str, Any]]: try: response = json.loads(response) except json.JSONDecodeError: - print(f"DEBUG: Alpha Vantage JSON decode failed") + logger.debug("Alpha Vantage JSON decode failed") return [] if not isinstance(response, dict): - print(f"DEBUG: Alpha Vantage response not a dict: {type(response)}") + logger.debug("Alpha Vantage response not a dict: %s", type(response)) return [] if "Information" in response: - print(f"DEBUG: Alpha Vantage info message: {response.get('Information')}") + logger.debug("Alpha Vantage info message: %s", response.get("Information")) feed = response.get("feed", []) if not feed: - print(f"DEBUG: Alpha Vantage feed empty. Keys in response: {list(response.keys())}") + logger.debug("Alpha Vantage feed empty. Keys in response: %s", list(response.keys())) articles = [] for item in feed: diff --git a/tradingagents/dataflows/brave.py b/tradingagents/dataflows/brave.py index b355f9a9..811f3d72 100644 --- a/tradingagents/dataflows/brave.py +++ b/tradingagents/dataflows/brave.py @@ -1,9 +1,12 @@ +import logging import os import time import requests from datetime import datetime, timedelta from typing import List, Dict, Any +logger = logging.getLogger(__name__) + BRAVE_SEARCH_URL = "https://api.search.brave.com/res/v1/news/search" DEFAULT_TIMEOUT = 30 MAX_RETRIES = 3 @@ -24,23 +27,23 @@ def _make_request_with_retry(url: str, headers: Dict, params: Dict, max_retries: response = requests.get(url, headers=headers, params=params, timeout=DEFAULT_TIMEOUT) if response.status_code == 429: retry_after = int(response.headers.get("Retry-After", RETRY_BACKOFF * (attempt + 1))) - print(f"DEBUG: Brave rate limited, waiting {retry_after}s before retry {attempt + 1}/{max_retries}") + logger.debug("Brave rate limited, waiting %ds before retry %d/%d", retry_after, attempt + 1, max_retries) time.sleep(retry_after) continue response.raise_for_status() return response except requests.exceptions.Timeout as e: last_exception = e - print(f"DEBUG: Brave request timeout, retry {attempt + 1}/{max_retries}") + logger.debug("Brave request timeout, retry %d/%d", attempt + 1, max_retries) time.sleep(RETRY_BACKOFF * (attempt + 1)) except requests.exceptions.ConnectionError as e: last_exception = e - print(f"DEBUG: Brave connection error, retry {attempt + 1}/{max_retries}") + logger.debug("Brave connection error, retry %d/%d", attempt + 1, max_retries) time.sleep(RETRY_BACKOFF * (attempt + 1)) except requests.exceptions.HTTPError as e: if e.response is not None and e.response.status_code >= 500: last_exception = e - print(f"DEBUG: Brave server error {e.response.status_code}, retry {attempt + 1}/{max_retries}") + logger.debug("Brave server error %d, retry %d/%d", e.response.status_code, attempt + 1, max_retries) time.sleep(RETRY_BACKOFF * (attempt + 1)) else: raise @@ -51,7 +54,7 @@ def get_bulk_news_brave(lookback_hours: int) -> List[Dict[str, Any]]: try: api_key = get_api_key() except ValueError as e: - print(f"DEBUG: Brave API key not configured: {e}") + logger.debug("Brave API key not configured: %s", e) return [] headers = { @@ -109,19 +112,19 @@ def get_bulk_news_brave(lookback_hours: int) -> List[Dict[str, Any]]: all_articles.append(article) except requests.exceptions.HTTPError as e: - print(f"DEBUG: Brave search HTTP error for '{query}': {e}") + logger.debug("Brave search HTTP error for '%s': %s", query, e) continue except requests.exceptions.Timeout as e: - print(f"DEBUG: Brave search timeout for '{query}': {e}") + logger.debug("Brave search timeout for '%s': %s", query, e) continue except requests.exceptions.RequestException as e: - print(f"DEBUG: Brave search request failed for '{query}': {e}") + logger.debug("Brave search request failed for '%s': %s", query, e) continue except Exception as e: - print(f"DEBUG: Brave search failed for query '{query}': {e}") + logger.debug("Brave search failed for query '%s': %s", query, e) continue - print(f"DEBUG: Brave returned {len(all_articles)} articles") + logger.debug("Brave returned %d articles", len(all_articles)) return all_articles diff --git a/tradingagents/dataflows/googlenews_utils.py b/tradingagents/dataflows/googlenews_utils.py index bdc6124d..c108aa3c 100644 --- a/tradingagents/dataflows/googlenews_utils.py +++ b/tradingagents/dataflows/googlenews_utils.py @@ -1,3 +1,4 @@ +import logging import json import requests from bs4 import BeautifulSoup @@ -12,9 +13,10 @@ from tenacity import ( retry_if_result, ) +logger = logging.getLogger(__name__) + def is_rate_limited(response): - """Check if the response indicates rate limiting (status code 429)""" return response.status_code == 429 @@ -24,20 +26,12 @@ def is_rate_limited(response): stop=stop_after_attempt(5), ) def make_request(url, headers): - """Make a request with retry logic for rate limiting""" - # Random delay before each request to avoid detection time.sleep(random.uniform(2, 6)) response = requests.get(url, headers=headers) return response def getNewsData(query, start_date, end_date): - """ - Scrape Google News search results for a given query and date range. - query: str - search query - start_date: str - start date in the format yyyy-mm-dd or mm/dd/yyyy - end_date: str - end date in the format yyyy-mm-dd or mm/dd/yyyy - """ if "-" in start_date: start_date = datetime.strptime(start_date, "%Y-%m-%d") start_date = start_date.strftime("%m/%d/%Y") @@ -69,7 +63,7 @@ def getNewsData(query, start_date, end_date): results_on_page = soup.select("div.SoaBEf") if not results_on_page: - break # No more results found + break for el in results_on_page: try: @@ -88,13 +82,9 @@ def getNewsData(query, start_date, end_date): } ) except Exception as e: - print(f"Error processing result: {e}") - # If one of the fields is not found, skip this result + logger.debug("Error processing result: %s", e) continue - # Update the progress bar with the current count of results scraped - - # Check for the "Next" link (pagination) next_link = soup.find("a", id="pnnext") if not next_link: break @@ -102,7 +92,7 @@ def getNewsData(query, start_date, end_date): page += 1 except Exception as e: - print(f"Failed after multiple retries: {e}") + logger.debug("Failed after multiple retries: %s", e) break return news_results diff --git a/tradingagents/dataflows/interface.py b/tradingagents/dataflows/interface.py index eda7a7e4..58576406 100644 --- a/tradingagents/dataflows/interface.py +++ b/tradingagents/dataflows/interface.py @@ -1,3 +1,4 @@ +import logging from typing import Annotated, List, Dict, Any, Optional from datetime import datetime, timedelta import threading @@ -25,6 +26,8 @@ from .config import get_config from tradingagents.agents.discovery import NewsArticle +logger = logging.getLogger(__name__) + TOOLS_CATEGORIES = { "core_stock_apis": { "description": "OHLCV stock price data", @@ -206,17 +209,17 @@ def _fetch_bulk_news_from_vendor(lookback_period: str) -> List[Dict[str, Any]]: vendor_func = VENDOR_METHODS["get_bulk_news"][vendor] try: - print(f"DEBUG: Attempting bulk news from vendor '{vendor}'...") + logger.debug("Attempting bulk news from vendor '%s'...", vendor) result = vendor_func(lookback_hours) if result: - print(f"SUCCESS: Got {len(result)} articles from vendor '{vendor}'") + logger.info("Got %d articles from vendor '%s'", len(result), vendor) return result - print(f"DEBUG: Vendor '{vendor}' returned empty results, trying next...") + logger.debug("Vendor '%s' returned empty results, trying next...", vendor) except AlphaVantageRateLimitError as e: - print(f"RATE_LIMIT: Alpha Vantage rate limit exceeded: {e}") + logger.warning("Alpha Vantage rate limit exceeded: %s", e) continue except Exception as e: - print(f"FAILED: Vendor '{vendor}' failed: {e}") + logger.error("Vendor '%s' failed: %s", vendor, e) continue return [] @@ -225,7 +228,7 @@ def _fetch_bulk_news_from_vendor(lookback_period: str) -> List[Dict[str, Any]]: def get_bulk_news(lookback_period: str = "24h") -> List[NewsArticle]: cached = _get_cached_bulk_news(lookback_period) if cached is not None: - print(f"DEBUG: Returning cached bulk news for period '{lookback_period}'") + logger.debug("Returning cached bulk news for period '%s'", lookback_period) return cached raw_articles = _fetch_bulk_news_from_vendor(lookback_period) @@ -271,7 +274,7 @@ def route_to_vendor(method: str, *args, **kwargs): primary_str = " -> ".join(primary_vendors) fallback_str = " -> ".join(fallback_vendors) - print(f"DEBUG: {method} - Primary: [{primary_str}] | Full fallback order: [{fallback_str}]") + logger.debug("%s - Primary: [%s] | Full fallback order: [%s]", method, primary_str, fallback_str) results = [] vendor_attempt_count = 0 @@ -281,7 +284,7 @@ def route_to_vendor(method: str, *args, **kwargs): for vendor in fallback_vendors: if vendor not in VENDOR_METHODS[method]: if vendor in primary_vendors: - print(f"INFO: Vendor '{vendor}' not supported for method '{method}', falling back to next vendor") + logger.info("Vendor '%s' not supported for method '%s', falling back to next vendor", vendor, method) continue vendor_impl = VENDOR_METHODS[method][vendor] @@ -292,48 +295,48 @@ def route_to_vendor(method: str, *args, **kwargs): any_primary_vendor_attempted = True vendor_type = "PRIMARY" if is_primary_vendor else "FALLBACK" - print(f"DEBUG: Attempting {vendor_type} vendor '{vendor}' for {method} (attempt #{vendor_attempt_count})") + logger.debug("Attempting %s vendor '%s' for %s (attempt #%d)", vendor_type, vendor, method, vendor_attempt_count) if isinstance(vendor_impl, list): vendor_methods = [(impl, vendor) for impl in vendor_impl] - print(f"DEBUG: Vendor '{vendor}' has multiple implementations: {len(vendor_methods)} functions") + logger.debug("Vendor '%s' has multiple implementations: %d functions", vendor, len(vendor_methods)) else: vendor_methods = [(vendor_impl, vendor)] vendor_results = [] for impl_func, vendor_name in vendor_methods: try: - print(f"DEBUG: Calling {impl_func.__name__} from vendor '{vendor_name}'...") + logger.debug("Calling %s from vendor '%s'...", impl_func.__name__, vendor_name) result = impl_func(*args, **kwargs) vendor_results.append(result) - print(f"SUCCESS: {impl_func.__name__} from vendor '{vendor_name}' completed successfully") + logger.info("%s from vendor '%s' completed successfully", impl_func.__name__, vendor_name) except AlphaVantageRateLimitError as e: if vendor == "alpha_vantage": - print(f"RATE_LIMIT: Alpha Vantage rate limit exceeded, falling back to next available vendor") - print(f"DEBUG: Rate limit details: {e}") + logger.warning("Alpha Vantage rate limit exceeded, falling back to next available vendor") + logger.debug("Rate limit details: %s", e) continue except Exception as e: - print(f"FAILED: {impl_func.__name__} from vendor '{vendor_name}' failed: {e}") + logger.error("%s from vendor '%s' failed: %s", impl_func.__name__, vendor_name, e) continue if vendor_results: results.extend(vendor_results) successful_vendor = vendor result_summary = f"Got {len(vendor_results)} result(s)" - print(f"SUCCESS: Vendor '{vendor}' succeeded - {result_summary}") + logger.info("Vendor '%s' succeeded - %s", vendor, result_summary) if len(primary_vendors) == 1: - print(f"DEBUG: Stopping after successful vendor '{vendor}' (single-vendor config)") + logger.debug("Stopping after successful vendor '%s' (single-vendor config)", vendor) break else: - print(f"FAILED: Vendor '{vendor}' produced no results") + logger.error("Vendor '%s' produced no results", vendor) if not results: - print(f"FAILURE: All {vendor_attempt_count} vendor attempts failed for method '{method}'") + logger.error("All %d vendor attempts failed for method '%s'", vendor_attempt_count, method) raise RuntimeError(f"All vendor implementations failed for method '{method}'") else: - print(f"FINAL: Method '{method}' completed with {len(results)} result(s) from {vendor_attempt_count} vendor attempt(s)") + logger.info("Method '%s' completed with %d result(s) from %d vendor attempt(s)", method, len(results), vendor_attempt_count) if len(results) == 1: return results[0] diff --git a/tradingagents/dataflows/local.py b/tradingagents/dataflows/local.py index 502bc43a..1c126ee2 100644 --- a/tradingagents/dataflows/local.py +++ b/tradingagents/dataflows/local.py @@ -1,3 +1,4 @@ +import logging from typing import Annotated import pandas as pd import os @@ -8,17 +9,17 @@ import json from .reddit_utils import fetch_top_from_category from tqdm import tqdm +logger = logging.getLogger(__name__) + def get_YFin_data_window( symbol: Annotated[str, "ticker symbol of the company"], curr_date: Annotated[str, "Start date in yyyy-mm-dd format"], look_back_days: Annotated[int, "how many days to look back"], ) -> str: - # calculate past days date_obj = datetime.strptime(curr_date, "%Y-%m-%d") before = date_obj - relativedelta(days=look_back_days) start_date = before.strftime("%Y-%m-%d") - # read in data data = pd.read_csv( os.path.join( DATA_DIR, @@ -26,18 +27,14 @@ def get_YFin_data_window( ) ) - # Extract just the date part for comparison data["DateOnly"] = data["Date"].str[:10] - # Filter data between the start and end dates (inclusive) filtered_data = data[ (data["DateOnly"] >= start_date) & (data["DateOnly"] <= curr_date) ] - # Drop the temporary column we created filtered_data = filtered_data.drop("DateOnly", axis=1) - # Set pandas display options to show the full DataFrame with pd.option_context( "display.max_rows", None, "display.max_columns", None, "display.width", None ): @@ -53,7 +50,6 @@ def get_YFin_data( start_date: Annotated[str, "Start date in yyyy-mm-dd format"], end_date: Annotated[str, "End date in yyyy-mm-dd format"], ) -> str: - # read in data data = pd.read_csv( os.path.join( DATA_DIR, @@ -66,18 +62,14 @@ def get_YFin_data( f"Get_YFin_Data: {end_date} is outside of the data range of 2015-01-01 to 2025-03-25" ) - # Extract just the date part for comparison data["DateOnly"] = data["Date"].str[:10] - # Filter data between the start and end dates (inclusive) filtered_data = data[ (data["DateOnly"] >= start_date) & (data["DateOnly"] <= end_date) ] - # Drop the temporary column we created filtered_data = filtered_data.drop("DateOnly", axis=1) - # remove the index from the dataframe filtered_data = filtered_data.reset_index(drop=True) return filtered_data @@ -87,17 +79,6 @@ def get_finnhub_news( start_date: Annotated[str, "Start date in yyyy-mm-dd format"], end_date: Annotated[str, "End date in yyyy-mm-dd format"], ): - """ - Retrieve news about a company within a time frame - - Args - query (str): Search query or ticker symbol - start_date (str): Start date in yyyy-mm-dd format - end_date (str): End date in yyyy-mm-dd format - Returns - str: dataframe containing the news of the company in the time frame - - """ result = get_data_in_range(query, start_date, end_date, "news_data", DATA_DIR) @@ -121,17 +102,9 @@ def get_finnhub_company_insider_sentiment( ticker: Annotated[str, "ticker symbol for the company"], curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"], ): - """ - Retrieve insider sentiment about a company (retrieved from public SEC information) for the past 15 days - Args: - ticker (str): ticker symbol of the company - curr_date (str): current date you are trading on, yyyy-mm-dd - Returns: - str: a report of the sentiment in the past 15 days starting at curr_date - """ date_obj = datetime.strptime(curr_date, "%Y-%m-%d") - before = date_obj - relativedelta(days=15) # Default 15 days lookback + before = date_obj - relativedelta(days=15) before = before.strftime("%Y-%m-%d") data = get_data_in_range(ticker, before, curr_date, "insider_senti", DATA_DIR) @@ -158,17 +131,9 @@ def get_finnhub_company_insider_transactions( ticker: Annotated[str, "ticker symbol"], curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"], ): - """ - Retrieve insider transcaction information about a company (retrieved from public SEC information) for the past 15 days - Args: - ticker (str): ticker symbol of the company - curr_date (str): current date you are trading at, yyyy-mm-dd - Returns: - str: a report of the company's insider transaction/trading informtaion in the past 15 days - """ date_obj = datetime.strptime(curr_date, "%Y-%m-%d") - before = date_obj - relativedelta(days=15) # Default 15 days lookback + before = date_obj - relativedelta(days=15) before = before.strftime("%Y-%m-%d") data = get_data_in_range(ticker, before, curr_date, "insider_trans", DATA_DIR) @@ -192,15 +157,6 @@ def get_finnhub_company_insider_transactions( ) def get_data_in_range(ticker, start_date, end_date, data_type, data_dir, period=None): - """ - Gets finnhub data saved and processed on disk. - Args: - start_date (str): Start date in YYYY-MM-DD format. - end_date (str): End date in YYYY-MM-DD format. - data_type (str): Type of data from finnhub to fetch. Can be insider_trans, SEC_filings, news_data, insider_senti, or fin_as_reported. - data_dir (str): Directory where the data is saved. - period (str): Default to none, if there is a period specified, should be annual or quarterly. - """ if period: data_path = os.path.join( @@ -217,7 +173,6 @@ def get_data_in_range(ticker, start_date, end_date, data_type, data_dir, period= data = open(data_path, "r") data = json.load(data) - # filter keys (date, str in format YYYY-MM-DD) by the date range (str, str in format YYYY-MM-DD) filtered_data = {} for key, value in data.items(): if start_date <= key <= end_date and len(value) > 0: @@ -243,25 +198,19 @@ def get_simfin_balance_sheet( ) df = pd.read_csv(data_path, sep=";") - # Convert date strings to datetime objects and remove any time components df["Report Date"] = pd.to_datetime(df["Report Date"], utc=True).dt.normalize() df["Publish Date"] = pd.to_datetime(df["Publish Date"], utc=True).dt.normalize() - # Convert the current date to datetime and normalize curr_date_dt = pd.to_datetime(curr_date, utc=True).normalize() - # Filter the DataFrame for the given ticker and for reports that were published on or before the current date filtered_df = df[(df["Ticker"] == ticker) & (df["Publish Date"] <= curr_date_dt)] - # Check if there are any available reports; if not, return a notification if filtered_df.empty: - print("No balance sheet available before the given current date.") + logger.info("No balance sheet available before the given current date.") return "" - # Get the most recent balance sheet by selecting the row with the latest Publish Date latest_balance_sheet = filtered_df.loc[filtered_df["Publish Date"].idxmax()] - # drop the SimFinID column latest_balance_sheet = latest_balance_sheet.drop("SimFinId") return ( @@ -290,25 +239,19 @@ def get_simfin_cashflow( ) df = pd.read_csv(data_path, sep=";") - # Convert date strings to datetime objects and remove any time components df["Report Date"] = pd.to_datetime(df["Report Date"], utc=True).dt.normalize() df["Publish Date"] = pd.to_datetime(df["Publish Date"], utc=True).dt.normalize() - # Convert the current date to datetime and normalize curr_date_dt = pd.to_datetime(curr_date, utc=True).normalize() - # Filter the DataFrame for the given ticker and for reports that were published on or before the current date filtered_df = df[(df["Ticker"] == ticker) & (df["Publish Date"] <= curr_date_dt)] - # Check if there are any available reports; if not, return a notification if filtered_df.empty: - print("No cash flow statement available before the given current date.") + logger.info("No cash flow statement available before the given current date.") return "" - # Get the most recent cash flow statement by selecting the row with the latest Publish Date latest_cash_flow = filtered_df.loc[filtered_df["Publish Date"].idxmax()] - # drop the SimFinID column latest_cash_flow = latest_cash_flow.drop("SimFinId") return ( @@ -337,25 +280,19 @@ def get_simfin_income_statements( ) df = pd.read_csv(data_path, sep=";") - # Convert date strings to datetime objects and remove any time components df["Report Date"] = pd.to_datetime(df["Report Date"], utc=True).dt.normalize() df["Publish Date"] = pd.to_datetime(df["Publish Date"], utc=True).dt.normalize() - # Convert the current date to datetime and normalize curr_date_dt = pd.to_datetime(curr_date, utc=True).normalize() - # Filter the DataFrame for the given ticker and for reports that were published on or before the current date filtered_df = df[(df["Ticker"] == ticker) & (df["Publish Date"] <= curr_date_dt)] - # Check if there are any available reports; if not, return a notification if filtered_df.empty: - print("No income statement available before the given current date.") + logger.info("No income statement available before the given current date.") return "" - # Get the most recent income statement by selecting the row with the latest Publish Date latest_income = filtered_df.loc[filtered_df["Publish Date"].idxmax()] - # drop the SimFinID column latest_income = latest_income.drop("SimFinId") return ( @@ -370,22 +307,12 @@ def get_reddit_global_news( look_back_days: Annotated[int, "Number of days to look back"] = 7, limit: Annotated[int, "Maximum number of articles to return"] = 5, ) -> str: - """ - Retrieve the latest top reddit news - Args: - curr_date: Current date in yyyy-mm-dd format - look_back_days: Number of days to look back (default 7) - limit: Maximum number of articles to return (default 5) - Returns: - str: A formatted string containing the latest news articles posts on reddit - """ curr_date_dt = datetime.strptime(curr_date, "%Y-%m-%d") before = curr_date_dt - relativedelta(days=look_back_days) before = before.strftime("%Y-%m-%d") posts = [] - # iterate from before to curr_date curr_iter_date = datetime.strptime(before, "%Y-%m-%d") total_iterations = (curr_date_dt - curr_iter_date).days + 1 @@ -423,21 +350,11 @@ def get_reddit_company_news( start_date: Annotated[str, "Start date in yyyy-mm-dd format"], end_date: Annotated[str, "End date in yyyy-mm-dd format"], ) -> str: - """ - Retrieve the latest top reddit news - Args: - query: Search query or ticker symbol - start_date: Start date in yyyy-mm-dd format - end_date: End date in yyyy-mm-dd format - Returns: - str: A formatted string containing news articles posts on reddit - """ start_date_dt = datetime.strptime(start_date, "%Y-%m-%d") end_date_dt = datetime.strptime(end_date, "%Y-%m-%d") posts = [] - # iterate from start_date to end_date curr_date = start_date_dt total_iterations = (end_date_dt - curr_date).days + 1 @@ -451,7 +368,7 @@ def get_reddit_company_news( fetch_result = fetch_top_from_category( "company_news", curr_date_str, - 10, # max limit per day + 10, query, data_path=os.path.join(DATA_DIR, "reddit_data"), ) @@ -472,4 +389,4 @@ def get_reddit_company_news( else: news_str += f"### {post['title']}\n\n{post['content']}\n\n" - return f"##{query} News Reddit, from {start_date} to {end_date}:\n\n{news_str}" \ No newline at end of file + return f"##{query} News Reddit, from {start_date} to {end_date}:\n\n{news_str}" diff --git a/tradingagents/dataflows/tavily.py b/tradingagents/dataflows/tavily.py index 560202a0..4651c0ad 100644 --- a/tradingagents/dataflows/tavily.py +++ b/tradingagents/dataflows/tavily.py @@ -1,8 +1,11 @@ +import logging import os import time from datetime import datetime, timedelta from typing import List, Dict, Any +logger = logging.getLogger(__name__) + try: from tavily import TavilyClient TAVILY_AVAILABLE = True @@ -37,17 +40,17 @@ def _search_with_retry(client, query: str, search_depth: str, topic: str, time_r error_str = str(e).lower() if "rate" in error_str or "limit" in error_str or "429" in error_str: wait_time = RETRY_BACKOFF * (attempt + 1) * 2 - print(f"DEBUG: Tavily rate limited, waiting {wait_time}s before retry {attempt + 1}/{max_retries}") + logger.debug("Tavily rate limited, waiting %ds before retry %d/%d", wait_time, attempt + 1, max_retries) time.sleep(wait_time) last_exception = e elif "timeout" in error_str or "timed out" in error_str: wait_time = RETRY_BACKOFF * (attempt + 1) - print(f"DEBUG: Tavily timeout, waiting {wait_time}s before retry {attempt + 1}/{max_retries}") + logger.debug("Tavily timeout, waiting %ds before retry %d/%d", wait_time, attempt + 1, max_retries) time.sleep(wait_time) last_exception = e elif "connection" in error_str or "network" in error_str: wait_time = RETRY_BACKOFF * (attempt + 1) - print(f"DEBUG: Tavily connection error, waiting {wait_time}s before retry {attempt + 1}/{max_retries}") + logger.debug("Tavily connection error, waiting %ds before retry %d/%d", wait_time, attempt + 1, max_retries) time.sleep(wait_time) last_exception = e else: @@ -57,13 +60,13 @@ def _search_with_retry(client, query: str, search_depth: str, topic: str, time_r def get_bulk_news_tavily(lookback_hours: int) -> List[Dict[str, Any]]: if not TAVILY_AVAILABLE: - print("DEBUG: Tavily library not installed") + logger.debug("Tavily library not installed") return [] try: client = TavilyClient(api_key=get_api_key()) except ValueError as e: - print(f"DEBUG: Tavily API key not configured: {e}") + logger.debug("Tavily API key not configured: %s", e) return [] queries = [ @@ -121,8 +124,8 @@ def get_bulk_news_tavily(lookback_hours: int) -> List[Dict[str, Any]]: all_articles.append(article) except Exception as e: - print(f"DEBUG: Tavily search failed for query '{query}': {e}") + logger.debug("Tavily search failed for query '%s': %s", query, e) continue - print(f"DEBUG: Tavily returned {len(all_articles)} articles") + logger.debug("Tavily returned %d articles", len(all_articles)) return all_articles diff --git a/tradingagents/dataflows/utils.py b/tradingagents/dataflows/utils.py index 4523de19..3132ecae 100644 --- a/tradingagents/dataflows/utils.py +++ b/tradingagents/dataflows/utils.py @@ -1,15 +1,18 @@ +import logging import os import json import pandas as pd from datetime import date, timedelta, datetime from typing import Annotated +logger = logging.getLogger(__name__) + SavePathType = Annotated[str, "File path to save data. If None, data is not saved."] def save_output(data: pd.DataFrame, tag: str, save_path: SavePathType = None) -> None: if save_path: data.to_csv(save_path) - print(f"{tag} saved to {save_path}") + logger.info("%s saved to %s", tag, save_path) def get_current_date(): diff --git a/tradingagents/dataflows/y_finance.py b/tradingagents/dataflows/y_finance.py index da7273d5..8bb4927d 100644 --- a/tradingagents/dataflows/y_finance.py +++ b/tradingagents/dataflows/y_finance.py @@ -1,3 +1,4 @@ +import logging from typing import Annotated from datetime import datetime from dateutil.relativedelta import relativedelta @@ -5,6 +6,8 @@ import yfinance as yf import os from .stockstats_utils import StockstatsUtils +logger = logging.getLogger(__name__) + def get_YFin_data_online( symbol: Annotated[str, "ticker symbol of the company"], start_date: Annotated[str, "Start date in yyyy-mm-dd format"], @@ -14,32 +17,25 @@ def get_YFin_data_online( datetime.strptime(start_date, "%Y-%m-%d") datetime.strptime(end_date, "%Y-%m-%d") - # Create ticker object ticker = yf.Ticker(symbol.upper()) - # Fetch historical data for the specified date range data = ticker.history(start=start_date, end=end_date) - # Check if data is empty if data.empty: return ( f"No data found for symbol '{symbol}' between {start_date} and {end_date}" ) - # Remove timezone info from index for cleaner output if data.index.tz is not None: data.index = data.index.tz_localize(None) - # Round numerical values to 2 decimal places for cleaner display numeric_columns = ["Open", "High", "Low", "Close", "Adj Close"] for col in numeric_columns: if col in data.columns: data[col] = data[col].round(2) - # Convert DataFrame to CSV string csv_string = data.to_csv() - # Add header information header = f"# Stock data for {symbol.upper()} from {start_date} to {end_date}\n" header += f"# Total records: {len(data)}\n" header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n" @@ -56,7 +52,6 @@ def get_stock_stats_indicators_window( ) -> str: best_ind_params = { - # Moving Averages "close_50_sma": ( "50 SMA: A medium-term trend indicator. " "Usage: Identify trend direction and serve as dynamic support/resistance. " @@ -72,7 +67,6 @@ def get_stock_stats_indicators_window( "Usage: Capture quick shifts in momentum and potential entry points. " "Tips: Prone to noise in choppy markets; use alongside longer averages for filtering false signals." ), - # MACD Related "macd": ( "MACD: Computes momentum via differences of EMAs. " "Usage: Look for crossovers and divergence as signals of trend changes. " @@ -88,13 +82,11 @@ def get_stock_stats_indicators_window( "Usage: Visualize momentum strength and spot divergence early. " "Tips: Can be volatile; complement with additional filters in fast-moving markets." ), - # Momentum Indicators "rsi": ( "RSI: Measures momentum to flag overbought/oversold conditions. " "Usage: Apply 70/30 thresholds and watch for divergence to signal reversals. " "Tips: In strong trends, RSI may remain extreme; always cross-check with trend analysis." ), - # Volatility Indicators "boll": ( "Bollinger Middle: A 20 SMA serving as the basis for Bollinger Bands. " "Usage: Acts as a dynamic benchmark for price movement. " @@ -115,7 +107,6 @@ def get_stock_stats_indicators_window( "Usage: Set stop-loss levels and adjust position sizes based on current market volatility. " "Tips: It's a reactive measure, so use it as part of a broader risk management strategy." ), - # Volume-Based Indicators "vwma": ( "VWMA: A moving average weighted by volume. " "Usage: Confirm trends by integrating price action with volume data. " @@ -137,34 +128,29 @@ def get_stock_stats_indicators_window( curr_date_dt = datetime.strptime(curr_date, "%Y-%m-%d") before = curr_date_dt - relativedelta(days=look_back_days) - # Optimized: Get stock data once and calculate indicators for all dates try: indicator_data = _get_stock_stats_bulk(symbol, indicator, curr_date) - - # Generate the date range we need + current_dt = curr_date_dt date_values = [] - + while current_dt >= before: date_str = current_dt.strftime('%Y-%m-%d') - - # Look up the indicator value for this date + if date_str in indicator_data: indicator_value = indicator_data[date_str] else: indicator_value = "N/A: Not a trading day (weekend or holiday)" - + date_values.append((date_str, indicator_value)) current_dt = current_dt - relativedelta(days=1) - - # Build the result string + ind_string = "" for date_str, value in date_values: ind_string += f"{date_str}: {value}\n" - + except Exception as e: - print(f"Error getting bulk stockstats data: {e}") - # Fallback to original implementation if bulk method fails + logger.error("Error getting bulk stockstats data: %s", e) ind_string = "" curr_date_dt = datetime.strptime(curr_date, "%Y-%m-%d") while curr_date_dt >= before: @@ -189,21 +175,15 @@ def _get_stock_stats_bulk( indicator: Annotated[str, "technical indicator to calculate"], curr_date: Annotated[str, "current date for reference"] ) -> dict: - """ - Optimized bulk calculation of stock stats indicators. - Fetches data once and calculates indicator for all available dates. - Returns dict mapping date strings to indicator values. - """ from .config import get_config import pandas as pd from stockstats import wrap import os - + config = get_config() online = config["data_vendors"]["technical_indicators"] != "local" - + if not online: - # Local data path try: data = pd.read_csv( os.path.join( @@ -215,22 +195,21 @@ def _get_stock_stats_bulk( except FileNotFoundError: raise Exception("Stockstats fail: Yahoo Finance data not fetched yet!") else: - # Online data fetching with caching today_date = pd.Timestamp.today() curr_date_dt = pd.to_datetime(curr_date) - + end_date = today_date start_date = today_date - pd.DateOffset(years=15) start_date_str = start_date.strftime("%Y-%m-%d") end_date_str = end_date.strftime("%Y-%m-%d") - + os.makedirs(config["data_cache_dir"], exist_ok=True) - + data_file = os.path.join( config["data_cache_dir"], f"{symbol}-YFin-data-{start_date_str}-{end_date_str}.csv", ) - + if os.path.exists(data_file): data = pd.read_csv(data_file) data["Date"] = pd.to_datetime(data["Date"]) @@ -245,25 +224,22 @@ def _get_stock_stats_bulk( ) data = data.reset_index() data.to_csv(data_file, index=False) - + df = wrap(data) df["Date"] = df["Date"].dt.strftime("%Y-%m-%d") - - # Calculate the indicator for all rows at once - df[indicator] # This triggers stockstats to calculate the indicator - - # Create a dictionary mapping date strings to indicator values + + df[indicator] + result_dict = {} for _, row in df.iterrows(): date_str = row["Date"] indicator_value = row[indicator] - - # Handle NaN/None values + if pd.isna(indicator_value): result_dict[date_str] = "N/A" else: result_dict[date_str] = str(indicator_value) - + return result_dict @@ -285,8 +261,9 @@ def get_stockstats_indicator( curr_date, ) except Exception as e: - print( - f"Error getting stockstats indicator data for indicator {indicator} on {curr_date}: {e}" + logger.error( + "Error getting stockstats indicator data for indicator %s on %s: %s", + indicator, curr_date, e ) return "" @@ -298,27 +275,24 @@ def get_balance_sheet( freq: Annotated[str, "frequency of data: 'annual' or 'quarterly'"] = "quarterly", curr_date: Annotated[str, "current date (not used for yfinance)"] = None ): - """Get balance sheet data from yfinance.""" try: ticker_obj = yf.Ticker(ticker.upper()) - + if freq.lower() == "quarterly": data = ticker_obj.quarterly_balance_sheet else: data = ticker_obj.balance_sheet - + if data.empty: return f"No balance sheet data found for symbol '{ticker}'" - - # Convert to CSV string for consistency with other functions + csv_string = data.to_csv() - - # Add header information + header = f"# Balance Sheet data for {ticker.upper()} ({freq})\n" header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n" - + return header + csv_string - + except Exception as e: return f"Error retrieving balance sheet for {ticker}: {str(e)}" @@ -328,27 +302,24 @@ def get_cashflow( freq: Annotated[str, "frequency of data: 'annual' or 'quarterly'"] = "quarterly", curr_date: Annotated[str, "current date (not used for yfinance)"] = None ): - """Get cash flow data from yfinance.""" try: ticker_obj = yf.Ticker(ticker.upper()) - + if freq.lower() == "quarterly": data = ticker_obj.quarterly_cashflow else: data = ticker_obj.cashflow - + if data.empty: return f"No cash flow data found for symbol '{ticker}'" - - # Convert to CSV string for consistency with other functions + csv_string = data.to_csv() - - # Add header information + header = f"# Cash Flow data for {ticker.upper()} ({freq})\n" header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n" - + return header + csv_string - + except Exception as e: return f"Error retrieving cash flow for {ticker}: {str(e)}" @@ -358,27 +329,24 @@ def get_income_statement( freq: Annotated[str, "frequency of data: 'annual' or 'quarterly'"] = "quarterly", curr_date: Annotated[str, "current date (not used for yfinance)"] = None ): - """Get income statement data from yfinance.""" try: ticker_obj = yf.Ticker(ticker.upper()) - + if freq.lower() == "quarterly": data = ticker_obj.quarterly_income_stmt else: data = ticker_obj.income_stmt - + if data.empty: return f"No income statement data found for symbol '{ticker}'" - - # Convert to CSV string for consistency with other functions + csv_string = data.to_csv() - - # Add header information + header = f"# Income Statement data for {ticker.upper()} ({freq})\n" header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n" - + return header + csv_string - + except Exception as e: return f"Error retrieving income statement for {ticker}: {str(e)}" @@ -386,22 +354,19 @@ def get_income_statement( def get_insider_transactions( ticker: Annotated[str, "ticker symbol of the company"] ): - """Get insider transactions data from yfinance.""" try: ticker_obj = yf.Ticker(ticker.upper()) data = ticker_obj.insider_transactions - + if data is None or data.empty: return f"No insider transactions data found for symbol '{ticker}'" - - # Convert to CSV string for consistency with other functions + csv_string = data.to_csv() - - # Add header information + header = f"# Insider Transactions data for {ticker.upper()}\n" header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n" - + return header + csv_string - + except Exception as e: - return f"Error retrieving insider transactions for {ticker}: {str(e)}" \ No newline at end of file + return f"Error retrieving insider transactions for {ticker}: {str(e)}" diff --git a/tradingagents/dataflows/yfin_utils.py b/tradingagents/dataflows/yfin_utils.py index bd7ca324..fa860f36 100644 --- a/tradingagents/dataflows/yfin_utils.py +++ b/tradingagents/dataflows/yfin_utils.py @@ -1,5 +1,4 @@ -# gets data/stats - +import logging import yfinance as yf from typing import Annotated, Callable, Any, Optional from pandas import DataFrame @@ -8,9 +7,10 @@ from functools import wraps from .utils import save_output, SavePathType, decorate_all_methods +logger = logging.getLogger(__name__) + def init_ticker(func: Callable) -> Callable: - """Decorator to initialize yf.Ticker and pass it to the function.""" @wraps(func) def wrapper(symbol: Annotated[str, "ticker symbol"], *args, **kwargs) -> Any: @@ -33,19 +33,15 @@ class YFinanceUtils: ], save_path: SavePathType = None, ) -> DataFrame: - """retrieve stock price data for designated ticker symbol""" ticker = symbol - # add one day to the end_date so that the data range is inclusive end_date = pd.to_datetime(end_date) + pd.DateOffset(days=1) end_date = end_date.strftime("%Y-%m-%d") stock_data = ticker.history(start=start_date, end=end_date) - # save_output(stock_data, f"Stock data for {ticker.ticker}", save_path) return stock_data def get_stock_info( symbol: Annotated[str, "ticker symbol"], ) -> dict: - """Fetches and returns latest stock information.""" ticker = symbol stock_info = ticker.info return stock_info @@ -54,7 +50,6 @@ class YFinanceUtils: symbol: Annotated[str, "ticker symbol"], save_path: Optional[str] = None, ) -> DataFrame: - """Fetches and returns company information as a DataFrame.""" ticker = symbol info = ticker.info company_info = { @@ -67,50 +62,43 @@ class YFinanceUtils: company_info_df = DataFrame([company_info]) if save_path: company_info_df.to_csv(save_path) - print(f"Company info for {ticker.ticker} saved to {save_path}") + logger.info("Company info for %s saved to %s", ticker.ticker, save_path) return company_info_df def get_stock_dividends( symbol: Annotated[str, "ticker symbol"], save_path: Optional[str] = None, ) -> DataFrame: - """Fetches and returns the latest dividends data as a DataFrame.""" ticker = symbol dividends = ticker.dividends if save_path: dividends.to_csv(save_path) - print(f"Dividends for {ticker.ticker} saved to {save_path}") + logger.info("Dividends for %s saved to %s", ticker.ticker, save_path) return dividends def get_income_stmt(symbol: Annotated[str, "ticker symbol"]) -> DataFrame: - """Fetches and returns the latest income statement of the company as a DataFrame.""" ticker = symbol income_stmt = ticker.financials return income_stmt def get_balance_sheet(symbol: Annotated[str, "ticker symbol"]) -> DataFrame: - """Fetches and returns the latest balance sheet of the company as a DataFrame.""" ticker = symbol balance_sheet = ticker.balance_sheet return balance_sheet def get_cash_flow(symbol: Annotated[str, "ticker symbol"]) -> DataFrame: - """Fetches and returns the latest cash flow statement of the company as a DataFrame.""" ticker = symbol cash_flow = ticker.cashflow return cash_flow def get_analyst_recommendations(symbol: Annotated[str, "ticker symbol"]) -> tuple: - """Fetches the latest analyst recommendations and returns the most common recommendation and its count.""" ticker = symbol recommendations = ticker.recommendations if recommendations.empty: - return None, 0 # No recommendations available + return None, 0 - # Assuming 'period' column exists and needs to be excluded - row_0 = recommendations.iloc[0, 1:] # Exclude 'period' column if necessary + row_0 = recommendations.iloc[0, 1:] - # Find the maximum voting result max_votes = row_0.max() majority_voting_result = row_0[row_0 == max_votes].index.tolist() diff --git a/tradingagents/default_config.py b/tradingagents/default_config.py index f0d37cb0..baecdfe5 100644 --- a/tradingagents/default_config.py +++ b/tradingagents/default_config.py @@ -31,4 +31,8 @@ DEFAULT_CONFIG = { "bulk_news_vendor_order": ["tavily", "brave", "alpha_vantage", "openai", "google"], "bulk_news_timeout": 30, "bulk_news_max_retries": 3, + "log_level": os.getenv("TRADINGAGENTS_LOG_LEVEL", "INFO"), + "log_dir": os.getenv("TRADINGAGENTS_LOG_DIR", "./logs"), + "log_console_enabled": os.getenv("TRADINGAGENTS_LOG_CONSOLE", "true").lower() in ("true", "1", "yes", "on"), + "log_file_enabled": os.getenv("TRADINGAGENTS_LOG_FILE", "true").lower() in ("true", "1", "yes", "on"), } diff --git a/tradingagents/graph/trading_graph.py b/tradingagents/graph/trading_graph.py index 435ad168..f6dd3e8f 100644 --- a/tradingagents/graph/trading_graph.py +++ b/tradingagents/graph/trading_graph.py @@ -1,3 +1,4 @@ +import logging import os import signal import threading @@ -54,6 +55,8 @@ from .propagation import Propagator from .reflection import Reflector from .signal_processing import SignalProcessor +logger = logging.getLogger(__name__) + class DiscoveryTimeoutException(Exception): pass @@ -164,7 +167,7 @@ class TradingAgentsGraph: if len(chunk["messages"]) == 0: pass else: - chunk["messages"][-1].pretty_print() + logger.debug("Agent message: %s", chunk["messages"][-1]) trace.append(chunk) final_state = trace[-1] diff --git a/tradingagents/logging.py b/tradingagents/logging.py new file mode 100644 index 00000000..37a86c55 --- /dev/null +++ b/tradingagents/logging.py @@ -0,0 +1,121 @@ +import logging +import logging.handlers +import os +import json +from datetime import datetime + +LOG_LEVEL_DEFAULT = "INFO" +LOG_DIR_DEFAULT = "./logs" +LOG_FILE_NAME = "tradingagents.log" +LOG_MAX_BYTES = 10 * 1024 * 1024 +LOG_BACKUP_COUNT = 5 + +_logging_initialized = False + + +class JSONFormatter(logging.Formatter): + def format(self, record): + log_record = { + "timestamp": datetime.utcnow().isoformat() + "Z", + "level": record.levelname, + "logger": record.name, + "message": record.getMessage(), + "filename": record.filename, + "funcName": record.funcName, + "lineno": record.lineno, + } + + if record.exc_info: + log_record["exception"] = self.formatException(record.exc_info) + + return json.dumps(log_record) + + +def _parse_bool(value): + if isinstance(value, bool): + return value + if isinstance(value, str): + return value.lower() in ("true", "1", "yes", "on") + return bool(value) + + +def _get_config_value(key, default): + try: + from tradingagents.default_config import DEFAULT_CONFIG + return DEFAULT_CONFIG.get(key, default) + except ImportError: + return default + + +def setup_logging(): + global _logging_initialized + + log_level_str = os.getenv("TRADINGAGENTS_LOG_LEVEL") + if log_level_str is None: + log_level_str = _get_config_value("log_level", LOG_LEVEL_DEFAULT) + + log_dir = os.getenv("TRADINGAGENTS_LOG_DIR") + if log_dir is None: + log_dir = _get_config_value("log_dir", LOG_DIR_DEFAULT) + + console_enabled_env = os.getenv("TRADINGAGENTS_LOG_CONSOLE") + if console_enabled_env is not None: + console_enabled = _parse_bool(console_enabled_env) + else: + console_enabled = _get_config_value("log_console_enabled", True) + + file_enabled_env = os.getenv("TRADINGAGENTS_LOG_FILE") + if file_enabled_env is not None: + file_enabled = _parse_bool(file_enabled_env) + else: + file_enabled = _get_config_value("log_file_enabled", True) + + log_level = getattr(logging, log_level_str.upper(), logging.INFO) + + root_logger = logging.getLogger("tradingagents") + + for handler in root_logger.handlers[:]: + root_logger.removeHandler(handler) + + root_logger.setLevel(log_level) + + if file_enabled: + os.makedirs(log_dir, exist_ok=True) + log_file_path = os.path.join(log_dir, LOG_FILE_NAME) + + file_handler = logging.handlers.RotatingFileHandler( + log_file_path, + maxBytes=LOG_MAX_BYTES, + backupCount=LOG_BACKUP_COUNT, + ) + file_handler.setLevel(log_level) + file_handler.setFormatter(JSONFormatter()) + root_logger.addHandler(file_handler) + + if console_enabled: + from rich.logging import RichHandler + from rich.console import Console + + console = Console(stderr=True) + rich_handler = RichHandler( + console=console, + show_time=True, + show_level=True, + show_path=True, + rich_tracebacks=True, + ) + rich_handler.setLevel(log_level) + root_logger.addHandler(rich_handler) + + _logging_initialized = True + + return root_logger + + +def get_logger(name): + global _logging_initialized + + if not _logging_initialized: + setup_logging() + + return logging.getLogger(name)