diff --git a/cli/main.py b/cli/main.py index 1647ea01..0427cae2 100644 --- a/cli/main.py +++ b/cli/main.py @@ -1722,20 +1722,22 @@ def run_portfolio(portfolio_id: str, date: str, macro_path: Path): repo = PortfolioRepository() - # Check if portfolio exists - portfolio = repo.get_portfolio(portfolio_id) - if not portfolio: + # Check if portfolio exists and fetch holdings + try: + portfolio, holdings = repo.get_portfolio_with_holdings(portfolio_id) + except Exception as e: console.print( - f"[yellow]Portfolio '{portfolio_id}' not found. Please ensure it is created in the database.[/yellow]" + f"[yellow]Failed to load portfolio '{portfolio_id}': {e}[/yellow]\n" + "Please ensure it is created in the database using 'python -m cli.main init-portfolio'." ) raise typer.Exit(1) - holdings = repo.get_holdings(portfolio_id) - - candidates = scan_summary.get("stocks_to_investigate", []) + # scan_summary["stocks_to_investigate"] is a list of dicts, we just want the tickers + candidate_dicts = scan_summary.get("stocks_to_investigate", []) + candidate_tickers = [c.get("ticker") for c in candidate_dicts if isinstance(c, dict) and "ticker" in c] holding_tickers = [h.ticker for h in holdings] - all_tickers = set(candidates + holding_tickers) + all_tickers = set(candidate_tickers + holding_tickers) console.print(f"[cyan]Fetching prices for {len(all_tickers)} tickers...[/cyan]") prices = {} @@ -1795,6 +1797,26 @@ def portfolio(): run_portfolio(portfolio_id, date, macro_path) +@app.command() +def init_portfolio( + name: str = typer.Option("My Portfolio", "--name", "-n", help="Name of the new portfolio"), + cash: float = typer.Option(100000.0, "--cash", "-c", help="Starting cash balance"), +): + """Create a completely new portfolio in the database and return its UUID.""" + from tradingagents.portfolio import PortfolioRepository + + console.print(f"[cyan]Initializing new portfolio '{name}' with ${cash:,.2f} cash...[/cyan]") + repo = PortfolioRepository() + try: + portfolio = repo.create_portfolio(name, initial_cash=cash) + console.print("[green]Portfolio created successfully![/green]") + console.print(f"\n[bold white]Your new Portfolio UUID is:[/bold white] [bold magenta]{portfolio.portfolio_id}[/bold magenta]") + console.print("\n[dim]Copy this UUID and paste it when the Portfolio Manager asks for 'Portfolio ID'.[/dim]\n") + except Exception as e: + console.print(f"[red]Failed to create portfolio: {e}[/red]") + raise typer.Exit(1) + + @app.command(name="check-portfolio") def check_portfolio( portfolio_id: str = typer.Option( diff --git a/pyproject.toml b/pyproject.toml index 86cd0ea6..6e88f4aa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,9 @@ dependencies = [ "typing-extensions>=4.14.0", "yfinance>=0.2.63", "psycopg2-binary>=2.9.11", + "fastapi>=0.115.9", + "uvicorn>=0.34.3", + "websockets>=15.0.1", ] [project.scripts] diff --git a/test_agent_os_connection.py b/test_agent_os_connection.py new file mode 100644 index 00000000..c1c5d610 --- /dev/null +++ b/test_agent_os_connection.py @@ -0,0 +1,40 @@ +import asyncio +import websockets +import json + +async def test_ws(): + uri = "ws://localhost:8001/ws/stream/test_run" + try: + async with websockets.connect(uri) as websocket: + print("Connected to WebSocket") + while True: + try: + message = await asyncio.wait_for(websocket.recv(), timeout=5.0) + data = json.loads(message) + print(f"Received: {data['type']} from {data.get('agent', 'system')}") + if data['type'] == 'system' and 'completed' in data['message']: + break + except asyncio.TimeoutError: + print("Timeout waiting for message") + break + except Exception as e: + print(f"Connection failed: {e}") + +if __name__ == "__main__": + # We need to trigger a run first to make the ID valid in the store + import requests + try: + resp = requests.post("http://localhost:8001/api/run/scan", json={}) + run_id = resp.json()["run_id"] + print(f"Triggered run: {run_id}") + + # Now connect to the stream + uri = f"ws://localhost:8001/ws/stream/{run_id}" + async def run_test(): + async with websockets.connect(uri) as ws: + print("Stream connected") + async for msg in ws: + print(f"Msg: {msg[:100]}...") + asyncio.run(run_test()) + except Exception as e: + print(f"Error: {e}") diff --git a/uv.lock b/uv.lock index 64842bfe..af02086e 100644 --- a/uv.lock +++ b/uv.lock @@ -3628,6 +3628,7 @@ source = { editable = "." } dependencies = [ { name = "backtrader" }, { name = "chainlit" }, + { name = "fastapi" }, { name = "langchain-anthropic" }, { name = "langchain-core" }, { name = "langchain-experimental" }, @@ -3649,6 +3650,8 @@ dependencies = [ { name = "tqdm" }, { name = "typer" }, { name = "typing-extensions" }, + { name = "uvicorn" }, + { name = "websockets" }, { name = "yfinance" }, ] @@ -3664,6 +3667,7 @@ dev = [ requires-dist = [ { name = "backtrader", specifier = ">=1.9.78.123" }, { name = "chainlit", specifier = ">=2.5.5" }, + { name = "fastapi", specifier = ">=0.115.9" }, { name = "langchain-anthropic", specifier = ">=0.3.15" }, { name = "langchain-core", specifier = ">=0.3.81" }, { name = "langchain-experimental", specifier = ">=0.3.4" }, @@ -3685,6 +3689,8 @@ requires-dist = [ { name = "tqdm", specifier = ">=4.67.1" }, { name = "typer", specifier = ">=0.21.0" }, { name = "typing-extensions", specifier = ">=4.14.0" }, + { name = "uvicorn", specifier = ">=0.34.3" }, + { name = "websockets", specifier = ">=15.0.1" }, { name = "yfinance", specifier = ">=0.2.63" }, ]