This commit is contained in:
kimheesu 2025-07-07 14:22:27 +09:00
parent fbd96e9c18
commit ab1b0120c2
73 changed files with 35353 additions and 44998 deletions

21
.gitignore vendored
View File

@ -1,10 +1,11 @@
env/ env/
__pycache__/ __pycache__/
.DS_Store .DS_Store
*.csv *.csv
src/ src/
eval_results/ eval_results/
eval_data/ eval_data/
*.egg-info/ *.egg-info/
results/ results/
.env .env
tradingagents/dataflows/data_cache/

View File

@ -1 +1 @@
3.10 3.10

28
.vscode/launch.json vendored
View File

@ -1,15 +1,15 @@
{ {
// Use IntelliSense to learn about possible attributes. // Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes. // Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0", "version": "0.2.0",
"configurations": [ "configurations": [
{ {
"name": "Python Debugger: main.py", "name": "Python Debugger: main.py",
"type": "debugpy", "type": "debugpy",
"request": "launch", "request": "launch",
"program": "${workspaceFolder}/main.py", "program": "${workspaceFolder}/main.py",
"console": "integratedTerminal" "console": "integratedTerminal"
} }
] ]
} }

430
README.md
View File

@ -1,215 +1,215 @@
<p align="center"> <p align="center">
<img src="assets/TauricResearch.png" style="width: 60%; height: auto;"> <img src="assets/TauricResearch.png" style="width: 60%; height: auto;">
</p> </p>
<div align="center" style="line-height: 1;"> <div align="center" style="line-height: 1;">
<a href="https://arxiv.org/abs/2412.20138" target="_blank"><img alt="arXiv" src="https://img.shields.io/badge/arXiv-2412.20138-B31B1B?logo=arxiv"/></a> <a href="https://arxiv.org/abs/2412.20138" target="_blank"><img alt="arXiv" src="https://img.shields.io/badge/arXiv-2412.20138-B31B1B?logo=arxiv"/></a>
<a href="https://discord.com/invite/hk9PGKShPK" target="_blank"><img alt="Discord" src="https://img.shields.io/badge/Discord-TradingResearch-7289da?logo=discord&logoColor=white&color=7289da"/></a> <a href="https://discord.com/invite/hk9PGKShPK" target="_blank"><img alt="Discord" src="https://img.shields.io/badge/Discord-TradingResearch-7289da?logo=discord&logoColor=white&color=7289da"/></a>
<a href="./assets/wechat.png" target="_blank"><img alt="WeChat" src="https://img.shields.io/badge/WeChat-TauricResearch-brightgreen?logo=wechat&logoColor=white"/></a> <a href="./assets/wechat.png" target="_blank"><img alt="WeChat" src="https://img.shields.io/badge/WeChat-TauricResearch-brightgreen?logo=wechat&logoColor=white"/></a>
<a href="https://x.com/TauricResearch" target="_blank"><img alt="X Follow" src="https://img.shields.io/badge/X-TauricResearch-white?logo=x&logoColor=white"/></a> <a href="https://x.com/TauricResearch" target="_blank"><img alt="X Follow" src="https://img.shields.io/badge/X-TauricResearch-white?logo=x&logoColor=white"/></a>
<br> <br>
<a href="https://github.com/TauricResearch/" target="_blank"><img alt="Community" src="https://img.shields.io/badge/Join_GitHub_Community-TauricResearch-14C290?logo=discourse"/></a> <a href="https://github.com/TauricResearch/" target="_blank"><img alt="Community" src="https://img.shields.io/badge/Join_GitHub_Community-TauricResearch-14C290?logo=discourse"/></a>
</div> </div>
<div align="center"> <div align="center">
<!-- Keep these links. Translations will automatically update with the README. --> <!-- Keep these links. Translations will automatically update with the README. -->
<a href="https://www.readme-i18n.com/TauricResearch/TradingAgents?lang=de">Deutsch</a> | <a href="https://www.readme-i18n.com/TauricResearch/TradingAgents?lang=de">Deutsch</a> |
<a href="https://www.readme-i18n.com/TauricResearch/TradingAgents?lang=es">Español</a> | <a href="https://www.readme-i18n.com/TauricResearch/TradingAgents?lang=es">Español</a> |
<a href="https://www.readme-i18n.com/TauricResearch/TradingAgents?lang=fr">français</a> | <a href="https://www.readme-i18n.com/TauricResearch/TradingAgents?lang=fr">français</a> |
<a href="https://www.readme-i18n.com/TauricResearch/TradingAgents?lang=ja">日本語</a> | <a href="https://www.readme-i18n.com/TauricResearch/TradingAgents?lang=ja">日本語</a> |
<a href="https://www.readme-i18n.com/TauricResearch/TradingAgents?lang=ko">한국어</a> | <a href="https://www.readme-i18n.com/TauricResearch/TradingAgents?lang=ko">한국어</a> |
<a href="https://www.readme-i18n.com/TauricResearch/TradingAgents?lang=pt">Português</a> | <a href="https://www.readme-i18n.com/TauricResearch/TradingAgents?lang=pt">Português</a> |
<a href="https://www.readme-i18n.com/TauricResearch/TradingAgents?lang=ru">Русский</a> | <a href="https://www.readme-i18n.com/TauricResearch/TradingAgents?lang=ru">Русский</a> |
<a href="https://www.readme-i18n.com/TauricResearch/TradingAgents?lang=zh">中文</a> <a href="https://www.readme-i18n.com/TauricResearch/TradingAgents?lang=zh">中文</a>
</div> </div>
--- ---
# TradingAgents: Multi-Agents LLM Financial Trading Framework # TradingAgents: Multi-Agents LLM Financial Trading Framework
> 🎉 **TradingAgents** officially released! We have received numerous inquiries about the work, and we would like to express our thanks for the enthusiasm in our community. > 🎉 **TradingAgents** officially released! We have received numerous inquiries about the work, and we would like to express our thanks for the enthusiasm in our community.
> >
> So we decided to fully open-source the framework. Looking forward to building impactful projects with you! > So we decided to fully open-source the framework. Looking forward to building impactful projects with you!
<div align="center"> <div align="center">
<a href="https://www.star-history.com/#TauricResearch/TradingAgents&Date"> <a href="https://www.star-history.com/#TauricResearch/TradingAgents&Date">
<picture> <picture>
<source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=TauricResearch/TradingAgents&type=Date&theme=dark" /> <source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=TauricResearch/TradingAgents&type=Date&theme=dark" />
<source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=TauricResearch/TradingAgents&type=Date" /> <source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=TauricResearch/TradingAgents&type=Date" />
<img alt="TradingAgents Star History" src="https://api.star-history.com/svg?repos=TauricResearch/TradingAgents&type=Date" style="width: 80%; height: auto;" /> <img alt="TradingAgents Star History" src="https://api.star-history.com/svg?repos=TauricResearch/TradingAgents&type=Date" style="width: 80%; height: auto;" />
</picture> </picture>
</a> </a>
</div> </div>
<div align="center"> <div align="center">
🚀 [TradingAgents](#tradingagents-framework) | ⚡ [Installation & CLI](#installation-and-cli) | 🎬 [Demo](https://www.youtube.com/watch?v=90gr5lwjIho) | 📦 [Package Usage](#tradingagents-package) | 🤝 [Contributing](#contributing) | 📄 [Citation](#citation) 🚀 [TradingAgents](#tradingagents-framework) | ⚡ [Installation & CLI](#installation-and-cli) | 🎬 [Demo](https://www.youtube.com/watch?v=90gr5lwjIho) | 📦 [Package Usage](#tradingagents-package) | 🤝 [Contributing](#contributing) | 📄 [Citation](#citation)
</div> </div>
## TradingAgents Framework ## TradingAgents Framework
TradingAgents is a multi-agent trading framework that mirrors the dynamics of real-world trading firms. By deploying specialized LLM-powered agents: from fundamental analysts, sentiment experts, and technical analysts, to trader, risk management team, the platform collaboratively evaluates market conditions and informs trading decisions. Moreover, these agents engage in dynamic discussions to pinpoint the optimal strategy. TradingAgents is a multi-agent trading framework that mirrors the dynamics of real-world trading firms. By deploying specialized LLM-powered agents: from fundamental analysts, sentiment experts, and technical analysts, to trader, risk management team, the platform collaboratively evaluates market conditions and informs trading decisions. Moreover, these agents engage in dynamic discussions to pinpoint the optimal strategy.
<p align="center"> <p align="center">
<img src="assets/schema.png" style="width: 100%; height: auto;"> <img src="assets/schema.png" style="width: 100%; height: auto;">
</p> </p>
> TradingAgents framework is designed for research purposes. Trading performance may vary based on many factors, including the chosen backbone language models, model temperature, trading periods, the quality of data, and other non-deterministic factors. [It is not intended as financial, investment, or trading advice.](https://tauric.ai/disclaimer/) > TradingAgents framework is designed for research purposes. Trading performance may vary based on many factors, including the chosen backbone language models, model temperature, trading periods, the quality of data, and other non-deterministic factors. [It is not intended as financial, investment, or trading advice.](https://tauric.ai/disclaimer/)
Our framework decomposes complex trading tasks into specialized roles. This ensures the system achieves a robust, scalable approach to market analysis and decision-making. Our framework decomposes complex trading tasks into specialized roles. This ensures the system achieves a robust, scalable approach to market analysis and decision-making.
### Analyst Team ### Analyst Team
- Fundamentals Analyst: Evaluates company financials and performance metrics, identifying intrinsic values and potential red flags. - Fundamentals Analyst: Evaluates company financials and performance metrics, identifying intrinsic values and potential red flags.
- Sentiment Analyst: Analyzes social media and public sentiment using sentiment scoring algorithms to gauge short-term market mood. - Sentiment Analyst: Analyzes social media and public sentiment using sentiment scoring algorithms to gauge short-term market mood.
- News Analyst: Monitors global news and macroeconomic indicators, interpreting the impact of events on market conditions. - News Analyst: Monitors global news and macroeconomic indicators, interpreting the impact of events on market conditions.
- Technical Analyst: Utilizes technical indicators (like MACD and RSI) to detect trading patterns and forecast price movements. - Technical Analyst: Utilizes technical indicators (like MACD and RSI) to detect trading patterns and forecast price movements.
<p align="center"> <p align="center">
<img src="assets/analyst.png" width="100%" style="display: inline-block; margin: 0 2%;"> <img src="assets/analyst.png" width="100%" style="display: inline-block; margin: 0 2%;">
</p> </p>
### Researcher Team ### Researcher Team
- Comprises both bullish and bearish researchers who critically assess the insights provided by the Analyst Team. Through structured debates, they balance potential gains against inherent risks. - Comprises both bullish and bearish researchers who critically assess the insights provided by the Analyst Team. Through structured debates, they balance potential gains against inherent risks.
<p align="center"> <p align="center">
<img src="assets/researcher.png" width="70%" style="display: inline-block; margin: 0 2%;"> <img src="assets/researcher.png" width="70%" style="display: inline-block; margin: 0 2%;">
</p> </p>
### Trader Agent ### Trader Agent
- Composes reports from the analysts and researchers to make informed trading decisions. It determines the timing and magnitude of trades based on comprehensive market insights. - Composes reports from the analysts and researchers to make informed trading decisions. It determines the timing and magnitude of trades based on comprehensive market insights.
<p align="center"> <p align="center">
<img src="assets/trader.png" width="70%" style="display: inline-block; margin: 0 2%;"> <img src="assets/trader.png" width="70%" style="display: inline-block; margin: 0 2%;">
</p> </p>
### Risk Management and Portfolio Manager ### Risk Management and Portfolio Manager
- Continuously evaluates portfolio risk by assessing market volatility, liquidity, and other risk factors. The risk management team evaluates and adjusts trading strategies, providing assessment reports to the Portfolio Manager for final decision. - Continuously evaluates portfolio risk by assessing market volatility, liquidity, and other risk factors. The risk management team evaluates and adjusts trading strategies, providing assessment reports to the Portfolio Manager for final decision.
- The Portfolio Manager approves/rejects the transaction proposal. If approved, the order will be sent to the simulated exchange and executed. - The Portfolio Manager approves/rejects the transaction proposal. If approved, the order will be sent to the simulated exchange and executed.
<p align="center"> <p align="center">
<img src="assets/risk.png" width="70%" style="display: inline-block; margin: 0 2%;"> <img src="assets/risk.png" width="70%" style="display: inline-block; margin: 0 2%;">
</p> </p>
## Installation and CLI ## Installation and CLI
### Installation ### Installation
Clone TradingAgents: Clone TradingAgents:
```bash ```bash
git clone https://github.com/TauricResearch/TradingAgents.git git clone https://github.com/TauricResearch/TradingAgents.git
cd TradingAgents cd TradingAgents
``` ```
Create a virtual environment in any of your favorite environment managers: Create a virtual environment in any of your favorite environment managers:
```bash ```bash
conda create -n tradingagents python=3.13 conda create -n tradingagents python=3.13
conda activate tradingagents conda activate tradingagents
``` ```
Install dependencies: Install dependencies:
```bash ```bash
pip install -r requirements.txt pip install -r requirements.txt
``` ```
### Required APIs ### Required APIs
You will also need the FinnHub API for financial data. All of our code is implemented with the free tier. You will also need the FinnHub API for financial data. All of our code is implemented with the free tier.
```bash ```bash
export FINNHUB_API_KEY=$YOUR_FINNHUB_API_KEY export FINNHUB_API_KEY=$YOUR_FINNHUB_API_KEY
``` ```
You will need the OpenAI API or GEMINI API for all the agents. You will need the OpenAI API or GEMINI API for all the agents.
```bash ```bash
export OPENAI_API_KEY=$YOUR_OPENAI_API_KEY export OPENAI_API_KEY=$YOUR_OPENAI_API_KEY
export GEMINI_API_KEY=$YOUR_GEMINI_API_KEY export GEMINI_API_KEY=$YOUR_GEMINI_API_KEY
export GOOGLE_API_KEY=$YOUR_GEMINI_API_KEY export GOOGLE_API_KEY=$YOUR_GEMINI_API_KEY
``` ```
### CLI Usage ### CLI Usage
You can also try out the CLI directly by running: You can also try out the CLI directly by running:
```bash ```bash
python -m cli.main python -m cli.main
``` ```
You will see a screen where you can select your desired tickers, date, LLMs, research depth, etc. You will see a screen where you can select your desired tickers, date, LLMs, research depth, etc.
<p align="center"> <p align="center">
<img src="assets/cli/cli_init.png" width="100%" style="display: inline-block; margin: 0 2%;"> <img src="assets/cli/cli_init.png" width="100%" style="display: inline-block; margin: 0 2%;">
</p> </p>
An interface will appear showing results as they load, letting you track the agent's progress as it runs. An interface will appear showing results as they load, letting you track the agent's progress as it runs.
<p align="center"> <p align="center">
<img src="assets/cli/cli_news.png" width="100%" style="display: inline-block; margin: 0 2%;"> <img src="assets/cli/cli_news.png" width="100%" style="display: inline-block; margin: 0 2%;">
</p> </p>
<p align="center"> <p align="center">
<img src="assets/cli/cli_transaction.png" width="100%" style="display: inline-block; margin: 0 2%;"> <img src="assets/cli/cli_transaction.png" width="100%" style="display: inline-block; margin: 0 2%;">
</p> </p>
## TradingAgents Package ## TradingAgents Package
### Implementation Details ### Implementation Details
We built TradingAgents with LangGraph to ensure flexibility and modularity. We utilize `o1-preview` and `gpt-4o` as our deep thinking and fast thinking LLMs for our experiments. However, for testing purposes, we recommend you use `o4-mini` and `gpt-4.1-mini` to save on costs as our framework makes **lots of** API calls. We built TradingAgents with LangGraph to ensure flexibility and modularity. We utilize `o1-preview` and `gpt-4o` as our deep thinking and fast thinking LLMs for our experiments. However, for testing purposes, we recommend you use `o4-mini` and `gpt-4.1-mini` to save on costs as our framework makes **lots of** API calls.
### Python Usage ### Python Usage
To use TradingAgents inside your code, you can import the `tradingagents` module and initialize a `TradingAgentsGraph()` object. The `.propagate()` function will return a decision. You can run `main.py`, here's also a quick example: To use TradingAgents inside your code, you can import the `tradingagents` module and initialize a `TradingAgentsGraph()` object. The `.propagate()` function will return a decision. You can run `main.py`, here's also a quick example:
```python ```python
from tradingagents.graph.trading_graph import TradingAgentsGraph from tradingagents.graph.trading_graph import TradingAgentsGraph
from tradingagents.default_config import DEFAULT_CONFIG from tradingagents.default_config import DEFAULT_CONFIG
ta = TradingAgentsGraph(debug=True, config=DEFAULT_CONFIG.copy()) ta = TradingAgentsGraph(debug=True, config=DEFAULT_CONFIG.copy())
# forward propagate # forward propagate
_, decision = ta.propagate("NVDA", "2024-05-10") _, decision = ta.propagate("NVDA", "2024-05-10")
print(decision) print(decision)
``` ```
You can also adjust the default configuration to set your own choice of LLMs, debate rounds, etc. You can also adjust the default configuration to set your own choice of LLMs, debate rounds, etc.
```python ```python
from tradingagents.graph.trading_graph import TradingAgentsGraph from tradingagents.graph.trading_graph import TradingAgentsGraph
from tradingagents.default_config import DEFAULT_CONFIG from tradingagents.default_config import DEFAULT_CONFIG
# Create a custom config # Create a custom config
config = DEFAULT_CONFIG.copy() config = DEFAULT_CONFIG.copy()
config["deep_think_llm"] = "gpt-4.1-nano" # Use a different model config["deep_think_llm"] = "gpt-4.1-nano" # Use a different model
config["quick_think_llm"] = "gpt-4.1-nano" # Use a different model config["quick_think_llm"] = "gpt-4.1-nano" # Use a different model
config["max_debate_rounds"] = 1 # Increase debate rounds config["max_debate_rounds"] = 1 # Increase debate rounds
config["online_tools"] = True # Use online tools or cached data config["online_tools"] = True # Use online tools or cached data
# Initialize with custom config # Initialize with custom config
ta = TradingAgentsGraph(debug=True, config=config) ta = TradingAgentsGraph(debug=True, config=config)
# forward propagate # forward propagate
_, decision = ta.propagate("NVDA", "2024-05-10") _, decision = ta.propagate("NVDA", "2024-05-10")
print(decision) print(decision)
``` ```
> For `online_tools`, we recommend enabling them for experimentation, as they provide access to real-time data. The agents' offline tools rely on cached data from our **Tauric TradingDB**, a curated dataset we use for backtesting. We're currently in the process of refining this dataset, and we plan to release it soon alongside our upcoming projects. Stay tuned! > For `online_tools`, we recommend enabling them for experimentation, as they provide access to real-time data. The agents' offline tools rely on cached data from our **Tauric TradingDB**, a curated dataset we use for backtesting. We're currently in the process of refining this dataset, and we plan to release it soon alongside our upcoming projects. Stay tuned!
You can view the full list of configurations in `tradingagents/default_config.py`. You can view the full list of configurations in `tradingagents/default_config.py`.
## Contributing ## Contributing
We welcome contributions from the community! Whether it's fixing a bug, improving documentation, or suggesting a new feature, your input helps make this project better. If you are interested in this line of research, please consider joining our open-source financial AI research community [Tauric Research](https://tauric.ai/). We welcome contributions from the community! Whether it's fixing a bug, improving documentation, or suggesting a new feature, your input helps make this project better. If you are interested in this line of research, please consider joining our open-source financial AI research community [Tauric Research](https://tauric.ai/).
## Citation ## Citation
Please reference our work if you find *TradingAgents* provides you with some help :) Please reference our work if you find *TradingAgents* provides you with some help :)
``` ```
@misc{xiao2025tradingagentsmultiagentsllmfinancial, @misc{xiao2025tradingagentsmultiagentsllmfinancial,
title={TradingAgents: Multi-Agents LLM Financial Trading Framework}, title={TradingAgents: Multi-Agents LLM Financial Trading Framework},
author={Yijia Xiao and Edward Sun and Di Luo and Wei Wang}, author={Yijia Xiao and Edward Sun and Di Luo and Wei Wang},
year={2025}, year={2025},
eprint={2412.20138}, eprint={2412.20138},
archivePrefix={arXiv}, archivePrefix={arXiv},
primaryClass={q-fin.TR}, primaryClass={q-fin.TR},
url={https://arxiv.org/abs/2412.20138}, url={https://arxiv.org/abs/2412.20138},
} }
``` ```

View File

@ -1,67 +1,67 @@
from typing import Generator, Optional from typing import Generator, Optional
from fastapi import Depends, HTTPException, status from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer from fastapi.security import OAuth2PasswordBearer
from jose import jwt, JWTError from jose import jwt, JWTError
from pydantic import BaseModel from pydantic import BaseModel
from sqlmodel import Session from sqlmodel import Session
from app.core.config import settings from app.core.config import settings
from app.infrastructure.database import get_db from app.infrastructure.database import get_db
from app.domain.models import User from app.domain.models import User
from app.infrastructure.repositories.user import UserRepository from app.infrastructure.repositories.user import UserRepository
from app.core.services.trading_analysis import TradingAnalysisService from app.core.services.trading_analysis import TradingAnalysisService
reusable_oauth2 = OAuth2PasswordBearer(tokenUrl=f"{settings.API_V1_STR}/login/access-token") reusable_oauth2 = OAuth2PasswordBearer(tokenUrl=f"{settings.API_V1_STR}/login/access-token")
class TokenData(BaseModel): class TokenData(BaseModel):
username: Optional[str] = None username: Optional[str] = None
def get_user_repository(db: Session = Depends(get_db)) -> UserRepository: def get_user_repository(db: Session = Depends(get_db)) -> UserRepository:
return UserRepository(db) return UserRepository(db)
def get_user_from_token(token: str, db: Session) -> Optional[User]: def get_user_from_token(token: str, db: Session) -> Optional[User]:
try: try:
payload = jwt.decode( payload = jwt.decode(
token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM] token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
) )
token_data = TokenData(username=payload.get("sub")) token_data = TokenData(username=payload.get("sub"))
except JWTError: except JWTError:
return None return None
user_repo = UserRepository(db) user_repo = UserRepository(db)
user = user_repo.get_by_email(email=token_data.username) user = user_repo.get_by_email(email=token_data.username)
return user return user
def get_current_user( def get_current_user(
db: Session = Depends(get_db), token: str = Depends(reusable_oauth2) db: Session = Depends(get_db), token: str = Depends(reusable_oauth2)
) -> User: ) -> User:
user = get_user_from_token(token=token, db=db) user = get_user_from_token(token=token, db=db)
if not user: if not user:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials", detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"}, headers={"WWW-Authenticate": "Bearer"},
) )
return user return user
def get_current_active_user( def get_current_active_user(
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
) -> User: ) -> User:
if not current_user.is_active: if not current_user.is_active:
raise HTTPException(status_code=400, detail="Inactive user") raise HTTPException(status_code=400, detail="Inactive user")
return current_user return current_user
def get_current_active_superuser( def get_current_active_superuser(
current_user: User = Depends(get_current_active_user), current_user: User = Depends(get_current_active_user),
) -> User: ) -> User:
if not current_user.is_superuser: if not current_user.is_superuser:
raise HTTPException( raise HTTPException(
status_code=403, detail="The user doesn't have enough privileges" status_code=403, detail="The user doesn't have enough privileges"
) )
return current_user return current_user
def get_analysis_service( def get_analysis_service(
db: Session = Depends(get_db), db: Session = Depends(get_db),
user: User = Depends(get_current_active_user) user: User = Depends(get_current_active_user)
) -> TradingAnalysisService: ) -> TradingAnalysisService:
return TradingAnalysisService(user=user, db=db) return TradingAnalysisService(user=user, db=db)

View File

