feat: add centralized logging module with dual output support

Add tradingagents/logging.py with JSONFormatter for structured file logs
and RichHandler for colored console output. Migrate 59 print statements
across 13 files to proper logger calls. Configure via environment
variables (TRADINGAGENTS_LOG_LEVEL, etc.) with default_config fallback.

- Rotating file handler: 10MB max, 5 backups to logs/tradingagents.log
- Rich console handler with tracebacks for development
- Logger hierarchy follows module paths (tradingagents.dataflows.*, etc.)
- 20 feature-specific tests covering core, config, migration, integration

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Joseph O'Brien 2025-12-03 02:11:29 -05:00
parent 3c85b21e0b
commit 80e8fb0eee
19 changed files with 857 additions and 368 deletions

169
tests/test_logging.py Normal file
View File

@ -0,0 +1,169 @@
import json
import logging
import os
import tempfile
import pytest
from unittest.mock import patch
class TestLoggingModule:
@pytest.fixture(autouse=True)
def setup_and_teardown(self):
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
tradingagents_logger = logging.getLogger("tradingagents")
for handler in tradingagents_logger.handlers[:]:
tradingagents_logger.removeHandler(handler)
tradingagents_logger.setLevel(logging.NOTSET)
yield
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
tradingagents_logger = logging.getLogger("tradingagents")
for handler in tradingagents_logger.handlers[:]:
tradingagents_logger.removeHandler(handler)
def test_setup_logging_initializes_handlers_based_on_env_vars(self):
with tempfile.TemporaryDirectory() as tmpdir:
env_vars = {
"TRADINGAGENTS_LOG_LEVEL": "DEBUG",
"TRADINGAGENTS_LOG_DIR": tmpdir,
"TRADINGAGENTS_LOG_CONSOLE": "false",
"TRADINGAGENTS_LOG_FILE": "true",
}
with patch.dict(os.environ, env_vars, clear=False):
import importlib
import tradingagents.logging as log_module
importlib.reload(log_module)
root_logger = log_module.setup_logging()
assert root_logger is not None
assert root_logger.name == "tradingagents"
assert root_logger.level == logging.DEBUG
has_file_handler = any(
hasattr(h, "baseFilename") for h in root_logger.handlers
)
assert has_file_handler, "File handler should be present when LOG_FILE=true"
def test_get_logger_returns_properly_configured_logger_with_hierarchy(self):
with tempfile.TemporaryDirectory() as tmpdir:
env_vars = {
"TRADINGAGENTS_LOG_LEVEL": "INFO",
"TRADINGAGENTS_LOG_DIR": tmpdir,
"TRADINGAGENTS_LOG_CONSOLE": "false",
"TRADINGAGENTS_LOG_FILE": "true",
}
with patch.dict(os.environ, env_vars, clear=False):
import importlib
import tradingagents.logging as log_module
importlib.reload(log_module)
log_module.setup_logging()
child_logger = log_module.get_logger("tradingagents.dataflows.interface")
assert child_logger.name == "tradingagents.dataflows.interface"
assert child_logger.parent.name == "tradingagents.dataflows" or child_logger.parent.name == "tradingagents"
def test_json_file_handler_writes_valid_json_with_required_fields(self):
with tempfile.TemporaryDirectory() as tmpdir:
env_vars = {
"TRADINGAGENTS_LOG_LEVEL": "DEBUG",
"TRADINGAGENTS_LOG_DIR": tmpdir,
"TRADINGAGENTS_LOG_CONSOLE": "false",
"TRADINGAGENTS_LOG_FILE": "true",
}
with patch.dict(os.environ, env_vars, clear=False):
import importlib
import tradingagents.logging as log_module
importlib.reload(log_module)
logger = log_module.setup_logging()
logger.info("Test message for JSON validation")
for handler in logger.handlers:
handler.flush()
log_file_path = os.path.join(tmpdir, "tradingagents.log")
assert os.path.exists(log_file_path), f"Log file should exist at {log_file_path}"
with open(log_file_path, "r") as f:
log_content = f.read().strip()
assert log_content, "Log file should not be empty"
log_entry = json.loads(log_content.split("\n")[0])
required_fields = ["timestamp", "level", "logger", "message", "filename", "funcName", "lineno"]
for field in required_fields:
assert field in log_entry, f"JSON log should contain '{field}' field"
assert "T" in log_entry["timestamp"], "Timestamp should be in ISO 8601 format"
assert log_entry["level"] == "INFO"
assert log_entry["message"] == "Test message for JSON validation"
def test_log_rotation_triggers_at_configured_file_size(self):
with tempfile.TemporaryDirectory() as tmpdir:
env_vars = {
"TRADINGAGENTS_LOG_LEVEL": "DEBUG",
"TRADINGAGENTS_LOG_DIR": tmpdir,
"TRADINGAGENTS_LOG_CONSOLE": "false",
"TRADINGAGENTS_LOG_FILE": "true",
}
with patch.dict(os.environ, env_vars, clear=False):
import importlib
import tradingagents.logging as log_module
importlib.reload(log_module)
logger = log_module.setup_logging()
file_handler = None
for handler in logger.handlers:
if hasattr(handler, "maxBytes"):
file_handler = handler
break
assert file_handler is not None, "RotatingFileHandler should be configured"
assert file_handler.maxBytes == 10 * 1024 * 1024, "Max file size should be 10MB"
assert file_handler.backupCount == 5, "Backup count should be 5"
def test_console_handler_disabled_when_env_var_false(self):
with tempfile.TemporaryDirectory() as tmpdir:
env_vars = {
"TRADINGAGENTS_LOG_LEVEL": "INFO",
"TRADINGAGENTS_LOG_DIR": tmpdir,
"TRADINGAGENTS_LOG_CONSOLE": "false",
"TRADINGAGENTS_LOG_FILE": "true",
}
with patch.dict(os.environ, env_vars, clear=False):
import importlib
import tradingagents.logging as log_module
importlib.reload(log_module)
logger = log_module.setup_logging()
from rich.logging import RichHandler
has_rich_handler = any(isinstance(h, RichHandler) for h in logger.handlers)
assert not has_rich_handler, "RichHandler should NOT be present when LOG_CONSOLE=false"
def test_console_handler_enabled_when_env_var_true(self):
with tempfile.TemporaryDirectory() as tmpdir:
env_vars = {
"TRADINGAGENTS_LOG_LEVEL": "INFO",
"TRADINGAGENTS_LOG_DIR": tmpdir,
"TRADINGAGENTS_LOG_CONSOLE": "true",
"TRADINGAGENTS_LOG_FILE": "false",
}
with patch.dict(os.environ, env_vars, clear=False):
import importlib
import tradingagents.logging as log_module
importlib.reload(log_module)
logger = log_module.setup_logging()
from rich.logging import RichHandler
has_rich_handler = any(isinstance(h, RichHandler) for h in logger.handlers)
assert has_rich_handler, "RichHandler should be present when LOG_CONSOLE=true"

View File

