feat: Resolve "main_portfolio" alias in portfolio routes and improve LangGraph content extraction robustness with new unit tests.
This commit is contained in:
parent
6999da0827
commit
319168c74f
|
|
@ -10,6 +10,16 @@ import datetime
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/portfolios", tags=["portfolios"])
|
router = APIRouter(prefix="/api/portfolios", tags=["portfolios"])
|
||||||
|
|
||||||
|
def _resolve_portfolio_id(portfolio_id: str, db: SupabaseClient) -> str:
|
||||||
|
"""Resolves the 'main_portfolio' alias to the first available portfolio ID."""
|
||||||
|
if portfolio_id == "main_portfolio":
|
||||||
|
portfolios = db.list_portfolios()
|
||||||
|
if portfolios:
|
||||||
|
return portfolios[0].portfolio_id
|
||||||
|
else:
|
||||||
|
raise PortfolioNotFoundError("No portfolios found to resolve 'main_portfolio' alias.")
|
||||||
|
return portfolio_id
|
||||||
|
|
||||||
@router.get("/")
|
@router.get("/")
|
||||||
async def list_portfolios(
|
async def list_portfolios(
|
||||||
user: dict = Depends(get_current_user),
|
user: dict = Depends(get_current_user),
|
||||||
|
|
@ -25,6 +35,7 @@ async def get_portfolio(
|
||||||
db: SupabaseClient = Depends(get_db_client)
|
db: SupabaseClient = Depends(get_db_client)
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
|
portfolio_id = _resolve_portfolio_id(portfolio_id, db)
|
||||||
portfolio = db.get_portfolio(portfolio_id)
|
portfolio = db.get_portfolio(portfolio_id)
|
||||||
return portfolio.to_dict()
|
return portfolio.to_dict()
|
||||||
except PortfolioNotFoundError:
|
except PortfolioNotFoundError:
|
||||||
|
|
@ -42,6 +53,7 @@ async def get_portfolio_summary(
|
||||||
date = datetime.datetime.now().strftime("%Y-%m-%d")
|
date = datetime.datetime.now().strftime("%Y-%m-%d")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
portfolio_id = _resolve_portfolio_id(portfolio_id, db)
|
||||||
# 1. Sharpe & Drawdown from latest snapshot
|
# 1. Sharpe & Drawdown from latest snapshot
|
||||||
snapshot = db.get_latest_snapshot(portfolio_id)
|
snapshot = db.get_latest_snapshot(portfolio_id)
|
||||||
sharpe = 0.0
|
sharpe = 0.0
|
||||||
|
|
@ -94,6 +106,7 @@ async def get_latest_portfolio_state(
|
||||||
db: SupabaseClient = Depends(get_db_client)
|
db: SupabaseClient = Depends(get_db_client)
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
|
portfolio_id = _resolve_portfolio_id(portfolio_id, db)
|
||||||
portfolio = db.get_portfolio(portfolio_id)
|
portfolio = db.get_portfolio(portfolio_id)
|
||||||
snapshot = db.get_latest_snapshot(portfolio_id)
|
snapshot = db.get_latest_snapshot(portfolio_id)
|
||||||
holdings = db.list_holdings(portfolio_id)
|
holdings = db.list_holdings(portfolio_id)
|
||||||
|
|
|
||||||
|
|
@ -190,6 +190,9 @@ class LangGraphEngine:
|
||||||
def _extract_content(obj: object) -> str:
|
def _extract_content(obj: object) -> str:
|
||||||
"""Safely extract text content from a LangChain message or plain object."""
|
"""Safely extract text content from a LangChain message or plain object."""
|
||||||
content = getattr(obj, "content", None)
|
content = getattr(obj, "content", None)
|
||||||
|
# Handle cases where .content might be a method instead of a property
|
||||||
|
if content is not None and callable(content):
|
||||||
|
content = None
|
||||||
return str(content) if content is not None else str(obj)
|
return str(content) if content is not None else str(obj)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
@ -421,10 +424,19 @@ class LangGraphEngine:
|
||||||
# If .content was empty or the repr of the whole object, try harder
|
# If .content was empty or the repr of the whole object, try harder
|
||||||
if not raw or raw.startswith("<") or raw == str(output):
|
if not raw or raw.startswith("<") or raw == str(output):
|
||||||
# Some providers wrap in .text or .message
|
# Some providers wrap in .text or .message
|
||||||
|
potential_text = getattr(output, "text", "")
|
||||||
|
if callable(potential_text):
|
||||||
|
potential_text = ""
|
||||||
|
|
||||||
raw = (
|
raw = (
|
||||||
getattr(output, "text", "")
|
potential_text
|
||||||
or (output.get("content", "") if isinstance(output, dict) else "")
|
or (output.get("content", "") if isinstance(output, dict) else "")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Ensure raw is a string before subscripting
|
||||||
|
if not isinstance(raw, str):
|
||||||
|
raw = str(raw)
|
||||||
|
|
||||||
if raw:
|
if raw:
|
||||||
full_response = raw[:_MAX_FULL_LEN]
|
full_response = raw[:_MAX_FULL_LEN]
|
||||||
response_snippet = self._truncate(raw)
|
response_snippet = self._truncate(raw)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,6 @@
|
||||||
|
{
|
||||||
|
"name": "TradingAgents",
|
||||||
|
"lockfileVersion": 3,
|
||||||
|
"requires": true,
|
||||||
|
"packages": {}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,68 @@
|
||||||
|
import sys
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
# Add project root to sys.path
|
||||||
|
sys.path.append("/Users/Ahmet/Repo/TradingAgents")
|
||||||
|
|
||||||
|
from agent_os.backend.services.langgraph_engine import LangGraphEngine
|
||||||
|
|
||||||
|
class TestLangGraphEngineExtraction(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.engine = LangGraphEngine()
|
||||||
|
|
||||||
|
def test_extract_content_string(self):
|
||||||
|
mock_obj = MagicMock()
|
||||||
|
mock_obj.content = "hello world"
|
||||||
|
self.assertEqual(self.engine._extract_content(mock_obj), "hello world")
|
||||||
|
|
||||||
|
def test_extract_content_method(self):
|
||||||
|
mock_obj = MagicMock()
|
||||||
|
# Mocking a method
|
||||||
|
def my_content():
|
||||||
|
return "should not be called"
|
||||||
|
mock_obj.content = my_content
|
||||||
|
# Should fall back to str(mock_obj)
|
||||||
|
result = self.engine._extract_content(mock_obj)
|
||||||
|
self.assertTrue(result.startswith("<MagicMock"))
|
||||||
|
|
||||||
|
def test_map_langgraph_event_llm_end_with_text_method(self):
|
||||||
|
# Mocking output object with a text method
|
||||||
|
mock_output = MagicMock()
|
||||||
|
def my_text():
|
||||||
|
return "bad"
|
||||||
|
mock_output.text = my_text
|
||||||
|
mock_output.content = None # Ensure it triggers fallback
|
||||||
|
|
||||||
|
event = {
|
||||||
|
"event": "on_chat_model_end",
|
||||||
|
"run_id": "test_run",
|
||||||
|
"name": "test_node",
|
||||||
|
"data": {"output": mock_output},
|
||||||
|
"metadata": {"langgraph_node": "test_node"}
|
||||||
|
}
|
||||||
|
|
||||||
|
# This used to raise TypeError
|
||||||
|
result = self.engine._map_langgraph_event("run_123", event)
|
||||||
|
self.assertIsNotNone(result)
|
||||||
|
self.assertIsInstance(result["response"], str)
|
||||||
|
# It's okay if it's empty, as long as it didn't crash
|
||||||
|
|
||||||
|
def test_map_langgraph_event_llm_end_with_text_string(self):
|
||||||
|
mock_output = MagicMock()
|
||||||
|
mock_output.text = "good text"
|
||||||
|
mock_output.content = None
|
||||||
|
|
||||||
|
event = {
|
||||||
|
"event": "on_chat_model_end",
|
||||||
|
"run_id": "test_run",
|
||||||
|
"name": "test_node",
|
||||||
|
"data": {"output": mock_output},
|
||||||
|
"metadata": {"langgraph_node": "test_node"}
|
||||||
|
}
|
||||||
|
|
||||||
|
result = self.engine._map_langgraph_event("run_123", event)
|
||||||
|
self.assertEqual(result["response"], "good text")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Loading…
Reference in New Issue