@ -1,94 +1,94 @@
from typing import Any, List from typing import Any, List
from fastapi import APIRouter, Depends, HTTPException, WebSocket, WebSocketDisconnect, BackgroundTasks from fastapi import APIRouter, Depends, HTTPException, WebSocket, WebSocketDisconnect, BackgroundTasks
from app.api import deps from app.api import deps
from app.core.schemas.analysis import AnalysisSession, AnalysisSessionCreate from app.core.schemas.analysis import AnalysisSession, AnalysisSessionCreate
from app.domain.models import User as UserModel from app.domain.models import User as UserModel
from app.core.services.trading_analysis import TradingAnalysisService from app.core.services.trading_analysis import TradingAnalysisService
from app.core.websocket_manager import WebSocketManager from app.core.websocket_manager import WebSocketManager
from sqlmodel import Session from sqlmodel import Session
from cli.utils import SHALLOW_AGENT_OPTIONS, DEEP_AGENT_OPTIONS, BASE_URLS from cli.utils import SHALLOW_AGENT_OPTIONS, DEEP_AGENT_OPTIONS, BASE_URLS
router = APIRouter() router = APIRouter()
manager = WebSocketManager() manager = WebSocketManager()
@router.post("/start", response_model=AnalysisSession) @router.post("/start", response_model=AnalysisSession)
def start_analysis( def start_analysis(
*, *,
analysis_in: AnalysisSessionCreate, analysis_in: AnalysisSessionCreate,
background_tasks: BackgroundTasks, background_tasks: BackgroundTasks,
service: TradingAnalysisService = Depends(deps.get_analysis_service), service: TradingAnalysisService = Depends(deps.get_analysis_service),
) -> Any: ) -> Any:
""" """
Start a new analysis session. Start a new analysis session.
""" """
session = service.create_session(analysis_in=analysis_in) session = service.create_session(analysis_in=analysis_in)
background_tasks.add_task(service.run_analysis, session_id=session.id) background_tasks.add_task(service.run_analysis, session_id=session.id)
return session return session
@router.get("/history", response_model=List[AnalysisSession]) @router.get("/history", response_model=List[AnalysisSession])
def get_analysis_history( def get_analysis_history(
service: TradingAnalysisService = Depends(deps.get_analysis_service), service: TradingAnalysisService = Depends(deps.get_analysis_service),
skip: int = 0, skip: int = 0,
limit: int = 100, limit: int = 100,
) -> Any: ) -> Any:
""" """
Get analysis history for the current user. Get analysis history for the current user.
""" """
return service.get_user_sessions(skip=skip, limit=limit) return service.get_user_sessions(skip=skip, limit=limit)
@router.get("/options") @router.get("/options")
def get_analysis_options(): def get_analysis_options():
""" """
Get available options for analysis. Get available options for analysis.
""" """
return { return {
'analysts': [ 'analysts': [
{'value': 'market', 'label': 'Market Analyst'}, {'value': 'market', 'label': 'Market Analyst'},
{'value': 'social', 'label': 'Social Analyst'}, {'value': 'social', 'label': 'Social Analyst'},
{'value': 'news', 'label': 'News Analyst'}, {'value': 'news', 'label': 'News Analyst'},
{'value': 'fundamentals', 'label': 'Fundamentals Analyst'}, {'value': 'fundamentals', 'label': 'Fundamentals Analyst'},
], ],
'research_depths': [ 'research_depths': [
{'value': 1, 'label': 'Shallow'}, {'value': 1, 'label': 'Shallow'},
{'value': 3, 'label': 'Medium'}, {'value': 3, 'label': 'Medium'},
{'value': 5, 'label': 'Deep'}, {'value': 5, 'label': 'Deep'},
], ],
'llm_providers': [{'name': p[0], 'url': p[1]} for p in BASE_URLS], 'llm_providers': [{'name': p[0], 'url': p[1]} for p in BASE_URLS],
'shallow_thinkers': SHALLOW_AGENT_OPTIONS, 'shallow_thinkers': SHALLOW_AGENT_OPTIONS,
'deep_thinkers': DEEP_AGENT_OPTIONS, 'deep_thinkers': DEEP_AGENT_OPTIONS,
} }
@router.get("/{session_id}", response_model=AnalysisSession) @router.get("/{session_id}", response_model=AnalysisSession)
def get_analysis_session( def get_analysis_session(
session_id: int, session_id: int,
service: TradingAnalysisService = Depends(deps.get_analysis_service), service: TradingAnalysisService = Depends(deps.get_analysis_service),
) -> Any: ) -> Any:
""" """
Get a specific analysis session by ID. Get a specific analysis session by ID.
""" """
session = service.get_session(session_id=session_id) session = service.get_session(session_id=session_id)
if not session: if not session:
raise HTTPException(status_code=404, detail="Analysis session not found") raise HTTPException(status_code=404, detail="Analysis session not found")
return session return session
@router.websocket("/ws") @router.websocket("/ws")
async def websocket_endpoint( async def websocket_endpoint(
websocket: WebSocket, websocket: WebSocket,
token: str, token: str,
db: Session = Depends(deps.get_db) db: Session = Depends(deps.get_db)
): ):
""" """
WebSocket endpoint for real-time analysis updates. WebSocket endpoint for real-time analysis updates.
""" """
user = deps.get_user_from_token(token=token, db=db) user = deps.get_user_from_token(token=token, db=db)
if not user or not user.is_active: if not user or not user.is_active:
await websocket.close(code=1008) await websocket.close(code=1008)
return return
await manager.connect(user.id, websocket) await manager.connect(user.id, websocket)
try: try:
while True: while True:
# Keep the connection alive # Keep the connection alive
await websocket.receive_text() await websocket.receive_text()
except WebSocketDisconnect: except WebSocketDisconnect:
manager.disconnect(user.id, websocket) manager.disconnect(user.id, websocket)

View File

@ -1,35 +1,35 @@
from datetime import timedelta from datetime import timedelta
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from fastapi.security import OAuth2PasswordRequestForm from fastapi.security import OAuth2PasswordRequestForm
from sqlmodel import Session from sqlmodel import Session
from app.api import deps from app.api import deps
from app.core.config import settings from app.core.config import settings
from app.core.schemas.token import Token from app.core.schemas.token import Token
from app.core import security from app.core import security
from app.infrastructure.repositories.user import UserRepository from app.infrastructure.repositories.user import UserRepository
router = APIRouter() router = APIRouter()
@router.post("/login/access-token", response_model=Token) @router.post("/login/access-token", response_model=Token)
def login_access_token( def login_access_token(
db: Session = Depends(deps.get_db), form_data: OAuth2PasswordRequestForm = Depends() db: Session = Depends(deps.get_db), form_data: OAuth2PasswordRequestForm = Depends()
): ):
""" """
OAuth2 compatible token login, get an access token for future requests OAuth2 compatible token login, get an access token for future requests
""" """
user_repo = UserRepository(db) user_repo = UserRepository(db)
user = user_repo.get_by_email(email=form_data.username) user = user_repo.get_by_email(email=form_data.username)
if not user or not security.verify_password(form_data.password, user.hashed_password): if not user or not security.verify_password(form_data.password, user.hashed_password):
raise HTTPException(status_code=400, detail="Incorrect email or password") raise HTTPException(status_code=400, detail="Incorrect email or password")
elif not user.is_active: elif not user.is_active:
raise HTTPException(status_code=400, detail="Inactive user") raise HTTPException(status_code=400, detail="Inactive user")
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
return { return {
"access_token": security.create_access_token( "access_token": security.create_access_token(
user.email, expires_delta=access_token_expires user.email, expires_delta=access_token_expires
), ),
"token_type": "bearer", "token_type": "bearer",
} }

View File

@ -1,89 +1,89 @@
from typing import Any, List from typing import Any, List
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from app.api import deps from app.api import deps
from app.core.schemas.user import User, UserCreate, UserUpdate from app.core.schemas.user import User, UserCreate, UserUpdate
from app.domain.models import User as UserModel from app.domain.models import User as UserModel
from app.domain.repositories import IUserRepository from app.domain.repositories import IUserRepository
router = APIRouter() router = APIRouter()
@router.get("/", response_model=List[User]) @router.get("/", response_model=List[User])
def read_users( def read_users(
repo: IUserRepository = Depends(deps.get_user_repository), repo: IUserRepository = Depends(deps.get_user_repository),
skip: int = 0, skip: int = 0,
limit: int = 100, limit: int = 100,
current_user: UserModel = Depends(deps.get_current_active_superuser), current_user: UserModel = Depends(deps.get_current_active_superuser),
) -> Any: ) -> Any:
""" """
Retrieve users. Retrieve users.
""" """
users = repo.get_multi(skip=skip, limit=limit) users = repo.get_multi(skip=skip, limit=limit)
return users return users
@router.post("/", response_model=User) @router.post("/", response_model=User)
def create_user( def create_user(
*, *,
repo: IUserRepository = Depends(deps.get_user_repository), repo: IUserRepository = Depends(deps.get_user_repository),
user_in: UserCreate, user_in: UserCreate,
current_user: UserModel = Depends(deps.get_current_active_superuser), current_user: UserModel = Depends(deps.get_current_active_superuser),
) -> Any: ) -> Any:
""" """
Create new user. Create new user.
""" """
user = repo.get_by_email(email=user_in.email) user = repo.get_by_email(email=user_in.email)
if user: if user:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail="The user with this username already exists in the system.", detail="The user with this username already exists in the system.",
) )
user = repo.create(obj_in=user_in) user = repo.create(obj_in=user_in)
return user return user
@router.get("/me", response_model=User) @router.get("/me", response_model=User)
def read_user_me( def read_user_me(
current_user: UserModel = Depends(deps.get_current_active_user), current_user: UserModel = Depends(deps.get_current_active_user),
) -> Any: ) -> Any:
""" """
Get current user. Get current user.
""" """
return current_user return current_user
@router.get("/{user_id}", response_model=User) @router.get("/{user_id}", response_model=User)
def read_user_by_id( def read_user_by_id(
user_id: int, user_id: int,
repo: IUserRepository = Depends(deps.get_user_repository), repo: IUserRepository = Depends(deps.get_user_repository),
current_user: UserModel = Depends(deps.get_current_active_user), current_user: UserModel = Depends(deps.get_current_active_user),
) -> Any: ) -> Any:
""" """
Get a specific user by id. Get a specific user by id.
""" """
user = repo.get(id=user_id) user = repo.get(id=user_id)
if not user: if not user:
raise HTTPException(status_code=404, detail="User not found") raise HTTPException(status_code=404, detail="User not found")
if user == current_user: if user == current_user:
return user return user
if not repo.is_superuser(user=current_user): if not repo.is_superuser(user=current_user):
raise HTTPException( raise HTTPException(
status_code=403, detail="The user doesn't have enough privileges" status_code=403, detail="The user doesn't have enough privileges"
) )
return user return user
@router.put("/{user_id}", response_model=User) @router.put("/{user_id}", response_model=User)
def update_user( def update_user(
*, *,
repo: IUserRepository = Depends(deps.get_user_repository), repo: IUserRepository = Depends(deps.get_user_repository),
user_id: int, user_id: int,
user_in: UserUpdate, user_in: UserUpdate,
current_user: UserModel = Depends(deps.get_current_active_superuser), current_user: UserModel = Depends(deps.get_current_active_superuser),
) -> Any: ) -> Any:
""" """
Update a user. Update a user.
""" """
user = repo.get(id=user_id) user = repo.get(id=user_id)
if not user: if not user:
raise HTTPException( raise HTTPException(
status_code=404, status_code=404,
detail="The user with this username does not exist in the system", detail="The user with this username does not exist in the system",
) )
user = repo.update(db_obj=user, obj_in=user_in) user = repo.update(db_obj=user, obj_in=user_in)
return user return user

View File

@ -1,7 +1,7 @@
from fastapi import APIRouter from fastapi import APIRouter
from app.api.endpoints import login, users, analysis from app.api.endpoints import login, users, analysis
api_router = APIRouter() api_router = APIRouter()
api_router.include_router(login.router, tags=["login"]) api_router.include_router(login.router, tags=["login"])
api_router.include_router(users.router, prefix="/users", tags=["users"]) api_router.include_router(users.router, prefix="/users", tags=["users"])
api_router.include_router(analysis.router, prefix="/analysis", tags=["analysis"]) api_router.include_router(analysis.router, prefix="/analysis", tags=["analysis"])

View File

@ -1,26 +1,26 @@
import os import os
from pydantic import BaseSettings from pydantic import BaseSettings
from typing import List, Optional from typing import List, Optional
class Settings(BaseSettings): class Settings(BaseSettings):
PROJECT_NAME: str = "TradingAgents Backend" PROJECT_NAME: str = "TradingAgents Backend"
API_V1_STR: str = "/api/v1" API_V1_STR: str = "/api/v1"
# Security # Security
SECRET_KEY: str = os.getenv("SECRET_KEY", "a_very_secret_key") SECRET_KEY: str = os.getenv("SECRET_KEY", "a_very_secret_key")
ACCESS_TOKEN_EXPIRE_MINUTES: int = 60 * 24 * 8 # 8 days ACCESS_TOKEN_EXPIRE_MINUTES: int = 60 * 24 * 8 # 8 days
ALGORITHM: str = "HS256" ALGORITHM: str = "HS256"
# Database # Database
DATABASE_URL: str = os.getenv("DATABASE_URL", "sqlite:///./tradingagents.db") DATABASE_URL: str = os.getenv("DATABASE_URL", "sqlite:///./tradingagents.db")
# OpenAI # OpenAI
OPENAI_API_KEY: str = os.getenv("OPENAI_API_KEY", "") OPENAI_API_KEY: str = os.getenv("OPENAI_API_KEY", "")
# CORS # CORS
CORS_ALLOWED_ORIGINS: List[str] = os.getenv("CORS_ALLOWED_ORIGINS", "http://localhost:3000,http://127.0.0.1:3000").split(',') CORS_ALLOWED_ORIGINS: List[str] = os.getenv("CORS_ALLOWED_ORIGINS", "http://localhost:3000,http://127.0.0.1:3000").split(',')
class Config: class Config:
case_sensitive = True case_sensitive = True
settings = Settings() settings = Settings()

View File

@ -1,4 +1,4 @@
from .user import User, UserCreate, UserUpdate from .user import User, UserCreate, UserUpdate
from .token import Token, TokenPayload from .token import Token, TokenPayload
from .profile import Profile, ProfileCreate, ProfileUpdate from .profile import Profile, ProfileCreate, ProfileUpdate
from .analysis import AnalysisSession, AnalysisSessionCreate, AnalysisSessionUpdate from .analysis import AnalysisSession, AnalysisSessionCreate, AnalysisSessionUpdate

View File

@ -1,38 +1,38 @@
from pydantic import BaseModel from pydantic import BaseModel
from typing import List, Optional from typing import List, Optional
from datetime import date, datetime from datetime import date, datetime
from app.domain.models import AnalysisStatus from app.domain.models import AnalysisStatus
class AnalysisSessionBase(BaseModel): class AnalysisSessionBase(BaseModel):
ticker: str ticker: str
analysts_selected: List[str] analysts_selected: List[str]
research_depth: int research_depth: int
llm_provider: str llm_provider: str
backend_url: str backend_url: str
shallow_thinker: str shallow_thinker: str
deep_thinker: str deep_thinker: str
class AnalysisSessionCreate(AnalysisSessionBase): class AnalysisSessionCreate(AnalysisSessionBase):
pass pass
class AnalysisSessionUpdate(BaseModel): class AnalysisSessionUpdate(BaseModel):
status: Optional[AnalysisStatus] = None status: Optional[AnalysisStatus] = None
final_report: Optional[str] = None final_report: Optional[str] = None
error_message: Optional[str] = None error_message: Optional[str] = None
class AnalysisSessionInDBBase(AnalysisSessionBase): class AnalysisSessionInDBBase(AnalysisSessionBase):
id: int id: int
user_id: int user_id: int
analysis_date: date analysis_date: date
status: AnalysisStatus status: AnalysisStatus
final_report: Optional[str] = None final_report: Optional[str] = None
error_message: Optional[str] = None error_message: Optional[str] = None
created_at: datetime created_at: datetime
started_at: Optional[datetime] = None started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None completed_at: Optional[datetime] = None
class Config: class Config:
orm_mode = True orm_mode = True
class AnalysisSession(AnalysisSessionInDBBase): class AnalysisSession(AnalysisSessionInDBBase):
pass pass

View File

@ -1,20 +1,20 @@
from pydantic import BaseModel from pydantic import BaseModel
from typing import Optional from typing import Optional
class ProfileBase(BaseModel): class ProfileBase(BaseModel):
default_ticker: str = "SPY" default_ticker: str = "SPY"
preferred_research_depth: int = 3 preferred_research_depth: int = 3
preferred_shallow_thinker: str = "gpt-4o-mini" preferred_shallow_thinker: str = "gpt-4o-mini"
preferred_deep_thinker: str = "gpt-4o" preferred_deep_thinker: str = "gpt-4o"
class ProfileCreate(ProfileBase): class ProfileCreate(ProfileBase):
pass pass
class ProfileUpdate(ProfileBase): class ProfileUpdate(ProfileBase):
openai_api_key: Optional[str] = None openai_api_key: Optional[str] = None
class Profile(ProfileBase): class Profile(ProfileBase):
has_openai_api_key: bool has_openai_api_key: bool
class Config: class Config:
orm_mode = True orm_mode = True

View File

@ -1,9 +1,9 @@
from pydantic import BaseModel from pydantic import BaseModel
from typing import Optional from typing import Optional
class Token(BaseModel): class Token(BaseModel):
access_token: str access_token: str
token_type: str token_type: str
class TokenPayload(BaseModel): class TokenPayload(BaseModel):
sub: Optional[int] = None sub: Optional[int] = None

View File

@ -1,28 +1,28 @@
from pydantic import BaseModel, EmailStr from pydantic import BaseModel, EmailStr
from typing import Optional from typing import Optional
class UserBase(BaseModel): class UserBase(BaseModel):
email: EmailStr email: EmailStr
username: str username: str
first_name: Optional[str] = None first_name: Optional[str] = None
last_name: Optional[str] = None last_name: Optional[str] = None
class UserCreate(UserBase): class UserCreate(UserBase):
password: str password: str
class UserUpdate(UserBase): class UserUpdate(UserBase):
pass pass
class UserInDBBase(UserBase): class UserInDBBase(UserBase):
id: int id: int
is_active: bool is_active: bool
is_superuser: bool is_superuser: bool
class Config: class Config:
orm_mode = True orm_mode = True
class User(UserInDBBase): class User(UserInDBBase):
pass pass
class UserInDB(UserInDBBase): class UserInDB(UserInDBBase):
hashed_password: str hashed_password: str

View File

@ -1,23 +1,23 @@
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Any, Union, Optional from typing import Any, Union, Optional
from jose import jwt from jose import jwt
from passlib.context import CryptContext from passlib.context import CryptContext
from app.core.config import settings from app.core.config import settings
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def create_access_token(subject: Union[str, Any], expires_delta: Optional[timedelta] = None) -> str: def create_access_token(subject: Union[str, Any], expires_delta: Optional[timedelta] = None) -> str:
if expires_delta: if expires_delta:
expire = datetime.utcnow() + expires_delta expire = datetime.utcnow() + expires_delta
else: else:
expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
to_encode = {"exp": expire, "sub": str(subject)} to_encode = {"exp": expire, "sub": str(subject)}
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM) encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
return encoded_jwt return encoded_jwt
def verify_password(plain_password: str, hashed_password: str) -> bool: def verify_password(plain_password: str, hashed_password: str) -> bool:
return pwd_context.verify(plain_password, hashed_password) return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password: str) -> str: def get_password_hash(password: str) -> str:
return pwd_context.hash(password) return pwd_context.hash(password)

View File

@ -1,128 +1,128 @@
import asyncio import asyncio
import datetime import datetime
import json import json
from typing import Dict, List, Optional from typing import Dict, List, Optional
from sqlmodel import Session, select from sqlmodel import Session, select
from app.domain.models import User, AnalysisSession, AnalysisStatus from app.domain.models import User, AnalysisSession, AnalysisStatus
from app.core.schemas.analysis import AnalysisSessionCreate from app.core.schemas.analysis import AnalysisSessionCreate
from app.core.config import settings from app.core.config import settings
from cli.models import AnalystType from cli.models import AnalystType
from tradingagents.graph.trading_graph import TradingAgentsGraph from tradingagents.graph.trading_graph import TradingAgentsGraph
from tradingagents.default_config import DEFAULT_CONFIG from tradingagents.default_config import DEFAULT_CONFIG
from app.api.deps import get_db from app.api.deps import get_db
from app.core.websocket_manager import WebSocketManager from app.core.websocket_manager import WebSocketManager
class TradingAnalysisService: class TradingAnalysisService:
def __init__(self, user: User, db: Session): def __init__(self, user: User, db: Session):
self.user = user self.user = user
self.db = db self.db = db
self.websocket_manager = WebSocketManager() self.websocket_manager = WebSocketManager()
async def run_analysis(self, session_id: int): async def run_analysis(self, session_id: int):
"""분석 실행""" """분석 실행"""
session = self.get_session(session_id=session_id) session = self.get_session(session_id=session_id)
if not session: if not session:
return return
try: try:
session.status = AnalysisStatus.RUNNING session.status = AnalysisStatus.RUNNING
session.started_at = datetime.datetime.utcnow() session.started_at = datetime.datetime.utcnow()
self.db.add(session) self.db.add(session)
self.db.commit() self.db.commit()
self.db.refresh(session) self.db.refresh(session)
await self.websocket_manager.send_to_user( await self.websocket_manager.send_to_user(
self.user.id, self.user.id,
{ {
'type': 'analysis_started', 'type': 'analysis_started',
'session_id': session.id, 'session_id': session.id,
'message': '분석을 시작합니다...' 'message': '분석을 시작합니다...'
} }
) )
# Prepare config for TradingAgentsGraph # Prepare config for TradingAgentsGraph
config = DEFAULT_CONFIG.copy() config = DEFAULT_CONFIG.copy()
config.update({ config.update({
'openai_api_key': settings.OPENAI_API_KEY, 'openai_api_key': settings.OPENAI_API_KEY,
'llm_provider': session.llm_provider, 'llm_provider': session.llm_provider,
'backend_url': session.backend_url, 'backend_url': session.backend_url,
'shallow_thinking_model': session.shallow_thinker, 'shallow_thinking_model': session.shallow_thinker,
'deep_thinking_model': session.deep_thinker, 'deep_thinking_model': session.deep_thinker,
}) })
# Progress callback for websocket # Progress callback for websocket
async def progress_callback(message_type: str, content: str, agent: str = None, step: int = 0, total: int = 0): async def progress_callback(message_type: str, content: str, agent: str = None, step: int = 0, total: int = 0):
progress_percent = int((step / total) * 99) if total > 0 else 0 progress_percent = int((step / total) * 99) if total > 0 else 0
await self.websocket_manager.send_to_user(self.user.id, { await self.websocket_manager.send_to_user(self.user.id, {
'type': 'analysis_progress', 'type': 'analysis_progress',
'session_id': session.id, 'session_id': session.id,
'message_type': message_type, 'message_type': message_type,
'content': content, 'content': content,
'agent': agent, 'agent': agent,
'progress': progress_percent, 'progress': progress_percent,
}) })
trading_graph = TradingAgentsGraph( trading_graph = TradingAgentsGraph(
config=config, config=config,
selected_analysts=session.analysts_selected, selected_analysts=session.analysts_selected,
) )
input_data = { input_data = {
'company_of_interest': session.ticker, 'company_of_interest': session.ticker,
'trade_date': session.analysis_date.strftime('%Y-%m-%d'), 'trade_date': session.analysis_date.strftime('%Y-%m-%d'),
} }
final_state, result = await asyncio.to_thread( final_state, result = await asyncio.to_thread(
trading_graph.propagate, trading_graph.propagate,
input_data['company_of_interest'], input_data['company_of_interest'],
input_data['trade_date'] input_data['trade_date']
) )
session.status = AnalysisStatus.COMPLETED session.status = AnalysisStatus.COMPLETED
session.completed_at = datetime.datetime.utcnow() session.completed_at = datetime.datetime.utcnow()
session.final_report = json.dumps(final_state) # Store full state as JSON session.final_report = json.dumps(final_state) # Store full state as JSON
self.db.add(session) self.db.add(session)
self.db.commit() self.db.commit()
await self.websocket_manager.send_to_user( await self.websocket_manager.send_to_user(
self.user.id, self.user.id,
{ {
'type': 'analysis_completed', 'type': 'analysis_completed',
'session_id': session.id, 'session_id': session.id,
'message': '분석이 완료되었습니다.', 'message': '분석이 완료되었습니다.',
'result': result 'result': result
} }
) )
except Exception as e: except Exception as e:
session.status = AnalysisStatus.FAILED session.status = AnalysisStatus.FAILED
session.error_message = str(e) session.error_message = str(e)
self.db.add(session) self.db.add(session)
self.db.commit() self.db.commit()
await self.websocket_manager.send_to_user( await self.websocket_manager.send_to_user(
self.user.id, self.user.id,
{ {
'type': 'analysis_failed', 'type': 'analysis_failed',
'session_id': session.id, 'session_id': session.id,
'message': f'분석 중 오류가 발생했습니다: {str(e)}' 'message': f'분석 중 오류가 발생했습니다: {str(e)}'
} }
) )
def create_session(self, *, analysis_in: AnalysisSessionCreate) -> AnalysisSession: def create_session(self, *, analysis_in: AnalysisSessionCreate) -> AnalysisSession:
session = AnalysisSession( session = AnalysisSession(
**analysis_in.dict(), **analysis_in.dict(),
user_id=self.user.id, user_id=self.user.id,
analysis_date=datetime.date.today() analysis_date=datetime.date.today()
) )
self.db.add(session) self.db.add(session)
self.db.commit() self.db.commit()
self.db.refresh(session) self.db.refresh(session)
return session return session
def get_session(self, *, session_id: int) -> Optional[AnalysisSession]: def get_session(self, *, session_id: int) -> Optional[AnalysisSession]:
statement = select(AnalysisSession).where(AnalysisSession.id == session_id, AnalysisSession.user_id == self.user.id) statement = select(AnalysisSession).where(AnalysisSession.id == session_id, AnalysisSession.user_id == self.user.id)
return self.db.exec(statement).first() return self.db.exec(statement).first()
def get_user_sessions(self, *, skip: int = 0, limit: int = 100) -> List[AnalysisSession]: def get_user_sessions(self, *, skip: int = 0, limit: int = 100) -> List[AnalysisSession]:
statement = select(AnalysisSession).where(AnalysisSession.user_id == self.user.id).order_by(AnalysisSession.created_at.desc()).offset(skip).limit(limit) statement = select(AnalysisSession).where(AnalysisSession.user_id == self.user.id).order_by(AnalysisSession.created_at.desc()).offset(skip).limit(limit)
return self.db.exec(statement).all() return self.db.exec(statement).all()

