feat: Resolve "main_portfolio" alias in portfolio routes and improve LangGraph content extraction robustness with new unit tests.
This commit is contained in:
parent
6999da0827
commit
319168c74f
|
|
@ -10,6 +10,16 @@ import datetime
|
|||
|
||||
router = APIRouter(prefix="/api/portfolios", tags=["portfolios"])
|
||||
|
||||
def _resolve_portfolio_id(portfolio_id: str, db: SupabaseClient) -> str:
|
||||
"""Resolves the 'main_portfolio' alias to the first available portfolio ID."""
|
||||
if portfolio_id == "main_portfolio":
|
||||
portfolios = db.list_portfolios()
|
||||
if portfolios:
|
||||
return portfolios[0].portfolio_id
|
||||
else:
|
||||
raise PortfolioNotFoundError("No portfolios found to resolve 'main_portfolio' alias.")
|
||||
return portfolio_id
|
||||
|
||||
@router.get("/")
|
||||
async def list_portfolios(
|
||||
user: dict = Depends(get_current_user),
|
||||
|
|
@ -25,6 +35,7 @@ async def get_portfolio(
|
|||
db: SupabaseClient = Depends(get_db_client)
|
||||
):
|
||||
try:
|
||||
portfolio_id = _resolve_portfolio_id(portfolio_id, db)
|
||||
portfolio = db.get_portfolio(portfolio_id)
|
||||
return portfolio.to_dict()
|
||||
except PortfolioNotFoundError:
|
||||
|
|
@ -42,6 +53,7 @@ async def get_portfolio_summary(
|
|||
date = datetime.datetime.now().strftime("%Y-%m-%d")
|
||||
|
||||
try:
|
||||
portfolio_id = _resolve_portfolio_id(portfolio_id, db)
|
||||
# 1. Sharpe & Drawdown from latest snapshot
|
||||
snapshot = db.get_latest_snapshot(portfolio_id)
|
||||
sharpe = 0.0
|
||||
|
|
@ -94,6 +106,7 @@ async def get_latest_portfolio_state(
|
|||
db: SupabaseClient = Depends(get_db_client)
|
||||
):
|
||||
try:
|
||||
portfolio_id = _resolve_portfolio_id(portfolio_id, db)
|
||||
portfolio = db.get_portfolio(portfolio_id)
|
||||
snapshot = db.get_latest_snapshot(portfolio_id)
|
||||
holdings = db.list_holdings(portfolio_id)
|
||||
|
|
|
|||
|
|
@ -190,6 +190,9 @@ class LangGraphEngine:
|
|||
def _extract_content(obj: object) -> str:
|
||||
"""Safely extract text content from a LangChain message or plain object."""
|
||||
content = getattr(obj, "content", None)
|
||||
# Handle cases where .content might be a method instead of a property
|
||||
if content is not None and callable(content):
|
||||
content = None
|
||||
return str(content) if content is not None else str(obj)
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -421,10 +424,19 @@ class LangGraphEngine:
|
|||
# If .content was empty or the repr of the whole object, try harder
|
||||
if not raw or raw.startswith("<") or raw == str(output):
|
||||
# Some providers wrap in .text or .message
|
||||
potential_text = getattr(output, "text", "")
|
||||
if callable(potential_text):
|
||||
potential_text = ""
|
||||
|
||||
raw = (
|
||||
getattr(output, "text", "")
|
||||
potential_text
|
||||
or (output.get("content", "") if isinstance(output, dict) else "")
|
||||
)
|
||||
|
||||
# Ensure raw is a string before subscripting
|
||||
if not isinstance(raw, str):
|
||||
raw = str(raw)
|
||||
|
||||
if raw:
|
||||
full_response = raw[:_MAX_FULL_LEN]
|
||||
response_snippet = self._truncate(raw)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,6 @@
|
|||
{
|
||||
"name": "TradingAgents",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {}
|
||||
}
|
||||
|
|
@ -0,0 +1,68 @@
|
|||
import sys
|
||||
import unittest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
# Add project root to sys.path
|
||||
sys.path.append("/Users/Ahmet/Repo/TradingAgents")
|
||||
|
||||
from agent_os.backend.services.langgraph_engine import LangGraphEngine
|
||||
|
||||
class TestLangGraphEngineExtraction(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.engine = LangGraphEngine()
|
||||
|
||||
def test_extract_content_string(self):
|
||||
mock_obj = MagicMock()
|
||||
mock_obj.content = "hello world"
|
||||
self.assertEqual(self.engine._extract_content(mock_obj), "hello world")
|
||||
|
||||
def test_extract_content_method(self):
|
||||
mock_obj = MagicMock()
|
||||
# Mocking a method
|
||||
def my_content():
|
||||
return "should not be called"
|
||||
mock_obj.content = my_content
|
||||
# Should fall back to str(mock_obj)
|
||||
result = self.engine._extract_content(mock_obj)
|
||||
self.assertTrue(result.startswith("<MagicMock"))
|
||||
|
||||
def test_map_langgraph_event_llm_end_with_text_method(self):
|
||||
# Mocking output object with a text method
|
||||
mock_output = MagicMock()
|
||||
def my_text():
|
||||
return "bad"
|
||||
mock_output.text = my_text
|
||||
mock_output.content = None # Ensure it triggers fallback
|
||||
|
||||
event = {
|
||||
"event": "on_chat_model_end",
|
||||
"run_id": "test_run",
|
||||
"name": "test_node",
|
||||
"data": {"output": mock_output},
|
||||
"metadata": {"langgraph_node": "test_node"}
|
||||
}
|
||||
|
||||
# This used to raise TypeError
|
||||
result = self.engine._map_langgraph_event("run_123", event)
|
||||
self.assertIsNotNone(result)
|
||||
self.assertIsInstance(result["response"], str)
|
||||
# It's okay if it's empty, as long as it didn't crash
|
||||
|
||||
def test_map_langgraph_event_llm_end_with_text_string(self):
|
||||
mock_output = MagicMock()
|
||||
mock_output.text = "good text"
|
||||
mock_output.content = None
|
||||
|
||||
event = {
|
||||
"event": "on_chat_model_end",
|
||||
"run_id": "test_run",
|
||||
"name": "test_node",
|
||||
"data": {"output": mock_output},
|
||||
"metadata": {"langgraph_node": "test_node"}
|
||||
}
|
||||
|
||||
result = self.engine._map_langgraph_event("run_123", event)
|
||||
self.assertEqual(result["response"], "good text")
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Loading…
Reference in New Issue