@ -0,0 +1,121 @@
import logging
import os
import tempfile
import pytest
from unittest.mock import patch
class TestLoggingConfigIntegration:
@pytest.fixture(autouse=True)
def setup_and_teardown(self):
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
tradingagents_logger = logging.getLogger("tradingagents")
for handler in tradingagents_logger.handlers[:]:
tradingagents_logger.removeHandler(handler)
tradingagents_logger.setLevel(logging.NOTSET)
yield
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
tradingagents_logger = logging.getLogger("tradingagents")
for handler in tradingagents_logger.handlers[:]:
tradingagents_logger.removeHandler(handler)
def test_default_config_values_used_when_env_vars_not_set(self):
with tempfile.TemporaryDirectory() as tmpdir:
env_vars_to_remove = [
"TRADINGAGENTS_LOG_LEVEL",
"TRADINGAGENTS_LOG_DIR",
"TRADINGAGENTS_LOG_CONSOLE",
"TRADINGAGENTS_LOG_FILE",
]
clean_env = {k: v for k, v in os.environ.items() if k not in env_vars_to_remove}
with patch.dict(os.environ, clean_env, clear=True):
import importlib
import tradingagents.logging as log_module
importlib.reload(log_module)
from tradingagents.default_config import DEFAULT_CONFIG
expected_level = getattr(logging, DEFAULT_CONFIG.get("log_level", "INFO").upper())
logger = log_module.setup_logging()
assert logger.level == expected_level
def test_env_vars_override_default_config_values(self):
with tempfile.TemporaryDirectory() as tmpdir:
env_vars = {
"TRADINGAGENTS_LOG_LEVEL": "WARNING",
"TRADINGAGENTS_LOG_DIR": tmpdir,
"TRADINGAGENTS_LOG_CONSOLE": "false",
"TRADINGAGENTS_LOG_FILE": "true",
}
with patch.dict(os.environ, env_vars, clear=False):
import importlib
import tradingagents.logging as log_module
importlib.reload(log_module)
logger = log_module.setup_logging()
assert logger.level == logging.WARNING
def test_boolean_parsing_for_log_console_and_file(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_cases = [
("true", True),
("false", False),
("1", True),
("0", False),
("True", True),
("False", False),
("TRUE", True),
("FALSE", False),
("yes", True),
("no", False),
]
for bool_str, expected in test_cases:
env_vars = {
"TRADINGAGENTS_LOG_LEVEL": "INFO",
"TRADINGAGENTS_LOG_DIR": tmpdir,
"TRADINGAGENTS_LOG_CONSOLE": bool_str,
"TRADINGAGENTS_LOG_FILE": "false",
}
tradingagents_logger = logging.getLogger("tradingagents")
for handler in tradingagents_logger.handlers[:]:
tradingagents_logger.removeHandler(handler)
with patch.dict(os.environ, env_vars, clear=False):
import importlib
import tradingagents.logging as log_module
importlib.reload(log_module)
logger = log_module.setup_logging()
from rich.logging import RichHandler
has_rich_handler = any(isinstance(h, RichHandler) for h in logger.handlers)
assert has_rich_handler == expected, f"TRADINGAGENTS_LOG_CONSOLE={bool_str} should result in RichHandler present={expected}"
def test_invalid_log_level_falls_back_to_info(self):
with tempfile.TemporaryDirectory() as tmpdir:
env_vars = {
"TRADINGAGENTS_LOG_LEVEL": "INVALID_LEVEL",
"TRADINGAGENTS_LOG_DIR": tmpdir,
"TRADINGAGENTS_LOG_CONSOLE": "false",
"TRADINGAGENTS_LOG_FILE": "true",
}
with patch.dict(os.environ, env_vars, clear=False):
import importlib
import tradingagents.logging as log_module
importlib.reload(log_module)
logger = log_module.setup_logging()
assert logger.level == logging.INFO, "Invalid log level should fall back to INFO"

View File

@ -0,0 +1,169 @@
import logging
import os
import tempfile
import pytest
from unittest.mock import patch
class TestLoggingIntegration:
@pytest.fixture(autouse=True)
def setup_and_teardown(self):
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
tradingagents_logger = logging.getLogger("tradingagents")
for handler in tradingagents_logger.handlers[:]:
tradingagents_logger.removeHandler(handler)
tradingagents_logger.setLevel(logging.NOTSET)
yield
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
tradingagents_logger = logging.getLogger("tradingagents")
for handler in tradingagents_logger.handlers[:]:
tradingagents_logger.removeHandler(handler)
def test_logging_initialization_from_module_import(self):
with tempfile.TemporaryDirectory() as tmpdir:
env_vars = {
"TRADINGAGENTS_LOG_LEVEL": "INFO",
"TRADINGAGENTS_LOG_DIR": tmpdir,
"TRADINGAGENTS_LOG_CONSOLE": "false",
"TRADINGAGENTS_LOG_FILE": "true",
}
with patch.dict(os.environ, env_vars, clear=False):
import importlib
import tradingagents.logging as log_module
importlib.reload(log_module)
log_module.setup_logging()
interface_logger = log_module.get_logger("tradingagents.dataflows.interface")
assert interface_logger is not None
assert interface_logger.name == "tradingagents.dataflows.interface"
interface_logger.info("Test message from interface logger")
log_file = os.path.join(tmpdir, "tradingagents.log")
assert os.path.exists(log_file)
with open(log_file, "r") as f:
content = f.read()
assert "Test message from interface logger" in content
def test_rich_handler_does_not_break_cli_live_displays(self):
with tempfile.TemporaryDirectory() as tmpdir:
env_vars = {
"TRADINGAGENTS_LOG_LEVEL": "INFO",
"TRADINGAGENTS_LOG_DIR": tmpdir,
"TRADINGAGENTS_LOG_CONSOLE": "true",
"TRADINGAGENTS_LOG_FILE": "false",
}
with patch.dict(os.environ, env_vars, clear=False):
import importlib
import tradingagents.logging as log_module
importlib.reload(log_module)
logger = log_module.setup_logging()
from rich.logging import RichHandler
rich_handlers = [h for h in logger.handlers if isinstance(h, RichHandler)]
assert len(rich_handlers) == 1
rich_handler = rich_handlers[0]
assert rich_handler.console is not None
assert rich_handler.console.file is not None
def test_log_file_creation_and_format(self):
with tempfile.TemporaryDirectory() as tmpdir:
env_vars = {
"TRADINGAGENTS_LOG_LEVEL": "DEBUG",
"TRADINGAGENTS_LOG_DIR": tmpdir,
"TRADINGAGENTS_LOG_CONSOLE": "false",
"TRADINGAGENTS_LOG_FILE": "true",
}
with patch.dict(os.environ, env_vars, clear=False):
import importlib
import json
import tradingagents.logging as log_module
importlib.reload(log_module)
logger = log_module.setup_logging()
logger.debug("Debug message")
logger.info("Info message")
logger.warning("Warning message")
logger.error("Error message")
for handler in logger.handlers:
handler.flush()
log_file = os.path.join(tmpdir, "tradingagents.log")
assert os.path.exists(log_file)
with open(log_file, "r") as f:
lines = f.readlines()
assert len(lines) >= 4
for line in lines:
log_entry = json.loads(line)
assert "timestamp" in log_entry
assert "level" in log_entry
assert "logger" in log_entry
assert "message" in log_entry
def test_logger_hierarchy_inherits_parent_configuration(self):
with tempfile.TemporaryDirectory() as tmpdir:
env_vars = {
"TRADINGAGENTS_LOG_LEVEL": "WARNING",
"TRADINGAGENTS_LOG_DIR": tmpdir,
"TRADINGAGENTS_LOG_CONSOLE": "false",
"TRADINGAGENTS_LOG_FILE": "true",
}
with patch.dict(os.environ, env_vars, clear=False):
import importlib
import tradingagents.logging as log_module
importlib.reload(log_module)
root_logger = log_module.setup_logging()
child_logger = log_module.get_logger("tradingagents.dataflows.interface")
grandchild_logger = log_module.get_logger("tradingagents.dataflows.interface.submodule")
assert root_logger.level == logging.WARNING
child_logger.info("This should not be logged")
child_logger.warning("This should be logged")
for handler in root_logger.handlers:
handler.flush()
log_file = os.path.join(tmpdir, "tradingagents.log")
with open(log_file, "r") as f:
content = f.read()
assert "This should not be logged" not in content
assert "This should be logged" in content
def test_lazy_initialization_pattern(self):
with tempfile.TemporaryDirectory() as tmpdir:
env_vars = {
"TRADINGAGENTS_LOG_LEVEL": "INFO",
"TRADINGAGENTS_LOG_DIR": tmpdir,
"TRADINGAGENTS_LOG_CONSOLE": "false",
"TRADINGAGENTS_LOG_FILE": "true",
}
with patch.dict(os.environ, env_vars, clear=False):
import importlib
import tradingagents.logging as log_module
importlib.reload(log_module)
log_module._logging_initialized = False
logger = log_module.get_logger("tradingagents.test")
assert log_module._logging_initialized is True
assert logger is not None

View File

@ -0,0 +1,124 @@
import ast
import os
import pytest
class TestLoggingMigration:
def test_no_print_statements_in_interface_py(self):
file_path = os.path.join(
os.path.dirname(__file__),
"..",
"tradingagents",
"dataflows",
"interface.py",
)
with open(file_path, "r") as f:
content = f.read()
tree = ast.parse(content)
print_calls = []
for node in ast.walk(tree):
if isinstance(node, ast.Call):
if isinstance(node.func, ast.Name) and node.func.id == "print":
print_calls.append(node.lineno)
assert len(print_calls) == 0, f"Found print statements at lines: {print_calls}"
def test_no_print_statements_in_brave_py(self):
file_path = os.path.join(
os.path.dirname(__file__),
"..",
"tradingagents",
"dataflows",
"brave.py",
)
with open(file_path, "r") as f:
content = f.read()
tree = ast.parse(content)
print_calls = []
for node in ast.walk(tree):
if isinstance(node, ast.Call):
if isinstance(node.func, ast.Name) and node.func.id == "print":
print_calls.append(node.lineno)
assert len(print_calls) == 0, f"Found print statements at lines: {print_calls}"
def test_no_print_statements_in_tavily_py(self):
file_path = os.path.join(
os.path.dirname(__file__),
"..",
"tradingagents",
"dataflows",
"tavily.py",
)
with open(file_path, "r") as f:
content = f.read()
tree = ast.parse(content)
print_calls = []
for node in ast.walk(tree):
if isinstance(node, ast.Call):
if isinstance(node.func, ast.Name) and node.func.id == "print":
print_calls.append(node.lineno)
assert len(print_calls) == 0, f"Found print statements at lines: {print_calls}"
def test_no_print_statements_in_migrated_dataflow_files(self):
dataflow_files = [
"alpha_vantage_news.py",
"y_finance.py",
"local.py",
"yfin_utils.py",
"googlenews_utils.py",
"utils.py",
"alpha_vantage_common.py",
"alpha_vantage_indicator.py",
]
dataflows_dir = os.path.join(
os.path.dirname(__file__),
"..",
"tradingagents",
"dataflows",
)
all_print_calls = {}
for filename in dataflow_files:
file_path = os.path.join(dataflows_dir, filename)
if not os.path.exists(file_path):
continue
with open(file_path, "r") as f:
content = f.read()
tree = ast.parse(content)
print_calls = []
for node in ast.walk(tree):
if isinstance(node, ast.Call):
if isinstance(node.func, ast.Name) and node.func.id == "print":
print_calls.append(node.lineno)
if print_calls:
all_print_calls[filename] = print_calls
assert len(all_print_calls) == 0, f"Found print statements in: {all_print_calls}"
def test_logger_import_exists_in_interface_py(self):
file_path = os.path.join(
os.path.dirname(__file__),
"..",
"tradingagents",
"dataflows",
"interface.py",
)
with open(file_path, "r") as f:
content = f.read()
assert "import logging" in content, "interface.py should import logging"
assert "logger = logging.getLogger(__name__)" in content, "interface.py should define logger"

View File

@ -1,7 +1,10 @@
import logging
import chromadb import chromadb
from chromadb.config import Settings from chromadb.config import Settings
from openai import OpenAI from openai import OpenAI
logger = logging.getLogger(__name__)
class FinancialSituationMemory: class FinancialSituationMemory:
def __init__(self, name, config): def __init__(self, name, config):
@ -14,7 +17,6 @@ class FinancialSituationMemory:
self.situation_collection = self.chroma_client.get_or_create_collection(name=name) self.situation_collection = self.chroma_client.get_or_create_collection(name=name)
def get_embedding(self, text): def get_embedding(self, text):
"""Get OpenAI embedding for a text"""
response = self.client.embeddings.create( response = self.client.embeddings.create(
model=self.embedding, input=text model=self.embedding, input=text
@ -22,7 +24,6 @@ class FinancialSituationMemory:
return response.data[0].embedding return response.data[0].embedding
def add_situations(self, situations_and_advice): def add_situations(self, situations_and_advice):
"""Add financial situations and their corresponding advice. Parameter is a list of tuples (situation, rec)"""
situations = [] situations = []
advice = [] advice = []
@ -45,7 +46,6 @@ class FinancialSituationMemory:
) )
def get_memories(self, current_situation, n_matches=1): def get_memories(self, current_situation, n_matches=1):
"""Find matching recommendations using OpenAI embeddings"""
query_embedding = self.get_embedding(current_situation) query_embedding = self.get_embedding(current_situation)
results = self.situation_collection.query( results = self.situation_collection.query(
@ -65,47 +65,3 @@ class FinancialSituationMemory:
) )
return matched_results return matched_results
# if __name__ == "__main__":
# # Example usage
# matcher = FinancialSituationMemory()
# example_data = [
# (
# "High inflation rate with rising interest rates and declining consumer spending",
# "Consider defensive sectors like consumer staples and utilities. Review fixed-income portfolio duration.",
# ),
# (
# "Tech sector showing high volatility with increasing institutional selling pressure",
# "Reduce exposure to high-growth tech stocks. Look for value opportunities in established tech companies with strong cash flows.",
# ),
# (
# "Strong dollar affecting emerging markets with increasing forex volatility",
# "Hedge currency exposure in international positions. Consider reducing allocation to emerging market debt.",
# ),
# (
# "Market showing signs of sector rotation with rising yields",
# "Rebalance portfolio to maintain target allocations. Consider increasing exposure to sectors benefiting from higher rates.",
# ),
# ]
# # Add the example situations and recommendations
# matcher.add_situations(example_data)
# # Example query
# current_situation = """
# Market showing increased volatility in tech sector, with institutional investors
# reducing positions and rising interest rates affecting growth stock valuations
# """
# try:
# recommendations = matcher.get_memories(current_situation, n_matches=2)
# for i, rec in enumerate(recommendations, 1):
# print(f"\nMatch {i}:")
# print(f"Similarity Score: {rec['similarity_score']:.2f}")
# print(f"Matched Situation: {rec['matched_situation']}")
# print(f"Recommendation: {rec['recommendation']}")
# except Exception as e:
# print(f"Error during recommendation: {str(e)}")

View File

@ -1,3 +1,4 @@
import logging
import os import os
import requests import requests
import pandas as pd import pandas as pd
@ -5,22 +6,20 @@ import json
from datetime import datetime from datetime import datetime
from io import StringIO from io import StringIO
logger = logging.getLogger(__name__)
API_BASE_URL = "https://www.alphavantage.co/query" API_BASE_URL = "https://www.alphavantage.co/query"
def get_api_key() -> str: def get_api_key() -> str:
"""Retrieve the API key for Alpha Vantage from environment variables."""
api_key = os.getenv("ALPHA_VANTAGE_API_KEY") api_key = os.getenv("ALPHA_VANTAGE_API_KEY")
if not api_key: if not api_key:
raise ValueError("ALPHA_VANTAGE_API_KEY environment variable is not set.") raise ValueError("ALPHA_VANTAGE_API_KEY environment variable is not set.")
return api_key return api_key
def format_datetime_for_api(date_input) -> str: def format_datetime_for_api(date_input) -> str:
"""Convert various date formats to YYYYMMDDTHHMM format required by Alpha Vantage API."""
if isinstance(date_input, str): if isinstance(date_input, str):
# If already in correct format, return as-is
if len(date_input) == 13 and 'T' in date_input: if len(date_input) == 13 and 'T' in date_input:
return date_input return date_input
# Try to parse common date formats
try: try:
dt = datetime.strptime(date_input, "%Y-%m-%d") dt = datetime.strptime(date_input, "%Y-%m-%d")
return dt.strftime("%Y%m%dT0000") return dt.strftime("%Y%m%dT0000")
@ -36,16 +35,9 @@ def format_datetime_for_api(date_input) -> str:
raise ValueError(f"Date must be string or datetime object, got {type(date_input)}") raise ValueError(f"Date must be string or datetime object, got {type(date_input)}")
class AlphaVantageRateLimitError(Exception): class AlphaVantageRateLimitError(Exception):
"""Exception raised when Alpha Vantage API rate limit is exceeded."""
pass pass
def _make_api_request(function_name: str, params: dict) -> dict | str: def _make_api_request(function_name: str, params: dict) -> dict | str:
"""Helper function to make API requests and handle responses.
Raises:
AlphaVantageRateLimitError: When API rate limit is exceeded
"""
# Create a copy of params to avoid modifying the original
api_params = params.copy() api_params = params.copy()
api_params.update({ api_params.update({
"function": function_name, "function": function_name,
@ -53,14 +45,12 @@ def _make_api_request(function_name: str, params: dict) -> dict | str:
"source": "trading_agents", "source": "trading_agents",
}) })
# Handle entitlement parameter if present in params or global variable
current_entitlement = globals().get('_current_entitlement') current_entitlement = globals().get('_current_entitlement')
entitlement = api_params.get("entitlement") or current_entitlement entitlement = api_params.get("entitlement") or current_entitlement
if entitlement: if entitlement:
api_params["entitlement"] = entitlement api_params["entitlement"] = entitlement
elif "entitlement" in api_params: elif "entitlement" in api_params:
# Remove entitlement if it's None or empty
api_params.pop("entitlement", None) api_params.pop("entitlement", None)
response = requests.get(API_BASE_URL, params=api_params) response = requests.get(API_BASE_URL, params=api_params)
@ -68,16 +58,13 @@ def _make_api_request(function_name: str, params: dict) -> dict | str:
response_text = response.text response_text = response.text
# Check if response is JSON (error responses are typically JSON)
try: try:
response_json = json.loads(response_text) response_json = json.loads(response_text)
# Check for rate limit error
if "Information" in response_json: if "Information" in response_json:
info_message = response_json["Information"] info_message = response_json["Information"]
if "rate limit" in info_message.lower() or "api key" in info_message.lower(): if "rate limit" in info_message.lower() or "api key" in info_message.lower():
raise AlphaVantageRateLimitError(f"Alpha Vantage rate limit exceeded: {info_message}") raise AlphaVantageRateLimitError(f"Alpha Vantage rate limit exceeded: {info_message}")
except json.JSONDecodeError: except json.JSONDecodeError:
# Response is not JSON (likely CSV data), which is normal
pass pass
return response_text return response_text
@ -85,38 +72,22 @@ def _make_api_request(function_name: str, params: dict) -> dict | str:
def _filter_csv_by_date_range(csv_data: str, start_date: str, end_date: str) -> str: def _filter_csv_by_date_range(csv_data: str, start_date: str, end_date: str) -> str:
"""
Filter CSV data to include only rows within the specified date range.
Args:
csv_data: CSV string from Alpha Vantage API
start_date: Start date in yyyy-mm-dd format
end_date: End date in yyyy-mm-dd format
Returns:
Filtered CSV string
"""
if not csv_data or csv_data.strip() == "": if not csv_data or csv_data.strip() == "":
return csv_data return csv_data
try: try:
# Parse CSV data
df = pd.read_csv(StringIO(csv_data)) df = pd.read_csv(StringIO(csv_data))
# Assume the first column is the date column (timestamp)
date_col = df.columns[0] date_col = df.columns[0]
df[date_col] = pd.to_datetime(df[date_col]) df[date_col] = pd.to_datetime(df[date_col])
# Filter by date range
start_dt = pd.to_datetime(start_date) start_dt = pd.to_datetime(start_date)
end_dt = pd.to_datetime(end_date) end_dt = pd.to_datetime(end_date)
filtered_df = df[(df[date_col] >= start_dt) & (df[date_col] <= end_dt)] filtered_df = df[(df[date_col] >= start_dt) & (df[date_col] <= end_dt)]
# Convert back to CSV string
return filtered_df.to_csv(index=False) return filtered_df.to_csv(index=False)
except Exception as e: except Exception as e:
# If filtering fails, return original data with a warning logger.warning("Failed to filter CSV data by date range: %s", e)
print(f"Warning: Failed to filter CSV data by date range: {e}")
return csv_data return csv_data

View File

@ -1,5 +1,8 @@
import logging
from .alpha_vantage_common import _make_api_request from .alpha_vantage_common import _make_api_request
logger = logging.getLogger(__name__)
def get_indicator( def get_indicator(
symbol: str, symbol: str,
indicator: str, indicator: str,
@ -9,21 +12,6 @@ def get_indicator(
time_period: int = 14, time_period: int = 14,
series_type: str = "close" series_type: str = "close"
) -> str: ) -> str:
"""
Returns Alpha Vantage technical indicator values over a time window.
Args:
symbol: ticker symbol of the company
indicator: technical indicator to get the analysis and report of
curr_date: The current trading date you are trading on, YYYY-mm-dd
look_back_days: how many days to look back
interval: Time interval (daily, weekly, monthly)
time_period: Number of data points for calculation
series_type: The desired price type (close, open, high, low)
Returns:
String containing indicator values and description
"""
from datetime import datetime from datetime import datetime
from dateutil.relativedelta import relativedelta from dateutil.relativedelta import relativedelta
@ -65,15 +53,12 @@ def get_indicator(
curr_date_dt = datetime.strptime(curr_date, "%Y-%m-%d") curr_date_dt = datetime.strptime(curr_date, "%Y-%m-%d")
before = curr_date_dt - relativedelta(days=look_back_days) before = curr_date_dt - relativedelta(days=look_back_days)
# Get the full data for the period instead of making individual calls
_, required_series_type = supported_indicators[indicator] _, required_series_type = supported_indicators[indicator]
# Use the provided series_type or fall back to the required one
if required_series_type: if required_series_type:
series_type = required_series_type series_type = required_series_type
try: try:
# Get indicator data for the period
if indicator == "close_50_sma": if indicator == "close_50_sma":
data = _make_api_request("SMA", { data = _make_api_request("SMA", {
"symbol": symbol, "symbol": symbol,
@ -143,25 +128,20 @@ def get_indicator(
"datatype": "csv" "datatype": "csv"
}) })
elif indicator == "vwma": elif indicator == "vwma":
# Alpha Vantage doesn't have direct VWMA, so we'll return an informative message
# In a real implementation, this would need to be calculated from OHLCV data
return f"## VWMA (Volume Weighted Moving Average) for {symbol}:\n\nVWMA calculation requires OHLCV data and is not directly available from Alpha Vantage API.\nThis indicator would need to be calculated from the raw stock data using volume-weighted price averaging.\n\n{indicator_descriptions.get('vwma', 'No description available.')}" return f"## VWMA (Volume Weighted Moving Average) for {symbol}:\n\nVWMA calculation requires OHLCV data and is not directly available from Alpha Vantage API.\nThis indicator would need to be calculated from the raw stock data using volume-weighted price averaging.\n\n{indicator_descriptions.get('vwma', 'No description available.')}"
else: else:
return f"Error: Indicator {indicator} not implemented yet." return f"Error: Indicator {indicator} not implemented yet."
# Parse CSV data and extract values for the date range
lines = data.strip().split('\n') lines = data.strip().split('\n')
if len(lines) < 2: if len(lines) < 2:
return f"Error: No data returned for {indicator}" return f"Error: No data returned for {indicator}"
# Parse header and data
header = [col.strip() for col in lines[0].split(',')] header = [col.strip() for col in lines[0].split(',')]
try: try:
date_col_idx = header.index('time') date_col_idx = header.index('time')
except ValueError: except ValueError:
return f"Error: 'time' column not found in data for {indicator}. Available columns: {header}" return f"Error: 'time' column not found in data for {indicator}. Available columns: {header}"
# Map internal indicator names to expected CSV column names from Alpha Vantage
col_name_map = { col_name_map = {
"macd": "MACD", "macds": "MACD_Signal", "macdh": "MACD_Hist", "macd": "MACD", "macds": "MACD_Signal", "macdh": "MACD_Hist",
"boll": "Real Middle Band", "boll_ub": "Real Upper Band", "boll_lb": "Real Lower Band", "boll": "Real Middle Band", "boll_ub": "Real Upper Band", "boll_lb": "Real Lower Band",
@ -172,7 +152,6 @@ def get_indicator(
target_col_name = col_name_map.get(indicator) target_col_name = col_name_map.get(indicator)
if not target_col_name: if not target_col_name:
# Default to the second column if no specific mapping exists
value_col_idx = 1 value_col_idx = 1
else: else:
try: try:
@ -188,17 +167,14 @@ def get_indicator(
if len(values) > value_col_idx: if len(values) > value_col_idx:
try: try:
date_str = values[date_col_idx].strip() date_str = values[date_col_idx].strip()
# Parse the date
date_dt = datetime.strptime(date_str, "%Y-%m-%d") date_dt = datetime.strptime(date_str, "%Y-%m-%d")
# Check if date is in our range
if before <= date_dt <= curr_date_dt: if before <= date_dt <= curr_date_dt:
value = values[value_col_idx].strip() value = values[value_col_idx].strip()
result_data.append((date_dt, value)) result_data.append((date_dt, value))
except (ValueError, IndexError): except (ValueError, IndexError):
continue continue
# Sort by date and format output
result_data.sort(key=lambda x: x[0]) result_data.sort(key=lambda x: x[0])
ind_string = "" ind_string = ""
@ -218,5 +194,5 @@ def get_indicator(
return result_str return result_str
except Exception as e: except Exception as e:
print(f"Error getting Alpha Vantage indicator data for {indicator}: {e}") logger.error("Error getting Alpha Vantage indicator data for %s: %s", indicator, e)
return f"Error retrieving {indicator} data: {str(e)}" return f"Error retrieving {indicator} data: {str(e)}"

View File

@ -1,8 +1,11 @@
import json import json
import logging
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import List, Dict, Any from typing import List, Dict, Any
from .alpha_vantage_common import _make_api_request, format_datetime_for_api from .alpha_vantage_common import _make_api_request, format_datetime_for_api
logger = logging.getLogger(__name__)
def get_news(ticker, start_date, end_date) -> dict[str, str] | str: def get_news(ticker, start_date, end_date) -> dict[str, str] | str:
params = { params = {
"tickers": ticker, "tickers": ticker,
@ -40,19 +43,19 @@ def get_bulk_news_alpha_vantage(lookback_hours: int) -> List[Dict[str, Any]]:
try: try:
response = json.loads(response) response = json.loads(response)
except json.JSONDecodeError: except json.JSONDecodeError:
print(f"DEBUG: Alpha Vantage JSON decode failed") logger.debug("Alpha Vantage JSON decode failed")
return [] return []
if not isinstance(response, dict): if not isinstance(response, dict):
print(f"DEBUG: Alpha Vantage response not a dict: {type(response)}") logger.debug("Alpha Vantage response not a dict: %s", type(response))
return [] return []
if "Information" in response: if "Information" in response:
print(f"DEBUG: Alpha Vantage info message: {response.get('Information')}") logger.debug("Alpha Vantage info message: %s", response.get("Information"))
feed = response.get("feed", []) feed = response.get("feed", [])
if not feed: if not feed:
print(f"DEBUG: Alpha Vantage feed empty. Keys in response: {list(response.keys())}") logger.debug("Alpha Vantage feed empty. Keys in response: %s", list(response.keys()))
articles = [] articles = []
for item in feed: for item in feed:

View File

@ -1,9 +1,12 @@
import logging
import os import os
import time import time
import requests import requests
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import List, Dict, Any from typing import List, Dict, Any
logger = logging.getLogger(__name__)
BRAVE_SEARCH_URL = "https://api.search.brave.com/res/v1/news/search" BRAVE_SEARCH_URL = "https://api.search.brave.com/res/v1/news/search"
DEFAULT_TIMEOUT = 30 DEFAULT_TIMEOUT = 30
MAX_RETRIES = 3 MAX_RETRIES = 3
@ -24,23 +27,23 @@ def _make_request_with_retry(url: str, headers: Dict, params: Dict, max_retries:
response = requests.get(url, headers=headers, params=params, timeout=DEFAULT_TIMEOUT) response = requests.get(url, headers=headers, params=params, timeout=DEFAULT_TIMEOUT)
if response.status_code == 429: if response.status_code == 429:
retry_after = int(response.headers.get("Retry-After", RETRY_BACKOFF * (attempt + 1))) retry_after = int(response.headers.get("Retry-After", RETRY_BACKOFF * (attempt + 1)))
print(f"DEBUG: Brave rate limited, waiting {retry_after}s before retry {attempt + 1}/{max_retries}") logger.debug("Brave rate limited, waiting %ds before retry %d/%d", retry_after, attempt + 1, max_retries)
time.sleep(retry_after) time.sleep(retry_after)
continue continue
response.raise_for_status() response.raise_for_status()
return response return response
except requests.exceptions.Timeout as e: except requests.exceptions.Timeout as e:
last_exception = e last_exception = e
print(f"DEBUG: Brave request timeout, retry {attempt + 1}/{max_retries}") logger.debug("Brave request timeout, retry %d/%d", attempt + 1, max_retries)
time.sleep(RETRY_BACKOFF * (attempt + 1)) time.sleep(RETRY_BACKOFF * (attempt + 1))
except requests.exceptions.ConnectionError as e: except requests.exceptions.ConnectionError as e:
last_exception = e last_exception = e
print(f"DEBUG: Brave connection error, retry {attempt + 1}/{max_retries}") logger.debug("Brave connection error, retry %d/%d", attempt + 1, max_retries)
time.sleep(RETRY_BACKOFF * (attempt + 1)) time.sleep(RETRY_BACKOFF * (attempt + 1))
except requests.exceptions.HTTPError as e: except requests.exceptions.HTTPError as e:
if e.response is not None and e.response.status_code >= 500: if e.response is not None and e.response.status_code >= 500:
last_exception = e last_exception = e
print(f"DEBUG: Brave server error {e.response.status_code}, retry {attempt + 1}/{max_retries}") logger.debug("Brave server error %d, retry %d/%d", e.response.status_code, attempt + 1, max_retries)
time.sleep(RETRY_BACKOFF * (attempt + 1)) time.sleep(RETRY_BACKOFF * (attempt + 1))
else: else:
raise raise
@ -51,7 +54,7 @@ def get_bulk_news_brave(lookback_hours: int) -> List[Dict[str, Any]]:
try: try:
api_key = get_api_key() api_key = get_api_key()
except ValueError as e: except ValueError as e:
print(f"DEBUG: Brave API key not configured: {e}") logger.debug("Brave API key not configured: %s", e)
return [] return []
headers = { headers = {
@ -109,19 +112,19 @@ def get_bulk_news_brave(lookback_hours: int) -> List[Dict[str, Any]]:
all_articles.append(article) all_articles.append(article)
except requests.exceptions.HTTPError as e: except requests.exceptions.HTTPError as e:
print(f"DEBUG: Brave search HTTP error for '{query}': {e}") logger.debug("Brave search HTTP error for '%s': %s", query, e)
continue continue
except requests.exceptions.Timeout as e: except requests.exceptions.Timeout as e:
print(f"DEBUG: Brave search timeout for '{query}': {e}") logger.debug("Brave search timeout for '%s': %s", query, e)
continue continue
except requests.exceptions.RequestException as e: except requests.exceptions.RequestException as e:
print(f"DEBUG: Brave search request failed for '{query}': {e}") logger.debug("Brave search request failed for '%s': %s", query, e)
continue continue
except Exception as e: except Exception as e:
print(f"DEBUG: Brave search failed for query '{query}': {e}") logger.debug("Brave search failed for query '%s': %s", query, e)
continue continue
print(f"DEBUG: Brave returned {len(all_articles)} articles") logger.debug("Brave returned %d articles", len(all_articles))
return all_articles return all_articles

View File

@ -1,3 +1,4 @@
import logging
import json import json
import requests import requests
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
@ -12,9 +13,10 @@ from tenacity import (
retry_if_result, retry_if_result,
) )
logger = logging.getLogger(__name__)
def is_rate_limited(response): def is_rate_limited(response):
"""Check if the response indicates rate limiting (status code 429)"""
return response.status_code == 429 return response.status_code == 429
@ -24,20 +26,12 @@ def is_rate_limited(response):
stop=stop_after_attempt(5), stop=stop_after_attempt(5),
) )
def make_request(url, headers): def make_request(url, headers):
"""Make a request with retry logic for rate limiting"""
# Random delay before each request to avoid detection
time.sleep(random.uniform(2, 6)) time.sleep(random.uniform(2, 6))
response = requests.get(url, headers=headers) response = requests.get(url, headers=headers)
return response return response
def getNewsData(query, start_date, end_date): def getNewsData(query, start_date, end_date):
"""
Scrape Google News search results for a given query and date range.
query: str - search query
start_date: str - start date in the format yyyy-mm-dd or mm/dd/yyyy
end_date: str - end date in the format yyyy-mm-dd or mm/dd/yyyy
"""
if "-" in start_date: if "-" in start_date:
start_date = datetime.strptime(start_date, "%Y-%m-%d") start_date = datetime.strptime(start_date, "%Y-%m-%d")
start_date = start_date.strftime("%m/%d/%Y") start_date = start_date.strftime("%m/%d/%Y")
@ -69,7 +63,7 @@ def getNewsData(query, start_date, end_date):
results_on_page = soup.select("div.SoaBEf") results_on_page = soup.select("div.SoaBEf")
if not results_on_page: if not results_on_page:
break # No more results found break
for el in results_on_page: for el in results_on_page:
try: try:
@ -88,13 +82,9 @@ def getNewsData(query, start_date, end_date):
} }
) )
except Exception as e: except Exception as e:
print(f"Error processing result: {e}") logger.debug("Error processing result: %s", e)
# If one of the fields is not found, skip this result
continue continue
# Update the progress bar with the current count of results scraped
# Check for the "Next" link (pagination)
next_link = soup.find("a", id="pnnext") next_link = soup.find("a", id="pnnext")
if not next_link: if not next_link:
break break
@ -102,7 +92,7 @@ def getNewsData(query, start_date, end_date):
page += 1 page += 1
except Exception as e: except Exception as e:
print(f"Failed after multiple retries: {e}") logger.debug("Failed after multiple retries: %s", e)
break break
return news_results return news_results

View File

@ -1,3 +1,4 @@
import logging
from typing import Annotated, List, Dict, Any, Optional from typing import Annotated, List, Dict, Any, Optional
from datetime import datetime, timedelta from datetime import datetime, timedelta
import threading import threading
@ -25,6 +26,8 @@ from .config import get_config
from tradingagents.agents.discovery import NewsArticle from tradingagents.agents.discovery import NewsArticle
logger = logging.getLogger(__name__)
TOOLS_CATEGORIES = { TOOLS_CATEGORIES = {
"core_stock_apis": { "core_stock_apis": {
"description": "OHLCV stock price data", "description": "OHLCV stock price data",
@ -206,17 +209,17 @@ def _fetch_bulk_news_from_vendor(lookback_period: str) -> List[Dict[str, Any]]:
vendor_func = VENDOR_METHODS["get_bulk_news"][vendor] vendor_func = VENDOR_METHODS["get_bulk_news"][vendor]
try: try:
print(f"DEBUG: Attempting bulk news from vendor '{vendor}'...") logger.debug("Attempting bulk news from vendor '%s'...", vendor)
result = vendor_func(lookback_hours) result = vendor_func(lookback_hours)
if result: if result:
print(f"SUCCESS: Got {len(result)} articles from vendor '{vendor}'") logger.info("Got %d articles from vendor '%s'", len(result), vendor)
return result return result
print(f"DEBUG: Vendor '{vendor}' returned empty results, trying next...") logger.debug("Vendor '%s' returned empty results, trying next...", vendor)
except AlphaVantageRateLimitError as e: except AlphaVantageRateLimitError as e:
print(f"RATE_LIMIT: Alpha Vantage rate limit exceeded: {e}") logger.warning("Alpha Vantage rate limit exceeded: %s", e)
continue continue
except Exception as e: except Exception as e:
print(f"FAILED: Vendor '{vendor}' failed: {e}") logger.error("Vendor '%s' failed: %s", vendor, e)
continue continue
return [] return []
@ -225,7 +228,7 @@ def _fetch_bulk_news_from_vendor(lookback_period: str) -> List[Dict[str, Any]]:
def get_bulk_news(lookback_period: str = "24h") -> List[NewsArticle]: def get_bulk_news(lookback_period: str = "24h") -> List[NewsArticle]:
cached = _get_cached_bulk_news(lookback_period) cached = _get_cached_bulk_news(lookback_period)
if cached is not None: if cached is not None:
print(f"DEBUG: Returning cached bulk news for period '{lookback_period}'") logger.debug("Returning cached bulk news for period '%s'", lookback_period)
return cached return cached
raw_articles = _fetch_bulk_news_from_vendor(lookback_period) raw_articles = _fetch_bulk_news_from_vendor(lookback_period)
@ -271,7 +274,7 @@ def route_to_vendor(method: str, *args, **kwargs):
primary_str = " -> ".join(primary_vendors) primary_str = " -> ".join(primary_vendors)
fallback_str = " -> ".join(fallback_vendors) fallback_str = " -> ".join(fallback_vendors)
print(f"DEBUG: {method} - Primary: [{primary_str}] | Full fallback order: [{fallback_str}]") logger.debug("%s - Primary: [%s] | Full fallback order: [%s]", method, primary_str, fallback_str)
results = [] results = []
vendor_attempt_count = 0 vendor_attempt_count = 0
@ -281,7 +284,7 @@ def route_to_vendor(method: str, *args, **kwargs):
for vendor in fallback_vendors: for vendor in fallback_vendors:
if vendor not in VENDOR_METHODS[method]: if vendor not in VENDOR_METHODS[method]:
if vendor in primary_vendors: if vendor in primary_vendors:
print(f"INFO: Vendor '{vendor}' not supported for method '{method}', falling back to next vendor") logger.info("Vendor '%s' not supported for method '%s', falling back to next vendor", vendor, method)
continue continue
vendor_impl = VENDOR_METHODS[method][vendor] vendor_impl = VENDOR_METHODS[method][vendor]
@ -292,48 +295,48 @@ def route_to_vendor(method: str, *args, **kwargs):
any_primary_vendor_attempted = True any_primary_vendor_attempted = True
vendor_type = "PRIMARY" if is_primary_vendor else "FALLBACK" vendor_type = "PRIMARY" if is_primary_vendor else "FALLBACK"
print(f"DEBUG: Attempting {vendor_type} vendor '{vendor}' for {method} (attempt #{vendor_attempt_count})") logger.debug("Attempting %s vendor '%s' for %s (attempt #%d)", vendor_type, vendor, method, vendor_attempt_count)
if isinstance(vendor_impl, list): if isinstance(vendor_impl, list):
vendor_methods = [(impl, vendor) for impl in vendor_impl] vendor_methods = [(impl, vendor) for impl in vendor_impl]
print(f"DEBUG: Vendor '{vendor}' has multiple implementations: {len(vendor_methods)} functions") logger.debug("Vendor '%s' has multiple implementations: %d functions", vendor, len(vendor_methods))
else: else:
vendor_methods = [(vendor_impl, vendor)] vendor_methods = [(vendor_impl, vendor)]
vendor_results = [] vendor_results = []
for impl_func, vendor_name in vendor_methods: for impl_func, vendor_name in vendor_methods:
try: try:
print(f"DEBUG: Calling {impl_func.__name__} from vendor '{vendor_name}'...") logger.debug("Calling %s from vendor '%s'...", impl_func.__name__, vendor_name)
result = impl_func(*args, **kwargs) result = impl_func(*args, **kwargs)
vendor_results.append(result) vendor_results.append(result)
print(f"SUCCESS: {impl_func.__name__} from vendor '{vendor_name}' completed successfully") logger.info("%s from vendor '%s' completed successfully", impl_func.__name__, vendor_name)
except AlphaVantageRateLimitError as e: except AlphaVantageRateLimitError as e:
if vendor == "alpha_vantage": if vendor == "alpha_vantage":
print(f"RATE_LIMIT: Alpha Vantage rate limit exceeded, falling back to next available vendor") logger.warning("Alpha Vantage rate limit exceeded, falling back to next available vendor")
print(f"DEBUG: Rate limit details: {e}") logger.debug("Rate limit details: %s", e)
continue continue
except Exception as e: except Exception as e:
print(f"FAILED: {impl_func.__name__} from vendor '{vendor_name}' failed: {e}") logger.error("%s from vendor '%s' failed: %s", impl_func.__name__, vendor_name, e)
continue continue
if vendor_results: if vendor_results:
results.extend(vendor_results) results.extend(vendor_results)
successful_vendor = vendor successful_vendor = vendor
result_summary = f"Got {len(vendor_results)} result(s)" result_summary = f"Got {len(vendor_results)} result(s)"
print(f"SUCCESS: Vendor '{vendor}' succeeded - {result_summary}") logger.info("Vendor '%s' succeeded - %s", vendor, result_summary)
if len(primary_vendors) == 1: if len(primary_vendors) == 1:
print(f"DEBUG: Stopping after successful vendor '{vendor}' (single-vendor config)") logger.debug("Stopping after successful vendor '%s' (single-vendor config)", vendor)
break break
else: else:
print(f"FAILED: Vendor '{vendor}' produced no results") logger.error("Vendor '%s' produced no results", vendor)
if not results: if not results:
print(f"FAILURE: All {vendor_attempt_count} vendor attempts failed for method '{method}'") logger.error("All %d vendor attempts failed for method '%s'", vendor_attempt_count, method)
raise RuntimeError(f"All vendor implementations failed for method '{method}'") raise RuntimeError(f"All vendor implementations failed for method '{method}'")
else: else:
print(f"FINAL: Method '{method}' completed with {len(results)} result(s) from {vendor_attempt_count} vendor attempt(s)") logger.info("Method '%s' completed with %d result(s) from %d vendor attempt(s)", method, len(results), vendor_attempt_count)
if len(results) == 1: if len(results) == 1:
return results[0] return results[0]

View File

@ -1,3 +1,4 @@
import logging
from typing import Annotated from typing import Annotated
import pandas as pd import pandas as pd
import os import os
@ -8,17 +9,17 @@ import json
from .reddit_utils import fetch_top_from_category from .reddit_utils import fetch_top_from_category
from tqdm import tqdm from tqdm import tqdm
logger = logging.getLogger(__name__)
def get_YFin_data_window( def get_YFin_data_window(
symbol: Annotated[str, "ticker symbol of the company"], symbol: Annotated[str, "ticker symbol of the company"],
curr_date: Annotated[str, "Start date in yyyy-mm-dd format"], curr_date: Annotated[str, "Start date in yyyy-mm-dd format"],
look_back_days: Annotated[int, "how many days to look back"], look_back_days: Annotated[int, "how many days to look back"],
) -> str: ) -> str:
# calculate past days
date_obj = datetime.strptime(curr_date, "%Y-%m-%d") date_obj = datetime.strptime(curr_date, "%Y-%m-%d")
before = date_obj - relativedelta(days=look_back_days) before = date_obj - relativedelta(days=look_back_days)
start_date = before.strftime("%Y-%m-%d") start_date = before.strftime("%Y-%m-%d")
# read in data
data = pd.read_csv( data = pd.read_csv(
os.path.join( os.path.join(
DATA_DIR, DATA_DIR,
@ -26,18 +27,14 @@ def get_YFin_data_window(
) )
) )
# Extract just the date part for comparison
data["DateOnly"] = data["Date"].str[:10] data["DateOnly"] = data["Date"].str[:10]
# Filter data between the start and end dates (inclusive)
filtered_data = data[ filtered_data = data[
(data["DateOnly"] >= start_date) & (data["DateOnly"] <= curr_date) (data["DateOnly"] >= start_date) & (data["DateOnly"] <= curr_date)
] ]
# Drop the temporary column we created
filtered_data = filtered_data.drop("DateOnly", axis=1) filtered_data = filtered_data.drop("DateOnly", axis=1)
# Set pandas display options to show the full DataFrame
with pd.option_context( with pd.option_context(
"display.max_rows", None, "display.max_columns", None, "display.width", None "display.max_rows", None, "display.max_columns", None, "display.width", None
): ):
@ -53,7 +50,6 @@ def get_YFin_data(
start_date: Annotated[str, "Start date in yyyy-mm-dd format"], start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
end_date: Annotated[str, "End date in yyyy-mm-dd format"], end_date: Annotated[str, "End date in yyyy-mm-dd format"],
) -> str: ) -> str:
# read in data
data = pd.read_csv( data = pd.read_csv(
os.path.join( os.path.join(
DATA_DIR, DATA_DIR,
@ -66,18 +62,14 @@ def get_YFin_data(
f"Get_YFin_Data: {end_date} is outside of the data range of 2015-01-01 to 2025-03-25" f"Get_YFin_Data: {end_date} is outside of the data range of 2015-01-01 to 2025-03-25"
) )
# Extract just the date part for comparison
data["DateOnly"] = data["Date"].str[:10] data["DateOnly"] = data["Date"].str[:10]
# Filter data between the start and end dates (inclusive)
filtered_data = data[ filtered_data = data[
(data["DateOnly"] >= start_date) & (data["DateOnly"] <= end_date) (data["DateOnly"] >= start_date) & (data["DateOnly"] <= end_date)
] ]
# Drop the temporary column we created
filtered_data = filtered_data.drop("DateOnly", axis=1) filtered_data = filtered_data.drop("DateOnly", axis=1)
# remove the index from the dataframe
filtered_data = filtered_data.reset_index(drop=True) filtered_data = filtered_data.reset_index(drop=True)
return filtered_data return filtered_data
@ -87,17 +79,6 @@ def get_finnhub_news(
start_date: Annotated[str, "Start date in yyyy-mm-dd format"], start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
end_date: Annotated[str, "End date in yyyy-mm-dd format"], end_date: Annotated[str, "End date in yyyy-mm-dd format"],
): ):
"""
Retrieve news about a company within a time frame
Args
query (str): Search query or ticker symbol
start_date (str): Start date in yyyy-mm-dd format
end_date (str): End date in yyyy-mm-dd format
Returns
str: dataframe containing the news of the company in the time frame
"""
result = get_data_in_range(query, start_date, end_date, "news_data", DATA_DIR) result = get_data_in_range(query, start_date, end_date, "news_data", DATA_DIR)
@ -121,17 +102,9 @@ def get_finnhub_company_insider_sentiment(
ticker: Annotated[str, "ticker symbol for the company"], ticker: Annotated[str, "ticker symbol for the company"],
curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"], curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"],
): ):
"""
Retrieve insider sentiment about a company (retrieved from public SEC information) for the past 15 days
Args:
ticker (str): ticker symbol of the company
curr_date (str): current date you are trading on, yyyy-mm-dd
Returns:
str: a report of the sentiment in the past 15 days starting at curr_date
"""
date_obj = datetime.strptime(curr_date, "%Y-%m-%d") date_obj = datetime.strptime(curr_date, "%Y-%m-%d")
before = date_obj - relativedelta(days=15) # Default 15 days lookback before = date_obj - relativedelta(days=15)
before = before.strftime("%Y-%m-%d") before = before.strftime("%Y-%m-%d")
data = get_data_in_range(ticker, before, curr_date, "insider_senti", DATA_DIR) data = get_data_in_range(ticker, before, curr_date, "insider_senti", DATA_DIR)
@ -158,17 +131,9 @@ def get_finnhub_company_insider_transactions(
ticker: Annotated[str, "ticker symbol"], ticker: Annotated[str, "ticker symbol"],
curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"], curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"],
): ):
"""
Retrieve insider transcaction information about a company (retrieved from public SEC information) for the past 15 days
Args:
ticker (str): ticker symbol of the company
curr_date (str): current date you are trading at, yyyy-mm-dd
Returns:
str: a report of the company's insider transaction/trading informtaion in the past 15 days
"""
date_obj = datetime.strptime(curr_date, "%Y-%m-%d") date_obj = datetime.strptime(curr_date, "%Y-%m-%d")
before = date_obj - relativedelta(days=15) # Default 15 days lookback before = date_obj - relativedelta(days=15)
before = before.strftime("%Y-%m-%d") before = before.strftime("%Y-%m-%d")
data = get_data_in_range(ticker, before, curr_date, "insider_trans", DATA_DIR) data = get_data_in_range(ticker, before, curr_date, "insider_trans", DATA_DIR)
@ -192,15 +157,6 @@ def get_finnhub_company_insider_transactions(
) )
def get_data_in_range(ticker, start_date, end_date, data_type, data_dir, period=None): def get_data_in_range(ticker, start_date, end_date, data_type, data_dir, period=None):
"""
Gets finnhub data saved and processed on disk.
Args:
start_date (str): Start date in YYYY-MM-DD format.
end_date (str): End date in YYYY-MM-DD format.
data_type (str): Type of data from finnhub to fetch. Can be insider_trans, SEC_filings, news_data, insider_senti, or fin_as_reported.
data_dir (str): Directory where the data is saved.
period (str): Default to none, if there is a period specified, should be annual or quarterly.
"""
if period: if period:
data_path = os.path.join( data_path = os.path.join(
@ -217,7 +173,6 @@ def get_data_in_range(ticker, start_date, end_date, data_type, data_dir, period=
data = open(data_path, "r") data = open(data_path, "r")
data = json.load(data) data = json.load(data)
# filter keys (date, str in format YYYY-MM-DD) by the date range (str, str in format YYYY-MM-DD)
filtered_data = {} filtered_data = {}
for key, value in data.items(): for key, value in data.items():
if start_date <= key <= end_date and len(value) > 0: if start_date <= key <= end_date and len(value) > 0:
@ -243,25 +198,19 @@ def get_simfin_balance_sheet(
) )
df = pd.read_csv(data_path, sep=";") df = pd.read_csv(data_path, sep=";")
# Convert date strings to datetime objects and remove any time components
df["Report Date"] = pd.to_datetime(df["Report Date"], utc=True).dt.normalize() df["Report Date"] = pd.to_datetime(df["Report Date"], utc=True).dt.normalize()
df["Publish Date"] = pd.to_datetime(df["Publish Date"], utc=True).dt.normalize() df["Publish Date"] = pd.to_datetime(df["Publish Date"], utc=True).dt.normalize()
# Convert the current date to datetime and normalize
curr_date_dt = pd.to_datetime(curr_date, utc=True).normalize() curr_date_dt = pd.to_datetime(curr_date, utc=True).normalize()
# Filter the DataFrame for the given ticker and for reports that were published on or before the current date
filtered_df = df[(df["Ticker"] == ticker) & (df["Publish Date"] <= curr_date_dt)] filtered_df = df[(df["Ticker"] == ticker) & (df["Publish Date"] <= curr_date_dt)]
# Check if there are any available reports; if not, return a notification
if filtered_df.empty: if filtered_df.empty:
print("No balance sheet available before the given current date.") logger.info("No balance sheet available before the given current date.")
return "" return ""
# Get the most recent balance sheet by selecting the row with the latest Publish Date
latest_balance_sheet = filtered_df.loc[filtered_df["Publish Date"].idxmax()] latest_balance_sheet = filtered_df.loc[filtered_df["Publish Date"].idxmax()]
# drop the SimFinID column
latest_balance_sheet = latest_balance_sheet.drop("SimFinId") latest_balance_sheet = latest_balance_sheet.drop("SimFinId")
return ( return (
@ -290,25 +239,19 @@ def get_simfin_cashflow(
) )
df = pd.read_csv(data_path, sep=";") df = pd.read_csv(data_path, sep=";")
# Convert date strings to datetime objects and remove any time components
df["Report Date"] = pd.to_datetime(df["Report Date"], utc=True).dt.normalize() df["Report Date"] = pd.to_datetime(df["Report Date"], utc=True).dt.normalize()
df["Publish Date"] = pd.to_datetime(df["Publish Date"], utc=True).dt.normalize() df["Publish Date"] = pd.to_datetime(df["Publish Date"], utc=True).dt.normalize()
# Convert the current date to datetime and normalize
curr_date_dt = pd.to_datetime(curr_date, utc=True).normalize() curr_date_dt = pd.to_datetime(curr_date, utc=True).normalize()
# Filter the DataFrame for the given ticker and for reports that were published on or before the current date
filtered_df = df[(df["Ticker"] == ticker) & (df["Publish Date"] <= curr_date_dt)] filtered_df = df[(df["Ticker"] == ticker) & (df["Publish Date"] <= curr_date_dt)]
# Check if there are any available reports; if not, return a notification
if filtered_df.empty: if filtered_df.empty:
print("No cash flow statement available before the given current date.") logger.info("No cash flow statement available before the given current date.")
return "" return ""
# Get the most recent cash flow statement by selecting the row with the latest Publish Date
latest_cash_flow = filtered_df.loc[filtered_df["Publish Date"].idxmax()] latest_cash_flow = filtered_df.loc[filtered_df["Publish Date"].idxmax()]
# drop the SimFinID column
latest_cash_flow = latest_cash_flow.drop("SimFinId") latest_cash_flow = latest_cash_flow.drop("SimFinId")
return ( return (
@ -337,25 +280,19 @@ def get_simfin_income_statements(
) )
df = pd.read_csv(data_path, sep=";") df = pd.read_csv(data_path, sep=";")
# Convert date strings to datetime objects and remove any time components
df["Report Date"] = pd.to_datetime(df["Report Date"], utc=True).dt.normalize() df["Report Date"] = pd.to_datetime(df["Report Date"], utc=True).dt.normalize()
df["Publish Date"] = pd.to_datetime(df["Publish Date"], utc=True).dt.normalize() df["Publish Date"] = pd.to_datetime(df["Publish Date"], utc=True).dt.normalize()
# Convert the current date to datetime and normalize
curr_date_dt = pd.to_datetime(curr_date, utc=True).normalize() curr_date_dt = pd.to_datetime(curr_date, utc=True).normalize()
# Filter the DataFrame for the given ticker and for reports that were published on or before the current date
filtered_df = df[(df["Ticker"] == ticker) & (df["Publish Date"] <= curr_date_dt)] filtered_df = df[(df["Ticker"] == ticker) & (df["Publish Date"] <= curr_date_dt)]
# Check if there are any available reports; if not, return a notification
if filtered_df.empty: if filtered_df.empty:
print("No income statement available before the given current date.") logger.info("No income statement available before the given current date.")
return "" return ""
# Get the most recent income statement by selecting the row with the latest Publish Date
latest_income = filtered_df.loc[filtered_df["Publish Date"].idxmax()] latest_income = filtered_df.loc[filtered_df["Publish Date"].idxmax()]
# drop the SimFinID column
latest_income = latest_income.drop("SimFinId") latest_income = latest_income.drop("SimFinId")
return ( return (
@ -370,22 +307,12 @@ def get_reddit_global_news(
look_back_days: Annotated[int, "Number of days to look back"] = 7, look_back_days: Annotated[int, "Number of days to look back"] = 7,
limit: Annotated[int, "Maximum number of articles to return"] = 5, limit: Annotated[int, "Maximum number of articles to return"] = 5,
) -> str: ) -> str:
"""
Retrieve the latest top reddit news
Args:
curr_date: Current date in yyyy-mm-dd format
look_back_days: Number of days to look back (default 7)
limit: Maximum number of articles to return (default 5)
Returns:
str: A formatted string containing the latest news articles posts on reddit
"""
curr_date_dt = datetime.strptime(curr_date, "%Y-%m-%d") curr_date_dt = datetime.strptime(curr_date, "%Y-%m-%d")
before = curr_date_dt - relativedelta(days=look_back_days) before = curr_date_dt - relativedelta(days=look_back_days)
before = before.strftime("%Y-%m-%d") before = before.strftime("%Y-%m-%d")
posts = [] posts = []
# iterate from before to curr_date
curr_iter_date = datetime.strptime(before, "%Y-%m-%d") curr_iter_date = datetime.strptime(before, "%Y-%m-%d")
total_iterations = (curr_date_dt - curr_iter_date).days + 1 total_iterations = (curr_date_dt - curr_iter_date).days + 1
@ -423,21 +350,11 @@ def get_reddit_company_news(
start_date: Annotated[str, "Start date in yyyy-mm-dd format"], start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
end_date: Annotated[str, "End date in yyyy-mm-dd format"], end_date: Annotated[str, "End date in yyyy-mm-dd format"],
) -> str: ) -> str:
"""
Retrieve the latest top reddit news
Args:
query: Search query or ticker symbol
start_date: Start date in yyyy-mm-dd format
end_date: End date in yyyy-mm-dd format
Returns:
str: A formatted string containing news articles posts on reddit
"""
start_date_dt = datetime.strptime(start_date, "%Y-%m-%d") start_date_dt = datetime.strptime(start_date, "%Y-%m-%d")
end_date_dt = datetime.strptime(end_date, "%Y-%m-%d") end_date_dt = datetime.strptime(end_date, "%Y-%m-%d")
posts = [] posts = []
# iterate from start_date to end_date
curr_date = start_date_dt curr_date = start_date_dt
total_iterations = (end_date_dt - curr_date).days + 1 total_iterations = (end_date_dt - curr_date).days + 1
@ -451,7 +368,7 @@ def get_reddit_company_news(
fetch_result = fetch_top_from_category( fetch_result = fetch_top_from_category(
"company_news", "company_news",
curr_date_str, curr_date_str,
10, # max limit per day 10,
query, query,
data_path=os.path.join(DATA_DIR, "reddit_data"), data_path=os.path.join(DATA_DIR, "reddit_data"),
) )

View File

@ -1,8 +1,11 @@
import logging
import os import os
import time import time
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import List, Dict, Any from typing import List, Dict, Any
logger = logging.getLogger(__name__)
try: try:
from tavily import TavilyClient from tavily import TavilyClient
TAVILY_AVAILABLE = True TAVILY_AVAILABLE = True
@ -37,17 +40,17 @@ def _search_with_retry(client, query: str, search_depth: str, topic: str, time_r
error_str = str(e).lower() error_str = str(e).lower()
if "rate" in error_str or "limit" in error_str or "429" in error_str: if "rate" in error_str or "limit" in error_str or "429" in error_str:
wait_time = RETRY_BACKOFF * (attempt + 1) * 2 wait_time = RETRY_BACKOFF * (attempt + 1) * 2
print(f"DEBUG: Tavily rate limited, waiting {wait_time}s before retry {attempt + 1}/{max_retries}") logger.debug("Tavily rate limited, waiting %ds before retry %d/%d", wait_time, attempt + 1, max_retries)
time.sleep(wait_time) time.sleep(wait_time)
last_exception = e last_exception = e
elif "timeout" in error_str or "timed out" in error_str: elif "timeout" in error_str or "timed out" in error_str:
wait_time = RETRY_BACKOFF * (attempt + 1) wait_time = RETRY_BACKOFF * (attempt + 1)
print(f"DEBUG: Tavily timeout, waiting {wait_time}s before retry {attempt + 1}/{max_retries}") logger.debug("Tavily timeout, waiting %ds before retry %d/%d", wait_time, attempt + 1, max_retries)
time.sleep(wait_time) time.sleep(wait_time)
last_exception = e last_exception = e
elif "connection" in error_str or "network" in error_str: elif "connection" in error_str or "network" in error_str:
wait_time = RETRY_BACKOFF * (attempt + 1) wait_time = RETRY_BACKOFF * (attempt + 1)
print(f"DEBUG: Tavily connection error, waiting {wait_time}s before retry {attempt + 1}/{max_retries}") logger.debug("Tavily connection error, waiting %ds before retry %d/%d", wait_time, attempt + 1, max_retries)
time.sleep(wait_time) time.sleep(wait_time)
last_exception = e last_exception = e
else: else:
@ -57,13 +60,13 @@ def _search_with_retry(client, query: str, search_depth: str, topic: str, time_r
def get_bulk_news_tavily(lookback_hours: int) -> List[Dict[str, Any]]: def get_bulk_news_tavily(lookback_hours: int) -> List[Dict[str, Any]]:
if not TAVILY_AVAILABLE: if not TAVILY_AVAILABLE:
print("DEBUG: Tavily library not installed") logger.debug("Tavily library not installed")
return [] return []
try: try:
client = TavilyClient(api_key=get_api_key()) client = TavilyClient(api_key=get_api_key())
except ValueError as e: except ValueError as e:
print(f"DEBUG: Tavily API key not configured: {e}") logger.debug("Tavily API key not configured: %s", e)
return [] return []
queries = [ queries = [
@ -121,8 +124,8 @@ def get_bulk_news_tavily(lookback_hours: int) -> List[Dict[str, Any]]:
all_articles.append(article) all_articles.append(article)
except Exception as e: except Exception as e:
print(f"DEBUG: Tavily search failed for query '{query}': {e}") logger.debug("Tavily search failed for query '%s': %s", query, e)
continue continue
print(f"DEBUG: Tavily returned {len(all_articles)} articles") logger.debug("Tavily returned %d articles", len(all_articles))
return all_articles return all_articles

View File

@ -1,15 +1,18 @@
import logging
import os import os
import json import json
import pandas as pd import pandas as pd
from datetime import date, timedelta, datetime from datetime import date, timedelta, datetime
from typing import Annotated from typing import Annotated
logger = logging.getLogger(__name__)
SavePathType = Annotated[str, "File path to save data. If None, data is not saved."] SavePathType = Annotated[str, "File path to save data. If None, data is not saved."]
def save_output(data: pd.DataFrame, tag: str, save_path: SavePathType = None) -> None: def save_output(data: pd.DataFrame, tag: str, save_path: SavePathType = None) -> None:
if save_path: if save_path:
data.to_csv(save_path) data.to_csv(save_path)
print(f"{tag} saved to {save_path}") logger.info("%s saved to %s", tag, save_path)
def get_current_date(): def get_current_date():

View File

@ -1,3 +1,4 @@
import logging
from typing import Annotated from typing import Annotated
from datetime import datetime from datetime import datetime
from dateutil.relativedelta import relativedelta from dateutil.relativedelta import relativedelta
@ -5,6 +6,8 @@ import yfinance as yf
import os import os
from .stockstats_utils import StockstatsUtils from .stockstats_utils import StockstatsUtils
logger = logging.getLogger(__name__)
def get_YFin_data_online( def get_YFin_data_online(
symbol: Annotated[str, "ticker symbol of the company"], symbol: Annotated[str, "ticker symbol of the company"],
start_date: Annotated[str, "Start date in yyyy-mm-dd format"], start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
@ -14,32 +17,25 @@ def get_YFin_data_online(
datetime.strptime(start_date, "%Y-%m-%d") datetime.strptime(start_date, "%Y-%m-%d")
datetime.strptime(end_date, "%Y-%m-%d") datetime.strptime(end_date, "%Y-%m-%d")
# Create ticker object
ticker = yf.Ticker(symbol.upper()) ticker = yf.Ticker(symbol.upper())
# Fetch historical data for the specified date range
data = ticker.history(start=start_date, end=end_date) data = ticker.history(start=start_date, end=end_date)
# Check if data is empty
if data.empty: if data.empty:
return ( return (
f"No data found for symbol '{symbol}' between {start_date} and {end_date}" f"No data found for symbol '{symbol}' between {start_date} and {end_date}"
) )
# Remove timezone info from index for cleaner output
if data.index.tz is not None: if data.index.tz is not None:
data.index = data.index.tz_localize(None) data.index = data.index.tz_localize(None)
# Round numerical values to 2 decimal places for cleaner display
numeric_columns = ["Open", "High", "Low", "Close", "Adj Close"] numeric_columns = ["Open", "High", "Low", "Close", "Adj Close"]
for col in numeric_columns: for col in numeric_columns:
if col in data.columns: if col in data.columns:
data[col] = data[col].round(2) data[col] = data[col].round(2)
# Convert DataFrame to CSV string
csv_string = data.to_csv() csv_string = data.to_csv()
# Add header information
header = f"# Stock data for {symbol.upper()} from {start_date} to {end_date}\n" header = f"# Stock data for {symbol.upper()} from {start_date} to {end_date}\n"
header += f"# Total records: {len(data)}\n" header += f"# Total records: {len(data)}\n"
header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n" header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
@ -56,7 +52,6 @@ def get_stock_stats_indicators_window(
) -> str: ) -> str:
best_ind_params = { best_ind_params = {
# Moving Averages
"close_50_sma": ( "close_50_sma": (
"50 SMA: A medium-term trend indicator. " "50 SMA: A medium-term trend indicator. "
"Usage: Identify trend direction and serve as dynamic support/resistance. " "Usage: Identify trend direction and serve as dynamic support/resistance. "
@ -72,7 +67,6 @@ def get_stock_stats_indicators_window(
"Usage: Capture quick shifts in momentum and potential entry points. " "Usage: Capture quick shifts in momentum and potential entry points. "
"Tips: Prone to noise in choppy markets; use alongside longer averages for filtering false signals." "Tips: Prone to noise in choppy markets; use alongside longer averages for filtering false signals."
), ),
# MACD Related
"macd": ( "macd": (
"MACD: Computes momentum via differences of EMAs. " "MACD: Computes momentum via differences of EMAs. "
"Usage: Look for crossovers and divergence as signals of trend changes. " "Usage: Look for crossovers and divergence as signals of trend changes. "
@ -88,13 +82,11 @@ def get_stock_stats_indicators_window(
"Usage: Visualize momentum strength and spot divergence early. " "Usage: Visualize momentum strength and spot divergence early. "
"Tips: Can be volatile; complement with additional filters in fast-moving markets." "Tips: Can be volatile; complement with additional filters in fast-moving markets."
), ),
# Momentum Indicators
"rsi": ( "rsi": (
"RSI: Measures momentum to flag overbought/oversold conditions. " "RSI: Measures momentum to flag overbought/oversold conditions. "
"Usage: Apply 70/30 thresholds and watch for divergence to signal reversals. " "Usage: Apply 70/30 thresholds and watch for divergence to signal reversals. "
"Tips: In strong trends, RSI may remain extreme; always cross-check with trend analysis." "Tips: In strong trends, RSI may remain extreme; always cross-check with trend analysis."
), ),
# Volatility Indicators
"boll": ( "boll": (
"Bollinger Middle: A 20 SMA serving as the basis for Bollinger Bands. " "Bollinger Middle: A 20 SMA serving as the basis for Bollinger Bands. "
"Usage: Acts as a dynamic benchmark for price movement. " "Usage: Acts as a dynamic benchmark for price movement. "
@ -115,7 +107,6 @@ def get_stock_stats_indicators_window(
"Usage: Set stop-loss levels and adjust position sizes based on current market volatility. " "Usage: Set stop-loss levels and adjust position sizes based on current market volatility. "
"Tips: It's a reactive measure, so use it as part of a broader risk management strategy." "Tips: It's a reactive measure, so use it as part of a broader risk management strategy."
), ),
# Volume-Based Indicators
"vwma": ( "vwma": (
"VWMA: A moving average weighted by volume. " "VWMA: A moving average weighted by volume. "
"Usage: Confirm trends by integrating price action with volume data. " "Usage: Confirm trends by integrating price action with volume data. "
@ -137,18 +128,15 @@ def get_stock_stats_indicators_window(
curr_date_dt = datetime.strptime(curr_date, "%Y-%m-%d") curr_date_dt = datetime.strptime(curr_date, "%Y-%m-%d")
before = curr_date_dt - relativedelta(days=look_back_days) before = curr_date_dt - relativedelta(days=look_back_days)
# Optimized: Get stock data once and calculate indicators for all dates
try: try:
indicator_data = _get_stock_stats_bulk(symbol, indicator, curr_date) indicator_data = _get_stock_stats_bulk(symbol, indicator, curr_date)
# Generate the date range we need
current_dt = curr_date_dt current_dt = curr_date_dt
date_values = [] date_values = []
while current_dt >= before: while current_dt >= before:
date_str = current_dt.strftime('%Y-%m-%d') date_str = current_dt.strftime('%Y-%m-%d')
# Look up the indicator value for this date
if date_str in indicator_data: if date_str in indicator_data:
indicator_value = indicator_data[date_str] indicator_value = indicator_data[date_str]
else: else:
@ -157,14 +145,12 @@ def get_stock_stats_indicators_window(
date_values.append((date_str, indicator_value)) date_values.append((date_str, indicator_value))
current_dt = current_dt - relativedelta(days=1) current_dt = current_dt - relativedelta(days=1)
# Build the result string
ind_string = "" ind_string = ""
for date_str, value in date_values: for date_str, value in date_values:
ind_string += f"{date_str}: {value}\n" ind_string += f"{date_str}: {value}\n"
except Exception as e: except Exception as e:
print(f"Error getting bulk stockstats data: {e}") logger.error("Error getting bulk stockstats data: %s", e)
# Fallback to original implementation if bulk method fails
ind_string = "" ind_string = ""
curr_date_dt = datetime.strptime(curr_date, "%Y-%m-%d") curr_date_dt = datetime.strptime(curr_date, "%Y-%m-%d")
while curr_date_dt >= before: while curr_date_dt >= before:
@ -189,11 +175,6 @@ def _get_stock_stats_bulk(
indicator: Annotated[str, "technical indicator to calculate"], indicator: Annotated[str, "technical indicator to calculate"],
curr_date: Annotated[str, "current date for reference"] curr_date: Annotated[str, "current date for reference"]
) -> dict: ) -> dict:
"""
Optimized bulk calculation of stock stats indicators.
Fetches data once and calculates indicator for all available dates.
Returns dict mapping date strings to indicator values.
"""
from .config import get_config from .config import get_config
import pandas as pd import pandas as pd
from stockstats import wrap from stockstats import wrap
@ -203,7 +184,6 @@ def _get_stock_stats_bulk(
online = config["data_vendors"]["technical_indicators"] != "local" online = config["data_vendors"]["technical_indicators"] != "local"
if not online: if not online:
# Local data path
try: try:
data = pd.read_csv( data = pd.read_csv(
os.path.join( os.path.join(
@ -215,7 +195,6 @@ def _get_stock_stats_bulk(
except FileNotFoundError: except FileNotFoundError:
raise Exception("Stockstats fail: Yahoo Finance data not fetched yet!") raise Exception("Stockstats fail: Yahoo Finance data not fetched yet!")
else: else:
# Online data fetching with caching
today_date = pd.Timestamp.today() today_date = pd.Timestamp.today()
curr_date_dt = pd.to_datetime(curr_date) curr_date_dt = pd.to_datetime(curr_date)
@ -249,16 +228,13 @@ def _get_stock_stats_bulk(
df = wrap(data) df = wrap(data)
df["Date"] = df["Date"].dt.strftime("%Y-%m-%d") df["Date"] = df["Date"].dt.strftime("%Y-%m-%d")
# Calculate the indicator for all rows at once df[indicator]
df[indicator] # This triggers stockstats to calculate the indicator
# Create a dictionary mapping date strings to indicator values
result_dict = {} result_dict = {}
for _, row in df.iterrows(): for _, row in df.iterrows():
date_str = row["Date"] date_str = row["Date"]
indicator_value = row[indicator] indicator_value = row[indicator]
# Handle NaN/None values
if pd.isna(indicator_value): if pd.isna(indicator_value):
result_dict[date_str] = "N/A" result_dict[date_str] = "N/A"
else: else:
@ -285,8 +261,9 @@ def get_stockstats_indicator(
curr_date, curr_date,
) )
except Exception as e: except Exception as e:
print( logger.error(
f"Error getting stockstats indicator data for indicator {indicator} on {curr_date}: {e}" "Error getting stockstats indicator data for indicator %s on %s: %s",
indicator, curr_date, e
) )
return "" return ""
@ -298,7 +275,6 @@ def get_balance_sheet(
freq: Annotated[str, "frequency of data: 'annual' or 'quarterly'"] = "quarterly", freq: Annotated[str, "frequency of data: 'annual' or 'quarterly'"] = "quarterly",
curr_date: Annotated[str, "current date (not used for yfinance)"] = None curr_date: Annotated[str, "current date (not used for yfinance)"] = None
): ):
"""Get balance sheet data from yfinance."""
try: try:
ticker_obj = yf.Ticker(ticker.upper()) ticker_obj = yf.Ticker(ticker.upper())
@ -310,10 +286,8 @@ def get_balance_sheet(
if data.empty: if data.empty:
return f"No balance sheet data found for symbol '{ticker}'" return f"No balance sheet data found for symbol '{ticker}'"
# Convert to CSV string for consistency with other functions
csv_string = data.to_csv() csv_string = data.to_csv()
# Add header information
header = f"# Balance Sheet data for {ticker.upper()} ({freq})\n" header = f"# Balance Sheet data for {ticker.upper()} ({freq})\n"
header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n" header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
@ -328,7 +302,6 @@ def get_cashflow(
freq: Annotated[str, "frequency of data: 'annual' or 'quarterly'"] = "quarterly", freq: Annotated[str, "frequency of data: 'annual' or 'quarterly'"] = "quarterly",
curr_date: Annotated[str, "current date (not used for yfinance)"] = None curr_date: Annotated[str, "current date (not used for yfinance)"] = None
): ):
"""Get cash flow data from yfinance."""
try: try:
ticker_obj = yf.Ticker(ticker.upper()) ticker_obj = yf.Ticker(ticker.upper())
@ -340,10 +313,8 @@ def get_cashflow(
if data.empty: if data.empty:
return f"No cash flow data found for symbol '{ticker}'" return f"No cash flow data found for symbol '{ticker}'"
# Convert to CSV string for consistency with other functions
csv_string = data.to_csv() csv_string = data.to_csv()
# Add header information
header = f"# Cash Flow data for {ticker.upper()} ({freq})\n" header = f"# Cash Flow data for {ticker.upper()} ({freq})\n"
header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n" header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
@ -358,7 +329,6 @@ def get_income_statement(
freq: Annotated[str, "frequency of data: 'annual' or 'quarterly'"] = "quarterly", freq: Annotated[str, "frequency of data: 'annual' or 'quarterly'"] = "quarterly",
curr_date: Annotated[str, "current date (not used for yfinance)"] = None curr_date: Annotated[str, "current date (not used for yfinance)"] = None
): ):
"""Get income statement data from yfinance."""
try: try:
ticker_obj = yf.Ticker(ticker.upper()) ticker_obj = yf.Ticker(ticker.upper())
@ -370,10 +340,8 @@ def get_income_statement(
if data.empty: if data.empty:
return f"No income statement data found for symbol '{ticker}'" return f"No income statement data found for symbol '{ticker}'"
# Convert to CSV string for consistency with other functions
csv_string = data.to_csv() csv_string = data.to_csv()
# Add header information
header = f"# Income Statement data for {ticker.upper()} ({freq})\n" header = f"# Income Statement data for {ticker.upper()} ({freq})\n"
header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n" header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
@ -386,7 +354,6 @@ def get_income_statement(
def get_insider_transactions( def get_insider_transactions(
ticker: Annotated[str, "ticker symbol of the company"] ticker: Annotated[str, "ticker symbol of the company"]
): ):
"""Get insider transactions data from yfinance."""
try: try:
ticker_obj = yf.Ticker(ticker.upper()) ticker_obj = yf.Ticker(ticker.upper())
data = ticker_obj.insider_transactions data = ticker_obj.insider_transactions
@ -394,10 +361,8 @@ def get_insider_transactions(
if data is None or data.empty: if data is None or data.empty:
return f"No insider transactions data found for symbol '{ticker}'" return f"No insider transactions data found for symbol '{ticker}'"
# Convert to CSV string for consistency with other functions
csv_string = data.to_csv() csv_string = data.to_csv()
# Add header information
header = f"# Insider Transactions data for {ticker.upper()}\n" header = f"# Insider Transactions data for {ticker.upper()}\n"
header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n" header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"

View File

@ -1,5 +1,4 @@
# gets data/stats import logging
import yfinance as yf import yfinance as yf
from typing import Annotated, Callable, Any, Optional from typing import Annotated, Callable, Any, Optional
from pandas import DataFrame from pandas import DataFrame
@ -8,9 +7,10 @@ from functools import wraps
from .utils import save_output, SavePathType, decorate_all_methods from .utils import save_output, SavePathType, decorate_all_methods
logger = logging.getLogger(__name__)
def init_ticker(func: Callable) -> Callable: def init_ticker(func: Callable) -> Callable:
"""Decorator to initialize yf.Ticker and pass it to the function."""
@wraps(func) @wraps(func)
def wrapper(symbol: Annotated[str, "ticker symbol"], *args, **kwargs) -> Any: def wrapper(symbol: Annotated[str, "ticker symbol"], *args, **kwargs) -> Any:
@ -33,19 +33,15 @@ class YFinanceUtils:
], ],
save_path: SavePathType = None, save_path: SavePathType = None,
) -> DataFrame: ) -> DataFrame:
"""retrieve stock price data for designated ticker symbol"""
ticker = symbol ticker = symbol
# add one day to the end_date so that the data range is inclusive
end_date = pd.to_datetime(end_date) + pd.DateOffset(days=1) end_date = pd.to_datetime(end_date) + pd.DateOffset(days=1)
end_date = end_date.strftime("%Y-%m-%d") end_date = end_date.strftime("%Y-%m-%d")
stock_data = ticker.history(start=start_date, end=end_date) stock_data = ticker.history(start=start_date, end=end_date)
# save_output(stock_data, f"Stock data for {ticker.ticker}", save_path)
return stock_data return stock_data
def get_stock_info( def get_stock_info(
symbol: Annotated[str, "ticker symbol"], symbol: Annotated[str, "ticker symbol"],
) -> dict: ) -> dict:
"""Fetches and returns latest stock information."""
ticker = symbol ticker = symbol
stock_info = ticker.info stock_info = ticker.info
return stock_info return stock_info
@ -54,7 +50,6 @@ class YFinanceUtils:
symbol: Annotated[str, "ticker symbol"], symbol: Annotated[str, "ticker symbol"],
save_path: Optional[str] = None, save_path: Optional[str] = None,
) -> DataFrame: ) -> DataFrame:
"""Fetches and returns company information as a DataFrame."""
ticker = symbol ticker = symbol
info = ticker.info info = ticker.info
company_info = { company_info = {
@ -67,50 +62,43 @@ class YFinanceUtils:
company_info_df = DataFrame([company_info]) company_info_df = DataFrame([company_info])
if save_path: if save_path:
company_info_df.to_csv(save_path) company_info_df.to_csv(save_path)
print(f"Company info for {ticker.ticker} saved to {save_path}") logger.info("Company info for %s saved to %s", ticker.ticker, save_path)
return company_info_df return company_info_df
def get_stock_dividends( def get_stock_dividends(
symbol: Annotated[str, "ticker symbol"], symbol: Annotated[str, "ticker symbol"],
save_path: Optional[str] = None, save_path: Optional[str] = None,
) -> DataFrame: ) -> DataFrame:
"""Fetches and returns the latest dividends data as a DataFrame."""
ticker = symbol ticker = symbol
dividends = ticker.dividends dividends = ticker.dividends
if save_path: if save_path:
dividends.to_csv(save_path) dividends.to_csv(save_path)
print(f"Dividends for {ticker.ticker} saved to {save_path}") logger.info("Dividends for %s saved to %s", ticker.ticker, save_path)
return dividends return dividends
def get_income_stmt(symbol: Annotated[str, "ticker symbol"]) -> DataFrame: def get_income_stmt(symbol: Annotated[str, "ticker symbol"]) -> DataFrame:
"""Fetches and returns the latest income statement of the company as a DataFrame."""
ticker = symbol ticker = symbol
income_stmt = ticker.financials income_stmt = ticker.financials
return income_stmt return income_stmt
def get_balance_sheet(symbol: Annotated[str, "ticker symbol"]) -> DataFrame: def get_balance_sheet(symbol: Annotated[str, "ticker symbol"]) -> DataFrame:
"""Fetches and returns the latest balance sheet of the company as a DataFrame."""
ticker = symbol ticker = symbol
balance_sheet = ticker.balance_sheet balance_sheet = ticker.balance_sheet
return balance_sheet return balance_sheet
def get_cash_flow(symbol: Annotated[str, "ticker symbol"]) -> DataFrame: def get_cash_flow(symbol: Annotated[str, "ticker symbol"]) -> DataFrame:
"""Fetches and returns the latest cash flow statement of the company as a DataFrame."""
ticker = symbol ticker = symbol
cash_flow = ticker.cashflow cash_flow = ticker.cashflow
return cash_flow return cash_flow
def get_analyst_recommendations(symbol: Annotated[str, "ticker symbol"]) -> tuple: def get_analyst_recommendations(symbol: Annotated[str, "ticker symbol"]) -> tuple:
"""Fetches the latest analyst recommendations and returns the most common recommendation and its count."""
ticker = symbol ticker = symbol
recommendations = ticker.recommendations recommendations = ticker.recommendations
if recommendations.empty: if recommendations.empty:
return None, 0 # No recommendations available return None, 0
# Assuming 'period' column exists and needs to be excluded row_0 = recommendations.iloc[0, 1:]
row_0 = recommendations.iloc[0, 1:] # Exclude 'period' column if necessary
# Find the maximum voting result
max_votes = row_0.max() max_votes = row_0.max()
majority_voting_result = row_0[row_0 == max_votes].index.tolist() majority_voting_result = row_0[row_0 == max_votes].index.tolist()

View File

@ -31,4 +31,8 @@ DEFAULT_CONFIG = {
"bulk_news_vendor_order": ["tavily", "brave", "alpha_vantage", "openai", "google"], "bulk_news_vendor_order": ["tavily", "brave", "alpha_vantage", "openai", "google"],
"bulk_news_timeout": 30, "bulk_news_timeout": 30,
"bulk_news_max_retries": 3, "bulk_news_max_retries": 3,
"log_level": os.getenv("TRADINGAGENTS_LOG_LEVEL", "INFO"),
"log_dir": os.getenv("TRADINGAGENTS_LOG_DIR", "./logs"),
"log_console_enabled": os.getenv("TRADINGAGENTS_LOG_CONSOLE", "true").lower() in ("true", "1", "yes", "on"),
"log_file_enabled": os.getenv("TRADINGAGENTS_LOG_FILE", "true").lower() in ("true", "1", "yes", "on"),
} }

View File

@ -1,3 +1,4 @@
import logging
import os import os
import signal import signal
import threading import threading
@ -54,6 +55,8 @@ from .propagation import Propagator
from .reflection import Reflector from .reflection import Reflector
from .signal_processing import SignalProcessor from .signal_processing import SignalProcessor
logger = logging.getLogger(__name__)
class DiscoveryTimeoutException(Exception): class DiscoveryTimeoutException(Exception):
pass pass
@ -164,7 +167,7 @@ class TradingAgentsGraph:
if len(chunk["messages"]) == 0: if len(chunk["messages"]) == 0:
pass pass
else: else:
chunk["messages"][-1].pretty_print() logger.debug("Agent message: %s", chunk["messages"][-1])
trace.append(chunk) trace.append(chunk)
final_state = trace[-1] final_state = trace[-1]

121
tradingagents/logging.py Normal file
View File

@ -0,0 +1,121 @@
import logging
import logging.handlers
import os
import json
from datetime import datetime
LOG_LEVEL_DEFAULT = "INFO"
LOG_DIR_DEFAULT = "./logs"
LOG_FILE_NAME = "tradingagents.log"
LOG_MAX_BYTES = 10 * 1024 * 1024
LOG_BACKUP_COUNT = 5
_logging_initialized = False
class JSONFormatter(logging.Formatter):
def format(self, record):
log_record = {
"timestamp": datetime.utcnow().isoformat() + "Z",
"level": record.levelname,
"logger": record.name,
"message": record.getMessage(),
"filename": record.filename,
"funcName": record.funcName,
"lineno": record.lineno,
}
if record.exc_info:
log_record["exception"] = self.formatException(record.exc_info)
return json.dumps(log_record)
def _parse_bool(value):
if isinstance(value, bool):
return value
if isinstance(value, str):
return value.lower() in ("true", "1", "yes", "on")
return bool(value)
def _get_config_value(key, default):
try:
from tradingagents.default_config import DEFAULT_CONFIG
return DEFAULT_CONFIG.get(key, default)
except ImportError:
return default
def setup_logging():
global _logging_initialized
log_level_str = os.getenv("TRADINGAGENTS_LOG_LEVEL")
if log_level_str is None:
log_level_str = _get_config_value("log_level", LOG_LEVEL_DEFAULT)
log_dir = os.getenv("TRADINGAGENTS_LOG_DIR")
if log_dir is None:
log_dir = _get_config_value("log_dir", LOG_DIR_DEFAULT)
console_enabled_env = os.getenv("TRADINGAGENTS_LOG_CONSOLE")
if console_enabled_env is not None:
console_enabled = _parse_bool(console_enabled_env)
else:
console_enabled = _get_config_value("log_console_enabled", True)
file_enabled_env = os.getenv("TRADINGAGENTS_LOG_FILE")
if file_enabled_env is not None:
file_enabled = _parse_bool(file_enabled_env)
else:
file_enabled = _get_config_value("log_file_enabled", True)
log_level = getattr(logging, log_level_str.upper(), logging.INFO)
root_logger = logging.getLogger("tradingagents")
for handler in root_logger.handlers[:]:
root_logger.removeHandler(handler)
root_logger.setLevel(log_level)
if file_enabled:
os.makedirs(log_dir, exist_ok=True)
log_file_path = os.path.join(log_dir, LOG_FILE_NAME)
file_handler = logging.handlers.RotatingFileHandler(
log_file_path,
maxBytes=LOG_MAX_BYTES,
backupCount=LOG_BACKUP_COUNT,
)
file_handler.setLevel(log_level)
file_handler.setFormatter(JSONFormatter())
root_logger.addHandler(file_handler)
if console_enabled:
from rich.logging import RichHandler
from rich.console import Console
console = Console(stderr=True)
rich_handler = RichHandler(
console=console,
show_time=True,
show_level=True,
show_path=True,
rich_tracebacks=True,
)
rich_handler.setLevel(log_level)
root_logger.addHandler(rich_handler)
_logging_initialized = True
return root_logger
def get_logger(name):
global _logging_initialized
if not _logging_initialized:
setup_logging()
return logging.getLogger(name)