View File

@ -1,23 +1,23 @@
from typing import Dict, List from typing import Dict, List
from fastapi import WebSocket from fastapi import WebSocket
class WebSocketManager: class WebSocketManager:
def __init__(self): def __init__(self):
self.active_connections: Dict[int, List[WebSocket]] = {} self.active_connections: Dict[int, List[WebSocket]] = {}
async def connect(self, user_id: int, websocket: WebSocket): async def connect(self, user_id: int, websocket: WebSocket):
await websocket.accept() await websocket.accept()
if user_id not in self.active_connections: if user_id not in self.active_connections:
self.active_connections[user_id] = [] self.active_connections[user_id] = []
self.active_connections[user_id].append(websocket) self.active_connections[user_id].append(websocket)
def disconnect(self, user_id: int, websocket: WebSocket): def disconnect(self, user_id: int, websocket: WebSocket):
if user_id in self.active_connections: if user_id in self.active_connections:
self.active_connections[user_id].remove(websocket) self.active_connections[user_id].remove(websocket)
if not self.active_connections[user_id]: if not self.active_connections[user_id]:
del self.active_connections[user_id] del self.active_connections[user_id]
async def send_to_user(self, user_id: int, message: dict): async def send_to_user(self, user_id: int, message: dict):
if user_id in self.active_connections: if user_id in self.active_connections:
for connection in self.active_connections[user_id]: for connection in self.active_connections[user_id]:
await connection.send_json(message) await connection.send_json(message)

View File

@ -1,56 +1,56 @@
from datetime import date, datetime from datetime import date, datetime
from typing import List, Optional from typing import List, Optional
from sqlmodel import Field, SQLModel, JSON, Column from sqlmodel import Field, SQLModel, JSON, Column
import enum import enum
class User(SQLModel, table=True): class User(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True) id: Optional[int] = Field(default=None, primary_key=True)
email: str = Field(unique=True, index=True) email: str = Field(unique=True, index=True)
username: str = Field(unique=True, index=True) username: str = Field(unique=True, index=True)
hashed_password: str hashed_password: str
first_name: Optional[str] = None first_name: Optional[str] = None
last_name: Optional[str] = None last_name: Optional[str] = None
is_active: bool = Field(default=True) is_active: bool = Field(default=True)
is_superuser: bool = Field(default=False) is_superuser: bool = Field(default=False)
created_at: datetime = Field(default_factory=datetime.utcnow) created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow, sa_column_kwargs={"onupdate": datetime.utcnow}) updated_at: datetime = Field(default_factory=datetime.utcnow, sa_column_kwargs={"onupdate": datetime.utcnow})
class UserProfile(SQLModel, table=True): class UserProfile(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True) id: Optional[int] = Field(default=None, primary_key=True)
user_id: int = Field(foreign_key="user.id", unique=True) user_id: int = Field(foreign_key="user.id", unique=True)
encrypted_openai_api_key: Optional[str] = None encrypted_openai_api_key: Optional[str] = None
default_ticker: str = Field(default="SPY") default_ticker: str = Field(default="SPY")
preferred_research_depth: int = Field(default=3) preferred_research_depth: int = Field(default=3)
preferred_shallow_thinker: str = Field(default="gpt-4o-mini") preferred_shallow_thinker: str = Field(default="gpt-4o-mini")
preferred_deep_thinker: str = Field(default="gpt-4o") preferred_deep_thinker: str = Field(default="gpt-4o")
created_at: datetime = Field(default_factory=datetime.utcnow) created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow, sa_column_kwargs={"onupdate": datetime.utcnow}) updated_at: datetime = Field(default_factory=datetime.utcnow, sa_column_kwargs={"onupdate": datetime.utcnow})
class AnalysisStatus(str, enum.Enum): class AnalysisStatus(str, enum.Enum):
PENDING = "pending" PENDING = "pending"
RUNNING = "running" RUNNING = "running"
COMPLETED = "completed" COMPLETED = "completed"
FAILED = "failed" FAILED = "failed"
CANCELLED = "cancelled" CANCELLED = "cancelled"
class AnalysisSession(SQLModel, table=True): class AnalysisSession(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True) id: Optional[int] = Field(default=None, primary_key=True)
user_id: int = Field(foreign_key="user.id") user_id: int = Field(foreign_key="user.id")
ticker: str ticker: str
analysis_date: date analysis_date: date
analysts_selected: List[str] = Field(sa_column=Column(JSON)) analysts_selected: List[str] = Field(sa_column=Column(JSON))
research_depth: int research_depth: int
llm_provider: str llm_provider: str
backend_url: str backend_url: str
shallow_thinker: str shallow_thinker: str
deep_thinker: str deep_thinker: str
status: AnalysisStatus = Field(default=AnalysisStatus.PENDING) status: AnalysisStatus = Field(default=AnalysisStatus.PENDING)
final_report: Optional[str] = None final_report: Optional[str] = None
error_message: Optional[str] = None error_message: Optional[str] = None
created_at: datetime = Field(default_factory=datetime.utcnow) created_at: datetime = Field(default_factory=datetime.utcnow)
started_at: Optional[datetime] = None started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None completed_at: Optional[datetime] = None

View File

@ -1,48 +1,48 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Generic, TypeVar, Optional, List from typing import Generic, TypeVar, Optional, List
from sqlmodel import SQLModel from sqlmodel import SQLModel
from app.core.schemas.user import UserCreate, UserUpdate from app.core.schemas.user import UserCreate, UserUpdate
from app.domain.models import User from app.domain.models import User
ModelType = TypeVar("ModelType", bound=SQLModel) ModelType = TypeVar("ModelType", bound=SQLModel)
CreateSchemaType = TypeVar("CreateSchemaType", bound=SQLModel) CreateSchemaType = TypeVar("CreateSchemaType", bound=SQLModel)
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=SQLModel) UpdateSchemaType = TypeVar("UpdateSchemaType", bound=SQLModel)
class IRepository(Generic[ModelType], ABC): class IRepository(Generic[ModelType], ABC):
@abstractmethod @abstractmethod
def get(self, id: int) -> Optional[ModelType]: def get(self, id: int) -> Optional[ModelType]:
pass pass
@abstractmethod @abstractmethod
def get_multi(self, *, skip: int = 0, limit: int = 100) -> List[ModelType]: def get_multi(self, *, skip: int = 0, limit: int = 100) -> List[ModelType]:
pass pass
@abstractmethod @abstractmethod
def create(self, *, obj_in: CreateSchemaType) -> ModelType: def create(self, *, obj_in: CreateSchemaType) -> ModelType:
pass pass
@abstractmethod @abstractmethod
def update(self, *, db_obj: ModelType, obj_in: UpdateSchemaType) -> ModelType: def update(self, *, db_obj: ModelType, obj_in: UpdateSchemaType) -> ModelType:
pass pass
@abstractmethod @abstractmethod
def remove(self, *, id: int) -> ModelType: def remove(self, *, id: int) -> ModelType:
pass pass
class IUserRepository(IRepository[User], ABC): class IUserRepository(IRepository[User], ABC):
@abstractmethod @abstractmethod
def get_by_email(self, *, email: str) -> Optional[User]: def get_by_email(self, *, email: str) -> Optional[User]:
pass pass
@abstractmethod @abstractmethod
def create(self, *, obj_in: UserCreate) -> User: def create(self, *, obj_in: UserCreate) -> User:
pass pass
@abstractmethod @abstractmethod
def update(self, *, db_obj: User, obj_in: UserUpdate) -> User: def update(self, *, db_obj: User, obj_in: UserUpdate) -> User:
pass pass
@abstractmethod @abstractmethod
def is_superuser(self, *, user: User) -> bool: def is_superuser(self, *, user: User) -> bool:
pass pass

View File

@ -1,9 +1,9 @@
from sqlmodel import create_engine, Session from sqlmodel import create_engine, Session
from app.core.config import settings from app.core.config import settings
engine = create_engine(settings.DATABASE_URL, echo=True, connect_args={"check_same_thread": False}) engine = create_engine(settings.DATABASE_URL, echo=True, connect_args={"check_same_thread": False})
def get_db(): def get_db():
with Session(engine) as session: with Session(engine) as session:
yield session yield session

View File

@ -1,53 +1,53 @@
from typing import Optional from typing import Optional
from sqlmodel import Session, select from sqlmodel import Session, select
from app.domain.models import User from app.domain.models import User
from app.core.schemas.user import UserCreate, UserUpdate from app.core.schemas.user import UserCreate, UserUpdate
from app.domain.repositories import IUserRepository from app.domain.repositories import IUserRepository
from app.core.security import get_password_hash from app.core.security import get_password_hash
class UserRepository(IUserRepository): class UserRepository(IUserRepository):
def __init__(self, db: Session): def __init__(self, db: Session):
self.db = db self.db = db
def get(self, id: int) -> Optional[User]: def get(self, id: int) -> Optional[User]:
return self.db.get(User, id) return self.db.get(User, id)
def get_by_email(self, *, email: str) -> Optional[User]: def get_by_email(self, *, email: str) -> Optional[User]:
statement = select(User).where(User.email == email) statement = select(User).where(User.email == email)
return self.db.exec(statement).first() return self.db.exec(statement).first()
def get_multi(self, *, skip: int = 0, limit: int = 100) -> list[User]: def get_multi(self, *, skip: int = 0, limit: int = 100) -> list[User]:
statement = select(User).offset(skip).limit(limit) statement = select(User).offset(skip).limit(limit)
return self.db.exec(statement).all() return self.db.exec(statement).all()
def create(self, *, obj_in: UserCreate) -> User: def create(self, *, obj_in: UserCreate) -> User:
db_obj = User( db_obj = User(
email=obj_in.email, email=obj_in.email,
username=obj_in.username, username=obj_in.username,
hashed_password=get_password_hash(obj_in.password), hashed_password=get_password_hash(obj_in.password),
first_name=obj_in.first_name, first_name=obj_in.first_name,
last_name=obj_in.last_name, last_name=obj_in.last_name,
) )
self.db.add(db_obj) self.db.add(db_obj)
self.db.commit() self.db.commit()
self.db.refresh(db_obj) self.db.refresh(db_obj)
return db_obj return db_obj
def update(self, *, db_obj: User, obj_in: UserUpdate) -> User: def update(self, *, db_obj: User, obj_in: UserUpdate) -> User:
update_data = obj_in.dict(exclude_unset=True) update_data = obj_in.dict(exclude_unset=True)
for field, value in update_data.items(): for field, value in update_data.items():
setattr(db_obj, field, value) setattr(db_obj, field, value)
self.db.add(db_obj) self.db.add(db_obj)
self.db.commit() self.db.commit()
self.db.refresh(db_obj) self.db.refresh(db_obj)
return db_obj return db_obj
def remove(self, *, id: int) -> User: def remove(self, *, id: int) -> User:
db_obj = self.db.get(User, id) db_obj = self.db.get(User, id)
self.db.delete(db_obj) self.db.delete(db_obj)
self.db.commit() self.db.commit()
return db_obj return db_obj
def is_superuser(self, *, user: User) -> bool: def is_superuser(self, *, user: User) -> bool:
return user.is_superuser return user.is_superuser

View File

@ -1,36 +1,36 @@
import sys import sys
import os import os
from fastapi import FastAPI from fastapi import FastAPI
from starlette.middleware.cors import CORSMiddleware from starlette.middleware.cors import CORSMiddleware
# Add project root to path to allow importing tradingagents # Add project root to path to allow importing tradingagents
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))) sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")))
from app.api.router import api_router from app.api.router import api_router
from app.core.config import settings from app.core.config import settings
from app.infrastructure.database import engine from app.infrastructure.database import engine
from sqlmodel import SQLModel from sqlmodel import SQLModel
def create_tables(): def create_tables():
SQLModel.metadata.create_all(engine) SQLModel.metadata.create_all(engine)
app = FastAPI( app = FastAPI(
title=settings.PROJECT_NAME, title=settings.PROJECT_NAME,
openapi_url=f"{settings.API_V1_STR}/openapi.json" openapi_url=f"{settings.API_V1_STR}/openapi.json"
) )
@app.on_event("startup") @app.on_event("startup")
def on_startup(): def on_startup():
create_tables() create_tables()
# Set all CORS enabled origins # Set all CORS enabled origins
if settings.CORS_ALLOWED_ORIGINS: if settings.CORS_ALLOWED_ORIGINS:
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
allow_origins=[str(origin) for origin in settings.CORS_ALLOWED_ORIGINS], allow_origins=[str(origin) for origin in settings.CORS_ALLOWED_ORIGINS],
allow_credentials=True, allow_credentials=True,
allow_methods=["*"], allow_methods=["*"],
allow_headers=["*"], allow_headers=["*"],
) )
app.include_router(api_router, prefix=settings.API_V1_STR) app.include_router(api_router, prefix=settings.API_V1_STR)

2
backend/.gitignore vendored
View File

@ -1,2 +1,2 @@
.env .env
wallet/ wallet/

View File

@ -1,247 +1,299 @@
from sqlmodel import Session import sys
from analysis.domain.repository.analysis_repo import IAnalysisRepository import os
from ulid import ULID sys.path.append(os.path.join(os.path.dirname(__file__), '../../..'))
from analysis.domain.analysis import Analysis as AnalysisVO
from analysis.interface.dto import TradingAnalysisRequest, AnalysisProgressUpdate import logging
from fastapi import HTTPException, status, BackgroundTasks from sqlmodel import Session
import asyncio from analysis.domain.repository.analysis_repo import IAnalysisRepository
from datetime import datetime from ulid import ULID
from analysis.domain.analysis import Analysis as AnalysisVO
from tradingagents.graph.trading_graph import TradingAgentsGraph from analysis.interface.dto import TradingAnalysisRequest, AnalysisProgressUpdate
from tradingagents.default_config import DEFAULT_CONFIG from fastapi import HTTPException, status, BackgroundTasks
import asyncio
from datetime import datetime
class AnalysisService: from tradingagents.graph.trading_graph import TradingAgentsGraph
def __init__( from tradingagents.default_config import DEFAULT_CONFIG
self, from analysis.application.websocket_manager import WebSocketManager
analysis_repo: IAnalysisRepository, from analysis.infra.db_models.analysis import AnalysisStatus
session: Session,
ulid: ULID logger = logging.getLogger(__name__)
):
self.analysis_repo = analysis_repo class AnalysisService:
self.session = session def __init__(
self.ulid = ulid self,
analysis_repo: IAnalysisRepository,
def get_analysis_list( session: Session,
self, ulid: ULID,
member_id: str websocket_manager: WebSocketManager
) -> list[AnalysisVO]: ):
analyses = self.analysis_repo.find_by_member_id(member_id) self.analysis_repo = analysis_repo
if not analyses: self.session = session
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Analysis not found") self.ulid = ulid
return analyses self.websocket_manager = websocket_manager
def get_analysis_by_id( def get_analysis_list(
self, self,
analysis_id: str, member_id: str
member_id: str ) -> list[AnalysisVO]:
) -> AnalysisVO: analyses = self.analysis_repo.find_by_member_id(member_id)
analysis = self.analysis_repo.find_by_id(analysis_id) if not analyses:
if not analysis: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Analysis not found")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Analysis not found") return analyses
if analysis.member_id != member_id: def get_analysis_by_id(
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Access denied") self,
analysis_id: str,
return analysis member_id: str
) -> AnalysisVO:
def create_analysis( analysis = self.analysis_repo.find_by_id(analysis_id)
self, if not analysis:
member_id: str, raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Analysis not found")
request: TradingAnalysisRequest,
background_tasks: BackgroundTasks if analysis.member_id != member_id:
) -> AnalysisVO: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Access denied")
# 분석 요청 생성
analysis_id = self.ulid.generate() return analysis
now = datetime.now()
def get_analysis_sessions_by_member(
analysis_vo = AnalysisVO( self,
id=analysis_id, member_id: str
member_id=member_id, ) -> list[AnalysisVO]:
ticker=request.ticker, analyses = self.analysis_repo.find_by_member_id(member_id)
analysis_date=request.analysis_date, if not analyses:
analysts_selected=[analyst.value for analyst in request.analysts], raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Analysis not found")
research_depth=request.research_depth, return analyses
llm_provider=request.llm_provider,
backend_url=request.backend_url, def create_analysis(
shallow_thinker=request.shallow_thinker, self,
deep_thinker=request.deep_thinker, member_id: str,
status="pending", request: TradingAnalysisRequest,
created_at=now, background_tasks: BackgroundTasks
updated_at=now ) -> AnalysisVO:
) # 분석 요청 생성
analysis_id = self.ulid.generate()
saved_analysis = self.analysis_repo.save(analysis_vo) now = datetime.now()
self.session.commit()
analysis_vo = AnalysisVO(
# 백그라운드에서 분석 실행 id=analysis_id,
background_tasks.add_task(self._run_analysis, saved_analysis.id) member_id=member_id,
ticker=request.ticker,
return saved_analysis analysis_date=request.analysis_date,
analysts_selected=[analyst.value for analyst in request.analysts],
async def _run_analysis(self, analysis_id: str): research_depth=request.research_depth,
"""백그라운드에서 실제 분석을 실행하는 메서드""" llm_provider=request.llm_provider,
try: backend_url=request.backend_url,
# 분석 상태를 RUNNING으로 변경 shallow_thinker=request.shallow_thinker,
analysis = self.analysis_repo.find_by_id(analysis_id) deep_thinker=request.deep_thinker,
if analysis: status=AnalysisStatus.PENDING,
analysis.status = "running" created_at=now,
analysis.updated_at = datetime.now() updated_at=now
self.analysis_repo.update(analysis) )
self.session.commit()
saved_analysis = self.analysis_repo.save(analysis_vo)
# 분석 정보 조회 if not saved_analysis:
analysis = self.analysis_repo.find_by_id(analysis_id) raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to save analysis")
if not analysis:
return self.session.commit()
# TradingAgentsGraph 설정 및 실행 # Register analysis with websocket manager
config = self._create_config(analysis) self.websocket_manager.register_analysis(saved_analysis.id, member_id)
# 분석 실행 (실제 구현) # 백그라운드에서 분석 실행
await self._execute_trading_analysis(analysis_id, analysis, config) background_tasks.add_task(self._run_analysis, saved_analysis.id)
# 분석 완료 상태로 변경 return saved_analysis
analysis = self.analysis_repo.find_by_id(analysis_id)
if analysis: async def _run_analysis(self, analysis_id: str):
analysis.status = "completed" """백그라운드에서 실제 분석을 실행하는 메서드"""
analysis.completed_at = datetime.now() try:
analysis.updated_at = datetime.now() analysis = AnalysisVO(
self.analysis_repo.update(analysis) id=analysis_id,
self.session.commit() status=AnalysisStatus.RUNNING,
updated_at=datetime.now()
except Exception as e: )
# 에러 발생 시 실패 상태로 변경
analysis = self.analysis_repo.find_by_id(analysis_id) analysis = self.analysis_repo.update(analysis)
if analysis: if not analysis:
analysis.status = "failed" raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Analysis not found")
analysis.error_message = str(e)
analysis.completed_at = datetime.now() await self.websocket_manager.send_analysis_update(
analysis.updated_at = datetime.now() analysis_id=analysis_id,
self.analysis_repo.update(analysis) update_type="status_changed",
self.session.commit() data={"status": "running", "message": "Analysis started"}
)
def _create_config(self, analysis: AnalysisVO) -> dict:
"""분석 설정을 생성하는 메서드"""
config = DEFAULT_CONFIG.copy() if DEFAULT_CONFIG else {}
config.update({ # TradingAgentsGraph 설정 및 실행
"max_debate_rounds": analysis.research_depth, if analysis:
"max_risk_discuss_rounds": analysis.research_depth, config = self._create_config(analysis)
"quick_think_llm": analysis.shallow_thinker,
"deep_think_llm": analysis.deep_thinker, # 분석 실행 (실제 구현)
"backend_url": analysis.backend_url, await self._execute_trading_analysis(analysis_id, analysis, config)
"llm_provider": analysis.llm_provider.lower(),
}) # 완료 상태로 업데이트
return config completed_analysis = AnalysisVO(
id=analysis_id,
async def _execute_trading_analysis(self, analysis_id: str, analysis: AnalysisVO, config: dict): status=AnalysisStatus.COMPLETED,
"""실제 TradingAgentsGraph를 실행하는 메서드""" completed_at=datetime.now(),
try: updated_at=datetime.now()
# TradingAgentsGraph 초기화 )
graph = TradingAgentsGraph( self.analysis_repo.update(completed_analysis)
analysis.analysts_selected, self.session.commit()
config=config,
debug=True
) except Exception as e:
now = datetime.now()
# 초기 상태 생성 updates = AnalysisVO(
init_agent_state = graph.propagator.create_initial_state( status=AnalysisStatus.FAILED,
analysis.ticker, error_message=str(e),
analysis.analysis_date completed_at = now,
) updated_at = now
args = graph.propagator.get_graph_args() )
# 분석 실행 및 결과 처리 self.analysis_repo.update(updates)
trace = [] self.session.commit()
async for chunk in graph.graph.astream(init_agent_state, **args):
trace.append(chunk)
def _create_config(self, analysis: AnalysisVO) -> dict:
# 실시간으로 분석 결과 업데이트 """분석 설정을 생성하는 메서드"""
await self._process_analysis_chunk(analysis_id, chunk) config = {}
config.update({
# 최종 결과 처리 "max_debate_rounds": analysis.research_depth,
if trace: "max_risk_discuss_rounds": analysis.research_depth,
final_state = trace[-1] "quick_think_llm": analysis.shallow_thinker,
final_decision = graph.process_signal(final_state.get("final_trade_decision", "")) "deep_think_llm": analysis.deep_thinker,
"backend_url": analysis.backend_url,
# 최종 보고서 생성 "llm_provider": analysis.llm_provider.lower(),
final_report = self._generate_final_report(final_state) })
return config
# 최종 결과 저장
self.analysis_repo.update(analysis_id, { async def _execute_trading_analysis(self, analysis_id: str, analysis: AnalysisVO, config: dict):
"final_trade_decision": final_decision, """실제 TradingAgentsGraph를 실행하는 메서드"""
"final_report": final_report try:
}) logger.info(f"Starting trading analysis for {analysis_id} with ticker {analysis.ticker}")
self.session.commit() logger.info(f"Analysts selected: {analysis.analysts_selected}")
logger.info(f"Config: {config}")
except Exception as e:
raise Exception(f"Analysis execution failed: {str(e)}") # TradingAgentsGraph 초기화
graph = TradingAgentsGraph(
async def _process_analysis_chunk(self, analysis_id: str, chunk: dict): analysis.analysts_selected,
"""분석 중간 결과를 처리하고 저장하는 메서드""" config=config,
updates = {} debug=True
)
# 개별 분석가 보고서 업데이트 logger.info("TradingAgentsGraph initialized successfully")
if "market_report" in chunk and chunk["market_report"]:
updates["market_report"] = chunk["market_report"] # 초기 상태 생성
init_agent_state = graph.propagator.create_initial_state(
if "sentiment_report" in chunk and chunk["sentiment_report"]: analysis.ticker,
updates["sentiment_report"] = chunk["sentiment_report"] analysis.analysis_date
)
if "news_report" in chunk and chunk["news_report"]: args = graph.propagator.get_graph_args()
updates["news_report"] = chunk["news_report"]
# 분석 실행 및 결과 처리
if "fundamentals_report" in chunk and chunk["fundamentals_report"]: logger.info("Starting graph execution...")
updates["fundamentals_report"] = chunk["fundamentals_report"] trace = []
chunk_count = 0
# 팀별 의사결정 과정 업데이트 async for chunk in graph.graph.astream(init_agent_state, **args):
if "investment_debate_state" in chunk and chunk["investment_debate_state"]: chunk_count += 1
updates["investment_debate_state"] = chunk["investment_debate_state"] logger.info(f"Processing chunk {chunk_count}: {list(chunk.keys()) if chunk else 'Empty chunk'}")
trace.append(chunk)
if "trader_investment_plan" in chunk and chunk["trader_investment_plan"]:
updates["trader_investment_plan"] = chunk["trader_investment_plan"] # 실시간으로 분석 결과 업데이트
await self._process_analysis_chunk(analysis_id, chunk)
if "risk_debate_state" in chunk and chunk["risk_debate_state"]:
updates["risk_debate_state"] = chunk["risk_debate_state"] # 최종 결과 처리
if trace:
# 업데이트가 있는 경우 저장 final_state = trace[-1]
if updates: final_decision = graph.process_signal(final_state.get("final_trade_decision", ""))
self.analysis_repo.update(analysis_id, updates)
self.session.commit() # 최종 보고서 생성
final_report = self._generate_final_report(final_state)
def _generate_final_report(self, final_state: dict) -> str: analysis.final_trade_decision = final_decision
"""최종 통합 보고서를 생성하는 메서드""" analysis.final_report = final_report
report_parts = []
# 최종 결과 저장
# Analyst Team Reports updates = AnalysisVO(
if any(final_state.get(section) for section in ["market_report", "sentiment_report", "news_report", "fundamentals_report"]): id=analysis_id,
report_parts.append("## Analyst Team Reports") final_trade_decision=final_decision,
final_report=final_report
if final_state.get("market_report"): )
report_parts.append(f"### Market Analysis\n{final_state['market_report']}") self.analysis_repo.update(updates)
if final_state.get("sentiment_report"):
report_parts.append(f"### Social Sentiment\n{final_state['sentiment_report']}") self.session.commit()
if final_state.get("news_report"):
report_parts.append(f"### News Analysis\n{final_state['news_report']}") except Exception as e:
if final_state.get("fundamentals_report"): raise Exception(f"Analysis execution failed: {str(e)}")
report_parts.append(f"### Fundamentals Analysis\n{final_state['fundamentals_report']}")
async def _process_analysis_chunk(self, analysis_id: str, chunk: dict):
# Research Team Reports """분석 중간 결과를 처리하고 저장하는 메서드"""
if final_state.get("investment_debate_state"): updates = {}
report_parts.append("## Research Team Decision")
debate_state = final_state["investment_debate_state"] # 개별 분석가 보고서 업데이트
if debate_state.get("judge_decision"): if "market_report" in chunk and chunk["market_report"]:
report_parts.append(f"{debate_state['judge_decision']}") updates["market_report"] = chunk["market_report"]
# Trading Team Reports if "sentiment_report" in chunk and chunk["sentiment_report"]:
if final_state.get("trader_investment_plan"): updates["sentiment_report"] = chunk["sentiment_report"]
report_parts.append("## Trading Team Plan")
report_parts.append(f"{final_state['trader_investment_plan']}") if "news_report" in chunk and chunk["news_report"]:
updates["news_report"] = chunk["news_report"]
# Portfolio Management Decision
if final_state.get("risk_debate_state") and final_state["risk_debate_state"].get("judge_decision"): if "fundamentals_report" in chunk and chunk["fundamentals_report"]:
report_parts.append("## Portfolio Management Decision") updates["fundamentals_report"] = chunk["fundamentals_report"]
report_parts.append(f"{final_state['risk_debate_state']['judge_decision']}")
# 팀별 의사결정 과정 업데이트
return "\n\n".join(report_parts) if report_parts else "No analysis results available." if "investment_debate_state" in chunk and chunk["investment_debate_state"]:
updates["investment_debate_state"] = chunk["investment_debate_state"]
if "trader_investment_plan" in chunk and chunk["trader_investment_plan"]:
updates["trader_investment_plan"] = chunk["trader_investment_plan"]
if "risk_debate_state" in chunk and chunk["risk_debate_state"]:
updates["risk_debate_state"] = chunk["risk_debate_state"]
# 업데이트가 있는 경우 저장
if updates:
# analysis_id를 포함한 AnalysisVO 객체 생성
updates["id"] = analysis_id
updates_vo = AnalysisVO(**updates)
self.analysis_repo.update(updates_vo)
self.session.commit()
def _generate_final_report(self, final_state: dict) -> str:
"""최종 통합 보고서를 생성하는 메서드"""
report_parts = []
# Analyst Team Reports
if any(final_state.get(section) for section in ["market_report", "sentiment_report", "news_report", "fundamentals_report"]):
report_parts.append("## Analyst Team Reports")
if final_state.get("market_report"):
report_parts.append(f"### Market Analysis\n{final_state['market_report']}")
if final_state.get("sentiment_report"):
report_parts.append(f"### Social Sentiment\n{final_state['sentiment_report']}")
if final_state.get("news_report"):
report_parts.append(f"### News Analysis\n{final_state['news_report']}")
if final_state.get("fundamentals_report"):
report_parts.append(f"### Fundamentals Analysis\n{final_state['fundamentals_report']}")
# Research Team Reports
if final_state.get("investment_debate_state"):
report_parts.append("## Research Team Decision")
debate_state = final_state["investment_debate_state"]
if debate_state.get("judge_decision"):
report_parts.append(f"{debate_state['judge_decision']}")
# Trading Team Reports
if final_state.get("trader_investment_plan"):
report_parts.append("## Trading Team Plan")
report_parts.append(f"{final_state['trader_investment_plan']}")
# Portfolio Management Decision
if final_state.get("risk_debate_state") and final_state["risk_debate_state"].get("judge_decision"):
report_parts.append("## Portfolio Management Decision")
report_parts.append(f"{final_state['risk_debate_state']['judge_decision']}")
return "\n\n".join(report_parts) if report_parts else "No analysis results available."

View File

@ -0,0 +1,65 @@
from typing import Dict, Set
from fastapi import WebSocket
import json
from datetime import datetime
class WebSocketManager:
def __init__(self):
# Store active connections by member_id
self.active_connections: Dict[str, Set[WebSocket]] = {}
# Store analysis_id to member_id mapping
self.analysis_member_map: Dict[str, str] = {}
async def connect(self, websocket: WebSocket, member_id: str):
await websocket.accept()
if member_id not in self.active_connections:
self.active_connections[member_id] = set()
self.active_connections[member_id].add(websocket)
def disconnect(self, websocket: WebSocket, member_id: str):
if member_id in self.active_connections:
self.active_connections[member_id].discard(websocket)
if not self.active_connections[member_id]:
del self.active_connections[member_id]
def register_analysis(self, analysis_id: str, member_id: str):
"""Register which member owns which analysis"""
self.analysis_member_map[analysis_id] = member_id
async def send_analysis_update(self, analysis_id: str, update_type: str, data: dict):
"""Send analysis update to the member who owns the analysis"""
member_id = self.analysis_member_map.get(analysis_id)
if not member_id:
return
message = {
"type": "analysis_update",
"analysis_id": analysis_id,
"update_type": update_type,
"data": data,
"timestamp": datetime.now().isoformat()
}
await self.send_to_member(member_id, message)
async def send_to_member(self, member_id: str, message: dict|str):
"""Send message to all connections of a specific member"""
if member_id not in self.active_connections:
return
dead_connections = set()
for connection in self.active_connections[member_id]:
try:
if isinstance(message, dict):
await connection.send_json(message)
else:
await connection.send_text(message)
except Exception:
dead_connections.add(connection)
# Clean up dead connections
for connection in dead_connections:
self.disconnect(connection, member_id)

View File

@ -1,19 +1,20 @@
from pydantic import BaseModel from pydantic import BaseModel, field_validator
from datetime import datetime from datetime import datetime, date
from typing import List, Dict from typing import List, Dict, Union
from analysis.infra.db_models.analysis import AnalysisStatus
class Analysis(BaseModel): class Analysis(BaseModel):
id: str | None = None id: str
member_id: str member_id: str | None = None
ticker: str ticker: str | None = None
analysis_date: str analysis_date: date | None = None
analysts_selected: List[str] = [] analysts_selected: list[str] = []
research_depth: int = 3 research_depth: int = 3
llm_provider: str = "openai" llm_provider: str = "openai"
backend_url: str = "https://api.openai.com/v1" backend_url: str = "https://api.openai.com/v1"
shallow_thinker: str = "gpt-4o-mini" shallow_thinker: str = "gpt-4o"
deep_thinker: str = "gpt-4o" deep_thinker: str = "o3"
status: str status: AnalysisStatus = AnalysisStatus.PENDING
# 개별 분석가 리포트들 # 개별 분석가 리포트들
market_report: str | None = None market_report: str | None = None
@ -33,5 +34,5 @@ class Analysis(BaseModel):
# 실행 결과 정보 # 실행 결과 정보
error_message: str | None = None error_message: str | None = None
completed_at: datetime | None = None completed_at: datetime | None = None
created_at: datetime created_at: datetime | None = None
updated_at: datetime updated_at: datetime | None = None

View File

@ -1,20 +1,20 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from analysis.domain.analysis import Analysis as AnalysisVO from analysis.domain.analysis import Analysis as AnalysisVO
from analysis.interface.dto import TradingAnalysisRequest from analysis.interface.dto import TradingAnalysisRequest
class IAnalysisRepository(ABC): class IAnalysisRepository(ABC):
@abstractmethod @abstractmethod
def find_by_member_id(self, member_id: str) -> list[AnalysisVO] | None: def find_by_member_id(self, member_id: str) -> list[AnalysisVO] | None:
raise NotImplementedError() raise NotImplementedError()
@abstractmethod @abstractmethod
def find_by_id(self, analysis_id: str) -> AnalysisVO | None: def find_by_id(self, analysis_id: str) -> AnalysisVO | None:
raise NotImplementedError() raise NotImplementedError()
@abstractmethod @abstractmethod
def update(self, analysis: AnalysisVO) -> AnalysisVO | None: def update(self, analysis: AnalysisVO) -> AnalysisVO | None:
raise NotImplementedError() raise NotImplementedError()
@abstractmethod @abstractmethod
def save(self, analysis: AnalysisVO) -> AnalysisVO: def save(self, analysis: AnalysisVO) -> AnalysisVO:
raise NotImplementedError() raise NotImplementedError()

View File

@ -1,57 +1,57 @@
from datetime import datetime,date from datetime import datetime,date
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from sqlmodel import SQLModel, Field, JSON, Relationship from sqlmodel import SQLModel, Field, JSON, Relationship
import enum import enum
from sqlalchemy import Column from sqlalchemy import Column, Text
# TYPE_CHECKING을 사용해서 circular import 방지 # TYPE_CHECKING을 사용해서 circular import 방지
if TYPE_CHECKING: if TYPE_CHECKING:
from member.infra.db_models.member import Member from member.infra.db_models.member import Member
class AnalysisStatus(str, enum.Enum): class AnalysisStatus(str, enum.Enum):
PENDING = "pending" PENDING = "pending"
RUNNING = "running" RUNNING = "running"
COMPLETED = "completed" COMPLETED = "completed"
FAILED = "failed" FAILED = "failed"
CANCELLED = "cancelled" CANCELLED = "cancelled"
class Analysis(SQLModel, table=True): class Analysis(SQLModel, table=True):
__tablename__ = "analyses" __tablename__ = "analyses"
id: str = Field(default=None, max_length=36, primary_key=True) id: str = Field(default=None, max_length=36, primary_key=True)
# 기본 분석 설정 정보 # 기본 분석 설정 정보
ticker: str ticker: str
analysis_date: date analysis_date: date
analysts_selected: list[str] = Field(sa_column=Column(JSON)) analysts_selected: list[str] = Field(sa_column=Column(JSON))
research_depth: int research_depth: int
llm_provider: str llm_provider: str
backend_url: str backend_url: str
shallow_thinker: str shallow_thinker: str
deep_thinker: str deep_thinker: str
status: AnalysisStatus = Field(default=AnalysisStatus.PENDING) status: AnalysisStatus = Field(default=AnalysisStatus.PENDING)
# 개별 분석가 리포트들 # 개별 분석가 리포트들
market_report: str | None = Field(default=None, description="Market Analyst 리포트") market_report: str | None = Field(default=None, sa_column=Column(Text), description="Market Analyst 리포트")
sentiment_report: str | None = Field(default=None, description="Social Analyst 리포트") sentiment_report: str | None = Field(default=None, sa_column=Column(Text), description="Social Analyst 리포트")
news_report: str | None = Field(default=None, description="News Analyst 리포트") news_report: str | None = Field(default=None, sa_column=Column(Text), description="News Analyst 리포트")
fundamentals_report: str | None = Field(default=None, description="Fundamentals Analyst 리포트") fundamentals_report: str | None = Field(default=None, sa_column=Column(Text), description="Fundamentals Analyst 리포트")
# 팀별 의사결정 과정 # 팀별 의사결정 과정
investment_debate_state: dict | None = Field(default=None, sa_column=Column(JSON), description="Research Team 토론 과정") investment_debate_state: dict | None = Field(default=None, sa_column=Column(JSON), description="Research Team 토론 과정")
trader_investment_plan: str | None = Field(default=None, description="Trading Team 계획") trader_investment_plan: str | None = Field(default=None, sa_column=Column(Text), description="Trading Team 계획")
risk_debate_state: dict | None = Field(default=None, sa_column=Column(JSON), description="Risk Management Team 토론 과정") risk_debate_state: dict | None = Field(default=None, sa_column=Column(JSON), description="Risk Management Team 토론 과정")
# 최종 결과물 # 최종 결과물
final_trade_decision: str | None = Field(default=None, description="최종 거래 결정") final_trade_decision: str | None = Field(default=None, sa_column=Column(Text), description="최종 거래 결정")
final_report: str | None = Field(default=None, description="전체 통합 리포트") final_report: str | None = Field(default=None, sa_column=Column(Text), description="전체 통합 리포트")
# 실행 결과 정보 # 실행 결과 정보
error_message: str | None = None error_message: str | None = Field(default=None, sa_column=Column(Text))
completed_at: datetime | None = None completed_at: datetime | None = None
created_at : datetime = Field(nullable=False) created_at : datetime = Field(nullable=False)
updated_at : datetime = Field(nullable=False) updated_at : datetime = Field(nullable=False)
# Foreign Key와 Relationship 설정 # Foreign Key와 Relationship 설정
member_id: str = Field(foreign_key="members.id") member_id: str = Field(foreign_key="members.id")
member: "Member" = Relationship(back_populates="analyses") member: "Member" = Relationship(back_populates="analyses")

View File

@ -1,80 +1,55 @@
from analysis.domain.repository.analysis_repo import IAnalysisRepository from analysis.domain.repository.analysis_repo import IAnalysisRepository
from sqlmodel import Session, select from sqlmodel import Session, select
from analysis.domain.analysis import Analysis as AnalysisVO from analysis.domain.analysis import Analysis as AnalysisVO
from analysis.infra.db_models.analysis import Analysis, AnalysisStatus from analysis.infra.db_models.analysis import Analysis, AnalysisStatus
from analysis.interface.dto import TradingAnalysisRequest from analysis.interface.dto import TradingAnalysisRequest
from utils.db_utils import row_to_dict from utils.db_utils import row_to_dict
from sqlalchemy.orm import selectinload from sqlalchemy.orm import selectinload
from datetime import datetime, date from datetime import datetime, date
class AnalysisRepository(IAnalysisRepository): class AnalysisRepository(IAnalysisRepository):
def __init__(self, session: Session): def __init__(self, session: Session):
self.session = session self.session = session
def find_by_member_id(self, member_id: str) -> list[AnalysisVO] | None: def find_by_member_id(self, member_id: str) -> list[AnalysisVO] | None:
query = select(Analysis).where(Analysis.member_id == member_id) query = select(Analysis).where(Analysis.member_id == member_id)
analyses = self.session.exec(query).all() analyses = self.session.exec(query).all()
if not analyses: if not analyses:
return None return None
return [AnalysisVO(**row_to_dict(analysis)) for analysis in analyses] return [AnalysisVO(**row_to_dict(analysis)) for analysis in analyses]
def find_by_id(self, analysis_id: str) -> AnalysisVO | None: def find_by_id(self, analysis_id: str) -> AnalysisVO | None:
analysis = self.session.get(Analysis, analysis_id) analysis = self.session.get(Analysis, analysis_id)
if not analysis: if not analysis:
return None return None
return AnalysisVO(**row_to_dict(analysis)) return AnalysisVO(**row_to_dict(analysis))
def save(self, analysis: AnalysisVO) -> AnalysisVO: def save(self, analysis: AnalysisVO) -> AnalysisVO:
new_analysis = Analysis( new_analysis = Analysis(
id=analysis.id, **analysis.model_dump()
member_id=analysis.member_id, )
ticker=analysis.ticker,
analysis_date=date.fromisoformat(analysis.analysis_date), self.session.add(new_analysis)
analysts_selected=analysis.analysts_selected, self.session.flush()
research_depth=analysis.research_depth, self.session.refresh(new_analysis)
llm_provider=analysis.llm_provider,
backend_url=analysis.backend_url, analysis.id = new_analysis.id
shallow_thinker=analysis.shallow_thinker, return analysis
deep_thinker=analysis.deep_thinker,
status=analysis.status, def update(self, analysis_vo: AnalysisVO) -> AnalysisVO | None:
market_report=analysis.market_report, analysis = self.session.get(Analysis, analysis_vo.id)
sentiment_report=analysis.sentiment_report, if not analysis:
news_report=analysis.news_report, return None
fundamentals_report=analysis.fundamentals_report,
investment_debate_state=analysis.investment_debate_state, # AnalysisVO의 데이터를 SQLModel 객체에 업데이트
trader_investment_plan=analysis.trader_investment_plan, analysis_data = analysis_vo.model_dump(exclude_unset=True)
risk_debate_state=analysis.risk_debate_state,
final_trade_decision=analysis.final_trade_decision, analysis.updated_at = datetime.now()
final_report=analysis.final_report, analysis.sqlmodel_update(analysis_data)
error_message=analysis.error_message,
completed_at=analysis.completed_at, self.session.flush()
created_at=analysis.created_at,
updated_at=analysis.updated_at
) return AnalysisVO(**row_to_dict(analysis))
self.session.add(new_analysis)
self.session.flush()
self.session.refresh(new_analysis)
analysis.id = new_analysis.id
return analysis
def update(self, analysis_vo: AnalysisVO) -> AnalysisVO | None:
analysis = self.session.get(Analysis, analysis_vo.id)
if not analysis:
return None
# AnalysisVO의 데이터를 SQLModel 객체에 업데이트
vo_data = analysis_vo.sqlmodel_dump(exclude_unset=True)
for key, value in vo_data.items():
if hasattr(analysis, key) and key != 'id': # id는 변경하지 않음
setattr(analysis, key, value)
analysis.updated_at = datetime.now()
self.session.add(analysis)
self.session.flush()
self.session.refresh(analysis)
return AnalysisVO(**row_to_dict(analysis))

View File

@ -1,108 +1,137 @@
from typing import Annotated from typing import Annotated
from fastapi import APIRouter, Depends, BackgroundTasks, HTTPException, status from fastapi import APIRouter, Depends, BackgroundTasks, HTTPException, status, WebSocket, WebSocketDisconnect
from analysis.interface.dto import ( from analysis.interface.dto import (
AnalysisSessionResponse, AnalysisSessionResponse,
TradingAnalysisRequest, TradingAnalysisRequest,
AnalysisResultResponse AnalysisResultResponse
) )
from utils.auth import get_current_member, CurrentMember from utils.auth import get_current_member, CurrentMember
from dependency_injector.wiring import inject, Provide from dependency_injector.wiring import inject, Provide
from analysis.application.analysis_service import AnalysisService from analysis.application.analysis_service import AnalysisService
from utils.containers import Container from utils.containers import Container
from analysis.application.websocket_manager import WebSocketManager
router = APIRouter(prefix="/analysis", tags=["analysis"])
router = APIRouter(prefix="/analysis", tags=["analysis"])
@router.get("/", response_model=list[AnalysisSessionResponse])
@inject @router.get("/", response_model=list[AnalysisSessionResponse])
def get_analysis_list_for_member( @inject
current_member: Annotated[CurrentMember, Depends(get_current_member)], def get_analysis_list_for_member(
analysis_service: Annotated[AnalysisService, Depends(Provide[Container.analysis_service])] current_member: Annotated[CurrentMember, Depends(get_current_member)],
): analysis_service: Annotated[AnalysisService, Depends(Provide[Container.analysis_service])]
""" ):
현재 로그인한 사용자의 모든 분석 세션 목록을 조회합니다. """
""" 현재 로그인한 사용자의 모든 분석 세션 목록을 조회합니다.
analyses = analysis_service.get_analysis_list(current_member.id) """
return [ analyses = analysis_service.get_analysis_list(current_member.id)
AnalysisSessionResponse( return [
id=analysis.id, AnalysisSessionResponse(
ticker=analysis.ticker, id=analysis.id,
status=analysis.status ticker=analysis.ticker,
) for analysis in analyses status=analysis.status
] ) for analysis in analyses
]
@router.post("/start", status_code=201, response_model=AnalysisSessionResponse)
@inject @router.post("/start", status_code=201, response_model=AnalysisSessionResponse)
def start_analysis_session( @inject
request: TradingAnalysisRequest, def start_analysis_session(
current_member: Annotated[CurrentMember, Depends(get_current_member)], request: TradingAnalysisRequest,
analysis_service: Annotated[AnalysisService, Depends(Provide[Container.analysis_service])], current_member: Annotated[CurrentMember, Depends(get_current_member)],
background_tasks: BackgroundTasks analysis_service: Annotated[AnalysisService, Depends(Provide[Container.analysis_service])],
): background_tasks: BackgroundTasks
""" ):
새로운 분석 세션을 시작합니다. """
""" 새로운 분석 세션을 시작합니다.
try:
new_analysis = analysis_service.create_analysis(current_member.id, request, background_tasks) """
return AnalysisSessionResponse( try:
id=new_analysis.id, new_analysis = analysis_service.create_analysis(current_member.id, request, background_tasks)
ticker=new_analysis.ticker, return AnalysisSessionResponse(
status=new_analysis.status id=new_analysis.id,
) ticker=new_analysis.ticker,
except Exception as e: status=new_analysis.status
raise HTTPException( )
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, except Exception as e:
detail=f"Failed to start analysis: {str(e)}" raise HTTPException(
) status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to start analysis: {str(e)}"
@router.get("/{analysis_id}", response_model=AnalysisResultResponse) )
@inject
def get_analysis_result( @router.get("/{analysis_id}", response_model=AnalysisResultResponse)
analysis_id: str, @inject
current_member: Annotated[CurrentMember, Depends(get_current_member)], def get_analysis_result(
analysis_service: Annotated[AnalysisService, Depends(Provide[Container.analysis_service])] analysis_id: str,
): current_member: Annotated[CurrentMember, Depends(get_current_member)],
""" analysis_service: Annotated[AnalysisService, Depends(Provide[Container.analysis_service])]
특정 분석 세션의 결과를 조회합니다. ):
""" """
analysis = analysis_service.get_analysis_by_id(analysis_id, current_member.id) 특정 분석 세션의 결과를 조회합니다.
"""
return AnalysisResultResponse( analysis = analysis_service.get_analysis_by_id(analysis_id, current_member.id)
id=analysis.id,
ticker=analysis.ticker, return AnalysisResultResponse(
analysis_date=analysis.analysis_date, id=analysis.id,
status=analysis.status, ticker=analysis.ticker,
market_report=analysis.market_report, analysis_date=analysis.analysis_date.isoformat() if hasattr(analysis.analysis_date, 'isoformat') else str(analysis.analysis_date),
sentiment_report=analysis.sentiment_report, status=analysis.status,
news_report=analysis.news_report, market_report=analysis.market_report,
fundamentals_report=analysis.fundamentals_report, sentiment_report=analysis.sentiment_report,
investment_debate_state=analysis.investment_debate_state, news_report=analysis.news_report,
trader_investment_plan=analysis.trader_investment_plan, fundamentals_report=analysis.fundamentals_report,
risk_debate_state=analysis.risk_debate_state, investment_debate_state=analysis.investment_debate_state,
final_trade_decision=analysis.final_trade_decision, trader_investment_plan=analysis.trader_investment_plan,
final_report=analysis.final_report, risk_debate_state=analysis.risk_debate_state,
created_at=analysis.created_at.isoformat(), final_trade_decision=analysis.final_trade_decision,
completed_at=analysis.completed_at.isoformat() if analysis.completed_at else None, final_report=analysis.final_report,
error_message=analysis.error_message created_at=analysis.created_at.isoformat(),
) completed_at=analysis.completed_at.isoformat() if analysis.completed_at else None,
error_message=analysis.error_message
@router.get("/{analysis_id}/status") )
@inject
def get_analysis_status( @router.get("/{analysis_id}/status")
analysis_id: str, @inject
current_member: Annotated[CurrentMember, Depends(get_current_member)], def get_analysis_status(
analysis_service: Annotated[AnalysisService, Depends(Provide[Container.analysis_service])] analysis_id: str,
): current_member: Annotated[CurrentMember, Depends(get_current_member)],
""" analysis_service: Annotated[AnalysisService, Depends(Provide[Container.analysis_service])]
분석 진행 상황을 조회합니다. ):
""" """
analysis = analysis_service.get_analysis_by_id(analysis_id, current_member.id) 분석 진행 상황을 조회합니다.
"""
return { analysis = analysis_service.get_analysis_by_id(analysis_id, current_member.id)
"analysis_id": analysis.id,
"status": analysis.status, return {
"ticker": analysis.ticker, "analysis_id": analysis.id,
"analysis_date": analysis.analysis_date, "status": analysis.status,
"created_at": analysis.created_at.isoformat(), "ticker": analysis.ticker,
"updated_at": analysis.updated_at.isoformat(), "analysis_date": analysis.analysis_date,
"error_message": analysis.error_message "created_at": analysis.created_at.isoformat(),
} "updated_at": analysis.updated_at.isoformat(),
"error_message": analysis.error_message
}
@router.websocket("/ws")
@inject
async def websocket_endpoint(
websocket: WebSocket,
current_member: Annotated[CurrentMember, Depends(get_current_member)],
websocket_manager: Annotated[WebSocketManager, Depends(Provide[Container.websocket_manager])]
):
"""
WebSocket endpoint for real-time analysis updates
"""
try:
# Connect the websocket
await websocket_manager.connect(websocket, current_member.id)
try:
# Keep connection alive
while True:
# Wait for messages from client (like ping/pong)
data = await websocket.receive_text()
# Echo back for heartbeat
if data == "ping":
await websocket.send_text("pong")
except WebSocketDisconnect:
websocket_manager.disconnect(websocket, current_member.id)
except Exception as e:
await websocket.close(code=1011, reason=str(e))

View File

@ -1,52 +1,52 @@
from pydantic import BaseModel from pydantic import BaseModel
from datetime import date from datetime import date
from typing import List from typing import List
from analysis.infra.db_models.analysis import AnalysisStatus from analysis.infra.db_models.analysis import AnalysisStatus
from enum import Enum from enum import Enum
class AnalystType(str, Enum): class AnalystType(str, Enum):
MARKET = "market" MARKET = "market"
SOCIAL = "social" SOCIAL = "social"
NEWS = "news" NEWS = "news"
FUNDAMENTALS = "fundamentals" FUNDAMENTALS = "fundamentals"
class TradingAnalysisRequest(BaseModel): class TradingAnalysisRequest(BaseModel):
ticker: str ticker: str = "NVDA"
analysis_date: str analysis_date: str = "2025-07-07"
analysts: List[AnalystType] analysts: List[AnalystType] = [AnalystType.MARKET, AnalystType.SOCIAL, AnalystType.NEWS, AnalystType.FUNDAMENTALS]
research_depth: int = 3 research_depth: int = 3
llm_provider: str = "openai" llm_provider: str = "openai"
backend_url: str = "https://api.openai.com/v1" backend_url: str = "https://api.openai.com/v1"
shallow_thinker: str = "gpt-4o-mini" shallow_thinker: str = "gpt-4o-mini"
deep_thinker: str = "gpt-4o" deep_thinker: str = "gpt-4o-mini"
class AnalysisSessionResponse(BaseModel): class AnalysisSessionResponse(BaseModel):
id : str id : str
ticker : str ticker : str
status : AnalysisStatus status : AnalysisStatus
class AnalysisProgressUpdate(BaseModel): class AnalysisProgressUpdate(BaseModel):
analysis_id: str analysis_id: str
current_agent: str current_agent: str
status: str status: str
progress_percentage: float progress_percentage: float
current_report_section: str | None = None current_report_section: str | None = None
message: str | None = None message: str | None = None
class AnalysisResultResponse(BaseModel): class AnalysisResultResponse(BaseModel):
id: str id: str
ticker: str ticker: str
analysis_date: str analysis_date: str
status: AnalysisStatus status: AnalysisStatus
market_report: str | None = None market_report: str | None = None
sentiment_report: str | None = None sentiment_report: str | None = None
news_report: str | None = None news_report: str | None = None
fundamentals_report: str | None = None fundamentals_report: str | None = None
investment_debate_state: dict | None = None investment_debate_state: dict | None = None
trader_investment_plan: str | None = None trader_investment_plan: str | None = None
risk_debate_state: dict | None = None risk_debate_state: dict | None = None
final_trade_decision: str | None = None final_trade_decision: str | None = None
final_report: str | None = None final_report: str | None = None
created_at: str created_at: str
completed_at: str | None = None completed_at: str | None = None
error_message: str | None = None error_message: str | None = None

View File

@ -1,20 +1,20 @@
from functools import lru_cache from functools import lru_cache
from pydantic_settings import BaseSettings, SettingsConfigDict from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings): class Settings(BaseSettings):
model_config = SettingsConfigDict( model_config = SettingsConfigDict(
env_file=".env", env_file=".env",
env_file_encoding="utf-8", env_file_encoding="utf-8",
) )
# MySQL 데이터베이스 설정 # MySQL 데이터베이스 설정
DB_HOST: str DB_HOST: str
DB_PORT: int DB_PORT: int
DB_USER: str DB_USER: str
DB_PASSWORD: str DB_PASSWORD: str
DB_NAME: str DB_NAME: str
SECRET_KEY: str SECRET_KEY: str
@lru_cache @lru_cache
def get_settings(): def get_settings():
return Settings() return Settings()

View File

@ -1,20 +1,30 @@
from fastapi import FastAPI from fastapi import FastAPI
from utils.database import create_db_and_tables from utils.database import create_db_and_tables
from utils.containers import Container from utils.containers import Container
from analysis.interface.controller.analysis_controller import router as analysis_router from analysis.interface.controller.analysis_controller import router as analysis_router
from member.interface.controller.member_controller import router as member_router from member.interface.controller.member_controller import router as member_router
import logging
# 로깅 설정
app = FastAPI() logging.basicConfig(
app.container = Container() level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
app.include_router(analysis_router) handlers=[
app.include_router(member_router) logging.StreamHandler(), # 콘솔 출력
]
)
@app.on_event("startup")
def startup_db_client():
app = FastAPI()
app.container = Container()
app.include_router(analysis_router)
app.include_router(member_router)
@app.on_event("startup")
def startup_db_client():
create_db_and_tables() create_db_and_tables()

View File

@ -1,97 +1,97 @@
from sqlmodel import Session from sqlmodel import Session
from utils.crypto import Crypto from utils.crypto import Crypto
from member.domain.repository.member_repo import IMemberRepository from member.domain.repository.member_repo import IMemberRepository
from utils.auth import Role from utils.auth import Role
from member.domain.member import Member as MemberVO from member.domain.member import Member as MemberVO
from fastapi import HTTPException, status from fastapi import HTTPException, status
from datetime import datetime from datetime import datetime
from utils.auth import create_access_token from utils.auth import create_access_token
from ulid import ULID from ulid import ULID
from analysis.domain.analysis import Analysis as AnalysisVO from analysis.domain.analysis import Analysis as AnalysisVO
class MemberService: class MemberService:
def __init__( def __init__(
self, self,
member_repo: IMemberRepository, member_repo: IMemberRepository,
crypto: Crypto, crypto: Crypto,
db_session: Session, session: Session,
ulid: ULID ulid: ULID
): ):
self.member_repo = member_repo self.member_repo = member_repo
self.crypto = crypto self.crypto = crypto
self.db_session = db_session self.db_session = session
self.ulid = ulid self.ulid = ulid
def create_member( def create_member(
self, self,
name: str, name: str,
email: str, email: str,
password: str, password: str,
role: Role role: Role
): ):
try: try:
if self.member_repo.find_by_email(email): if self.member_repo.find_by_email(email):
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Email already exists") raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Email already exists")
except Exception as e: except Exception as e:
self.db_session.rollback() self.db_session.rollback()
raise e raise e
now = datetime.now() now = datetime.now()
member_vo = MemberVO( member_vo = MemberVO(
id=self.ulid.generate(), id=self.ulid.generate(),
name=name, name=name,
email=email, email=email,
password=self.crypto.encrypt(password), password=self.crypto.encrypt(password),
created_at=now, created_at=now,
updated_at=now, updated_at=now,
role=role role=role
) )
saved_member = self.member_repo.save(member_vo) saved_member = self.member_repo.save(member_vo)
self.db_session.commit() self.db_session.commit()
return saved_member return saved_member
def get_members( def get_members(
self, self,
page: int, page: int,
items_per_page: int items_per_page: int
)->tuple[int, list[MemberVO]] : )->tuple[int, list[MemberVO]] :
return self.member_repo.get_members(page, items_per_page) return self.member_repo.get_members(page, items_per_page)
def get_member( def get_member(
self, self,
id: str id: str
)->MemberVO | None: )->MemberVO | None:
member = self.member_repo.find_by_id(id) member = self.member_repo.find_by_id(id)
if not member: if not member:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Member not found") raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Member not found")
return member return member
def login( def login(
self, self,
email: str, email: str,
password: str password: str
): ):
member = self.member_repo.find_by_email(email) member = self.member_repo.find_by_email(email)
if not member: if not member:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials") raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials")
if not self.crypto.verify(password, member.password): if not self.crypto.verify(password, member.password):
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials") raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials")
access_token = create_access_token( access_token = create_access_token(
payload={"member_id": member.id, "role": member.role}, payload={"member_id": member.id, "role": member.role},
role=member.role, role=member.role,
) )
return access_token return access_token
def get_analysis_sessions_by_member( def get_analysis_sessions_by_member(
self, self,
member_id: str member_id: str
)->list[AnalysisVO]: )->list[AnalysisVO]:
analysis_sessions = self.member_repo.find_analysis_sessions_by_member(member_id) analysis_sessions = self.member_repo.find_analysis_sessions_by_member(member_id)
return analysis_sessions return analysis_sessions

View File

@ -1,12 +1,12 @@
from pydantic import BaseModel from pydantic import BaseModel
from utils.auth import Role from utils.auth import Role
from datetime import datetime from datetime import datetime
class Member(BaseModel): class Member(BaseModel):
id: str | None = None id: str | None = None
name: str name: str
email: str email: str
password: str password: str
role: Role role: Role
created_at: datetime created_at: datetime
updated_at: datetime updated_at: datetime

View File

@ -1,4 +1,4 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
class IMemberRepository(ABC): class IMemberRepository(ABC):
pass pass

View File

@ -1,24 +1,24 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from member.domain.member import Member as MemberVO from member.domain.member import Member as MemberVO
from analysis.domain.analysis import Analysis as AnalysisVO from analysis.domain.analysis import Analysis as AnalysisVO
class IMemberRepository(ABC): class IMemberRepository(ABC):
@abstractmethod @abstractmethod
def find_by_email(self, email: str) -> MemberVO | None: def find_by_email(self, email: str) -> MemberVO | None:
raise NotImplementedError() raise NotImplementedError()
@abstractmethod @abstractmethod
def save(self, member: MemberVO) -> MemberVO: def save(self, member: MemberVO) -> MemberVO:
raise NotImplementedError() raise NotImplementedError()
@abstractmethod @abstractmethod
def find_by_id(self, id: str) -> MemberVO | None: def find_by_id(self, id: str) -> MemberVO | None:
raise NotImplementedError() raise NotImplementedError()
@abstractmethod @abstractmethod
def get_members(self, page: int, items_per_page: int) -> tuple[int, list[MemberVO]]: def get_members(self, page: int, items_per_page: int) -> tuple[int, list[MemberVO]]:
raise NotImplementedError() raise NotImplementedError()
@abstractmethod @abstractmethod
def find_analysis_sessions_by_member(self, member_id: str) -> list[AnalysisVO]: def find_analysis_sessions_by_member(self, member_id: str) -> list[AnalysisVO]:
raise NotImplementedError() raise NotImplementedError()

View File

@ -1,66 +1,66 @@
from member.domain.repository import IMemberRepository from member.domain.repository import IMemberRepository
from sqlmodel import Session, select from sqlmodel import Session, select
from member.domain.member import Member as MemberVO from member.domain.member import Member as MemberVO
from member.infra.db_models.member import Member from member.infra.db_models.member import Member
from utils.db_utils import row_to_dict from utils.db_utils import row_to_dict
from sqlalchemy import func from sqlalchemy import func
class MemberRepository(IMemberRepository): class MemberRepository(IMemberRepository):
def __init__(self, session: Session): def __init__(self, session: Session):
self.session = session self.session = session
def find_by_email(self, email: str) -> MemberVO | None: def find_by_email(self, email: str) -> MemberVO | None:
query = select(Member).where(Member.email == email) query = select(Member).where(Member.email == email)
member = self.session.exec(query).first() member = self.session.exec(query).first()
if not member: if not member:
return None return None
return MemberVO(**row_to_dict(member)) return MemberVO(**row_to_dict(member))
def save(self, member: MemberVO) -> MemberVO: def save(self, member: MemberVO) -> MemberVO:
new_member = Member( new_member = Member(
id=member.id, id=member.id,
email=member.email, email=member.email,
name=member.name, name=member.name,
password=member.password, password=member.password,
role=member.role, role=member.role,
created_at=member.created_at, created_at=member.created_at,
updated_at=member.updated_at updated_at=member.updated_at
) )
self.session.add(new_member) self.session.add(new_member)
self.session.flush() self.session.flush()
self.session.refresh(new_member) self.session.refresh(new_member)
member.id = new_member.id member.id = new_member.id
return member return member
def get_members(self, page: int, items_per_page: int) -> tuple[int, list[MemberVO]]: def get_members(self, page: int, items_per_page: int) -> tuple[int, list[MemberVO]]:
offset = (page - 1) * items_per_page offset = (page - 1) * items_per_page
total_count_query = select(func.count(Member.id)) total_count_query = select(func.count(Member.id))
total_count = self.session.exec(total_count_query).one() total_count = self.session.exec(total_count_query).one()
if total_count == 0: if total_count == 0:
return 0, [] return 0, []
query = ( query = (
select(Member) select(Member)
.order_by(Member.created_at.desc()) .order_by(Member.created_at.desc())
.offset(offset) .offset(offset)
.limit(items_per_page) .limit(items_per_page)
) )
members = self.session.exec(query).all() members = self.session.exec(query).all()
return total_count, [MemberVO(**row_to_dict(member)) for member in members] return total_count, [MemberVO(**row_to_dict(member)) for member in members]
def find_by_id(self, id: str) -> MemberVO | None: def find_by_id(self, id: str) -> MemberVO | None:
query = select(Member).where(Member.id == id) query = select(Member).where(Member.id == id)
member = self.session.exec(query).first() member = self.session.exec(query).first()
if not member: if not member:
return None return None
return MemberVO(**row_to_dict(member)) return MemberVO(**row_to_dict(member))

View File

@ -1,78 +1,71 @@
from fastapi import APIRouter, status, Depends,HTTPException from fastapi import APIRouter, status, Depends,HTTPException
from member.interface.dto import CreateUserBody, MemberResponse from member.interface.dto import CreateUserBody, MemberResponse
from member.application.member_service import MemberService from member.application.member_service import MemberService
from typing import Annotated from typing import Annotated
from utils.containers import Container from utils.containers import Container
from dependency_injector.wiring import inject, Provide from dependency_injector.wiring import inject, Provide
from fastapi.security import OAuth2PasswordRequestForm from fastapi.security import OAuth2PasswordRequestForm
from utils.auth import get_current_member, CurrentMember, get_admin_member from utils.auth import get_current_member, CurrentMember, get_admin_member
from analysis.interface.dto import AnalysisSessionResponse
router = APIRouter(prefix="/members", tags=["members"]) from analysis.application.analysis_service import AnalysisService
@router.post("", status_code=status.HTTP_201_CREATED, response_model=MemberResponse) router = APIRouter(prefix="/members", tags=["members"])
@inject
async def create_user( @router.post("", status_code=status.HTTP_201_CREATED, response_model=MemberResponse)
member: CreateUserBody, @inject
member_service: MemberService = Depends(Provide[Container.member_service]) async def create_user(
): member: CreateUserBody,
created_member = member_service.create_member( member_service: MemberService = Depends(Provide[Container.member_service])
member.name, ):
member.email, created_member = member_service.create_member(
member.password, member.name,
member.role member.email,
) member.password,
member.role
return created_member )
@router.post("/login") return created_member
@inject
def login( @router.post("/login")
form_data: Annotated[OAuth2PasswordRequestForm, Depends()], @inject
member_service: MemberService = Depends(Provide[Container.member_service]) def login(
): form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
access_token = member_service.login( member_service: MemberService = Depends(Provide[Container.member_service])
email=form_data.username, ):
password=form_data.password access_token = member_service.login(
) email=form_data.username,
password=form_data.password
return { )
"access_token" : access_token,
"token_type" : "Bearer" return {
} "access_token" : access_token,
"token_type" : "Bearer"
@router.get("/me", response_model=dict) }
def get_current_user_info(
current_user: CurrentMember = Depends(get_current_member) @router.get("/me", response_model=dict)
): def get_current_user_info(
""" current_user: CurrentMember = Depends(get_current_member)
현재 로그인한 사용자 정보를 조회합니다. ):
엔드포인트는 JWT 토큰이 필요하며, Swagger UI에서 Authorize 버튼을 활성화합니다. """
""" 현재 로그인한 사용자 정보를 조회합니다.
return { 엔드포인트는 JWT 토큰이 필요하며, Swagger UI에서 Authorize 버튼을 활성화합니다.
"user_id": current_user.id, """
"role": current_user.role, return {
"message": "Successfully authenticated" "user_id": current_user.id,
} "role": current_user.role,
"message": "Successfully authenticated"
@router.get("/{member_id}", response_model=MemberResponse) }
@inject
def get_member( @router.get("/{member_id}", response_model=MemberResponse)
member_id: str, @inject
current_member: Annotated[CurrentMember | None, Depends(get_current_member)] = None, def get_member(
member_service: Annotated[MemberService | None, Depends(Provide[Container.member_service])] = None member_id: str,
): current_member: Annotated[CurrentMember | None, Depends(get_current_member)] = None,
member_service: Annotated[MemberService | None, Depends(Provide[Container.member_service])] = None
member = member_service.get_member(member_id) ):
if not member:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Member not found") member = member_service.get_member(member_id)
return member if not member:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Member not found")
# @router.get("/analysis-sessions", response_model=list[AnalysisSessionResponse]) return member
# @inject
# def get_member_analysis_sessions(
# current_member: Annotated[CurrentMember | None, Depends(get_current_member)] = None,
# member_service: Annotated[MemberService | None, Depends(Provide[Container.member_service])] = None
# ):
# result = member_service.get_analysis_sessions_by_member(current_member.id)
# return result

View File

@ -1,18 +1,18 @@
from typing import Annotated from typing import Annotated
from pydantic import BaseModel, Field, EmailStr from pydantic import BaseModel, Field, EmailStr
from utils.auth import Role from utils.auth import Role
from datetime import datetime from datetime import datetime
class CreateUserBody(BaseModel): class CreateUserBody(BaseModel):
name : Annotated[str, Field(min_length=1, max_length=32)] name : Annotated[str, Field(min_length=1, max_length=32)]
email : Annotated[EmailStr, Field(max_length=32)] email : Annotated[EmailStr, Field(max_length=32)]
password : Annotated[str, Field(max_length=32)] password : Annotated[str, Field(max_length=32)]
role : Annotated[Role, Field(default=Role.USER)] role : Annotated[Role, Field(default=Role.USER)]
class MemberResponse(BaseModel): class MemberResponse(BaseModel):
id : str id : str
name : str | None = None name : str | None = None
email : str email : str
created_at : datetime created_at : datetime
updated_at : datetime updated_at : datetime
role : Role role : Role

View File

@ -1,69 +1,69 @@
from datetime import datetime, timedelta from datetime import datetime, timedelta
from fastapi import HTTPException, status, Depends from fastapi import HTTPException, status, Depends
from jose import jwt, JWTError from jose import jwt, JWTError
from fastapi.security import OAuth2PasswordBearer from fastapi.security import OAuth2PasswordBearer
import os import os
from dotenv import load_dotenv from dotenv import load_dotenv
from enum import StrEnum from enum import StrEnum
from pydantic import BaseModel from pydantic import BaseModel
from typing import Annotated from typing import Annotated
from config import get_settings from config import get_settings
settings = get_settings() settings = get_settings()
SECRET_KEY = settings.SECRET_KEY SECRET_KEY = settings.SECRET_KEY
ALGORITHM = "HS256" ALGORITHM = "HS256"
class Role(StrEnum): class Role(StrEnum):
ADMIN = "ADMIN" ADMIN = "ADMIN"
USER = "USER" USER = "USER"
class CurrentMember(BaseModel): class CurrentMember(BaseModel):
id : str id : str
role : Role role : Role
def __str__(self): def __str__(self):
return f"{self.id}({self.role})" return f"{self.id}({self.role})"
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/members/login") oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/members/login")
def create_access_token( def create_access_token(
payload: dict, payload: dict,
role: Role, role: Role,
expires_delta: timedelta = timedelta(hours=6) expires_delta: timedelta = timedelta(hours=6)
): ):
expire = datetime.utcnow() + expires_delta expire = datetime.utcnow() + expires_delta
payload.update({"exp": expire, "role": role}) payload.update({"exp": expire, "role": role})
encoded_jwt = jwt.encode(payload, SECRET_KEY, algorithm=ALGORITHM) encoded_jwt = jwt.encode(payload, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt return encoded_jwt
def decode_access_token(token: str): def decode_access_token(token: str):
try: try:
return jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) return jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
except JWTError: except JWTError:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token") raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token")
# ✅ 수정된 부분: Annotated 올바른 사용법 # ✅ 수정된 부분: Annotated 올바른 사용법
def get_current_member(token: Annotated[str, Depends(oauth2_scheme)]): def get_current_member(token: Annotated[str, Depends(oauth2_scheme)]):
payload = decode_access_token(token) payload = decode_access_token(token)
member_id = payload.get("member_id") member_id = payload.get("member_id")
role = payload.get("role") role = payload.get("role")
if not member_id or not role: if not member_id or not role:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Invalid token") raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Invalid token")
return CurrentMember(id=member_id, role=Role(role)) return CurrentMember(id=member_id, role=Role(role))
def get_admin_member(token: Annotated[str, Depends(oauth2_scheme)]): def get_admin_member(token: Annotated[str, Depends(oauth2_scheme)]):
payload = decode_access_token(token) payload = decode_access_token(token)
member_id = payload.get("member_id") member_id = payload.get("member_id")
role = payload.get("role") role = payload.get("role")
if not role or role != Role.ADMIN: if not role or role != Role.ADMIN:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Invalid token") raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Invalid token")
return CurrentMember(id=member_id, role=Role(role)) return CurrentMember(id=member_id, role=Role(role))

View File

@ -1,43 +1,49 @@
from dependency_injector import containers, providers from dependency_injector import containers, providers
from utils.database import get_session from utils.database import get_session
from utils.crypto import Crypto from utils.crypto import Crypto
from member.infra.repository.member_repo import MemberRepository from member.infra.repository.member_repo import MemberRepository
from member.application.member_service import MemberService from member.application.member_service import MemberService
from analysis.application.analysis_service import AnalysisService from analysis.application.analysis_service import AnalysisService
from analysis.infra.repository.analysis_repo import AnalysisRepository from analysis.infra.repository.analysis_repo import AnalysisRepository
from ulid import ULID from analysis.application.websocket_manager import WebSocketManager
from ulid import ULID
class Container(containers.DeclarativeContainer):
wiring_config = containers.WiringConfiguration( class Container(containers.DeclarativeContainer):
packages=["member", "analysis"] wiring_config = containers.WiringConfiguration(
) packages=["member", "analysis"]
)
db_session = providers.Resource(get_session)
crypto = providers.Factory(Crypto) session = providers.Resource(get_session)
ulid = providers.Factory(ULID) crypto = providers.Factory(Crypto)
ulid = providers.Factory(ULID)
member_repo = providers.Factory(
MemberRepository, member_repo = providers.Factory(
session=db_session MemberRepository,
) session=session
)
member_service = providers.Factory(
MemberService, member_service = providers.Factory(
member_repo=member_repo, MemberService,
crypto=crypto, member_repo=member_repo,
db_session=db_session, crypto=crypto,
ulid=ulid session=session,
) ulid=ulid
)
analysis_repo = providers.Factory(
AnalysisRepository, analysis_repo = providers.Factory(
session=db_session AnalysisRepository,
) session=session
)
analysis_service = providers.Factory(
AnalysisService, websocket_manager = providers.Singleton(
analysis_repo=analysis_repo, WebSocketManager
db_session=db_session, )
ulid=ulid
) analysis_service = providers.Factory(
AnalysisService,
analysis_repo=analysis_repo,
session=session,
ulid=ulid,
websocket_manager=websocket_manager
)

View File

@ -1,12 +1,12 @@
from passlib.context import CryptContext from passlib.context import CryptContext
class Crypto: class Crypto:
def __init__(self): def __init__(self):
self.pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") self.pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def encrypt(self, secret): def encrypt(self, secret):
return self.pwd_context.hash(secret) return self.pwd_context.hash(secret)
def verify(self, secret, hash): def verify(self, secret, hash):
return self.pwd_context.verify(secret, hash) return self.pwd_context.verify(secret, hash)

View File

@ -1,32 +1,32 @@
import os import os
from pathlib import Path from pathlib import Path
from sqlmodel import SQLModel, create_engine, Session from sqlmodel import SQLModel, create_engine, Session
from config.config import get_settings from config.config import get_settings
from member.infra.db_models.member import Member from member.infra.db_models.member import Member
from analysis.infra.db_models.analysis import Analysis from analysis.infra.db_models.analysis import Analysis
settings = get_settings() settings = get_settings()
BASE_DIR = Path(__file__).resolve().parent.parent BASE_DIR = Path(__file__).resolve().parent.parent
# MySQL 데이터베이스 URL 구성 # MySQL 데이터베이스 URL 구성
DATABASE_URL = f"mysql+pymysql://{settings.DB_USER}:{settings.DB_PASSWORD}@{settings.DB_HOST}:{settings.DB_PORT}/{settings.DB_NAME}?charset=utf8mb4" DATABASE_URL = f"mysql+pymysql://{settings.DB_USER}:{settings.DB_PASSWORD}@{settings.DB_HOST}:{settings.DB_PORT}/{settings.DB_NAME}?charset=utf8mb4"
# MySQL 엔진 생성 # MySQL 엔진 생성
engine = create_engine( engine = create_engine(
DATABASE_URL, DATABASE_URL,
echo=True echo=True
) )
def get_session(): def get_session():
with Session(engine) as session: with Session(engine) as session:
yield session yield session
def create_db_and_tables(): def create_db_and_tables():
# 테이블 생성 # 테이블 생성
# SQLModel.metadata.drop_all(engine) # SQLModel.metadata.drop_all(engine)
SQLModel.metadata.create_all(engine) SQLModel.metadata.create_all(engine)
if __name__ == "__main__": if __name__ == "__main__":
create_db_and_tables() create_db_and_tables()
print(DATABASE_URL) print(DATABASE_URL)

View File

@ -1,4 +1,4 @@
from sqlalchemy import inspect from sqlalchemy import inspect
def row_to_dict(row)->dict: def row_to_dict(row)->dict:
return {key : getattr(row, key) for key in inspect(row).attrs.keys()} return {key : getattr(row, key) for key in inspect(row).attrs.keys()}

View File

@ -1,59 +1,59 @@
version: '3.8' version: '3.8'
services: services:
mysql: mysql:
image: mysql:8.0 image: mysql:8.0
container_name: tradingagents_mysql container_name: tradingagents_mysql
restart: unless-stopped restart: unless-stopped
environment: environment:
MYSQL_ROOT_PASSWORD: ${DB_PASSWORD:-password} MYSQL_ROOT_PASSWORD: ${DB_PASSWORD:-password}
MYSQL_DATABASE: ${DB_NAME:-tradingagents_db} MYSQL_DATABASE: ${DB_NAME:-tradingagents_db}
MYSQL_USER: ${DB_USER:-tradinguser} MYSQL_USER: ${DB_USER:-tradinguser}
MYSQL_PASSWORD: ${DB_PASSWORD:-password} MYSQL_PASSWORD: ${DB_PASSWORD:-password}
ports: ports:
- "3306:3306" - "3306:3306"
volumes: volumes:
- /home/hskim/mysql_data:/var/lib/mysql - /home/hskim/mysql_data:/var/lib/mysql
- /home/hskim/docker/mysql/init:/docker-entrypoint-initdb.d - /home/hskim/docker/mysql/init:/docker-entrypoint-initdb.d
command: --default-authentication-plugin=mysql_native_password --character-set-server=utf8mb4 --collation-server=utf8mb4_unicode_ci command: --default-authentication-plugin=mysql_native_password --character-set-server=utf8mb4 --collation-server=utf8mb4_unicode_ci
networks: networks:
- tradingagents_network - tradingagents_network
redis: redis:
image: redis:7-alpine image: redis:7-alpine
container_name: tradingagents_redis container_name: tradingagents_redis
restart: unless-stopped restart: unless-stopped
ports: ports:
- "6379:6379" - "6379:6379"
volumes: volumes:
- redis_data:/data - redis_data:/data
command: redis-server --appendonly yes command: redis-server --appendonly yes
networks: networks:
- tradingagents_network - tradingagents_network
# 개발용 phpMyAdmin (선택사항) # 개발용 phpMyAdmin (선택사항)
# phpmyadmin: # phpmyadmin:
# image: phpmyadmin/phpmyadmin # image: phpmyadmin/phpmyadmin
# container_name: tradingagents_phpmyadmin # container_name: tradingagents_phpmyadmin
# restart: unless-stopped # restart: unless-stopped
# environment: # environment:
# PMA_HOST: mysql # PMA_HOST: mysql
# PMA_PORT: 3306 # PMA_PORT: 3306
# PMA_USER: root # PMA_USER: root
# PMA_PASSWORD: ${DB_PASSWORD:-password} # PMA_PASSWORD: ${DB_PASSWORD:-password}
# ports: # ports:
# - "8080:80" # - "8080:80"
# depends_on: # depends_on:
# - mysql # - mysql
# networks: # networks:
# - tradingagents_network # - tradingagents_network
volumes: volumes:
mysql_data: mysql_data:
driver: local driver: local
redis_data: redis_data:
driver: local driver: local
networks: networks:
tradingagents_network: tradingagents_network:
driver: bridge driver: bridge

53212
frontend/package-lock.json generated

File diff suppressed because it is too large Load Diff

View File

@ -1,48 +1,48 @@
{ {
"name": "tradingagents-web-frontend", "name": "tradingagents-web-frontend",
"version": "0.1.0", "version": "0.1.0",
"private": true, "private": true,
"dependencies": { "dependencies": {
"@ant-design/icons": "^5.2.6", "@ant-design/icons": "^5.2.6",
"@testing-library/jest-dom": "^5.16.4", "@testing-library/jest-dom": "^5.16.4",
"@testing-library/react": "^13.3.0", "@testing-library/react": "^13.3.0",
"@testing-library/user-event": "^13.5.0", "@testing-library/user-event": "^13.5.0",
"antd": "^5.10.0", "antd": "^5.10.0",
"axios": "^1.5.0", "axios": "^1.5.0",
"dayjs": "^1.11.9", "dayjs": "^1.11.9",
"react": "^18.2.0", "react": "^18.2.0",
"react-dom": "^18.2.0", "react-dom": "^18.2.0",
"react-markdown": "^8.0.7", "react-markdown": "^8.0.7",
"react-router-dom": "^6.4.0", "react-router-dom": "^6.4.0",
"react-scripts": "5.0.1", "react-scripts": "5.0.1",
"recharts": "^2.8.0", "recharts": "^2.8.0",
"remark-gfm": "^4.0.1", "remark-gfm": "^4.0.1",
"styled-components": "^6.0.8", "styled-components": "^6.0.8",
"websocket": "^1.0.34" "websocket": "^1.0.34"
}, },
"scripts": { "scripts": {
"start": "react-scripts start", "start": "react-scripts start",
"build": "react-scripts build", "build": "react-scripts build",
"test": "react-scripts test", "test": "react-scripts test",
"eject": "react-scripts eject" "eject": "react-scripts eject"
}, },
"eslintConfig": { "eslintConfig": {
"extends": [ "extends": [
"react-app", "react-app",
"react-app/jest" "react-app/jest"
] ]
}, },
"browserslist": { "browserslist": {
"production": [ "production": [
">0.2%", ">0.2%",
"not dead", "not dead",
"not op_mini all" "not op_mini all"
], ],
"development": [ "development": [
"last 1 chrome version", "last 1 chrome version",
"last 1 firefox version", "last 1 firefox version",
"last 1 safari version" "last 1 safari version"
] ]
}, },
"proxy": "http://localhost:8000" "proxy": "http://localhost:8000"
} }

View File

@ -1,20 +1,20 @@
<!DOCTYPE html> <!DOCTYPE html>
<html lang="ko"> <html lang="ko">
<head> <head>
<meta charset="utf-8" /> <meta charset="utf-8" />
<link rel="icon" href="%PUBLIC_URL%/favicon.ico" /> <link rel="icon" href="%PUBLIC_URL%/favicon.ico" />
<meta name="viewport" content="width=device-width, initial-scale=1" /> <meta name="viewport" content="width=device-width, initial-scale=1" />
<meta name="theme-color" content="#000000" /> <meta name="theme-color" content="#000000" />
<meta <meta
name="description" name="description"
content="TradingAgents - Multi-Agents LLM Financial Trading Framework" content="TradingAgents - Multi-Agents LLM Financial Trading Framework"
/> />
<link rel="apple-touch-icon" href="%PUBLIC_URL%/logo192.png" /> <link rel="apple-touch-icon" href="%PUBLIC_URL%/logo192.png" />
<link rel="manifest" href="%PUBLIC_URL%/manifest.json" /> <link rel="manifest" href="%PUBLIC_URL%/manifest.json" />
<title>TradingAgents - AI 거래 분석 플랫폼</title> <title>TradingAgents - AI 거래 분석 플랫폼</title>
</head> </head>
<body> <body>
<noscript>JavaScript를 활성화해야 이 앱을 실행할 수 있습니다.</noscript> <noscript>JavaScript를 활성화해야 이 앱을 실행할 수 있습니다.</noscript>
<div id="root"></div> <div id="root"></div>
</body> </body>
</html> </html>

42
main.py
View File

@ -1,21 +1,21 @@
from tradingagents.graph.trading_graph import TradingAgentsGraph from tradingagents.graph.trading_graph import TradingAgentsGraph
from tradingagents.default_config import DEFAULT_CONFIG from tradingagents.default_config import DEFAULT_CONFIG
# Create a custom config # Create a custom config
config = DEFAULT_CONFIG.copy() config = DEFAULT_CONFIG.copy()
config["llm_provider"] = "google" # Use a different model config["llm_provider"] = "google" # Use a different model
config["backend_url"] = "https://generativelanguage.googleapis.com/v1" # Use a different backend config["backend_url"] = "https://generativelanguage.googleapis.com/v1" # Use a different backend
config["deep_think_llm"] = "gemini-2.5-pro" # Use a different model config["deep_think_llm"] = "gemini-2.5-pro" # Use a different model
config["quick_think_llm"] = "gemini-2.5-flash-lite-preview-06-17" # Use a different model config["quick_think_llm"] = "gemini-2.5-flash-lite-preview-06-17" # Use a different model
config["max_debate_rounds"] = 1 # Increase debate rounds config["max_debate_rounds"] = 1 # Increase debate rounds
config["online_tools"] = True # Increase debate rounds config["online_tools"] = True # Increase debate rounds
# Initialize with custom config # Initialize with custom config
ta = TradingAgentsGraph(debug=True, config=config) ta = TradingAgentsGraph(debug=True, config=config)
# forward propagate # forward propagate
_, decision = ta.propagate("NVDA", "2024-05-10") _, decision = ta.propagate("NVDA", "2024-05-10")
print(decision) print(decision)
# Memorize mistakes and reflect # Memorize mistakes and reflect
# ta.reflect_and_remember(1000) # parameter is the position returns # ta.reflect_and_remember(1000) # parameter is the position returns

View File

@ -1,34 +1,34 @@
[project] [project]
name = "tradingagents" name = "tradingagents"
version = "0.1.0" version = "0.1.0"
description = "Add your description here" description = "Add your description here"
readme = "README.md" readme = "README.md"
requires-python = ">=3.10" requires-python = ">=3.10"
dependencies = [ dependencies = [
"akshare>=1.16.98", "akshare>=1.16.98",
"backtrader>=1.9.78.123", "backtrader>=1.9.78.123",
"chainlit>=2.5.5", "chainlit>=2.5.5",
"chromadb>=1.0.12", "chromadb>=1.0.12",
"eodhd>=1.0.32", "eodhd>=1.0.32",
"feedparser>=6.0.11", "feedparser>=6.0.11",
"finnhub-python>=2.4.23", "finnhub-python>=2.4.23",
"langchain-anthropic>=0.3.15", "langchain-anthropic>=0.3.15",
"langchain-experimental>=0.3.4", "langchain-experimental>=0.3.4",
"langchain-google-genai>=2.1.5", "langchain-google-genai>=2.1.5",
"langchain-openai>=0.3.23", "langchain-openai>=0.3.23",
"langgraph>=0.4.8", "langgraph>=0.4.8",
"pandas>=2.3.0", "pandas>=2.3.0",
"parsel>=1.10.0", "parsel>=1.10.0",
"praw>=7.8.1", "praw>=7.8.1",
"pytz>=2025.2", "pytz>=2025.2",
"questionary>=2.1.0", "questionary>=2.1.0",
"redis>=6.2.0", "redis>=6.2.0",
"requests>=2.32.4", "requests>=2.32.4",
"rich>=14.0.0", "rich>=14.0.0",
"setuptools>=80.9.0", "setuptools>=80.9.0",
"stockstats>=0.6.5", "stockstats>=0.6.5",
"tqdm>=4.67.1", "tqdm>=4.67.1",
"tushare>=1.4.21", "tushare>=1.4.21",
"typing-extensions>=4.14.0", "typing-extensions>=4.14.0",
"yfinance>=0.2.63", "yfinance>=0.2.63",
] ]

View File

@ -1,60 +1,60 @@
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
import time import time
import json import json
def create_news_analyst(llm, toolkit): def create_news_analyst(llm, toolkit):
def news_analyst_node(state): def news_analyst_node(state):
current_date = state["trade_date"] current_date = state["trade_date"]
ticker = state["company_of_interest"] ticker = state["company_of_interest"]
if toolkit.config["online_tools"]: if toolkit.config["online_tools"]:
tools = [toolkit.get_global_news, toolkit.get_google_news] tools = [toolkit.get_global_news, toolkit.get_google_news]
else: else:
tools = [ tools = [
toolkit.get_finnhub_news, toolkit.get_finnhub_news,
toolkit.get_reddit_news, toolkit.get_reddit_news,
toolkit.get_google_news, toolkit.get_google_news,
] ]
system_message = ( system_message = (
"**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)\n\nYou are a news researcher tasked with analyzing recent news and trends over the past week. Please write a comprehensive report of the current state of the world that is relevant for trading and macroeconomics. Look at news from EODHD, and finnhub to be comprehensive. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions." "**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)\n\nYou are a news researcher tasked with analyzing recent news and trends over the past week. Please write a comprehensive report of the current state of the world that is relevant for trading and macroeconomics. Look at news from EODHD, and finnhub to be comprehensive. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."
+ """ Make sure to append a Makrdown table at the end of the report to organize key points in the report, organized and easy to read.""" + """ Make sure to append a Makrdown table at the end of the report to organize key points in the report, organized and easy to read."""
) )
prompt = ChatPromptTemplate.from_messages( prompt = ChatPromptTemplate.from_messages(
[ [
( (
"system", "system",
"You are a helpful AI assistant, collaborating with other assistants." "You are a helpful AI assistant, collaborating with other assistants."
" Use the provided tools to progress towards answering the question." " Use the provided tools to progress towards answering the question."
" If you are unable to fully answer, that's OK; another assistant with different tools" " If you are unable to fully answer, that's OK; another assistant with different tools"
" will help where you left off. Execute what you can to make progress." " will help where you left off. Execute what you can to make progress."
" If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable," " If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable,"
" prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop." " prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop."
" You have access to the following tools: {tool_names}.\n{system_message}" " You have access to the following tools: {tool_names}.\n{system_message}"
"For your reference, the current date is {current_date}. We are looking at the company {ticker}", "For your reference, the current date is {current_date}. We are looking at the company {ticker}",
), ),
MessagesPlaceholder(variable_name="messages"), MessagesPlaceholder(variable_name="messages"),
] ]
) )
prompt = prompt.partial(system_message=system_message) prompt = prompt.partial(system_message=system_message)
prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools])) prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools]))
prompt = prompt.partial(current_date=current_date) prompt = prompt.partial(current_date=current_date)
prompt = prompt.partial(ticker=ticker) prompt = prompt.partial(ticker=ticker)
chain = prompt | llm.bind_tools(tools) chain = prompt | llm.bind_tools(tools)
result = chain.invoke(state["messages"]) result = chain.invoke(state["messages"])
report = "" report = ""
if len(result.tool_calls) == 0: if len(result.tool_calls) == 0:
report = result.content report = result.content
return { return {
"messages": [result], "messages": [result],
"news_report": report, "news_report": report,
} }
return news_analyst_node return news_analyst_node

View File

@ -1,60 +1,60 @@
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
import time import time
import json import json
def create_social_media_analyst(llm, toolkit): def create_social_media_analyst(llm, toolkit):
def social_media_analyst_node(state): def social_media_analyst_node(state):
current_date = state["trade_date"] current_date = state["trade_date"]
ticker = state["company_of_interest"] ticker = state["company_of_interest"]
company_name = state["company_of_interest"] company_name = state["company_of_interest"]
if toolkit.config["online_tools"]: if toolkit.config["online_tools"]:
tools = [toolkit.get_stock_news] tools = [toolkit.get_stock_news]
else: else:
tools = [ tools = [
toolkit.get_reddit_stock_info, toolkit.get_reddit_stock_info,
] ]
system_message = ( system_message = (
"**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)\n\nYou are a social media and company specific news researcher/analyst tasked with analyzing social media posts, recent company news, and public sentiment for a specific company over the past week. You will be given a company's name your objective is to write a comprehensive long report detailing your analysis, insights, and implications for traders and investors on this company's current state after looking at social media and what people are saying about that company, analyzing sentiment data of what people feel each day about the company, and looking at recent company news. Try to look at all sources possible from social media to sentiment to news. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions." "**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)\n\nYou are a social media and company specific news researcher/analyst tasked with analyzing social media posts, recent company news, and public sentiment for a specific company over the past week. You will be given a company's name your objective is to write a comprehensive long report detailing your analysis, insights, and implications for traders and investors on this company's current state after looking at social media and what people are saying about that company, analyzing sentiment data of what people feel each day about the company, and looking at recent company news. Try to look at all sources possible from social media to sentiment to news. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."
+ """ Make sure to append a Makrdown table at the end of the report to organize key points in the report, organized and easy to read.""", + """ Make sure to append a Makrdown table at the end of the report to organize key points in the report, organized and easy to read.""",
) )
prompt = ChatPromptTemplate.from_messages( prompt = ChatPromptTemplate.from_messages(
[ [
( (
"system", "system",
"You are a helpful AI assistant, collaborating with other assistants." "You are a helpful AI assistant, collaborating with other assistants."
" Use the provided tools to progress towards answering the question." " Use the provided tools to progress towards answering the question."
" If you are unable to fully answer, that's OK; another assistant with different tools" " If you are unable to fully answer, that's OK; another assistant with different tools"
" will help where you left off. Execute what you can to make progress." " will help where you left off. Execute what you can to make progress."
" If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable," " If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable,"
" prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop." " prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop."
" You have access to the following tools: {tool_names}.\n{system_message}" " You have access to the following tools: {tool_names}.\n{system_message}"
"For your reference, the current date is {current_date}. The current company we want to analyze is {ticker}", "For your reference, the current date is {current_date}. The current company we want to analyze is {ticker}",
), ),
MessagesPlaceholder(variable_name="messages"), MessagesPlaceholder(variable_name="messages"),
] ]
) )
prompt = prompt.partial(system_message=system_message) prompt = prompt.partial(system_message=system_message)
prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools])) prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools]))
prompt = prompt.partial(current_date=current_date) prompt = prompt.partial(current_date=current_date)
prompt = prompt.partial(ticker=ticker) prompt = prompt.partial(ticker=ticker)
chain = prompt | llm.bind_tools(tools) chain = prompt | llm.bind_tools(tools)
result = chain.invoke(state["messages"]) result = chain.invoke(state["messages"])
report = "" report = ""
if len(result.tool_calls) == 0: if len(result.tool_calls) == 0:
report = result.content report = result.content
return { return {
"messages": [result], "messages": [result],
"sentiment_report": report, "sentiment_report": report,
} }
return social_media_analyst_node return social_media_analyst_node

View File

@ -1,57 +1,57 @@
import time import time
import json import json
def create_research_manager(llm, memory): def create_research_manager(llm, memory):
def research_manager_node(state) -> dict: def research_manager_node(state) -> dict:
history = state["investment_debate_state"].get("history", "") history = state["investment_debate_state"].get("history", "")
market_research_report = state["market_report"] market_research_report = state["market_report"]
sentiment_report = state["sentiment_report"] sentiment_report = state["sentiment_report"]
news_report = state["news_report"] news_report = state["news_report"]
fundamentals_report = state["fundamentals_report"] fundamentals_report = state["fundamentals_report"]
investment_debate_state = state["investment_debate_state"] investment_debate_state = state["investment_debate_state"]
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}" curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
past_memories = memory.get_memories(curr_situation, n_matches=2) past_memories = memory.get_memories(curr_situation, n_matches=2)
past_memory_str = "" past_memory_str = ""
for i, rec in enumerate(past_memories, 1): for i, rec in enumerate(past_memories, 1):
past_memory_str += rec["recommendation"] + "\n\n" past_memory_str += rec["recommendation"] + "\n\n"
prompt = f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요) prompt = f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)
As the portfolio manager and debate facilitator, your role is to critically evaluate this round of debate and make a definitive decision: align with the bear analyst, the bull analyst, or choose Hold only if it is strongly justified based on the arguments presented. As the portfolio manager and debate facilitator, your role is to critically evaluate this round of debate and make a definitive decision: align with the bear analyst, the bull analyst, or choose Hold only if it is strongly justified based on the arguments presented.
Summarize the key points from both sides concisely, focusing on the most compelling evidence or reasoning. Your recommendationBuy, Sell, or Holdmust be clear and actionable. Avoid defaulting to Hold simply because both sides have valid points; commit to a stance grounded in the debate's strongest arguments. Summarize the key points from both sides concisely, focusing on the most compelling evidence or reasoning. Your recommendationBuy, Sell, or Holdmust be clear and actionable. Avoid defaulting to Hold simply because both sides have valid points; commit to a stance grounded in the debate's strongest arguments.
Additionally, develop a detailed investment plan for the trader. This should include: Additionally, develop a detailed investment plan for the trader. This should include:
Your Recommendation: A decisive stance supported by the most convincing arguments. Your Recommendation: A decisive stance supported by the most convincing arguments.
Rationale: An explanation of why these arguments lead to your conclusion. Rationale: An explanation of why these arguments lead to your conclusion.
Strategic Actions: Concrete steps for implementing the recommendation. Strategic Actions: Concrete steps for implementing the recommendation.
Take into account your past mistakes on similar situations. Use these insights to refine your decision-making and ensure you are learning and improving. Present your analysis conversationally, as if speaking naturally, without special formatting. Take into account your past mistakes on similar situations. Use these insights to refine your decision-making and ensure you are learning and improving. Present your analysis conversationally, as if speaking naturally, without special formatting.
Here are your past reflections on mistakes: Here are your past reflections on mistakes:
\"{past_memory_str}\" \"{past_memory_str}\"
Here is the debate: Here is the debate:
Debate History: Debate History:
{history}""" {history}"""
response = llm.invoke(prompt) response = llm.invoke(prompt)
new_investment_debate_state = { new_investment_debate_state = {
"judge_decision": response.content, "judge_decision": response.content,
"history": investment_debate_state.get("history", ""), "history": investment_debate_state.get("history", ""),
"bear_history": investment_debate_state.get("bear_history", ""), "bear_history": investment_debate_state.get("bear_history", ""),
"bull_history": investment_debate_state.get("bull_history", ""), "bull_history": investment_debate_state.get("bull_history", ""),
"current_response": response.content, "current_response": response.content,
"count": investment_debate_state["count"], "count": investment_debate_state["count"],
} }
return { return {
"investment_debate_state": new_investment_debate_state, "investment_debate_state": new_investment_debate_state,
"investment_plan": response.content, "investment_plan": response.content,
} }
return research_manager_node return research_manager_node

View File

@ -1,68 +1,68 @@
import time import time
import json import json
def create_risk_manager(llm, memory): def create_risk_manager(llm, memory):
def risk_manager_node(state) -> dict: def risk_manager_node(state) -> dict:
company_name = state["company_of_interest"] company_name = state["company_of_interest"]
history = state["risk_debate_state"]["history"] history = state["risk_debate_state"]["history"]
risk_debate_state = state["risk_debate_state"] risk_debate_state = state["risk_debate_state"]
market_research_report = state["market_report"] market_research_report = state["market_report"]
news_report = state["news_report"] news_report = state["news_report"]
fundamentals_report = state["news_report"] fundamentals_report = state["news_report"]
sentiment_report = state["sentiment_report"] sentiment_report = state["sentiment_report"]
trader_plan = state["investment_plan"] trader_plan = state["investment_plan"]
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}" curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
past_memories = memory.get_memories(curr_situation, n_matches=2) past_memories = memory.get_memories(curr_situation, n_matches=2)
past_memory_str = "" past_memory_str = ""
for i, rec in enumerate(past_memories, 1): for i, rec in enumerate(past_memories, 1):
past_memory_str += rec["recommendation"] + "\n\n" past_memory_str += rec["recommendation"] + "\n\n"
prompt = f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요) prompt = f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)
As the Risk Management Judge and Debate Facilitator, your goal is to evaluate the debate between three risk analystsRisky, Neutral, and Safe/Conservativeand determine the best course of action for the trader. Your decision must result in a clear recommendation: Buy, Sell, or Hold. Choose Hold only if strongly justified by specific arguments, not as a fallback when all sides seem valid. Strive for clarity and decisiveness. As the Risk Management Judge and Debate Facilitator, your goal is to evaluate the debate between three risk analystsRisky, Neutral, and Safe/Conservativeand determine the best course of action for the trader. Your decision must result in a clear recommendation: Buy, Sell, or Hold. Choose Hold only if strongly justified by specific arguments, not as a fallback when all sides seem valid. Strive for clarity and decisiveness.
Guidelines for Decision-Making: Guidelines for Decision-Making:
1. **Summarize Key Arguments**: Extract the strongest points from each analyst, focusing on relevance to the context. 1. **Summarize Key Arguments**: Extract the strongest points from each analyst, focusing on relevance to the context.
2. **Provide Rationale**: Support your recommendation with direct quotes and counterarguments from the debate. 2. **Provide Rationale**: Support your recommendation with direct quotes and counterarguments from the debate.
3. **Refine the Trader's Plan**: Start with the trader's original plan, **{trader_plan}**, and adjust it based on the analysts' insights. 3. **Refine the Trader's Plan**: Start with the trader's original plan, **{trader_plan}**, and adjust it based on the analysts' insights.
4. **Learn from Past Mistakes**: Use lessons from **{past_memory_str}** to address prior misjudgments and improve the decision you are making now to make sure you don't make a wrong BUY/SELL/HOLD call that loses money. 4. **Learn from Past Mistakes**: Use lessons from **{past_memory_str}** to address prior misjudgments and improve the decision you are making now to make sure you don't make a wrong BUY/SELL/HOLD call that loses money.
Deliverables: Deliverables:
- A clear and actionable recommendation: Buy, Sell, or Hold. - A clear and actionable recommendation: Buy, Sell, or Hold.
- Detailed reasoning anchored in the debate and past reflections. - Detailed reasoning anchored in the debate and past reflections.
--- ---
**Analysts Debate History:** **Analysts Debate History:**
{history} {history}
--- ---
Focus on actionable insights and continuous improvement. Build on past lessons, critically evaluate all perspectives, and ensure each decision advances better outcomes.""" Focus on actionable insights and continuous improvement. Build on past lessons, critically evaluate all perspectives, and ensure each decision advances better outcomes."""
response = llm.invoke(prompt) response = llm.invoke(prompt)
new_risk_debate_state = { new_risk_debate_state = {
"judge_decision": response.content, "judge_decision": response.content,
"history": risk_debate_state["history"], "history": risk_debate_state["history"],
"risky_history": risk_debate_state["risky_history"], "risky_history": risk_debate_state["risky_history"],
"safe_history": risk_debate_state["safe_history"], "safe_history": risk_debate_state["safe_history"],
"neutral_history": risk_debate_state["neutral_history"], "neutral_history": risk_debate_state["neutral_history"],
"latest_speaker": "Judge", "latest_speaker": "Judge",
"current_risky_response": risk_debate_state["current_risky_response"], "current_risky_response": risk_debate_state["current_risky_response"],
"current_safe_response": risk_debate_state["current_safe_response"], "current_safe_response": risk_debate_state["current_safe_response"],
"current_neutral_response": risk_debate_state["current_neutral_response"], "current_neutral_response": risk_debate_state["current_neutral_response"],
"count": risk_debate_state["count"], "count": risk_debate_state["count"],
} }
return { return {
"risk_debate_state": new_risk_debate_state, "risk_debate_state": new_risk_debate_state,
"final_trade_decision": response.content, "final_trade_decision": response.content,
} }
return risk_manager_node return risk_manager_node

View File

@ -1,63 +1,63 @@
from langchain_core.messages import AIMessage from langchain_core.messages import AIMessage
import time import time
import json import json
def create_bear_researcher(llm, memory): def create_bear_researcher(llm, memory):
def bear_node(state) -> dict: def bear_node(state) -> dict:
investment_debate_state = state["investment_debate_state"] investment_debate_state = state["investment_debate_state"]
history = investment_debate_state.get("history", "") history = investment_debate_state.get("history", "")
bear_history = investment_debate_state.get("bear_history", "") bear_history = investment_debate_state.get("bear_history", "")
current_response = investment_debate_state.get("current_response", "") current_response = investment_debate_state.get("current_response", "")
market_research_report = state["market_report"] market_research_report = state["market_report"]
sentiment_report = state["sentiment_report"] sentiment_report = state["sentiment_report"]
news_report = state["news_report"] news_report = state["news_report"]
fundamentals_report = state["fundamentals_report"] fundamentals_report = state["fundamentals_report"]
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}" curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
past_memories = memory.get_memories(curr_situation, n_matches=2) past_memories = memory.get_memories(curr_situation, n_matches=2)
past_memory_str = "" past_memory_str = ""
for i, rec in enumerate(past_memories, 1): for i, rec in enumerate(past_memories, 1):
past_memory_str += rec["recommendation"] + "\n\n" past_memory_str += rec["recommendation"] + "\n\n"
prompt = f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요) prompt = f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)
You are a Bear Analyst making the case against investing in the stock. Your goal is to present a well-reasoned argument emphasizing risks, challenges, and negative indicators. Leverage the provided research and data to highlight potential downsides and counter bullish arguments effectively. You are a Bear Analyst making the case against investing in the stock. Your goal is to present a well-reasoned argument emphasizing risks, challenges, and negative indicators. Leverage the provided research and data to highlight potential downsides and counter bullish arguments effectively.
Key points to focus on: Key points to focus on:
- Risks and Challenges: Highlight factors like market saturation, financial instability, or macroeconomic threats that could hinder the stock's performance. - Risks and Challenges: Highlight factors like market saturation, financial instability, or macroeconomic threats that could hinder the stock's performance.
- Competitive Weaknesses: Emphasize vulnerabilities such as weaker market positioning, declining innovation, or threats from competitors. - Competitive Weaknesses: Emphasize vulnerabilities such as weaker market positioning, declining innovation, or threats from competitors.
- Negative Indicators: Use evidence from financial data, market trends, or recent adverse news to support your position. - Negative Indicators: Use evidence from financial data, market trends, or recent adverse news to support your position.
- Bull Counterpoints: Critically analyze the bull argument with specific data and sound reasoning, exposing weaknesses or over-optimistic assumptions. - Bull Counterpoints: Critically analyze the bull argument with specific data and sound reasoning, exposing weaknesses or over-optimistic assumptions.
- Engagement: Present your argument in a conversational style, directly engaging with the bull analyst's points and debating effectively rather than simply listing facts. - Engagement: Present your argument in a conversational style, directly engaging with the bull analyst's points and debating effectively rather than simply listing facts.
Resources available: Resources available:
Market research report: {market_research_report} Market research report: {market_research_report}
Social media sentiment report: {sentiment_report} Social media sentiment report: {sentiment_report}
Latest world affairs news: {news_report} Latest world affairs news: {news_report}
Company fundamentals report: {fundamentals_report} Company fundamentals report: {fundamentals_report}
Conversation history of the debate: {history} Conversation history of the debate: {history}
Last bull argument: {current_response} Last bull argument: {current_response}
Reflections from similar situations and lessons learned: {past_memory_str} Reflections from similar situations and lessons learned: {past_memory_str}
Use this information to deliver a compelling bear argument, refute the bull's claims, and engage in a dynamic debate that demonstrates the risks and weaknesses of investing in the stock. You must also address reflections and learn from lessons and mistakes you made in the past. Use this information to deliver a compelling bear argument, refute the bull's claims, and engage in a dynamic debate that demonstrates the risks and weaknesses of investing in the stock. You must also address reflections and learn from lessons and mistakes you made in the past.
""" """
response = llm.invoke(prompt) response = llm.invoke(prompt)
argument = f"Bear Analyst: {response.content}" argument = f"Bear Analyst: {response.content}"
new_investment_debate_state = { new_investment_debate_state = {
"history": history + "\n" + argument, "history": history + "\n" + argument,
"bear_history": bear_history + "\n" + argument, "bear_history": bear_history + "\n" + argument,
"bull_history": investment_debate_state.get("bull_history", ""), "bull_history": investment_debate_state.get("bull_history", ""),
"current_response": argument, "current_response": argument,
"count": investment_debate_state["count"] + 1, "count": investment_debate_state["count"] + 1,
} }
return {"investment_debate_state": new_investment_debate_state} return {"investment_debate_state": new_investment_debate_state}
return bear_node return bear_node

View File

@ -1,61 +1,61 @@
from langchain_core.messages import AIMessage from langchain_core.messages import AIMessage
import time import time
import json import json
def create_bull_researcher(llm, memory): def create_bull_researcher(llm, memory):
def bull_node(state) -> dict: def bull_node(state) -> dict:
investment_debate_state = state["investment_debate_state"] investment_debate_state = state["investment_debate_state"]
history = investment_debate_state.get("history", "") history = investment_debate_state.get("history", "")
bull_history = investment_debate_state.get("bull_history", "") bull_history = investment_debate_state.get("bull_history", "")
current_response = investment_debate_state.get("current_response", "") current_response = investment_debate_state.get("current_response", "")
market_research_report = state["market_report"] market_research_report = state["market_report"]
sentiment_report = state["sentiment_report"] sentiment_report = state["sentiment_report"]
news_report = state["news_report"] news_report = state["news_report"]
fundamentals_report = state["fundamentals_report"] fundamentals_report = state["fundamentals_report"]
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}" curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
past_memories = memory.get_memories(curr_situation, n_matches=2) past_memories = memory.get_memories(curr_situation, n_matches=2)
past_memory_str = "" past_memory_str = ""
for i, rec in enumerate(past_memories, 1): for i, rec in enumerate(past_memories, 1):
past_memory_str += rec["recommendation"] + "\n\n" past_memory_str += rec["recommendation"] + "\n\n"
prompt = f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요) prompt = f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)
You are a Bull Analyst advocating for investing in the stock. Your task is to build a strong, evidence-based case emphasizing growth potential, competitive advantages, and positive market indicators. Leverage the provided research and data to address concerns and counter bearish arguments effectively. You are a Bull Analyst advocating for investing in the stock. Your task is to build a strong, evidence-based case emphasizing growth potential, competitive advantages, and positive market indicators. Leverage the provided research and data to address concerns and counter bearish arguments effectively.
Key points to focus on: Key points to focus on:
- Growth Potential: Highlight the company's market opportunities, revenue projections, and scalability. - Growth Potential: Highlight the company's market opportunities, revenue projections, and scalability.
- Competitive Advantages: Emphasize factors like unique products, strong branding, or dominant market positioning. - Competitive Advantages: Emphasize factors like unique products, strong branding, or dominant market positioning.
- Positive Indicators: Use financial health, industry trends, and recent positive news as evidence. - Positive Indicators: Use financial health, industry trends, and recent positive news as evidence.
- Bear Counterpoints: Critically analyze the bear argument with specific data and sound reasoning, addressing concerns thoroughly and showing why the bull perspective holds stronger merit. - Bear Counterpoints: Critically analyze the bear argument with specific data and sound reasoning, addressing concerns thoroughly and showing why the bull perspective holds stronger merit.
- Engagement: Present your argument in a conversational style, engaging directly with the bear analyst's points and debating effectively rather than just listing data. - Engagement: Present your argument in a conversational style, engaging directly with the bear analyst's points and debating effectively rather than just listing data.
Resources available: Resources available:
Market research report: {market_research_report} Market research report: {market_research_report}
Social media sentiment report: {sentiment_report} Social media sentiment report: {sentiment_report}
Latest world affairs news: {news_report} Latest world affairs news: {news_report}
Company fundamentals report: {fundamentals_report} Company fundamentals report: {fundamentals_report}
Conversation history of the debate: {history} Conversation history of the debate: {history}
Last bear argument: {current_response} Last bear argument: {current_response}
Reflections from similar situations and lessons learned: {past_memory_str} Reflections from similar situations and lessons learned: {past_memory_str}
Use this information to deliver a compelling bull argument, refute the bear's concerns, and engage in a dynamic debate that demonstrates the strengths of the bull position. You must also address reflections and learn from lessons and mistakes you made in the past. Use this information to deliver a compelling bull argument, refute the bear's concerns, and engage in a dynamic debate that demonstrates the strengths of the bull position. You must also address reflections and learn from lessons and mistakes you made in the past.
""" """
response = llm.invoke(prompt) response = llm.invoke(prompt)
argument = f"Bull Analyst: {response.content}" argument = f"Bull Analyst: {response.content}"
new_investment_debate_state = { new_investment_debate_state = {
"history": history + "\n" + argument, "history": history + "\n" + argument,
"bull_history": bull_history + "\n" + argument, "bull_history": bull_history + "\n" + argument,
"bear_history": investment_debate_state.get("bear_history", ""), "bear_history": investment_debate_state.get("bear_history", ""),
"current_response": argument, "current_response": argument,
"count": investment_debate_state["count"] + 1, "count": investment_debate_state["count"] + 1,
} }
return {"investment_debate_state": new_investment_debate_state} return {"investment_debate_state": new_investment_debate_state}
return bull_node return bull_node

View File

@ -1,57 +1,57 @@
import time import time
import json import json
def create_risky_debator(llm): def create_risky_debator(llm):
def risky_node(state) -> dict: def risky_node(state) -> dict:
risk_debate_state = state["risk_debate_state"] risk_debate_state = state["risk_debate_state"]
history = risk_debate_state.get("history", "") history = risk_debate_state.get("history", "")
risky_history = risk_debate_state.get("risky_history", "") risky_history = risk_debate_state.get("risky_history", "")
current_safe_response = risk_debate_state.get("current_safe_response", "") current_safe_response = risk_debate_state.get("current_safe_response", "")
current_neutral_response = risk_debate_state.get("current_neutral_response", "") current_neutral_response = risk_debate_state.get("current_neutral_response", "")
market_research_report = state["market_report"] market_research_report = state["market_report"]
sentiment_report = state["sentiment_report"] sentiment_report = state["sentiment_report"]
news_report = state["news_report"] news_report = state["news_report"]
fundamentals_report = state["fundamentals_report"] fundamentals_report = state["fundamentals_report"]
trader_decision = state["trader_investment_plan"] trader_decision = state["trader_investment_plan"]
prompt = f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요) prompt = f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)
As the Risky Risk Analyst, your role is to actively champion high-reward, high-risk opportunities, emphasizing bold strategies and competitive advantages. When evaluating the trader's decision or plan, focus intently on the potential upside, growth potential, and innovative benefits—even when these come with elevated risk. Use the provided market data and sentiment analysis to strengthen your arguments and challenge the opposing views. Specifically, respond directly to each point made by the conservative and neutral analysts, countering with data-driven rebuttals and persuasive reasoning. Highlight where their caution might miss critical opportunities or where their assumptions may be overly conservative. Here is the trader's decision: As the Risky Risk Analyst, your role is to actively champion high-reward, high-risk opportunities, emphasizing bold strategies and competitive advantages. When evaluating the trader's decision or plan, focus intently on the potential upside, growth potential, and innovative benefits—even when these come with elevated risk. Use the provided market data and sentiment analysis to strengthen your arguments and challenge the opposing views. Specifically, respond directly to each point made by the conservative and neutral analysts, countering with data-driven rebuttals and persuasive reasoning. Highlight where their caution might miss critical opportunities or where their assumptions may be overly conservative. Here is the trader's decision:
{trader_decision} {trader_decision}
Your task is to create a compelling case for the trader's decision by questioning and critiquing the conservative and neutral stances to demonstrate why your high-reward perspective offers the best path forward. Incorporate insights from the following sources into your arguments: Your task is to create a compelling case for the trader's decision by questioning and critiquing the conservative and neutral stances to demonstrate why your high-reward perspective offers the best path forward. Incorporate insights from the following sources into your arguments:
Market Research Report: {market_research_report} Market Research Report: {market_research_report}
Social Media Sentiment Report: {sentiment_report} Social Media Sentiment Report: {sentiment_report}
Latest World Affairs Report: {news_report} Latest World Affairs Report: {news_report}
Company Fundamentals Report: {fundamentals_report} Company Fundamentals Report: {fundamentals_report}
Here is the current conversation history: {history} Here are the last arguments from the conservative analyst: {current_safe_response} Here are the last arguments from the neutral analyst: {current_neutral_response}. If there are no responses from the other viewpoints, do not halluncinate and just present your point. Here is the current conversation history: {history} Here are the last arguments from the conservative analyst: {current_safe_response} Here are the last arguments from the neutral analyst: {current_neutral_response}. If there are no responses from the other viewpoints, do not halluncinate and just present your point.
Engage actively by addressing any specific concerns raised, refuting the weaknesses in their logic, and asserting the benefits of risk-taking to outpace market norms. Maintain a focus on debating and persuading, not just presenting data. Challenge each counterpoint to underscore why a high-risk approach is optimal. Output conversationally as if you are speaking without any special formatting.""" Engage actively by addressing any specific concerns raised, refuting the weaknesses in their logic, and asserting the benefits of risk-taking to outpace market norms. Maintain a focus on debating and persuading, not just presenting data. Challenge each counterpoint to underscore why a high-risk approach is optimal. Output conversationally as if you are speaking without any special formatting."""
response = llm.invoke(prompt) response = llm.invoke(prompt)
argument = f"Risky Analyst: {response.content}" argument = f"Risky Analyst: {response.content}"
new_risk_debate_state = { new_risk_debate_state = {
"history": history + "\n" + argument, "history": history + "\n" + argument,
"risky_history": risky_history + "\n" + argument, "risky_history": risky_history + "\n" + argument,
"safe_history": risk_debate_state.get("safe_history", ""), "safe_history": risk_debate_state.get("safe_history", ""),
"neutral_history": risk_debate_state.get("neutral_history", ""), "neutral_history": risk_debate_state.get("neutral_history", ""),
"latest_speaker": "Risky", "latest_speaker": "Risky",
"current_risky_response": argument, "current_risky_response": argument,
"current_safe_response": risk_debate_state.get("current_safe_response", ""), "current_safe_response": risk_debate_state.get("current_safe_response", ""),
"current_neutral_response": risk_debate_state.get( "current_neutral_response": risk_debate_state.get(
"current_neutral_response", "" "current_neutral_response", ""
), ),
"count": risk_debate_state["count"] + 1, "count": risk_debate_state["count"] + 1,
} }
return {"risk_debate_state": new_risk_debate_state} return {"risk_debate_state": new_risk_debate_state}
return risky_node return risky_node

View File

@ -1,60 +1,60 @@
from langchain_core.messages import AIMessage from langchain_core.messages import AIMessage
import time import time
import json import json
def create_safe_debator(llm): def create_safe_debator(llm):
def safe_node(state) -> dict: def safe_node(state) -> dict:
risk_debate_state = state["risk_debate_state"] risk_debate_state = state["risk_debate_state"]
history = risk_debate_state.get("history", "") history = risk_debate_state.get("history", "")
safe_history = risk_debate_state.get("safe_history", "") safe_history = risk_debate_state.get("safe_history", "")
current_risky_response = risk_debate_state.get("current_risky_response", "") current_risky_response = risk_debate_state.get("current_risky_response", "")
current_neutral_response = risk_debate_state.get("current_neutral_response", "") current_neutral_response = risk_debate_state.get("current_neutral_response", "")
market_research_report = state["market_report"] market_research_report = state["market_report"]
sentiment_report = state["sentiment_report"] sentiment_report = state["sentiment_report"]
news_report = state["news_report"] news_report = state["news_report"]
fundamentals_report = state["fundamentals_report"] fundamentals_report = state["fundamentals_report"]
trader_decision = state["trader_investment_plan"] trader_decision = state["trader_investment_plan"]
prompt = f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요) prompt = f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)
As the Safe/Conservative Risk Analyst, your primary objective is to protect assets, minimize volatility, and ensure steady, reliable growth. You prioritize stability, security, and risk mitigation, carefully assessing potential losses, economic downturns, and market volatility. When evaluating the trader's decision or plan, critically examine high-risk elements, pointing out where the decision may expose the firm to undue risk and where more cautious alternatives could secure long-term gains. Here is the trader's decision: As the Safe/Conservative Risk Analyst, your primary objective is to protect assets, minimize volatility, and ensure steady, reliable growth. You prioritize stability, security, and risk mitigation, carefully assessing potential losses, economic downturns, and market volatility. When evaluating the trader's decision or plan, critically examine high-risk elements, pointing out where the decision may expose the firm to undue risk and where more cautious alternatives could secure long-term gains. Here is the trader's decision:
{trader_decision} {trader_decision}
Your task is to actively counter the arguments of the Risky and Neutral Analysts, highlighting where their views may overlook potential threats or fail to prioritize sustainability. Respond directly to their points, drawing from the following data sources to build a convincing case for a low-risk approach adjustment to the trader's decision: Your task is to actively counter the arguments of the Risky and Neutral Analysts, highlighting where their views may overlook potential threats or fail to prioritize sustainability. Respond directly to their points, drawing from the following data sources to build a convincing case for a low-risk approach adjustment to the trader's decision:
Market Research Report: {market_research_report} Market Research Report: {market_research_report}
Social Media Sentiment Report: {sentiment_report} Social Media Sentiment Report: {sentiment_report}
Latest World Affairs Report: {news_report} Latest World Affairs Report: {news_report}
Company Fundamentals Report: {fundamentals_report} Company Fundamentals Report: {fundamentals_report}
Here is the current conversation history: {history} Here is the last response from the risky analyst: {current_risky_response} Here is the last response from the neutral analyst: {current_neutral_response}. If there are no responses from the other viewpoints, do not halluncinate and just present your point. Here is the current conversation history: {history} Here is the last response from the risky analyst: {current_risky_response} Here is the last response from the neutral analyst: {current_neutral_response}. If there are no responses from the other viewpoints, do not halluncinate and just present your point.
Engage by questioning their optimism and emphasizing the potential downsides they may have overlooked. Address each of their counterpoints to showcase why a conservative stance is ultimately the safest path for the firm's assets. Focus on debating and critiquing their arguments to demonstrate the strength of a low-risk strategy over their approaches. Output conversationally as if you are speaking without any special formatting.""" Engage by questioning their optimism and emphasizing the potential downsides they may have overlooked. Address each of their counterpoints to showcase why a conservative stance is ultimately the safest path for the firm's assets. Focus on debating and critiquing their arguments to demonstrate the strength of a low-risk strategy over their approaches. Output conversationally as if you are speaking without any special formatting."""
response = llm.invoke(prompt) response = llm.invoke(prompt)
argument = f"Safe Analyst: {response.content}" argument = f"Safe Analyst: {response.content}"
new_risk_debate_state = { new_risk_debate_state = {
"history": history + "\n" + argument, "history": history + "\n" + argument,
"risky_history": risk_debate_state.get("risky_history", ""), "risky_history": risk_debate_state.get("risky_history", ""),
"safe_history": safe_history + "\n" + argument, "safe_history": safe_history + "\n" + argument,
"neutral_history": risk_debate_state.get("neutral_history", ""), "neutral_history": risk_debate_state.get("neutral_history", ""),
"latest_speaker": "Safe", "latest_speaker": "Safe",
"current_risky_response": risk_debate_state.get( "current_risky_response": risk_debate_state.get(
"current_risky_response", "" "current_risky_response", ""
), ),
"current_safe_response": argument, "current_safe_response": argument,
"current_neutral_response": risk_debate_state.get( "current_neutral_response": risk_debate_state.get(
"current_neutral_response", "" "current_neutral_response", ""
), ),
"count": risk_debate_state["count"] + 1, "count": risk_debate_state["count"] + 1,
} }
return {"risk_debate_state": new_risk_debate_state} return {"risk_debate_state": new_risk_debate_state}
return safe_node return safe_node

View File

@ -1,57 +1,57 @@
import time import time
import json import json
def create_neutral_debator(llm): def create_neutral_debator(llm):
def neutral_node(state) -> dict: def neutral_node(state) -> dict:
risk_debate_state = state["risk_debate_state"] risk_debate_state = state["risk_debate_state"]
history = risk_debate_state.get("history", "") history = risk_debate_state.get("history", "")
neutral_history = risk_debate_state.get("neutral_history", "") neutral_history = risk_debate_state.get("neutral_history", "")
current_risky_response = risk_debate_state.get("current_risky_response", "") current_risky_response = risk_debate_state.get("current_risky_response", "")
current_safe_response = risk_debate_state.get("current_safe_response", "") current_safe_response = risk_debate_state.get("current_safe_response", "")
market_research_report = state["market_report"] market_research_report = state["market_report"]
sentiment_report = state["sentiment_report"] sentiment_report = state["sentiment_report"]
news_report = state["news_report"] news_report = state["news_report"]
fundamentals_report = state["fundamentals_report"] fundamentals_report = state["fundamentals_report"]
trader_decision = state["trader_investment_plan"] trader_decision = state["trader_investment_plan"]
prompt = f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요) prompt = f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)
As the Neutral Risk Analyst, your role is to provide a balanced perspective, weighing both the potential benefits and risks of the trader's decision or plan. You prioritize a well-rounded approach, evaluating the upsides and downsides while factoring in broader market trends, potential economic shifts, and diversification strategies.Here is the trader's decision: As the Neutral Risk Analyst, your role is to provide a balanced perspective, weighing both the potential benefits and risks of the trader's decision or plan. You prioritize a well-rounded approach, evaluating the upsides and downsides while factoring in broader market trends, potential economic shifts, and diversification strategies.Here is the trader's decision:
{trader_decision} {trader_decision}
Your task is to challenge both the Risky and Safe Analysts, pointing out where each perspective may be overly optimistic or overly cautious. Use insights from the following data sources to support a moderate, sustainable strategy to adjust the trader's decision: Your task is to challenge both the Risky and Safe Analysts, pointing out where each perspective may be overly optimistic or overly cautious. Use insights from the following data sources to support a moderate, sustainable strategy to adjust the trader's decision:
Market Research Report: {market_research_report} Market Research Report: {market_research_report}
Social Media Sentiment Report: {sentiment_report} Social Media Sentiment Report: {sentiment_report}
Latest World Affairs Report: {news_report} Latest World Affairs Report: {news_report}
Company Fundamentals Report: {fundamentals_report} Company Fundamentals Report: {fundamentals_report}
Here is the current conversation history: {history} Here is the last response from the risky analyst: {current_risky_response} Here is the last response from the safe analyst: {current_safe_response}. If there are no responses from the other viewpoints, do not halluncinate and just present your point. Here is the current conversation history: {history} Here is the last response from the risky analyst: {current_risky_response} Here is the last response from the safe analyst: {current_safe_response}. If there are no responses from the other viewpoints, do not halluncinate and just present your point.
Engage actively by analyzing both sides critically, addressing weaknesses in the risky and conservative arguments to advocate for a more balanced approach. Challenge each of their points to illustrate why a moderate risk strategy might offer the best of both worlds, providing growth potential while safeguarding against extreme volatility. Focus on debating rather than simply presenting data, aiming to show that a balanced view can lead to the most reliable outcomes. Output conversationally as if you are speaking without any special formatting.""" Engage actively by analyzing both sides critically, addressing weaknesses in the risky and conservative arguments to advocate for a more balanced approach. Challenge each of their points to illustrate why a moderate risk strategy might offer the best of both worlds, providing growth potential while safeguarding against extreme volatility. Focus on debating rather than simply presenting data, aiming to show that a balanced view can lead to the most reliable outcomes. Output conversationally as if you are speaking without any special formatting."""
response = llm.invoke(prompt) response = llm.invoke(prompt)
argument = f"Neutral Analyst: {response.content}" argument = f"Neutral Analyst: {response.content}"
new_risk_debate_state = { new_risk_debate_state = {
"history": history + "\n" + argument, "history": history + "\n" + argument,
"risky_history": risk_debate_state.get("risky_history", ""), "risky_history": risk_debate_state.get("risky_history", ""),
"safe_history": risk_debate_state.get("safe_history", ""), "safe_history": risk_debate_state.get("safe_history", ""),
"neutral_history": neutral_history + "\n" + argument, "neutral_history": neutral_history + "\n" + argument,
"latest_speaker": "Neutral", "latest_speaker": "Neutral",
"current_risky_response": risk_debate_state.get( "current_risky_response": risk_debate_state.get(
"current_risky_response", "" "current_risky_response", ""
), ),
"current_safe_response": risk_debate_state.get("current_safe_response", ""), "current_safe_response": risk_debate_state.get("current_safe_response", ""),
"current_neutral_response": argument, "current_neutral_response": argument,
"count": risk_debate_state["count"] + 1, "count": risk_debate_state["count"] + 1,
} }
return {"risk_debate_state": new_risk_debate_state} return {"risk_debate_state": new_risk_debate_state}
return neutral_node return neutral_node

View File

@ -1,45 +1,45 @@
import functools import functools
import time import time
import json import json
def create_trader(llm, memory): def create_trader(llm, memory):
def trader_node(state, name): def trader_node(state, name):
company_name = state["company_of_interest"] company_name = state["company_of_interest"]
investment_plan = state["investment_plan"] investment_plan = state["investment_plan"]
market_research_report = state["market_report"] market_research_report = state["market_report"]
sentiment_report = state["sentiment_report"] sentiment_report = state["sentiment_report"]
news_report = state["news_report"] news_report = state["news_report"]
fundamentals_report = state["fundamentals_report"] fundamentals_report = state["fundamentals_report"]
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}" curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
past_memories = memory.get_memories(curr_situation, n_matches=2) past_memories = memory.get_memories(curr_situation, n_matches=2)
past_memory_str = "" past_memory_str = ""
for i, rec in enumerate(past_memories, 1): for i, rec in enumerate(past_memories, 1):
past_memory_str += rec["recommendation"] + "\n\n" past_memory_str += rec["recommendation"] + "\n\n"
context = { context = {
"role": "user", "role": "user",
"content": f"Based on a comprehensive analysis by a team of analysts, here is an investment plan tailored for {company_name}. This plan incorporates insights from current technical market trends, macroeconomic indicators, and social media sentiment. Use this plan as a foundation for evaluating your next trading decision.\n\nProposed Investment Plan: {investment_plan}\n\nLeverage these insights to make an informed and strategic decision.", "content": f"Based on a comprehensive analysis by a team of analysts, here is an investment plan tailored for {company_name}. This plan incorporates insights from current technical market trends, macroeconomic indicators, and social media sentiment. Use this plan as a foundation for evaluating your next trading decision.\n\nProposed Investment Plan: {investment_plan}\n\nLeverage these insights to make an informed and strategic decision.",
} }
messages = [ messages = [
{ {
"role": "system", "role": "system",
"content": f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요) "content": f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)
You are a trading agent analyzing market data to make investment decisions. Based on your analysis, provide a specific recommendation to buy, sell, or hold. End with a firm decision and always conclude your response with 'FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL**' to confirm your recommendation. Do not forget to utilize lessons from past decisions to learn from your mistakes. Here is some reflections from similar situatiosn you traded in and the lessons learned: {past_memory_str}""", You are a trading agent analyzing market data to make investment decisions. Based on your analysis, provide a specific recommendation to buy, sell, or hold. End with a firm decision and always conclude your response with 'FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL**' to confirm your recommendation. Do not forget to utilize lessons from past decisions to learn from your mistakes. Here is some reflections from similar situatiosn you traded in and the lessons learned: {past_memory_str}""",
}, },
context, context,
] ]
result = llm.invoke(messages) result = llm.invoke(messages)
return { return {
"messages": [result], "messages": [result],
"trader_investment_plan": result.content, "trader_investment_plan": result.content,
"sender": name, "sender": name,
} }
return functools.partial(trader_node, name="Trader") return functools.partial(trader_node, name="Trader")

View File

@ -1,20 +1,20 @@
from .embedding_providers import ( from .embedding_providers import (
EmbeddingProvider, EmbeddingProvider,
OpenAIEmbeddingProvider, OpenAIEmbeddingProvider,
GeminiEmbeddingProvider, GeminiEmbeddingProvider,
OllamaEmbeddingProvider OllamaEmbeddingProvider
) )
from typing import Any from typing import Any
class EmbeddingProviderFactory: class EmbeddingProviderFactory:
@staticmethod @staticmethod
def create_provider(config : dict[str, Any])->EmbeddingProvider: def create_provider(config : dict[str, Any])->EmbeddingProvider:
backend_url = config["backend_url"] backend_url = config["backend_url"]
if "generativelanguage.googleapis.com" in backend_url: if "generativelanguage.googleapis.com" in backend_url:
return GeminiEmbeddingProvider(backend_url) return GeminiEmbeddingProvider(backend_url)
elif "localhost:11434" in backend_url: elif "localhost:11434" in backend_url:
return OllamaEmbeddingProvider(backend_url) return OllamaEmbeddingProvider(backend_url)
else: else:
return OpenAIEmbeddingProvider(backend_url) return OpenAIEmbeddingProvider(backend_url)

View File

@ -1,66 +1,66 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from openai import OpenAI from openai import OpenAI
from google import genai from google import genai
class EmbeddingProvider(ABC): class EmbeddingProvider(ABC):
@abstractmethod @abstractmethod
def get_embedding(self, text: str)->list[float]: def get_embedding(self, text: str)->list[float]:
pass pass
@property @property
@abstractmethod @abstractmethod
def model_name(self)->str: def model_name(self)->str:
pass pass
class OpenAIEmbeddingProvider(EmbeddingProvider): class OpenAIEmbeddingProvider(EmbeddingProvider):
def __init__(self, backend_url: str, embedding_model: str = "text-embedding-3-small"): def __init__(self, backend_url: str, embedding_model: str = "text-embedding-3-small"):
self.client = OpenAI(base_url=backend_url) self.client = OpenAI(base_url=backend_url)
self._embedding_model = embedding_model self._embedding_model = embedding_model
def get_embedding(self, text: str)->list[float]: def get_embedding(self, text: str)->list[float]:
response = self.client.embeddings.create( response = self.client.embeddings.create(
model=self._embedding_model, model=self._embedding_model,
input=text input=text
) )
return response.data[0].embedding return response.data[0].embedding
@property @property
def model_name(self)->str: def model_name(self)->str:
return self._embedding_model return self._embedding_model
class GeminiEmbeddingProvider(EmbeddingProvider): class GeminiEmbeddingProvider(EmbeddingProvider):
def __init__(self, backend_url: str, embedding_model: str = "gemini-embedding-exp-03-07"): def __init__(self, backend_url: str, embedding_model: str = "gemini-embedding-exp-03-07"):
self.client = genai.Client() self.client = genai.Client()
self._embedding_model = embedding_model self._embedding_model = embedding_model
def get_embedding(self, text: str)->list[float]: def get_embedding(self, text: str)->list[float]:
response = self.client.models.embed_content( response = self.client.models.embed_content(
model=self._embedding_model, model=self._embedding_model,
contents=text contents=text
) )
return response.embeddings[0].values return response.embeddings[0].values
@property @property
def model_name(self)->str: def model_name(self)->str:
return self._embedding_model return self._embedding_model
class OllamaEmbeddingProvider(EmbeddingProvider): class OllamaEmbeddingProvider(EmbeddingProvider):
def __init__(self, backend_url: str, embedding_model: str = "nomic-embed-text"): def __init__(self, backend_url: str, embedding_model: str = "nomic-embed-text"):
self.client = OpenAI(base_url=backend_url) self.client = OpenAI(base_url=backend_url)
self._embedding_model = embedding_model self._embedding_model = embedding_model
def get_embedding(self, text: str)->list[float]: def get_embedding(self, text: str)->list[float]:
response = self.client.embeddings.create( response = self.client.embeddings.create(
model=self._embedding_model, model=self._embedding_model,
input=text input=text
) )
return response.data[0].embedding return response.data[0].embedding
@property @property
def model_name(self)->str: def model_name(self)->str:
return self._embedding_model return self._embedding_model

View File

@ -1,112 +1,112 @@
import chromadb import chromadb
from chromadb.config import Settings from chromadb.config import Settings
from openai import OpenAI from openai import OpenAI
import os import os
from .embedding_provider_factory import EmbeddingProviderFactory from .embedding_provider_factory import EmbeddingProviderFactory
from google import genai from google import genai
class FinancialSituationMemory: class FinancialSituationMemory:
def __init__(self, name, config): def __init__(self, name, config):
self.config = config self.config = config
self.backend_url = config["backend_url"] self.backend_url = config["backend_url"]
self.embedding_provider = EmbeddingProviderFactory.create_provider(config) self.embedding_provider = EmbeddingProviderFactory.create_provider(config)
self.chroma_client = chromadb.Client(Settings(allow_reset=True)) self.chroma_client = chromadb.Client(Settings(allow_reset=True))
self.situation_collection = self.chroma_client.create_collection(name=name) self.situation_collection = self.chroma_client.create_collection(name=name)
def get_embedding(self, text): def get_embedding(self, text):
"""Get embedding for a text using the appropriate API""" """Get embedding for a text using the appropriate API"""
return self.embedding_provider.get_embedding(text) return self.embedding_provider.get_embedding(text)
def add_situations(self, situations_and_advice): def add_situations(self, situations_and_advice):
"""Add financial situations and their corresponding advice. Parameter is a list of tuples (situation, rec)""" """Add financial situations and their corresponding advice. Parameter is a list of tuples (situation, rec)"""
situations = [] situations = []
advice = [] advice = []
ids = [] ids = []
embeddings = [] embeddings = []
offset = self.situation_collection.count() offset = self.situation_collection.count()
for i, (situation, recommendation) in enumerate(situations_and_advice): for i, (situation, recommendation) in enumerate(situations_and_advice):
situations.append(situation) situations.append(situation)
advice.append(recommendation) advice.append(recommendation)
ids.append(str(offset + i)) ids.append(str(offset + i))
embeddings.append(self.get_embedding(situation)) embeddings.append(self.get_embedding(situation))
self.situation_collection.add( self.situation_collection.add(
documents=situations, documents=situations,
metadatas=[{"recommendation": rec} for rec in advice], metadatas=[{"recommendation": rec} for rec in advice],
embeddings=embeddings, embeddings=embeddings,
ids=ids, ids=ids,
) )
def get_memories(self, current_situation, n_matches=1): def get_memories(self, current_situation, n_matches=1):
"""Find matching recommendations using embeddings""" """Find matching recommendations using embeddings"""
query_embedding = self.get_embedding(current_situation) query_embedding = self.get_embedding(current_situation)
results = self.situation_collection.query( results = self.situation_collection.query(
query_embeddings=[query_embedding], query_embeddings=[query_embedding],
n_results=n_matches, n_results=n_matches,
include=["metadatas", "documents", "distances"], include=["metadatas", "documents", "distances"],
) )
matched_results = [] matched_results = []
for i in range(len(results["documents"][0])): for i in range(len(results["documents"][0])):
matched_results.append( matched_results.append(
{ {
"matched_situation": results["documents"][0][i], "matched_situation": results["documents"][0][i],
"recommendation": results["metadatas"][0][i]["recommendation"], "recommendation": results["metadatas"][0][i]["recommendation"],
"similarity_score": 1 - results["distances"][0][i], "similarity_score": 1 - results["distances"][0][i],
} }
) )
return matched_results return matched_results
if __name__ == "__main__": if __name__ == "__main__":
# Example usage # Example usage
matcher = FinancialSituationMemory() matcher = FinancialSituationMemory()
# Example data # Example data
example_data = [ example_data = [
( (
"High inflation rate with rising interest rates and declining consumer spending", "High inflation rate with rising interest rates and declining consumer spending",
"Consider defensive sectors like consumer staples and utilities. Review fixed-income portfolio duration.", "Consider defensive sectors like consumer staples and utilities. Review fixed-income portfolio duration.",
), ),
( (
"Tech sector showing high volatility with increasing institutional selling pressure", "Tech sector showing high volatility with increasing institutional selling pressure",
"Reduce exposure to high-growth tech stocks. Look for value opportunities in established tech companies with strong cash flows.", "Reduce exposure to high-growth tech stocks. Look for value opportunities in established tech companies with strong cash flows.",
), ),
( (
"Strong dollar affecting emerging markets with increasing forex volatility", "Strong dollar affecting emerging markets with increasing forex volatility",
"Hedge currency exposure in international positions. Consider reducing allocation to emerging market debt.", "Hedge currency exposure in international positions. Consider reducing allocation to emerging market debt.",
), ),
( (
"Market showing signs of sector rotation with rising yields", "Market showing signs of sector rotation with rising yields",
"Rebalance portfolio to maintain target allocations. Consider increasing exposure to sectors benefiting from higher rates.", "Rebalance portfolio to maintain target allocations. Consider increasing exposure to sectors benefiting from higher rates.",
), ),
] ]
# Add the example situations and recommendations # Add the example situations and recommendations
matcher.add_situations(example_data) matcher.add_situations(example_data)
# Example query # Example query
current_situation = """ current_situation = """
Market showing increased volatility in tech sector, with institutional investors Market showing increased volatility in tech sector, with institutional investors
reducing positions and rising interest rates affecting growth stock valuations reducing positions and rising interest rates affecting growth stock valuations
""" """
try: try:
recommendations = matcher.get_memories(current_situation, n_matches=2) recommendations = matcher.get_memories(current_situation, n_matches=2)
for i, rec in enumerate(recommendations, 1): for i, rec in enumerate(recommendations, 1):
print(f"\nMatch {i}:") print(f"\nMatch {i}:")
print(f"Similarity Score: {rec['similarity_score']:.2f}") print(f"Similarity Score: {rec['similarity_score']:.2f}")
print(f"Matched Situation: {rec['matched_situation']}") print(f"Matched Situation: {rec['matched_situation']}")
print(f"Recommendation: {rec['recommendation']}") print(f"Recommendation: {rec['recommendation']}")
except Exception as e: except Exception as e:
print(f"Error during recommendation: {str(e)}") print(f"Error during recommendation: {str(e)}")

View File

@ -1,76 +1,76 @@
from google import genai from google import genai
from google.genai.types import Tool, GenerateContentConfig, GoogleSearch from google.genai.types import Tool, GenerateContentConfig, GoogleSearch
from openai import OpenAI from openai import OpenAI
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
class SearchProvider(ABC): class SearchProvider(ABC):
@abstractmethod @abstractmethod
def search(self, query: str, ticker: str, curr_date: str) -> str: def search(self, query: str, ticker: str, curr_date: str) -> str:
pass pass
class GoogleSearchProvider(SearchProvider): class GoogleSearchProvider(SearchProvider):
def __init__(self, model: str): def __init__(self, model: str):
self.client = genai.Client() self.client = genai.Client()
self.model = model self.model = model
def search(self, query: str) -> str: def search(self, query: str) -> str:
google_search_tool = Tool( google_search_tool = Tool(
google_search=GoogleSearch() google_search=GoogleSearch()
) )
response = self.client.models.generate_content( response = self.client.models.generate_content(
model=self.model, model=self.model,
contents=query, contents=query,
config=GenerateContentConfig( config=GenerateContentConfig(
tools=[google_search_tool], tools=[google_search_tool],
response_modalities=["TEXT"] response_modalities=["TEXT"]
) )
) )
result_text = "" result_text = ""
for part in response.candidates[0].content.parts: for part in response.candidates[0].content.parts:
if hasattr(part, 'text'): if hasattr(part, 'text'):
result_text += part.text result_text += part.text
return result_text return result_text
class OpenAISearchProvider(SearchProvider): class OpenAISearchProvider(SearchProvider):
def __init__(self, model: str, backend_url: str): def __init__(self, model: str, backend_url: str):
self.client = OpenAI(base_url=backend_url) self.client = OpenAI(base_url=backend_url)
self.model = model self.model = model
def search(self, query: str) -> str: def search(self, query: str) -> str:
response = self.client.responses.create( response = self.client.responses.create(
model=self.model, model=self.model,
input=[ input=[
{ {
"role": "system", "role": "system",
"content": [ "content": [
{ {
"type": "input_text", "type": "input_text",
"text": query "text": query
} }
], ],
} }
], ],
text={"format": {"type": "text"}}, text={"format": {"type": "text"}},
reasoning={}, reasoning={},
tools=[ tools=[
{ {
"type": "web_search_preview", "type": "web_search_preview",
"user_location": {"type": "approximate"}, "user_location": {"type": "approximate"},
"search_context_size": "low", "search_context_size": "low",
} }
], ],
temperature=1, temperature=1,
max_output_tokens=4096, max_output_tokens=4096,
top_p=1, top_p=1,
store=True, store=True,
) )
return response.output[1].content[0].text return response.output[1].content[0].text

View File

@ -1,47 +1,47 @@
from .search_provider import ( from .search_provider import (
SearchProvider, SearchProvider,
GoogleSearchProvider, GoogleSearchProvider,
OpenAISearchProvider OpenAISearchProvider
) )
import hashlib import hashlib
import json import json
class SearchProviderFactory: class SearchProviderFactory:
_cache = {} # 클래스 레벨 캐시 _cache = {} # 클래스 레벨 캐시
@staticmethod @staticmethod
def create_provider(config: dict[str, any]) -> SearchProvider: def create_provider(config: dict[str, any]) -> SearchProvider:
""" """
Create a SearchProvider with caching to avoid creating new instances. Create a SearchProvider with caching to avoid creating new instances.
Uses config hash as cache key for efficient reuse. Uses config hash as cache key for efficient reuse.
""" """
# Create cache key from relevant config values # Create cache key from relevant config values
cache_key_data = { cache_key_data = {
"backend_url": config["backend_url"], "backend_url": config["backend_url"],
"model": config["quick_think_llm"] "model": config["quick_think_llm"]
} }
cache_key = hashlib.md5(json.dumps(cache_key_data, sort_keys=True).encode()).hexdigest() cache_key = hashlib.md5(json.dumps(cache_key_data, sort_keys=True).encode()).hexdigest()
# Return cached instance if exists # Return cached instance if exists
if cache_key in SearchProviderFactory._cache: if cache_key in SearchProviderFactory._cache:
return SearchProviderFactory._cache[cache_key] return SearchProviderFactory._cache[cache_key]
# Create new instance # Create new instance
backend_url = config["backend_url"] backend_url = config["backend_url"]
model = config["quick_think_llm"] model = config["quick_think_llm"]
if "generativelanguage.googleapis.com" in backend_url: if "generativelanguage.googleapis.com" in backend_url:
provider = GoogleSearchProvider(model) provider = GoogleSearchProvider(model)
else: else:
provider = OpenAISearchProvider(model, backend_url) provider = OpenAISearchProvider(model, backend_url)
# Cache and return # Cache and return
SearchProviderFactory._cache[cache_key] = provider SearchProviderFactory._cache[cache_key] = provider
return provider return provider
@staticmethod @staticmethod
def clear_cache(): def clear_cache():
"""Clear the provider cache (useful for testing or config changes).""" """Clear the provider cache (useful for testing or config changes)."""
SearchProviderFactory._cache.clear() SearchProviderFactory._cache.clear()

View File

@ -1,31 +1,31 @@
# TradingAgents/graph/signal_processing.py # TradingAgents/graph/signal_processing.py
from langchain_openai import ChatOpenAI from langchain_openai import ChatOpenAI
class SignalProcessor: class SignalProcessor:
"""Processes trading signals to extract actionable decisions.""" """Processes trading signals to extract actionable decisions."""
def __init__(self, quick_thinking_llm: ChatOpenAI): def __init__(self, quick_thinking_llm: ChatOpenAI):
"""Initialize with an LLM for processing.""" """Initialize with an LLM for processing."""
self.quick_thinking_llm = quick_thinking_llm self.quick_thinking_llm = quick_thinking_llm
def process_signal(self, full_signal: str) -> str: def process_signal(self, full_signal: str) -> str:
""" """
Process a full trading signal to extract the core decision. Process a full trading signal to extract the core decision.
Args: Args:
full_signal: Complete trading signal text full_signal: Complete trading signal text
Returns: Returns:
Extracted decision (BUY, SELL, or HOLD) Extracted decision (BUY, SELL, or HOLD)
""" """
messages = [ messages = [
( (
"system", "system",
"**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)\n\nYou are an efficient assistant designed to analyze paragraphs or financial reports provided by a group of analysts. Your task is to extract the investment decision: SELL, BUY, or HOLD. Provide only the extracted decision (SELL, BUY, or HOLD) as your output, without adding any additional text or information.", "**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)\n\nYou are an efficient assistant designed to analyze paragraphs or financial reports provided by a group of analysts. Your task is to extract the investment decision: SELL, BUY, or HOLD. Provide only the extracted decision (SELL, BUY, or HOLD) as your output, without adding any additional text or information.",
), ),
("human", full_signal), ("human", full_signal),
] ]
return self.quick_thinking_llm.invoke(messages).content return self.quick_thinking_llm.invoke(messages).content

10810
uv.lock

File diff suppressed because it is too large Load Diff