This commit is contained in:
kimheesu 2025-07-07 14:22:27 +09:00
parent fbd96e9c18
commit ab1b0120c2
73 changed files with 35353 additions and 44998 deletions

21
.gitignore vendored
View File

@ -1,10 +1,11 @@
env/
__pycache__/
.DS_Store
*.csv
src/
eval_results/
eval_data/
*.egg-info/
results/
.env
env/
__pycache__/
.DS_Store
*.csv
src/
eval_results/
eval_data/
*.egg-info/
results/
.env
tradingagents/dataflows/data_cache/

View File

@ -1 +1 @@
3.10
3.10

28
.vscode/launch.json vendored
View File

@ -1,15 +1,15 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Python Debugger: main.py",
"type": "debugpy",
"request": "launch",
"program": "${workspaceFolder}/main.py",
"console": "integratedTerminal"
}
]
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Python Debugger: main.py",
"type": "debugpy",
"request": "launch",
"program": "${workspaceFolder}/main.py",
"console": "integratedTerminal"
}
]
}

430
README.md
View File

@ -1,215 +1,215 @@
<p align="center">
<img src="assets/TauricResearch.png" style="width: 60%; height: auto;">
</p>
<div align="center" style="line-height: 1;">
<a href="https://arxiv.org/abs/2412.20138" target="_blank"><img alt="arXiv" src="https://img.shields.io/badge/arXiv-2412.20138-B31B1B?logo=arxiv"/></a>
<a href="https://discord.com/invite/hk9PGKShPK" target="_blank"><img alt="Discord" src="https://img.shields.io/badge/Discord-TradingResearch-7289da?logo=discord&logoColor=white&color=7289da"/></a>
<a href="./assets/wechat.png" target="_blank"><img alt="WeChat" src="https://img.shields.io/badge/WeChat-TauricResearch-brightgreen?logo=wechat&logoColor=white"/></a>
<a href="https://x.com/TauricResearch" target="_blank"><img alt="X Follow" src="https://img.shields.io/badge/X-TauricResearch-white?logo=x&logoColor=white"/></a>
<br>
<a href="https://github.com/TauricResearch/" target="_blank"><img alt="Community" src="https://img.shields.io/badge/Join_GitHub_Community-TauricResearch-14C290?logo=discourse"/></a>
</div>
<div align="center">
<!-- Keep these links. Translations will automatically update with the README. -->
<a href="https://www.readme-i18n.com/TauricResearch/TradingAgents?lang=de">Deutsch</a> |
<a href="https://www.readme-i18n.com/TauricResearch/TradingAgents?lang=es">Español</a> |
<a href="https://www.readme-i18n.com/TauricResearch/TradingAgents?lang=fr">français</a> |
<a href="https://www.readme-i18n.com/TauricResearch/TradingAgents?lang=ja">日本語</a> |
<a href="https://www.readme-i18n.com/TauricResearch/TradingAgents?lang=ko">한국어</a> |
<a href="https://www.readme-i18n.com/TauricResearch/TradingAgents?lang=pt">Português</a> |
<a href="https://www.readme-i18n.com/TauricResearch/TradingAgents?lang=ru">Русский</a> |
<a href="https://www.readme-i18n.com/TauricResearch/TradingAgents?lang=zh">中文</a>
</div>
---
# TradingAgents: Multi-Agents LLM Financial Trading Framework
> 🎉 **TradingAgents** officially released! We have received numerous inquiries about the work, and we would like to express our thanks for the enthusiasm in our community.
>
> So we decided to fully open-source the framework. Looking forward to building impactful projects with you!
<div align="center">
<a href="https://www.star-history.com/#TauricResearch/TradingAgents&Date">
<picture>
<source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=TauricResearch/TradingAgents&type=Date&theme=dark" />
<source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=TauricResearch/TradingAgents&type=Date" />
<img alt="TradingAgents Star History" src="https://api.star-history.com/svg?repos=TauricResearch/TradingAgents&type=Date" style="width: 80%; height: auto;" />
</picture>
</a>
</div>
<div align="center">
🚀 [TradingAgents](#tradingagents-framework) | ⚡ [Installation & CLI](#installation-and-cli) | 🎬 [Demo](https://www.youtube.com/watch?v=90gr5lwjIho) | 📦 [Package Usage](#tradingagents-package) | 🤝 [Contributing](#contributing) | 📄 [Citation](#citation)
</div>
## TradingAgents Framework
TradingAgents is a multi-agent trading framework that mirrors the dynamics of real-world trading firms. By deploying specialized LLM-powered agents: from fundamental analysts, sentiment experts, and technical analysts, to trader, risk management team, the platform collaboratively evaluates market conditions and informs trading decisions. Moreover, these agents engage in dynamic discussions to pinpoint the optimal strategy.
<p align="center">
<img src="assets/schema.png" style="width: 100%; height: auto;">
</p>
> TradingAgents framework is designed for research purposes. Trading performance may vary based on many factors, including the chosen backbone language models, model temperature, trading periods, the quality of data, and other non-deterministic factors. [It is not intended as financial, investment, or trading advice.](https://tauric.ai/disclaimer/)
Our framework decomposes complex trading tasks into specialized roles. This ensures the system achieves a robust, scalable approach to market analysis and decision-making.
### Analyst Team
- Fundamentals Analyst: Evaluates company financials and performance metrics, identifying intrinsic values and potential red flags.
- Sentiment Analyst: Analyzes social media and public sentiment using sentiment scoring algorithms to gauge short-term market mood.
- News Analyst: Monitors global news and macroeconomic indicators, interpreting the impact of events on market conditions.
- Technical Analyst: Utilizes technical indicators (like MACD and RSI) to detect trading patterns and forecast price movements.
<p align="center">
<img src="assets/analyst.png" width="100%" style="display: inline-block; margin: 0 2%;">
</p>
### Researcher Team
- Comprises both bullish and bearish researchers who critically assess the insights provided by the Analyst Team. Through structured debates, they balance potential gains against inherent risks.
<p align="center">
<img src="assets/researcher.png" width="70%" style="display: inline-block; margin: 0 2%;">
</p>
### Trader Agent
- Composes reports from the analysts and researchers to make informed trading decisions. It determines the timing and magnitude of trades based on comprehensive market insights.
<p align="center">
<img src="assets/trader.png" width="70%" style="display: inline-block; margin: 0 2%;">
</p>
### Risk Management and Portfolio Manager
- Continuously evaluates portfolio risk by assessing market volatility, liquidity, and other risk factors. The risk management team evaluates and adjusts trading strategies, providing assessment reports to the Portfolio Manager for final decision.
- The Portfolio Manager approves/rejects the transaction proposal. If approved, the order will be sent to the simulated exchange and executed.
<p align="center">
<img src="assets/risk.png" width="70%" style="display: inline-block; margin: 0 2%;">
</p>
## Installation and CLI
### Installation
Clone TradingAgents:
```bash
git clone https://github.com/TauricResearch/TradingAgents.git
cd TradingAgents
```
Create a virtual environment in any of your favorite environment managers:
```bash
conda create -n tradingagents python=3.13
conda activate tradingagents
```
Install dependencies:
```bash
pip install -r requirements.txt
```
### Required APIs
You will also need the FinnHub API for financial data. All of our code is implemented with the free tier.
```bash
export FINNHUB_API_KEY=$YOUR_FINNHUB_API_KEY
```
You will need the OpenAI API or GEMINI API for all the agents.
```bash
export OPENAI_API_KEY=$YOUR_OPENAI_API_KEY
export GEMINI_API_KEY=$YOUR_GEMINI_API_KEY
export GOOGLE_API_KEY=$YOUR_GEMINI_API_KEY
```
### CLI Usage
You can also try out the CLI directly by running:
```bash
python -m cli.main
```
You will see a screen where you can select your desired tickers, date, LLMs, research depth, etc.
<p align="center">
<img src="assets/cli/cli_init.png" width="100%" style="display: inline-block; margin: 0 2%;">
</p>
An interface will appear showing results as they load, letting you track the agent's progress as it runs.
<p align="center">
<img src="assets/cli/cli_news.png" width="100%" style="display: inline-block; margin: 0 2%;">
</p>
<p align="center">
<img src="assets/cli/cli_transaction.png" width="100%" style="display: inline-block; margin: 0 2%;">
</p>
## TradingAgents Package
### Implementation Details
We built TradingAgents with LangGraph to ensure flexibility and modularity. We utilize `o1-preview` and `gpt-4o` as our deep thinking and fast thinking LLMs for our experiments. However, for testing purposes, we recommend you use `o4-mini` and `gpt-4.1-mini` to save on costs as our framework makes **lots of** API calls.
### Python Usage
To use TradingAgents inside your code, you can import the `tradingagents` module and initialize a `TradingAgentsGraph()` object. The `.propagate()` function will return a decision. You can run `main.py`, here's also a quick example:
```python
from tradingagents.graph.trading_graph import TradingAgentsGraph
from tradingagents.default_config import DEFAULT_CONFIG
ta = TradingAgentsGraph(debug=True, config=DEFAULT_CONFIG.copy())
# forward propagate
_, decision = ta.propagate("NVDA", "2024-05-10")
print(decision)
```
You can also adjust the default configuration to set your own choice of LLMs, debate rounds, etc.
```python
from tradingagents.graph.trading_graph import TradingAgentsGraph
from tradingagents.default_config import DEFAULT_CONFIG
# Create a custom config
config = DEFAULT_CONFIG.copy()
config["deep_think_llm"] = "gpt-4.1-nano" # Use a different model
config["quick_think_llm"] = "gpt-4.1-nano" # Use a different model
config["max_debate_rounds"] = 1 # Increase debate rounds
config["online_tools"] = True # Use online tools or cached data
# Initialize with custom config
ta = TradingAgentsGraph(debug=True, config=config)
# forward propagate
_, decision = ta.propagate("NVDA", "2024-05-10")
print(decision)
```
> For `online_tools`, we recommend enabling them for experimentation, as they provide access to real-time data. The agents' offline tools rely on cached data from our **Tauric TradingDB**, a curated dataset we use for backtesting. We're currently in the process of refining this dataset, and we plan to release it soon alongside our upcoming projects. Stay tuned!
You can view the full list of configurations in `tradingagents/default_config.py`.
## Contributing
We welcome contributions from the community! Whether it's fixing a bug, improving documentation, or suggesting a new feature, your input helps make this project better. If you are interested in this line of research, please consider joining our open-source financial AI research community [Tauric Research](https://tauric.ai/).
## Citation
Please reference our work if you find *TradingAgents* provides you with some help :)
```
@misc{xiao2025tradingagentsmultiagentsllmfinancial,
title={TradingAgents: Multi-Agents LLM Financial Trading Framework},
author={Yijia Xiao and Edward Sun and Di Luo and Wei Wang},
year={2025},
eprint={2412.20138},
archivePrefix={arXiv},
primaryClass={q-fin.TR},
url={https://arxiv.org/abs/2412.20138},
}
```
<p align="center">
<img src="assets/TauricResearch.png" style="width: 60%; height: auto;">
</p>
<div align="center" style="line-height: 1;">
<a href="https://arxiv.org/abs/2412.20138" target="_blank"><img alt="arXiv" src="https://img.shields.io/badge/arXiv-2412.20138-B31B1B?logo=arxiv"/></a>
<a href="https://discord.com/invite/hk9PGKShPK" target="_blank"><img alt="Discord" src="https://img.shields.io/badge/Discord-TradingResearch-7289da?logo=discord&logoColor=white&color=7289da"/></a>
<a href="./assets/wechat.png" target="_blank"><img alt="WeChat" src="https://img.shields.io/badge/WeChat-TauricResearch-brightgreen?logo=wechat&logoColor=white"/></a>
<a href="https://x.com/TauricResearch" target="_blank"><img alt="X Follow" src="https://img.shields.io/badge/X-TauricResearch-white?logo=x&logoColor=white"/></a>
<br>
<a href="https://github.com/TauricResearch/" target="_blank"><img alt="Community" src="https://img.shields.io/badge/Join_GitHub_Community-TauricResearch-14C290?logo=discourse"/></a>
</div>
<div align="center">
<!-- Keep these links. Translations will automatically update with the README. -->
<a href="https://www.readme-i18n.com/TauricResearch/TradingAgents?lang=de">Deutsch</a> |
<a href="https://www.readme-i18n.com/TauricResearch/TradingAgents?lang=es">Español</a> |
<a href="https://www.readme-i18n.com/TauricResearch/TradingAgents?lang=fr">français</a> |
<a href="https://www.readme-i18n.com/TauricResearch/TradingAgents?lang=ja">日本語</a> |
<a href="https://www.readme-i18n.com/TauricResearch/TradingAgents?lang=ko">한국어</a> |
<a href="https://www.readme-i18n.com/TauricResearch/TradingAgents?lang=pt">Português</a> |
<a href="https://www.readme-i18n.com/TauricResearch/TradingAgents?lang=ru">Русский</a> |
<a href="https://www.readme-i18n.com/TauricResearch/TradingAgents?lang=zh">中文</a>
</div>
---
# TradingAgents: Multi-Agents LLM Financial Trading Framework
> 🎉 **TradingAgents** officially released! We have received numerous inquiries about the work, and we would like to express our thanks for the enthusiasm in our community.
>
> So we decided to fully open-source the framework. Looking forward to building impactful projects with you!
<div align="center">
<a href="https://www.star-history.com/#TauricResearch/TradingAgents&Date">
<picture>
<source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=TauricResearch/TradingAgents&type=Date&theme=dark" />
<source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=TauricResearch/TradingAgents&type=Date" />
<img alt="TradingAgents Star History" src="https://api.star-history.com/svg?repos=TauricResearch/TradingAgents&type=Date" style="width: 80%; height: auto;" />
</picture>
</a>
</div>
<div align="center">
🚀 [TradingAgents](#tradingagents-framework) | ⚡ [Installation & CLI](#installation-and-cli) | 🎬 [Demo](https://www.youtube.com/watch?v=90gr5lwjIho) | 📦 [Package Usage](#tradingagents-package) | 🤝 [Contributing](#contributing) | 📄 [Citation](#citation)
</div>
## TradingAgents Framework
TradingAgents is a multi-agent trading framework that mirrors the dynamics of real-world trading firms. By deploying specialized LLM-powered agents: from fundamental analysts, sentiment experts, and technical analysts, to trader, risk management team, the platform collaboratively evaluates market conditions and informs trading decisions. Moreover, these agents engage in dynamic discussions to pinpoint the optimal strategy.
<p align="center">
<img src="assets/schema.png" style="width: 100%; height: auto;">
</p>
> TradingAgents framework is designed for research purposes. Trading performance may vary based on many factors, including the chosen backbone language models, model temperature, trading periods, the quality of data, and other non-deterministic factors. [It is not intended as financial, investment, or trading advice.](https://tauric.ai/disclaimer/)
Our framework decomposes complex trading tasks into specialized roles. This ensures the system achieves a robust, scalable approach to market analysis and decision-making.
### Analyst Team
- Fundamentals Analyst: Evaluates company financials and performance metrics, identifying intrinsic values and potential red flags.
- Sentiment Analyst: Analyzes social media and public sentiment using sentiment scoring algorithms to gauge short-term market mood.
- News Analyst: Monitors global news and macroeconomic indicators, interpreting the impact of events on market conditions.
- Technical Analyst: Utilizes technical indicators (like MACD and RSI) to detect trading patterns and forecast price movements.
<p align="center">
<img src="assets/analyst.png" width="100%" style="display: inline-block; margin: 0 2%;">
</p>
### Researcher Team
- Comprises both bullish and bearish researchers who critically assess the insights provided by the Analyst Team. Through structured debates, they balance potential gains against inherent risks.
<p align="center">
<img src="assets/researcher.png" width="70%" style="display: inline-block; margin: 0 2%;">
</p>
### Trader Agent
- Composes reports from the analysts and researchers to make informed trading decisions. It determines the timing and magnitude of trades based on comprehensive market insights.
<p align="center">
<img src="assets/trader.png" width="70%" style="display: inline-block; margin: 0 2%;">
</p>
### Risk Management and Portfolio Manager
- Continuously evaluates portfolio risk by assessing market volatility, liquidity, and other risk factors. The risk management team evaluates and adjusts trading strategies, providing assessment reports to the Portfolio Manager for final decision.
- The Portfolio Manager approves/rejects the transaction proposal. If approved, the order will be sent to the simulated exchange and executed.
<p align="center">
<img src="assets/risk.png" width="70%" style="display: inline-block; margin: 0 2%;">
</p>
## Installation and CLI
### Installation
Clone TradingAgents:
```bash
git clone https://github.com/TauricResearch/TradingAgents.git
cd TradingAgents
```
Create a virtual environment in any of your favorite environment managers:
```bash
conda create -n tradingagents python=3.13
conda activate tradingagents
```
Install dependencies:
```bash
pip install -r requirements.txt
```
### Required APIs
You will also need the FinnHub API for financial data. All of our code is implemented with the free tier.
```bash
export FINNHUB_API_KEY=$YOUR_FINNHUB_API_KEY
```
You will need the OpenAI API or GEMINI API for all the agents.
```bash
export OPENAI_API_KEY=$YOUR_OPENAI_API_KEY
export GEMINI_API_KEY=$YOUR_GEMINI_API_KEY
export GOOGLE_API_KEY=$YOUR_GEMINI_API_KEY
```
### CLI Usage
You can also try out the CLI directly by running:
```bash
python -m cli.main
```
You will see a screen where you can select your desired tickers, date, LLMs, research depth, etc.
<p align="center">
<img src="assets/cli/cli_init.png" width="100%" style="display: inline-block; margin: 0 2%;">
</p>
An interface will appear showing results as they load, letting you track the agent's progress as it runs.
<p align="center">
<img src="assets/cli/cli_news.png" width="100%" style="display: inline-block; margin: 0 2%;">
</p>
<p align="center">
<img src="assets/cli/cli_transaction.png" width="100%" style="display: inline-block; margin: 0 2%;">
</p>
## TradingAgents Package
### Implementation Details
We built TradingAgents with LangGraph to ensure flexibility and modularity. We utilize `o1-preview` and `gpt-4o` as our deep thinking and fast thinking LLMs for our experiments. However, for testing purposes, we recommend you use `o4-mini` and `gpt-4.1-mini` to save on costs as our framework makes **lots of** API calls.
### Python Usage
To use TradingAgents inside your code, you can import the `tradingagents` module and initialize a `TradingAgentsGraph()` object. The `.propagate()` function will return a decision. You can run `main.py`, here's also a quick example:
```python
from tradingagents.graph.trading_graph import TradingAgentsGraph
from tradingagents.default_config import DEFAULT_CONFIG
ta = TradingAgentsGraph(debug=True, config=DEFAULT_CONFIG.copy())
# forward propagate
_, decision = ta.propagate("NVDA", "2024-05-10")
print(decision)
```
You can also adjust the default configuration to set your own choice of LLMs, debate rounds, etc.
```python
from tradingagents.graph.trading_graph import TradingAgentsGraph
from tradingagents.default_config import DEFAULT_CONFIG
# Create a custom config
config = DEFAULT_CONFIG.copy()
config["deep_think_llm"] = "gpt-4.1-nano" # Use a different model
config["quick_think_llm"] = "gpt-4.1-nano" # Use a different model
config["max_debate_rounds"] = 1 # Increase debate rounds
config["online_tools"] = True # Use online tools or cached data
# Initialize with custom config
ta = TradingAgentsGraph(debug=True, config=config)
# forward propagate
_, decision = ta.propagate("NVDA", "2024-05-10")
print(decision)
```
> For `online_tools`, we recommend enabling them for experimentation, as they provide access to real-time data. The agents' offline tools rely on cached data from our **Tauric TradingDB**, a curated dataset we use for backtesting. We're currently in the process of refining this dataset, and we plan to release it soon alongside our upcoming projects. Stay tuned!
You can view the full list of configurations in `tradingagents/default_config.py`.
## Contributing
We welcome contributions from the community! Whether it's fixing a bug, improving documentation, or suggesting a new feature, your input helps make this project better. If you are interested in this line of research, please consider joining our open-source financial AI research community [Tauric Research](https://tauric.ai/).
## Citation
Please reference our work if you find *TradingAgents* provides you with some help :)
```
@misc{xiao2025tradingagentsmultiagentsllmfinancial,
title={TradingAgents: Multi-Agents LLM Financial Trading Framework},
author={Yijia Xiao and Edward Sun and Di Luo and Wei Wang},
year={2025},
eprint={2412.20138},
archivePrefix={arXiv},
primaryClass={q-fin.TR},
url={https://arxiv.org/abs/2412.20138},
}
```

View File

@ -1,67 +1,67 @@
from typing import Generator, Optional
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from jose import jwt, JWTError
from pydantic import BaseModel
from sqlmodel import Session
from app.core.config import settings
from app.infrastructure.database import get_db
from app.domain.models import User
from app.infrastructure.repositories.user import UserRepository
from app.core.services.trading_analysis import TradingAnalysisService
reusable_oauth2 = OAuth2PasswordBearer(tokenUrl=f"{settings.API_V1_STR}/login/access-token")
class TokenData(BaseModel):
username: Optional[str] = None
def get_user_repository(db: Session = Depends(get_db)) -> UserRepository:
return UserRepository(db)
def get_user_from_token(token: str, db: Session) -> Optional[User]:
try:
payload = jwt.decode(
token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
)
token_data = TokenData(username=payload.get("sub"))
except JWTError:
return None
user_repo = UserRepository(db)
user = user_repo.get_by_email(email=token_data.username)
return user
def get_current_user(
db: Session = Depends(get_db), token: str = Depends(reusable_oauth2)
) -> User:
user = get_user_from_token(token=token, db=db)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
return user
def get_current_active_user(
current_user: User = Depends(get_current_user),
) -> User:
if not current_user.is_active:
raise HTTPException(status_code=400, detail="Inactive user")
return current_user
def get_current_active_superuser(
current_user: User = Depends(get_current_active_user),
) -> User:
if not current_user.is_superuser:
raise HTTPException(
status_code=403, detail="The user doesn't have enough privileges"
)
return current_user
def get_analysis_service(
db: Session = Depends(get_db),
user: User = Depends(get_current_active_user)
) -> TradingAnalysisService:
return TradingAnalysisService(user=user, db=db)
from typing import Generator, Optional
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from jose import jwt, JWTError
from pydantic import BaseModel
from sqlmodel import Session
from app.core.config import settings
from app.infrastructure.database import get_db
from app.domain.models import User
from app.infrastructure.repositories.user import UserRepository
from app.core.services.trading_analysis import TradingAnalysisService
reusable_oauth2 = OAuth2PasswordBearer(tokenUrl=f"{settings.API_V1_STR}/login/access-token")
class TokenData(BaseModel):
username: Optional[str] = None
def get_user_repository(db: Session = Depends(get_db)) -> UserRepository:
return UserRepository(db)
def get_user_from_token(token: str, db: Session) -> Optional[User]:
try:
payload = jwt.decode(
token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
)
token_data = TokenData(username=payload.get("sub"))
except JWTError:
return None
user_repo = UserRepository(db)
user = user_repo.get_by_email(email=token_data.username)
return user
def get_current_user(
db: Session = Depends(get_db), token: str = Depends(reusable_oauth2)
) -> User:
user = get_user_from_token(token=token, db=db)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
return user
def get_current_active_user(
current_user: User = Depends(get_current_user),
) -> User:
if not current_user.is_active:
raise HTTPException(status_code=400, detail="Inactive user")
return current_user
def get_current_active_superuser(
current_user: User = Depends(get_current_active_user),
) -> User:
if not current_user.is_superuser:
raise HTTPException(
status_code=403, detail="The user doesn't have enough privileges"
)
return current_user
def get_analysis_service(
db: Session = Depends(get_db),
user: User = Depends(get_current_active_user)
) -> TradingAnalysisService:
return TradingAnalysisService(user=user, db=db)

View File

@ -1,94 +1,94 @@
from typing import Any, List
from fastapi import APIRouter, Depends, HTTPException, WebSocket, WebSocketDisconnect, BackgroundTasks
from app.api import deps
from app.core.schemas.analysis import AnalysisSession, AnalysisSessionCreate
from app.domain.models import User as UserModel
from app.core.services.trading_analysis import TradingAnalysisService
from app.core.websocket_manager import WebSocketManager
from sqlmodel import Session
from cli.utils import SHALLOW_AGENT_OPTIONS, DEEP_AGENT_OPTIONS, BASE_URLS
router = APIRouter()
manager = WebSocketManager()
@router.post("/start", response_model=AnalysisSession)
def start_analysis(
*,
analysis_in: AnalysisSessionCreate,
background_tasks: BackgroundTasks,
service: TradingAnalysisService = Depends(deps.get_analysis_service),
) -> Any:
"""
Start a new analysis session.
"""
session = service.create_session(analysis_in=analysis_in)
background_tasks.add_task(service.run_analysis, session_id=session.id)
return session
@router.get("/history", response_model=List[AnalysisSession])
def get_analysis_history(
service: TradingAnalysisService = Depends(deps.get_analysis_service),
skip: int = 0,
limit: int = 100,
) -> Any:
"""
Get analysis history for the current user.
"""
return service.get_user_sessions(skip=skip, limit=limit)
@router.get("/options")
def get_analysis_options():
"""
Get available options for analysis.
"""
return {
'analysts': [
{'value': 'market', 'label': 'Market Analyst'},
{'value': 'social', 'label': 'Social Analyst'},
{'value': 'news', 'label': 'News Analyst'},
{'value': 'fundamentals', 'label': 'Fundamentals Analyst'},
],
'research_depths': [
{'value': 1, 'label': 'Shallow'},
{'value': 3, 'label': 'Medium'},
{'value': 5, 'label': 'Deep'},
],
'llm_providers': [{'name': p[0], 'url': p[1]} for p in BASE_URLS],
'shallow_thinkers': SHALLOW_AGENT_OPTIONS,
'deep_thinkers': DEEP_AGENT_OPTIONS,
}
@router.get("/{session_id}", response_model=AnalysisSession)
def get_analysis_session(
session_id: int,
service: TradingAnalysisService = Depends(deps.get_analysis_service),
) -> Any:
"""
Get a specific analysis session by ID.
"""
session = service.get_session(session_id=session_id)
if not session:
raise HTTPException(status_code=404, detail="Analysis session not found")
return session
@router.websocket("/ws")
async def websocket_endpoint(
websocket: WebSocket,
token: str,
db: Session = Depends(deps.get_db)
):
"""
WebSocket endpoint for real-time analysis updates.
"""
user = deps.get_user_from_token(token=token, db=db)
if not user or not user.is_active:
await websocket.close(code=1008)
return
await manager.connect(user.id, websocket)
try:
while True:
# Keep the connection alive
await websocket.receive_text()
except WebSocketDisconnect:
from typing import Any, List
from fastapi import APIRouter, Depends, HTTPException, WebSocket, WebSocketDisconnect, BackgroundTasks
from app.api import deps
from app.core.schemas.analysis import AnalysisSession, AnalysisSessionCreate
from app.domain.models import User as UserModel
from app.core.services.trading_analysis import TradingAnalysisService
from app.core.websocket_manager import WebSocketManager
from sqlmodel import Session
from cli.utils import SHALLOW_AGENT_OPTIONS, DEEP_AGENT_OPTIONS, BASE_URLS
router = APIRouter()
manager = WebSocketManager()
@router.post("/start", response_model=AnalysisSession)
def start_analysis(
*,
analysis_in: AnalysisSessionCreate,
background_tasks: BackgroundTasks,
service: TradingAnalysisService = Depends(deps.get_analysis_service),
) -> Any:
"""
Start a new analysis session.
"""
session = service.create_session(analysis_in=analysis_in)
background_tasks.add_task(service.run_analysis, session_id=session.id)
return session
@router.get("/history", response_model=List[AnalysisSession])
def get_analysis_history(
service: TradingAnalysisService = Depends(deps.get_analysis_service),
skip: int = 0,
limit: int = 100,
) -> Any:
"""
Get analysis history for the current user.
"""
return service.get_user_sessions(skip=skip, limit=limit)
@router.get("/options")
def get_analysis_options():
"""
Get available options for analysis.
"""
return {
'analysts': [
{'value': 'market', 'label': 'Market Analyst'},
{'value': 'social', 'label': 'Social Analyst'},
{'value': 'news', 'label': 'News Analyst'},
{'value': 'fundamentals', 'label': 'Fundamentals Analyst'},
],
'research_depths': [
{'value': 1, 'label': 'Shallow'},
{'value': 3, 'label': 'Medium'},
{'value': 5, 'label': 'Deep'},
],
'llm_providers': [{'name': p[0], 'url': p[1]} for p in BASE_URLS],
'shallow_thinkers': SHALLOW_AGENT_OPTIONS,
'deep_thinkers': DEEP_AGENT_OPTIONS,
}
@router.get("/{session_id}", response_model=AnalysisSession)
def get_analysis_session(
session_id: int,
service: TradingAnalysisService = Depends(deps.get_analysis_service),
) -> Any:
"""
Get a specific analysis session by ID.
"""
session = service.get_session(session_id=session_id)
if not session:
raise HTTPException(status_code=404, detail="Analysis session not found")
return session
@router.websocket("/ws")
async def websocket_endpoint(
websocket: WebSocket,
token: str,
db: Session = Depends(deps.get_db)
):
"""
WebSocket endpoint for real-time analysis updates.
"""
user = deps.get_user_from_token(token=token, db=db)
if not user or not user.is_active:
await websocket.close(code=1008)
return
await manager.connect(user.id, websocket)
try:
while True:
# Keep the connection alive
await websocket.receive_text()
except WebSocketDisconnect:
manager.disconnect(user.id, websocket)

View File

@ -1,35 +1,35 @@
from datetime import timedelta
from fastapi import APIRouter, Depends, HTTPException
from fastapi.security import OAuth2PasswordRequestForm
from sqlmodel import Session
from app.api import deps
from app.core.config import settings
from app.core.schemas.token import Token
from app.core import security
from app.infrastructure.repositories.user import UserRepository
router = APIRouter()
@router.post("/login/access-token", response_model=Token)
def login_access_token(
db: Session = Depends(deps.get_db), form_data: OAuth2PasswordRequestForm = Depends()
):
"""
OAuth2 compatible token login, get an access token for future requests
"""
user_repo = UserRepository(db)
user = user_repo.get_by_email(email=form_data.username)
if not user or not security.verify_password(form_data.password, user.hashed_password):
raise HTTPException(status_code=400, detail="Incorrect email or password")
elif not user.is_active:
raise HTTPException(status_code=400, detail="Inactive user")
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
return {
"access_token": security.create_access_token(
user.email, expires_delta=access_token_expires
),
"token_type": "bearer",
}
from datetime import timedelta
from fastapi import APIRouter, Depends, HTTPException
from fastapi.security import OAuth2PasswordRequestForm
from sqlmodel import Session
from app.api import deps
from app.core.config import settings
from app.core.schemas.token import Token
from app.core import security
from app.infrastructure.repositories.user import UserRepository
router = APIRouter()
@router.post("/login/access-token", response_model=Token)
def login_access_token(
db: Session = Depends(deps.get_db), form_data: OAuth2PasswordRequestForm = Depends()
):
"""
OAuth2 compatible token login, get an access token for future requests
"""
user_repo = UserRepository(db)
user = user_repo.get_by_email(email=form_data.username)
if not user or not security.verify_password(form_data.password, user.hashed_password):
raise HTTPException(status_code=400, detail="Incorrect email or password")
elif not user.is_active:
raise HTTPException(status_code=400, detail="Inactive user")
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
return {
"access_token": security.create_access_token(
user.email, expires_delta=access_token_expires
),
"token_type": "bearer",
}

View File

@ -1,89 +1,89 @@
from typing import Any, List
from fastapi import APIRouter, Depends, HTTPException
from app.api import deps
from app.core.schemas.user import User, UserCreate, UserUpdate
from app.domain.models import User as UserModel
from app.domain.repositories import IUserRepository
router = APIRouter()
@router.get("/", response_model=List[User])
def read_users(
repo: IUserRepository = Depends(deps.get_user_repository),
skip: int = 0,
limit: int = 100,
current_user: UserModel = Depends(deps.get_current_active_superuser),
) -> Any:
"""
Retrieve users.
"""
users = repo.get_multi(skip=skip, limit=limit)
return users
@router.post("/", response_model=User)
def create_user(
*,
repo: IUserRepository = Depends(deps.get_user_repository),
user_in: UserCreate,
current_user: UserModel = Depends(deps.get_current_active_superuser),
) -> Any:
"""
Create new user.
"""
user = repo.get_by_email(email=user_in.email)
if user:
raise HTTPException(
status_code=400,
detail="The user with this username already exists in the system.",
)
user = repo.create(obj_in=user_in)
return user
@router.get("/me", response_model=User)
def read_user_me(
current_user: UserModel = Depends(deps.get_current_active_user),
) -> Any:
"""
Get current user.
"""
return current_user
@router.get("/{user_id}", response_model=User)
def read_user_by_id(
user_id: int,
repo: IUserRepository = Depends(deps.get_user_repository),
current_user: UserModel = Depends(deps.get_current_active_user),
) -> Any:
"""
Get a specific user by id.
"""
user = repo.get(id=user_id)
if not user:
raise HTTPException(status_code=404, detail="User not found")
if user == current_user:
return user
if not repo.is_superuser(user=current_user):
raise HTTPException(
status_code=403, detail="The user doesn't have enough privileges"
)
return user
@router.put("/{user_id}", response_model=User)
def update_user(
*,
repo: IUserRepository = Depends(deps.get_user_repository),
user_id: int,
user_in: UserUpdate,
current_user: UserModel = Depends(deps.get_current_active_superuser),
) -> Any:
"""
Update a user.
"""
user = repo.get(id=user_id)
if not user:
raise HTTPException(
status_code=404,
detail="The user with this username does not exist in the system",
)
user = repo.update(db_obj=user, obj_in=user_in)
return user
from typing import Any, List
from fastapi import APIRouter, Depends, HTTPException
from app.api import deps
from app.core.schemas.user import User, UserCreate, UserUpdate
from app.domain.models import User as UserModel
from app.domain.repositories import IUserRepository
router = APIRouter()
@router.get("/", response_model=List[User])
def read_users(
repo: IUserRepository = Depends(deps.get_user_repository),
skip: int = 0,
limit: int = 100,
current_user: UserModel = Depends(deps.get_current_active_superuser),
) -> Any:
"""
Retrieve users.
"""
users = repo.get_multi(skip=skip, limit=limit)
return users
@router.post("/", response_model=User)
def create_user(
*,
repo: IUserRepository = Depends(deps.get_user_repository),
user_in: UserCreate,
current_user: UserModel = Depends(deps.get_current_active_superuser),
) -> Any:
"""
Create new user.
"""
user = repo.get_by_email(email=user_in.email)
if user:
raise HTTPException(
status_code=400,
detail="The user with this username already exists in the system.",
)
user = repo.create(obj_in=user_in)
return user
@router.get("/me", response_model=User)
def read_user_me(
current_user: UserModel = Depends(deps.get_current_active_user),
) -> Any:
"""
Get current user.
"""
return current_user
@router.get("/{user_id}", response_model=User)
def read_user_by_id(
user_id: int,
repo: IUserRepository = Depends(deps.get_user_repository),
current_user: UserModel = Depends(deps.get_current_active_user),
) -> Any:
"""
Get a specific user by id.
"""
user = repo.get(id=user_id)
if not user:
raise HTTPException(status_code=404, detail="User not found")
if user == current_user:
return user
if not repo.is_superuser(user=current_user):
raise HTTPException(
status_code=403, detail="The user doesn't have enough privileges"
)
return user
@router.put("/{user_id}", response_model=User)
def update_user(
*,
repo: IUserRepository = Depends(deps.get_user_repository),
user_id: int,
user_in: UserUpdate,
current_user: UserModel = Depends(deps.get_current_active_superuser),
) -> Any:
"""
Update a user.
"""
user = repo.get(id=user_id)
if not user:
raise HTTPException(
status_code=404,
detail="The user with this username does not exist in the system",
)
user = repo.update(db_obj=user, obj_in=user_in)
return user

View File

@ -1,7 +1,7 @@
from fastapi import APIRouter
from app.api.endpoints import login, users, analysis
api_router = APIRouter()
api_router.include_router(login.router, tags=["login"])
api_router.include_router(users.router, prefix="/users", tags=["users"])
api_router.include_router(analysis.router, prefix="/analysis", tags=["analysis"])
from fastapi import APIRouter
from app.api.endpoints import login, users, analysis
api_router = APIRouter()
api_router.include_router(login.router, tags=["login"])
api_router.include_router(users.router, prefix="/users", tags=["users"])
api_router.include_router(analysis.router, prefix="/analysis", tags=["analysis"])

View File

@ -1,26 +1,26 @@
import os
from pydantic import BaseSettings
from typing import List, Optional
class Settings(BaseSettings):
PROJECT_NAME: str = "TradingAgents Backend"
API_V1_STR: str = "/api/v1"
# Security
SECRET_KEY: str = os.getenv("SECRET_KEY", "a_very_secret_key")
ACCESS_TOKEN_EXPIRE_MINUTES: int = 60 * 24 * 8 # 8 days
ALGORITHM: str = "HS256"
# Database
DATABASE_URL: str = os.getenv("DATABASE_URL", "sqlite:///./tradingagents.db")
# OpenAI
OPENAI_API_KEY: str = os.getenv("OPENAI_API_KEY", "")
# CORS
CORS_ALLOWED_ORIGINS: List[str] = os.getenv("CORS_ALLOWED_ORIGINS", "http://localhost:3000,http://127.0.0.1:3000").split(',')
class Config:
case_sensitive = True
import os
from pydantic import BaseSettings
from typing import List, Optional
class Settings(BaseSettings):
PROJECT_NAME: str = "TradingAgents Backend"
API_V1_STR: str = "/api/v1"
# Security
SECRET_KEY: str = os.getenv("SECRET_KEY", "a_very_secret_key")
ACCESS_TOKEN_EXPIRE_MINUTES: int = 60 * 24 * 8 # 8 days
ALGORITHM: str = "HS256"
# Database
DATABASE_URL: str = os.getenv("DATABASE_URL", "sqlite:///./tradingagents.db")
# OpenAI
OPENAI_API_KEY: str = os.getenv("OPENAI_API_KEY", "")
# CORS
CORS_ALLOWED_ORIGINS: List[str] = os.getenv("CORS_ALLOWED_ORIGINS", "http://localhost:3000,http://127.0.0.1:3000").split(',')
class Config:
case_sensitive = True
settings = Settings()

View File

@ -1,4 +1,4 @@
from .user import User, UserCreate, UserUpdate
from .token import Token, TokenPayload
from .profile import Profile, ProfileCreate, ProfileUpdate
from .analysis import AnalysisSession, AnalysisSessionCreate, AnalysisSessionUpdate
from .user import User, UserCreate, UserUpdate
from .token import Token, TokenPayload
from .profile import Profile, ProfileCreate, ProfileUpdate
from .analysis import AnalysisSession, AnalysisSessionCreate, AnalysisSessionUpdate

View File

@ -1,38 +1,38 @@
from pydantic import BaseModel
from typing import List, Optional
from datetime import date, datetime
from app.domain.models import AnalysisStatus
class AnalysisSessionBase(BaseModel):
ticker: str
analysts_selected: List[str]
research_depth: int
llm_provider: str
backend_url: str
shallow_thinker: str
deep_thinker: str
class AnalysisSessionCreate(AnalysisSessionBase):
pass
class AnalysisSessionUpdate(BaseModel):
status: Optional[AnalysisStatus] = None
final_report: Optional[str] = None
error_message: Optional[str] = None
class AnalysisSessionInDBBase(AnalysisSessionBase):
id: int
user_id: int
analysis_date: date
status: AnalysisStatus
final_report: Optional[str] = None
error_message: Optional[str] = None
created_at: datetime
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
class Config:
orm_mode = True
class AnalysisSession(AnalysisSessionInDBBase):
from pydantic import BaseModel
from typing import List, Optional
from datetime import date, datetime
from app.domain.models import AnalysisStatus
class AnalysisSessionBase(BaseModel):
ticker: str
analysts_selected: List[str]
research_depth: int
llm_provider: str
backend_url: str
shallow_thinker: str
deep_thinker: str
class AnalysisSessionCreate(AnalysisSessionBase):
pass
class AnalysisSessionUpdate(BaseModel):
status: Optional[AnalysisStatus] = None
final_report: Optional[str] = None
error_message: Optional[str] = None
class AnalysisSessionInDBBase(AnalysisSessionBase):
id: int
user_id: int
analysis_date: date
status: AnalysisStatus
final_report: Optional[str] = None
error_message: Optional[str] = None
created_at: datetime
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
class Config:
orm_mode = True
class AnalysisSession(AnalysisSessionInDBBase):
pass

View File

@ -1,20 +1,20 @@
from pydantic import BaseModel
from typing import Optional
class ProfileBase(BaseModel):
default_ticker: str = "SPY"
preferred_research_depth: int = 3
preferred_shallow_thinker: str = "gpt-4o-mini"
preferred_deep_thinker: str = "gpt-4o"
class ProfileCreate(ProfileBase):
pass
class ProfileUpdate(ProfileBase):
openai_api_key: Optional[str] = None
class Profile(ProfileBase):
has_openai_api_key: bool
class Config:
orm_mode = True
from pydantic import BaseModel
from typing import Optional
class ProfileBase(BaseModel):
default_ticker: str = "SPY"
preferred_research_depth: int = 3
preferred_shallow_thinker: str = "gpt-4o-mini"
preferred_deep_thinker: str = "gpt-4o"
class ProfileCreate(ProfileBase):
pass
class ProfileUpdate(ProfileBase):
openai_api_key: Optional[str] = None
class Profile(ProfileBase):
has_openai_api_key: bool
class Config:
orm_mode = True

View File

@ -1,9 +1,9 @@
from pydantic import BaseModel
from typing import Optional
class Token(BaseModel):
access_token: str
token_type: str
class TokenPayload(BaseModel):
sub: Optional[int] = None
from pydantic import BaseModel
from typing import Optional
class Token(BaseModel):
access_token: str
token_type: str
class TokenPayload(BaseModel):
sub: Optional[int] = None

View File

@ -1,28 +1,28 @@
from pydantic import BaseModel, EmailStr
from typing import Optional
class UserBase(BaseModel):
email: EmailStr
username: str
first_name: Optional[str] = None
last_name: Optional[str] = None
class UserCreate(UserBase):
password: str
class UserUpdate(UserBase):
pass
class UserInDBBase(UserBase):
id: int
is_active: bool
is_superuser: bool
class Config:
orm_mode = True
class User(UserInDBBase):
pass
class UserInDB(UserInDBBase):
hashed_password: str
from pydantic import BaseModel, EmailStr
from typing import Optional
class UserBase(BaseModel):
email: EmailStr
username: str
first_name: Optional[str] = None
last_name: Optional[str] = None
class UserCreate(UserBase):
password: str
class UserUpdate(UserBase):
pass
class UserInDBBase(UserBase):
id: int
is_active: bool
is_superuser: bool
class Config:
orm_mode = True
class User(UserInDBBase):
pass
class UserInDB(UserInDBBase):
hashed_password: str

View File

@ -1,23 +1,23 @@
from datetime import datetime, timedelta
from typing import Any, Union, Optional
from jose import jwt
from passlib.context import CryptContext
from app.core.config import settings
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def create_access_token(subject: Union[str, Any], expires_delta: Optional[timedelta] = None) -> str:
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
to_encode = {"exp": expire, "sub": str(subject)}
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
return encoded_jwt
def verify_password(plain_password: str, hashed_password: str) -> bool:
return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password: str) -> str:
return pwd_context.hash(password)
from datetime import datetime, timedelta
from typing import Any, Union, Optional
from jose import jwt
from passlib.context import CryptContext
from app.core.config import settings
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def create_access_token(subject: Union[str, Any], expires_delta: Optional[timedelta] = None) -> str:
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
to_encode = {"exp": expire, "sub": str(subject)}
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
return encoded_jwt
def verify_password(plain_password: str, hashed_password: str) -> bool:
return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password: str) -> str:
return pwd_context.hash(password)

View File

@ -1,128 +1,128 @@
import asyncio
import datetime
import json
from typing import Dict, List, Optional
from sqlmodel import Session, select
from app.domain.models import User, AnalysisSession, AnalysisStatus
from app.core.schemas.analysis import AnalysisSessionCreate
from app.core.config import settings
from cli.models import AnalystType
from tradingagents.graph.trading_graph import TradingAgentsGraph
from tradingagents.default_config import DEFAULT_CONFIG
from app.api.deps import get_db
from app.core.websocket_manager import WebSocketManager
class TradingAnalysisService:
def __init__(self, user: User, db: Session):
self.user = user
self.db = db
self.websocket_manager = WebSocketManager()
async def run_analysis(self, session_id: int):
"""분석 실행"""
session = self.get_session(session_id=session_id)
if not session:
return
try:
session.status = AnalysisStatus.RUNNING
session.started_at = datetime.datetime.utcnow()
self.db.add(session)
self.db.commit()
self.db.refresh(session)
await self.websocket_manager.send_to_user(
self.user.id,
{
'type': 'analysis_started',
'session_id': session.id,
'message': '분석을 시작합니다...'
}
)
# Prepare config for TradingAgentsGraph
config = DEFAULT_CONFIG.copy()
config.update({
'openai_api_key': settings.OPENAI_API_KEY,
'llm_provider': session.llm_provider,
'backend_url': session.backend_url,
'shallow_thinking_model': session.shallow_thinker,
'deep_thinking_model': session.deep_thinker,
})
# Progress callback for websocket
async def progress_callback(message_type: str, content: str, agent: str = None, step: int = 0, total: int = 0):
progress_percent = int((step / total) * 99) if total > 0 else 0
await self.websocket_manager.send_to_user(self.user.id, {
'type': 'analysis_progress',
'session_id': session.id,
'message_type': message_type,
'content': content,
'agent': agent,
'progress': progress_percent,
})
trading_graph = TradingAgentsGraph(
config=config,
selected_analysts=session.analysts_selected,
)
input_data = {
'company_of_interest': session.ticker,
'trade_date': session.analysis_date.strftime('%Y-%m-%d'),
}
final_state, result = await asyncio.to_thread(
trading_graph.propagate,
input_data['company_of_interest'],
input_data['trade_date']
)
session.status = AnalysisStatus.COMPLETED
session.completed_at = datetime.datetime.utcnow()
session.final_report = json.dumps(final_state) # Store full state as JSON
self.db.add(session)
self.db.commit()
await self.websocket_manager.send_to_user(
self.user.id,
{
'type': 'analysis_completed',
'session_id': session.id,
'message': '분석이 완료되었습니다.',
'result': result
}
)
except Exception as e:
session.status = AnalysisStatus.FAILED
session.error_message = str(e)
self.db.add(session)
self.db.commit()
await self.websocket_manager.send_to_user(
self.user.id,
{
'type': 'analysis_failed',
'session_id': session.id,
'message': f'분석 중 오류가 발생했습니다: {str(e)}'
}
)
def create_session(self, *, analysis_in: AnalysisSessionCreate) -> AnalysisSession:
session = AnalysisSession(
**analysis_in.dict(),
user_id=self.user.id,
analysis_date=datetime.date.today()
)
self.db.add(session)
self.db.commit()
self.db.refresh(session)
return session
def get_session(self, *, session_id: int) -> Optional[AnalysisSession]:
statement = select(AnalysisSession).where(AnalysisSession.id == session_id, AnalysisSession.user_id == self.user.id)
return self.db.exec(statement).first()
def get_user_sessions(self, *, skip: int = 0, limit: int = 100) -> List[AnalysisSession]:
statement = select(AnalysisSession).where(AnalysisSession.user_id == self.user.id).order_by(AnalysisSession.created_at.desc()).offset(skip).limit(limit)
return self.db.exec(statement).all()
import asyncio
import datetime
import json
from typing import Dict, List, Optional
from sqlmodel import Session, select
from app.domain.models import User, AnalysisSession, AnalysisStatus
from app.core.schemas.analysis import AnalysisSessionCreate
from app.core.config import settings
from cli.models import AnalystType
from tradingagents.graph.trading_graph import TradingAgentsGraph
from tradingagents.default_config import DEFAULT_CONFIG
from app.api.deps import get_db
from app.core.websocket_manager import WebSocketManager
class TradingAnalysisService:
def __init__(self, user: User, db: Session):
self.user = user
self.db = db
self.websocket_manager = WebSocketManager()
async def run_analysis(self, session_id: int):
"""분석 실행"""
session = self.get_session(session_id=session_id)
if not session:
return
try:
session.status = AnalysisStatus.RUNNING
session.started_at = datetime.datetime.utcnow()
self.db.add(session)
self.db.commit()
self.db.refresh(session)
await self.websocket_manager.send_to_user(
self.user.id,
{
'type': 'analysis_started',
'session_id': session.id,
'message': '분석을 시작합니다...'
}
)
# Prepare config for TradingAgentsGraph
config = DEFAULT_CONFIG.copy()
config.update({
'openai_api_key': settings.OPENAI_API_KEY,
'llm_provider': session.llm_provider,
'backend_url': session.backend_url,
'shallow_thinking_model': session.shallow_thinker,
'deep_thinking_model': session.deep_thinker,
})
# Progress callback for websocket
async def progress_callback(message_type: str, content: str, agent: str = None, step: int = 0, total: int = 0):
progress_percent = int((step / total) * 99) if total > 0 else 0
await self.websocket_manager.send_to_user(self.user.id, {
'type': 'analysis_progress',
'session_id': session.id,
'message_type': message_type,
'content': content,
'agent': agent,
'progress': progress_percent,
})
trading_graph = TradingAgentsGraph(
config=config,
selected_analysts=session.analysts_selected,
)
input_data = {
'company_of_interest': session.ticker,
'trade_date': session.analysis_date.strftime('%Y-%m-%d'),
}
final_state, result = await asyncio.to_thread(
trading_graph.propagate,
input_data['company_of_interest'],
input_data['trade_date']
)
session.status = AnalysisStatus.COMPLETED
session.completed_at = datetime.datetime.utcnow()
session.final_report = json.dumps(final_state) # Store full state as JSON
self.db.add(session)
self.db.commit()
await self.websocket_manager.send_to_user(
self.user.id,
{
'type': 'analysis_completed',
'session_id': session.id,
'message': '분석이 완료되었습니다.',
'result': result
}
)
except Exception as e:
session.status = AnalysisStatus.FAILED
session.error_message = str(e)
self.db.add(session)
self.db.commit()
await self.websocket_manager.send_to_user(
self.user.id,
{
'type': 'analysis_failed',
'session_id': session.id,
'message': f'분석 중 오류가 발생했습니다: {str(e)}'
}
)
def create_session(self, *, analysis_in: AnalysisSessionCreate) -> AnalysisSession:
session = AnalysisSession(
**analysis_in.dict(),
user_id=self.user.id,
analysis_date=datetime.date.today()
)
self.db.add(session)
self.db.commit()
self.db.refresh(session)
return session
def get_session(self, *, session_id: int) -> Optional[AnalysisSession]:
statement = select(AnalysisSession).where(AnalysisSession.id == session_id, AnalysisSession.user_id == self.user.id)
return self.db.exec(statement).first()
def get_user_sessions(self, *, skip: int = 0, limit: int = 100) -> List[AnalysisSession]:
statement = select(AnalysisSession).where(AnalysisSession.user_id == self.user.id).order_by(AnalysisSession.created_at.desc()).offset(skip).limit(limit)
return self.db.exec(statement).all()

View File

@ -1,23 +1,23 @@
from typing import Dict, List
from fastapi import WebSocket
class WebSocketManager:
def __init__(self):
self.active_connections: Dict[int, List[WebSocket]] = {}
async def connect(self, user_id: int, websocket: WebSocket):
await websocket.accept()
if user_id not in self.active_connections:
self.active_connections[user_id] = []
self.active_connections[user_id].append(websocket)
def disconnect(self, user_id: int, websocket: WebSocket):
if user_id in self.active_connections:
self.active_connections[user_id].remove(websocket)
if not self.active_connections[user_id]:
del self.active_connections[user_id]
async def send_to_user(self, user_id: int, message: dict):
if user_id in self.active_connections:
for connection in self.active_connections[user_id]:
await connection.send_json(message)
from typing import Dict, List
from fastapi import WebSocket
class WebSocketManager:
def __init__(self):
self.active_connections: Dict[int, List[WebSocket]] = {}
async def connect(self, user_id: int, websocket: WebSocket):
await websocket.accept()
if user_id not in self.active_connections:
self.active_connections[user_id] = []
self.active_connections[user_id].append(websocket)
def disconnect(self, user_id: int, websocket: WebSocket):
if user_id in self.active_connections:
self.active_connections[user_id].remove(websocket)
if not self.active_connections[user_id]:
del self.active_connections[user_id]
async def send_to_user(self, user_id: int, message: dict):
if user_id in self.active_connections:
for connection in self.active_connections[user_id]:
await connection.send_json(message)

View File

@ -1,56 +1,56 @@
from datetime import date, datetime
from typing import List, Optional
from sqlmodel import Field, SQLModel, JSON, Column
import enum
class User(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
email: str = Field(unique=True, index=True)
username: str = Field(unique=True, index=True)
hashed_password: str
first_name: Optional[str] = None
last_name: Optional[str] = None
is_active: bool = Field(default=True)
is_superuser: bool = Field(default=False)
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow, sa_column_kwargs={"onupdate": datetime.utcnow})
class UserProfile(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
user_id: int = Field(foreign_key="user.id", unique=True)
encrypted_openai_api_key: Optional[str] = None
default_ticker: str = Field(default="SPY")
preferred_research_depth: int = Field(default=3)
preferred_shallow_thinker: str = Field(default="gpt-4o-mini")
preferred_deep_thinker: str = Field(default="gpt-4o")
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow, sa_column_kwargs={"onupdate": datetime.utcnow})
class AnalysisStatus(str, enum.Enum):
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
class AnalysisSession(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
user_id: int = Field(foreign_key="user.id")
ticker: str
analysis_date: date
analysts_selected: List[str] = Field(sa_column=Column(JSON))
research_depth: int
llm_provider: str
backend_url: str
shallow_thinker: str
deep_thinker: str
status: AnalysisStatus = Field(default=AnalysisStatus.PENDING)
final_report: Optional[str] = None
error_message: Optional[str] = None
created_at: datetime = Field(default_factory=datetime.utcnow)
started_at: Optional[datetime] = None
from datetime import date, datetime
from typing import List, Optional
from sqlmodel import Field, SQLModel, JSON, Column
import enum
class User(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
email: str = Field(unique=True, index=True)
username: str = Field(unique=True, index=True)
hashed_password: str
first_name: Optional[str] = None
last_name: Optional[str] = None
is_active: bool = Field(default=True)
is_superuser: bool = Field(default=False)
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow, sa_column_kwargs={"onupdate": datetime.utcnow})
class UserProfile(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
user_id: int = Field(foreign_key="user.id", unique=True)
encrypted_openai_api_key: Optional[str] = None
default_ticker: str = Field(default="SPY")
preferred_research_depth: int = Field(default=3)
preferred_shallow_thinker: str = Field(default="gpt-4o-mini")
preferred_deep_thinker: str = Field(default="gpt-4o")
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow, sa_column_kwargs={"onupdate": datetime.utcnow})
class AnalysisStatus(str, enum.Enum):
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
class AnalysisSession(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
user_id: int = Field(foreign_key="user.id")
ticker: str
analysis_date: date
analysts_selected: List[str] = Field(sa_column=Column(JSON))
research_depth: int
llm_provider: str
backend_url: str
shallow_thinker: str
deep_thinker: str
status: AnalysisStatus = Field(default=AnalysisStatus.PENDING)
final_report: Optional[str] = None
error_message: Optional[str] = None
created_at: datetime = Field(default_factory=datetime.utcnow)
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None

View File

@ -1,48 +1,48 @@
from abc import ABC, abstractmethod
from typing import Generic, TypeVar, Optional, List
from sqlmodel import SQLModel
from app.core.schemas.user import UserCreate, UserUpdate
from app.domain.models import User
ModelType = TypeVar("ModelType", bound=SQLModel)
CreateSchemaType = TypeVar("CreateSchemaType", bound=SQLModel)
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=SQLModel)
class IRepository(Generic[ModelType], ABC):
@abstractmethod
def get(self, id: int) -> Optional[ModelType]:
pass
@abstractmethod
def get_multi(self, *, skip: int = 0, limit: int = 100) -> List[ModelType]:
pass
@abstractmethod
def create(self, *, obj_in: CreateSchemaType) -> ModelType:
pass
@abstractmethod
def update(self, *, db_obj: ModelType, obj_in: UpdateSchemaType) -> ModelType:
pass
@abstractmethod
def remove(self, *, id: int) -> ModelType:
pass
class IUserRepository(IRepository[User], ABC):
@abstractmethod
def get_by_email(self, *, email: str) -> Optional[User]:
pass
@abstractmethod
def create(self, *, obj_in: UserCreate) -> User:
pass
@abstractmethod
def update(self, *, db_obj: User, obj_in: UserUpdate) -> User:
pass
@abstractmethod
def is_superuser(self, *, user: User) -> bool:
pass
from abc import ABC, abstractmethod
from typing import Generic, TypeVar, Optional, List
from sqlmodel import SQLModel
from app.core.schemas.user import UserCreate, UserUpdate
from app.domain.models import User
ModelType = TypeVar("ModelType", bound=SQLModel)
CreateSchemaType = TypeVar("CreateSchemaType", bound=SQLModel)
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=SQLModel)
class IRepository(Generic[ModelType], ABC):
@abstractmethod
def get(self, id: int) -> Optional[ModelType]:
pass
@abstractmethod
def get_multi(self, *, skip: int = 0, limit: int = 100) -> List[ModelType]:
pass
@abstractmethod
def create(self, *, obj_in: CreateSchemaType) -> ModelType:
pass
@abstractmethod
def update(self, *, db_obj: ModelType, obj_in: UpdateSchemaType) -> ModelType:
pass
@abstractmethod
def remove(self, *, id: int) -> ModelType:
pass
class IUserRepository(IRepository[User], ABC):
@abstractmethod
def get_by_email(self, *, email: str) -> Optional[User]:
pass
@abstractmethod
def create(self, *, obj_in: UserCreate) -> User:
pass
@abstractmethod
def update(self, *, db_obj: User, obj_in: UserUpdate) -> User:
pass
@abstractmethod
def is_superuser(self, *, user: User) -> bool:
pass

View File

@ -1,9 +1,9 @@
from sqlmodel import create_engine, Session
from app.core.config import settings
engine = create_engine(settings.DATABASE_URL, echo=True, connect_args={"check_same_thread": False})
def get_db():
with Session(engine) as session:
yield session
from sqlmodel import create_engine, Session
from app.core.config import settings
engine = create_engine(settings.DATABASE_URL, echo=True, connect_args={"check_same_thread": False})
def get_db():
with Session(engine) as session:
yield session

View File

@ -1,53 +1,53 @@
from typing import Optional
from sqlmodel import Session, select
from app.domain.models import User
from app.core.schemas.user import UserCreate, UserUpdate
from app.domain.repositories import IUserRepository
from app.core.security import get_password_hash
class UserRepository(IUserRepository):
def __init__(self, db: Session):
self.db = db
def get(self, id: int) -> Optional[User]:
return self.db.get(User, id)
def get_by_email(self, *, email: str) -> Optional[User]:
statement = select(User).where(User.email == email)
return self.db.exec(statement).first()
def get_multi(self, *, skip: int = 0, limit: int = 100) -> list[User]:
statement = select(User).offset(skip).limit(limit)
return self.db.exec(statement).all()
def create(self, *, obj_in: UserCreate) -> User:
db_obj = User(
email=obj_in.email,
username=obj_in.username,
hashed_password=get_password_hash(obj_in.password),
first_name=obj_in.first_name,
last_name=obj_in.last_name,
)
self.db.add(db_obj)
self.db.commit()
self.db.refresh(db_obj)
return db_obj
def update(self, *, db_obj: User, obj_in: UserUpdate) -> User:
update_data = obj_in.dict(exclude_unset=True)
for field, value in update_data.items():
setattr(db_obj, field, value)
self.db.add(db_obj)
self.db.commit()
self.db.refresh(db_obj)
return db_obj
def remove(self, *, id: int) -> User:
db_obj = self.db.get(User, id)
self.db.delete(db_obj)
self.db.commit()
return db_obj
def is_superuser(self, *, user: User) -> bool:
return user.is_superuser
from typing import Optional
from sqlmodel import Session, select
from app.domain.models import User
from app.core.schemas.user import UserCreate, UserUpdate
from app.domain.repositories import IUserRepository
from app.core.security import get_password_hash
class UserRepository(IUserRepository):
def __init__(self, db: Session):
self.db = db
def get(self, id: int) -> Optional[User]:
return self.db.get(User, id)
def get_by_email(self, *, email: str) -> Optional[User]:
statement = select(User).where(User.email == email)
return self.db.exec(statement).first()
def get_multi(self, *, skip: int = 0, limit: int = 100) -> list[User]:
statement = select(User).offset(skip).limit(limit)
return self.db.exec(statement).all()
def create(self, *, obj_in: UserCreate) -> User:
db_obj = User(
email=obj_in.email,
username=obj_in.username,
hashed_password=get_password_hash(obj_in.password),
first_name=obj_in.first_name,
last_name=obj_in.last_name,
)
self.db.add(db_obj)
self.db.commit()
self.db.refresh(db_obj)
return db_obj
def update(self, *, db_obj: User, obj_in: UserUpdate) -> User:
update_data = obj_in.dict(exclude_unset=True)
for field, value in update_data.items():
setattr(db_obj, field, value)
self.db.add(db_obj)
self.db.commit()
self.db.refresh(db_obj)
return db_obj
def remove(self, *, id: int) -> User:
db_obj = self.db.get(User, id)
self.db.delete(db_obj)
self.db.commit()
return db_obj
def is_superuser(self, *, user: User) -> bool:
return user.is_superuser

View File

@ -1,36 +1,36 @@
import sys
import os
from fastapi import FastAPI
from starlette.middleware.cors import CORSMiddleware
# Add project root to path to allow importing tradingagents
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")))
from app.api.router import api_router
from app.core.config import settings
from app.infrastructure.database import engine
from sqlmodel import SQLModel
def create_tables():
SQLModel.metadata.create_all(engine)
app = FastAPI(
title=settings.PROJECT_NAME,
openapi_url=f"{settings.API_V1_STR}/openapi.json"
)
@app.on_event("startup")
def on_startup():
create_tables()
# Set all CORS enabled origins
if settings.CORS_ALLOWED_ORIGINS:
app.add_middleware(
CORSMiddleware,
allow_origins=[str(origin) for origin in settings.CORS_ALLOWED_ORIGINS],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
import sys
import os
from fastapi import FastAPI
from starlette.middleware.cors import CORSMiddleware
# Add project root to path to allow importing tradingagents
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")))
from app.api.router import api_router
from app.core.config import settings
from app.infrastructure.database import engine
from sqlmodel import SQLModel
def create_tables():
SQLModel.metadata.create_all(engine)
app = FastAPI(
title=settings.PROJECT_NAME,
openapi_url=f"{settings.API_V1_STR}/openapi.json"
)
@app.on_event("startup")
def on_startup():
create_tables()
# Set all CORS enabled origins
if settings.CORS_ALLOWED_ORIGINS:
app.add_middleware(
CORSMiddleware,
allow_origins=[str(origin) for origin in settings.CORS_ALLOWED_ORIGINS],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.include_router(api_router, prefix=settings.API_V1_STR)

2
backend/.gitignore vendored
View File

@ -1,2 +1,2 @@
.env
.env
wallet/

View File

@ -1,247 +1,299 @@
from sqlmodel import Session
from analysis.domain.repository.analysis_repo import IAnalysisRepository
from ulid import ULID
from analysis.domain.analysis import Analysis as AnalysisVO
from analysis.interface.dto import TradingAnalysisRequest, AnalysisProgressUpdate
from fastapi import HTTPException, status, BackgroundTasks
import asyncio
from datetime import datetime
from tradingagents.graph.trading_graph import TradingAgentsGraph
from tradingagents.default_config import DEFAULT_CONFIG
class AnalysisService:
def __init__(
self,
analysis_repo: IAnalysisRepository,
session: Session,
ulid: ULID
):
self.analysis_repo = analysis_repo
self.session = session
self.ulid = ulid
def get_analysis_list(
self,
member_id: str
) -> list[AnalysisVO]:
analyses = self.analysis_repo.find_by_member_id(member_id)
if not analyses:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Analysis not found")
return analyses
def get_analysis_by_id(
self,
analysis_id: str,
member_id: str
) -> AnalysisVO:
analysis = self.analysis_repo.find_by_id(analysis_id)
if not analysis:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Analysis not found")
if analysis.member_id != member_id:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Access denied")
return analysis
def create_analysis(
self,
member_id: str,
request: TradingAnalysisRequest,
background_tasks: BackgroundTasks
) -> AnalysisVO:
# 분석 요청 생성
analysis_id = self.ulid.generate()
now = datetime.now()
analysis_vo = AnalysisVO(
id=analysis_id,
member_id=member_id,
ticker=request.ticker,
analysis_date=request.analysis_date,
analysts_selected=[analyst.value for analyst in request.analysts],
research_depth=request.research_depth,
llm_provider=request.llm_provider,
backend_url=request.backend_url,
shallow_thinker=request.shallow_thinker,
deep_thinker=request.deep_thinker,
status="pending",
created_at=now,
updated_at=now
)
saved_analysis = self.analysis_repo.save(analysis_vo)
self.session.commit()
# 백그라운드에서 분석 실행
background_tasks.add_task(self._run_analysis, saved_analysis.id)
return saved_analysis
async def _run_analysis(self, analysis_id: str):
"""백그라운드에서 실제 분석을 실행하는 메서드"""
try:
# 분석 상태를 RUNNING으로 변경
analysis = self.analysis_repo.find_by_id(analysis_id)
if analysis:
analysis.status = "running"
analysis.updated_at = datetime.now()
self.analysis_repo.update(analysis)
self.session.commit()
# 분석 정보 조회
analysis = self.analysis_repo.find_by_id(analysis_id)
if not analysis:
return
# TradingAgentsGraph 설정 및 실행
config = self._create_config(analysis)
# 분석 실행 (실제 구현)
await self._execute_trading_analysis(analysis_id, analysis, config)
# 분석 완료 상태로 변경
analysis = self.analysis_repo.find_by_id(analysis_id)
if analysis:
analysis.status = "completed"
analysis.completed_at = datetime.now()
analysis.updated_at = datetime.now()
self.analysis_repo.update(analysis)
self.session.commit()
except Exception as e:
# 에러 발생 시 실패 상태로 변경
analysis = self.analysis_repo.find_by_id(analysis_id)
if analysis:
analysis.status = "failed"
analysis.error_message = str(e)
analysis.completed_at = datetime.now()
analysis.updated_at = datetime.now()
self.analysis_repo.update(analysis)
self.session.commit()
def _create_config(self, analysis: AnalysisVO) -> dict:
"""분석 설정을 생성하는 메서드"""
config = DEFAULT_CONFIG.copy() if DEFAULT_CONFIG else {}
config.update({
"max_debate_rounds": analysis.research_depth,
"max_risk_discuss_rounds": analysis.research_depth,
"quick_think_llm": analysis.shallow_thinker,
"deep_think_llm": analysis.deep_thinker,
"backend_url": analysis.backend_url,
"llm_provider": analysis.llm_provider.lower(),
})
return config
async def _execute_trading_analysis(self, analysis_id: str, analysis: AnalysisVO, config: dict):
"""실제 TradingAgentsGraph를 실행하는 메서드"""
try:
# TradingAgentsGraph 초기화
graph = TradingAgentsGraph(
analysis.analysts_selected,
config=config,
debug=True
)
# 초기 상태 생성
init_agent_state = graph.propagator.create_initial_state(
analysis.ticker,
analysis.analysis_date
)
args = graph.propagator.get_graph_args()
# 분석 실행 및 결과 처리
trace = []
async for chunk in graph.graph.astream(init_agent_state, **args):
trace.append(chunk)
# 실시간으로 분석 결과 업데이트
await self._process_analysis_chunk(analysis_id, chunk)
# 최종 결과 처리
if trace:
final_state = trace[-1]
final_decision = graph.process_signal(final_state.get("final_trade_decision", ""))
# 최종 보고서 생성
final_report = self._generate_final_report(final_state)
# 최종 결과 저장
self.analysis_repo.update(analysis_id, {
"final_trade_decision": final_decision,
"final_report": final_report
})
self.session.commit()
except Exception as e:
raise Exception(f"Analysis execution failed: {str(e)}")
async def _process_analysis_chunk(self, analysis_id: str, chunk: dict):
"""분석 중간 결과를 처리하고 저장하는 메서드"""
updates = {}
# 개별 분석가 보고서 업데이트
if "market_report" in chunk and chunk["market_report"]:
updates["market_report"] = chunk["market_report"]
if "sentiment_report" in chunk and chunk["sentiment_report"]:
updates["sentiment_report"] = chunk["sentiment_report"]
if "news_report" in chunk and chunk["news_report"]:
updates["news_report"] = chunk["news_report"]
if "fundamentals_report" in chunk and chunk["fundamentals_report"]:
updates["fundamentals_report"] = chunk["fundamentals_report"]
# 팀별 의사결정 과정 업데이트
if "investment_debate_state" in chunk and chunk["investment_debate_state"]:
updates["investment_debate_state"] = chunk["investment_debate_state"]
if "trader_investment_plan" in chunk and chunk["trader_investment_plan"]:
updates["trader_investment_plan"] = chunk["trader_investment_plan"]
if "risk_debate_state" in chunk and chunk["risk_debate_state"]:
updates["risk_debate_state"] = chunk["risk_debate_state"]
# 업데이트가 있는 경우 저장
if updates:
self.analysis_repo.update(analysis_id, updates)
self.session.commit()
def _generate_final_report(self, final_state: dict) -> str:
"""최종 통합 보고서를 생성하는 메서드"""
report_parts = []
# Analyst Team Reports
if any(final_state.get(section) for section in ["market_report", "sentiment_report", "news_report", "fundamentals_report"]):
report_parts.append("## Analyst Team Reports")
if final_state.get("market_report"):
report_parts.append(f"### Market Analysis\n{final_state['market_report']}")
if final_state.get("sentiment_report"):
report_parts.append(f"### Social Sentiment\n{final_state['sentiment_report']}")
if final_state.get("news_report"):
report_parts.append(f"### News Analysis\n{final_state['news_report']}")
if final_state.get("fundamentals_report"):
report_parts.append(f"### Fundamentals Analysis\n{final_state['fundamentals_report']}")
# Research Team Reports
if final_state.get("investment_debate_state"):
report_parts.append("## Research Team Decision")
debate_state = final_state["investment_debate_state"]
if debate_state.get("judge_decision"):
report_parts.append(f"{debate_state['judge_decision']}")
# Trading Team Reports
if final_state.get("trader_investment_plan"):
report_parts.append("## Trading Team Plan")
report_parts.append(f"{final_state['trader_investment_plan']}")
# Portfolio Management Decision
if final_state.get("risk_debate_state") and final_state["risk_debate_state"].get("judge_decision"):
report_parts.append("## Portfolio Management Decision")
report_parts.append(f"{final_state['risk_debate_state']['judge_decision']}")
return "\n\n".join(report_parts) if report_parts else "No analysis results available."
import sys
import os
sys.path.append(os.path.join(os.path.dirname(__file__), '../../..'))
import logging
from sqlmodel import Session
from analysis.domain.repository.analysis_repo import IAnalysisRepository
from ulid import ULID
from analysis.domain.analysis import Analysis as AnalysisVO
from analysis.interface.dto import TradingAnalysisRequest, AnalysisProgressUpdate
from fastapi import HTTPException, status, BackgroundTasks
import asyncio
from datetime import datetime
from tradingagents.graph.trading_graph import TradingAgentsGraph
from tradingagents.default_config import DEFAULT_CONFIG
from analysis.application.websocket_manager import WebSocketManager
from analysis.infra.db_models.analysis import AnalysisStatus
logger = logging.getLogger(__name__)
class AnalysisService:
def __init__(
self,
analysis_repo: IAnalysisRepository,
session: Session,
ulid: ULID,
websocket_manager: WebSocketManager
):
self.analysis_repo = analysis_repo
self.session = session
self.ulid = ulid
self.websocket_manager = websocket_manager
def get_analysis_list(
self,
member_id: str
) -> list[AnalysisVO]:
analyses = self.analysis_repo.find_by_member_id(member_id)
if not analyses:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Analysis not found")
return analyses
def get_analysis_by_id(
self,
analysis_id: str,
member_id: str
) -> AnalysisVO:
analysis = self.analysis_repo.find_by_id(analysis_id)
if not analysis:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Analysis not found")
if analysis.member_id != member_id:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Access denied")
return analysis
def get_analysis_sessions_by_member(
self,
member_id: str
) -> list[AnalysisVO]:
analyses = self.analysis_repo.find_by_member_id(member_id)
if not analyses:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Analysis not found")
return analyses
def create_analysis(
self,
member_id: str,
request: TradingAnalysisRequest,
background_tasks: BackgroundTasks
) -> AnalysisVO:
# 분석 요청 생성
analysis_id = self.ulid.generate()
now = datetime.now()
analysis_vo = AnalysisVO(
id=analysis_id,
member_id=member_id,
ticker=request.ticker,
analysis_date=request.analysis_date,
analysts_selected=[analyst.value for analyst in request.analysts],
research_depth=request.research_depth,
llm_provider=request.llm_provider,
backend_url=request.backend_url,
shallow_thinker=request.shallow_thinker,
deep_thinker=request.deep_thinker,
status=AnalysisStatus.PENDING,
created_at=now,
updated_at=now
)
saved_analysis = self.analysis_repo.save(analysis_vo)
if not saved_analysis:
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to save analysis")
self.session.commit()
# Register analysis with websocket manager
self.websocket_manager.register_analysis(saved_analysis.id, member_id)
# 백그라운드에서 분석 실행
background_tasks.add_task(self._run_analysis, saved_analysis.id)
return saved_analysis
async def _run_analysis(self, analysis_id: str):
"""백그라운드에서 실제 분석을 실행하는 메서드"""
try:
analysis = AnalysisVO(
id=analysis_id,
status=AnalysisStatus.RUNNING,
updated_at=datetime.now()
)
analysis = self.analysis_repo.update(analysis)
if not analysis:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Analysis not found")
await self.websocket_manager.send_analysis_update(
analysis_id=analysis_id,
update_type="status_changed",
data={"status": "running", "message": "Analysis started"}
)
# TradingAgentsGraph 설정 및 실행
if analysis:
config = self._create_config(analysis)
# 분석 실행 (실제 구현)
await self._execute_trading_analysis(analysis_id, analysis, config)
# 완료 상태로 업데이트
completed_analysis = AnalysisVO(
id=analysis_id,
status=AnalysisStatus.COMPLETED,
completed_at=datetime.now(),
updated_at=datetime.now()
)
self.analysis_repo.update(completed_analysis)
self.session.commit()
except Exception as e:
now = datetime.now()
updates = AnalysisVO(
status=AnalysisStatus.FAILED,
error_message=str(e),
completed_at = now,
updated_at = now
)
self.analysis_repo.update(updates)
self.session.commit()
def _create_config(self, analysis: AnalysisVO) -> dict:
"""분석 설정을 생성하는 메서드"""
config = {}
config.update({
"max_debate_rounds": analysis.research_depth,
"max_risk_discuss_rounds": analysis.research_depth,
"quick_think_llm": analysis.shallow_thinker,
"deep_think_llm": analysis.deep_thinker,
"backend_url": analysis.backend_url,
"llm_provider": analysis.llm_provider.lower(),
})
return config
async def _execute_trading_analysis(self, analysis_id: str, analysis: AnalysisVO, config: dict):
"""실제 TradingAgentsGraph를 실행하는 메서드"""
try:
logger.info(f"Starting trading analysis for {analysis_id} with ticker {analysis.ticker}")
logger.info(f"Analysts selected: {analysis.analysts_selected}")
logger.info(f"Config: {config}")
# TradingAgentsGraph 초기화
graph = TradingAgentsGraph(
analysis.analysts_selected,
config=config,
debug=True
)
logger.info("TradingAgentsGraph initialized successfully")
# 초기 상태 생성
init_agent_state = graph.propagator.create_initial_state(
analysis.ticker,
analysis.analysis_date
)
args = graph.propagator.get_graph_args()
# 분석 실행 및 결과 처리
logger.info("Starting graph execution...")
trace = []
chunk_count = 0
async for chunk in graph.graph.astream(init_agent_state, **args):
chunk_count += 1
logger.info(f"Processing chunk {chunk_count}: {list(chunk.keys()) if chunk else 'Empty chunk'}")
trace.append(chunk)
# 실시간으로 분석 결과 업데이트
await self._process_analysis_chunk(analysis_id, chunk)
# 최종 결과 처리
if trace:
final_state = trace[-1]
final_decision = graph.process_signal(final_state.get("final_trade_decision", ""))
# 최종 보고서 생성
final_report = self._generate_final_report(final_state)
analysis.final_trade_decision = final_decision
analysis.final_report = final_report
# 최종 결과 저장
updates = AnalysisVO(
id=analysis_id,
final_trade_decision=final_decision,
final_report=final_report
)
self.analysis_repo.update(updates)
self.session.commit()
except Exception as e:
raise Exception(f"Analysis execution failed: {str(e)}")
async def _process_analysis_chunk(self, analysis_id: str, chunk: dict):
"""분석 중간 결과를 처리하고 저장하는 메서드"""
updates = {}
# 개별 분석가 보고서 업데이트
if "market_report" in chunk and chunk["market_report"]:
updates["market_report"] = chunk["market_report"]
if "sentiment_report" in chunk and chunk["sentiment_report"]:
updates["sentiment_report"] = chunk["sentiment_report"]
if "news_report" in chunk and chunk["news_report"]:
updates["news_report"] = chunk["news_report"]
if "fundamentals_report" in chunk and chunk["fundamentals_report"]:
updates["fundamentals_report"] = chunk["fundamentals_report"]
# 팀별 의사결정 과정 업데이트
if "investment_debate_state" in chunk and chunk["investment_debate_state"]:
updates["investment_debate_state"] = chunk["investment_debate_state"]
if "trader_investment_plan" in chunk and chunk["trader_investment_plan"]:
updates["trader_investment_plan"] = chunk["trader_investment_plan"]
if "risk_debate_state" in chunk and chunk["risk_debate_state"]:
updates["risk_debate_state"] = chunk["risk_debate_state"]
# 업데이트가 있는 경우 저장
if updates:
# analysis_id를 포함한 AnalysisVO 객체 생성
updates["id"] = analysis_id
updates_vo = AnalysisVO(**updates)
self.analysis_repo.update(updates_vo)
self.session.commit()
def _generate_final_report(self, final_state: dict) -> str:
"""최종 통합 보고서를 생성하는 메서드"""
report_parts = []
# Analyst Team Reports
if any(final_state.get(section) for section in ["market_report", "sentiment_report", "news_report", "fundamentals_report"]):
report_parts.append("## Analyst Team Reports")
if final_state.get("market_report"):
report_parts.append(f"### Market Analysis\n{final_state['market_report']}")
if final_state.get("sentiment_report"):
report_parts.append(f"### Social Sentiment\n{final_state['sentiment_report']}")
if final_state.get("news_report"):
report_parts.append(f"### News Analysis\n{final_state['news_report']}")
if final_state.get("fundamentals_report"):
report_parts.append(f"### Fundamentals Analysis\n{final_state['fundamentals_report']}")
# Research Team Reports
if final_state.get("investment_debate_state"):
report_parts.append("## Research Team Decision")
debate_state = final_state["investment_debate_state"]
if debate_state.get("judge_decision"):
report_parts.append(f"{debate_state['judge_decision']}")
# Trading Team Reports
if final_state.get("trader_investment_plan"):
report_parts.append("## Trading Team Plan")
report_parts.append(f"{final_state['trader_investment_plan']}")
# Portfolio Management Decision
if final_state.get("risk_debate_state") and final_state["risk_debate_state"].get("judge_decision"):
report_parts.append("## Portfolio Management Decision")
report_parts.append(f"{final_state['risk_debate_state']['judge_decision']}")
return "\n\n".join(report_parts) if report_parts else "No analysis results available."

View File

@ -0,0 +1,65 @@
from typing import Dict, Set
from fastapi import WebSocket
import json
from datetime import datetime
class WebSocketManager:
def __init__(self):
# Store active connections by member_id
self.active_connections: Dict[str, Set[WebSocket]] = {}
# Store analysis_id to member_id mapping
self.analysis_member_map: Dict[str, str] = {}
async def connect(self, websocket: WebSocket, member_id: str):
await websocket.accept()
if member_id not in self.active_connections:
self.active_connections[member_id] = set()
self.active_connections[member_id].add(websocket)
def disconnect(self, websocket: WebSocket, member_id: str):
if member_id in self.active_connections:
self.active_connections[member_id].discard(websocket)
if not self.active_connections[member_id]:
del self.active_connections[member_id]
def register_analysis(self, analysis_id: str, member_id: str):
"""Register which member owns which analysis"""
self.analysis_member_map[analysis_id] = member_id
async def send_analysis_update(self, analysis_id: str, update_type: str, data: dict):
"""Send analysis update to the member who owns the analysis"""
member_id = self.analysis_member_map.get(analysis_id)
if not member_id:
return
message = {
"type": "analysis_update",
"analysis_id": analysis_id,
"update_type": update_type,
"data": data,
"timestamp": datetime.now().isoformat()
}
await self.send_to_member(member_id, message)
async def send_to_member(self, member_id: str, message: dict|str):
"""Send message to all connections of a specific member"""
if member_id not in self.active_connections:
return
dead_connections = set()
for connection in self.active_connections[member_id]:
try:
if isinstance(message, dict):
await connection.send_json(message)
else:
await connection.send_text(message)
except Exception:
dead_connections.add(connection)
# Clean up dead connections
for connection in dead_connections:
self.disconnect(connection, member_id)

View File

@ -1,19 +1,20 @@
from pydantic import BaseModel
from datetime import datetime
from typing import List, Dict
from pydantic import BaseModel, field_validator
from datetime import datetime, date
from typing import List, Dict, Union
from analysis.infra.db_models.analysis import AnalysisStatus
class Analysis(BaseModel):
id: str | None = None
member_id: str
ticker: str
analysis_date: str
analysts_selected: List[str] = []
id: str
member_id: str | None = None
ticker: str | None = None
analysis_date: date | None = None
analysts_selected: list[str] = []
research_depth: int = 3
llm_provider: str = "openai"
backend_url: str = "https://api.openai.com/v1"
shallow_thinker: str = "gpt-4o-mini"
deep_thinker: str = "gpt-4o"
status: str
shallow_thinker: str = "gpt-4o"
deep_thinker: str = "o3"
status: AnalysisStatus = AnalysisStatus.PENDING
# 개별 분석가 리포트들
market_report: str | None = None
@ -33,5 +34,5 @@ class Analysis(BaseModel):
# 실행 결과 정보
error_message: str | None = None
completed_at: datetime | None = None
created_at: datetime
updated_at: datetime
created_at: datetime | None = None
updated_at: datetime | None = None

View File

@ -1,20 +1,20 @@
from abc import ABC, abstractmethod
from analysis.domain.analysis import Analysis as AnalysisVO
from analysis.interface.dto import TradingAnalysisRequest
class IAnalysisRepository(ABC):
@abstractmethod
def find_by_member_id(self, member_id: str) -> list[AnalysisVO] | None:
raise NotImplementedError()
@abstractmethod
def find_by_id(self, analysis_id: str) -> AnalysisVO | None:
raise NotImplementedError()
@abstractmethod
def update(self, analysis: AnalysisVO) -> AnalysisVO | None:
raise NotImplementedError()
@abstractmethod
def save(self, analysis: AnalysisVO) -> AnalysisVO:
raise NotImplementedError()
from abc import ABC, abstractmethod
from analysis.domain.analysis import Analysis as AnalysisVO
from analysis.interface.dto import TradingAnalysisRequest
class IAnalysisRepository(ABC):
@abstractmethod
def find_by_member_id(self, member_id: str) -> list[AnalysisVO] | None:
raise NotImplementedError()
@abstractmethod
def find_by_id(self, analysis_id: str) -> AnalysisVO | None:
raise NotImplementedError()
@abstractmethod
def update(self, analysis: AnalysisVO) -> AnalysisVO | None:
raise NotImplementedError()
@abstractmethod
def save(self, analysis: AnalysisVO) -> AnalysisVO:
raise NotImplementedError()

View File

@ -1,57 +1,57 @@
from datetime import datetime,date
from typing import TYPE_CHECKING
from sqlmodel import SQLModel, Field, JSON, Relationship
import enum
from sqlalchemy import Column
# TYPE_CHECKING을 사용해서 circular import 방지
if TYPE_CHECKING:
from member.infra.db_models.member import Member
class AnalysisStatus(str, enum.Enum):
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
class Analysis(SQLModel, table=True):
__tablename__ = "analyses"
id: str = Field(default=None, max_length=36, primary_key=True)
# 기본 분석 설정 정보
ticker: str
analysis_date: date
analysts_selected: list[str] = Field(sa_column=Column(JSON))
research_depth: int
llm_provider: str
backend_url: str
shallow_thinker: str
deep_thinker: str
status: AnalysisStatus = Field(default=AnalysisStatus.PENDING)
# 개별 분석가 리포트들
market_report: str | None = Field(default=None, description="Market Analyst 리포트")
sentiment_report: str | None = Field(default=None, description="Social Analyst 리포트")
news_report: str | None = Field(default=None, description="News Analyst 리포트")
fundamentals_report: str | None = Field(default=None, description="Fundamentals Analyst 리포트")
# 팀별 의사결정 과정
investment_debate_state: dict | None = Field(default=None, sa_column=Column(JSON), description="Research Team 토론 과정")
trader_investment_plan: str | None = Field(default=None, description="Trading Team 계획")
risk_debate_state: dict | None = Field(default=None, sa_column=Column(JSON), description="Risk Management Team 토론 과정")
# 최종 결과물
final_trade_decision: str | None = Field(default=None, description="최종 거래 결정")
final_report: str | None = Field(default=None, description="전체 통합 리포트")
# 실행 결과 정보
error_message: str | None = None
completed_at: datetime | None = None
created_at : datetime = Field(nullable=False)
updated_at : datetime = Field(nullable=False)
# Foreign Key와 Relationship 설정
member_id: str = Field(foreign_key="members.id")
from datetime import datetime,date
from typing import TYPE_CHECKING
from sqlmodel import SQLModel, Field, JSON, Relationship
import enum
from sqlalchemy import Column, Text
# TYPE_CHECKING을 사용해서 circular import 방지
if TYPE_CHECKING:
from member.infra.db_models.member import Member
class AnalysisStatus(str, enum.Enum):
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
class Analysis(SQLModel, table=True):
__tablename__ = "analyses"
id: str = Field(default=None, max_length=36, primary_key=True)
# 기본 분석 설정 정보
ticker: str
analysis_date: date
analysts_selected: list[str] = Field(sa_column=Column(JSON))
research_depth: int
llm_provider: str
backend_url: str
shallow_thinker: str
deep_thinker: str
status: AnalysisStatus = Field(default=AnalysisStatus.PENDING)
# 개별 분석가 리포트들
market_report: str | None = Field(default=None, sa_column=Column(Text), description="Market Analyst 리포트")
sentiment_report: str | None = Field(default=None, sa_column=Column(Text), description="Social Analyst 리포트")
news_report: str | None = Field(default=None, sa_column=Column(Text), description="News Analyst 리포트")
fundamentals_report: str | None = Field(default=None, sa_column=Column(Text), description="Fundamentals Analyst 리포트")
# 팀별 의사결정 과정
investment_debate_state: dict | None = Field(default=None, sa_column=Column(JSON), description="Research Team 토론 과정")
trader_investment_plan: str | None = Field(default=None, sa_column=Column(Text), description="Trading Team 계획")
risk_debate_state: dict | None = Field(default=None, sa_column=Column(JSON), description="Risk Management Team 토론 과정")
# 최종 결과물
final_trade_decision: str | None = Field(default=None, sa_column=Column(Text), description="최종 거래 결정")
final_report: str | None = Field(default=None, sa_column=Column(Text), description="전체 통합 리포트")
# 실행 결과 정보
error_message: str | None = Field(default=None, sa_column=Column(Text))
completed_at: datetime | None = None
created_at : datetime = Field(nullable=False)
updated_at : datetime = Field(nullable=False)
# Foreign Key와 Relationship 설정
member_id: str = Field(foreign_key="members.id")
member: "Member" = Relationship(back_populates="analyses")

View File

@ -1,80 +1,55 @@
from analysis.domain.repository.analysis_repo import IAnalysisRepository
from sqlmodel import Session, select
from analysis.domain.analysis import Analysis as AnalysisVO
from analysis.infra.db_models.analysis import Analysis, AnalysisStatus
from analysis.interface.dto import TradingAnalysisRequest
from utils.db_utils import row_to_dict
from sqlalchemy.orm import selectinload
from datetime import datetime, date
class AnalysisRepository(IAnalysisRepository):
def __init__(self, session: Session):
self.session = session
def find_by_member_id(self, member_id: str) -> list[AnalysisVO] | None:
query = select(Analysis).where(Analysis.member_id == member_id)
analyses = self.session.exec(query).all()
if not analyses:
return None
return [AnalysisVO(**row_to_dict(analysis)) for analysis in analyses]
def find_by_id(self, analysis_id: str) -> AnalysisVO | None:
analysis = self.session.get(Analysis, analysis_id)
if not analysis:
return None
return AnalysisVO(**row_to_dict(analysis))
def save(self, analysis: AnalysisVO) -> AnalysisVO:
new_analysis = Analysis(
id=analysis.id,
member_id=analysis.member_id,
ticker=analysis.ticker,
analysis_date=date.fromisoformat(analysis.analysis_date),
analysts_selected=analysis.analysts_selected,
research_depth=analysis.research_depth,
llm_provider=analysis.llm_provider,
backend_url=analysis.backend_url,
shallow_thinker=analysis.shallow_thinker,
deep_thinker=analysis.deep_thinker,
status=analysis.status,
market_report=analysis.market_report,
sentiment_report=analysis.sentiment_report,
news_report=analysis.news_report,
fundamentals_report=analysis.fundamentals_report,
investment_debate_state=analysis.investment_debate_state,
trader_investment_plan=analysis.trader_investment_plan,
risk_debate_state=analysis.risk_debate_state,
final_trade_decision=analysis.final_trade_decision,
final_report=analysis.final_report,
error_message=analysis.error_message,
completed_at=analysis.completed_at,
created_at=analysis.created_at,
updated_at=analysis.updated_at
)
self.session.add(new_analysis)
self.session.flush()
self.session.refresh(new_analysis)
analysis.id = new_analysis.id
return analysis
def update(self, analysis_vo: AnalysisVO) -> AnalysisVO | None:
analysis = self.session.get(Analysis, analysis_vo.id)
if not analysis:
return None
# AnalysisVO의 데이터를 SQLModel 객체에 업데이트
vo_data = analysis_vo.sqlmodel_dump(exclude_unset=True)
for key, value in vo_data.items():
if hasattr(analysis, key) and key != 'id': # id는 변경하지 않음
setattr(analysis, key, value)
analysis.updated_at = datetime.now()
self.session.add(analysis)
self.session.flush()
self.session.refresh(analysis)
return AnalysisVO(**row_to_dict(analysis))
from analysis.domain.repository.analysis_repo import IAnalysisRepository
from sqlmodel import Session, select
from analysis.domain.analysis import Analysis as AnalysisVO
from analysis.infra.db_models.analysis import Analysis, AnalysisStatus
from analysis.interface.dto import TradingAnalysisRequest
from utils.db_utils import row_to_dict
from sqlalchemy.orm import selectinload
from datetime import datetime, date
class AnalysisRepository(IAnalysisRepository):
def __init__(self, session: Session):
self.session = session
def find_by_member_id(self, member_id: str) -> list[AnalysisVO] | None:
query = select(Analysis).where(Analysis.member_id == member_id)
analyses = self.session.exec(query).all()
if not analyses:
return None
return [AnalysisVO(**row_to_dict(analysis)) for analysis in analyses]
def find_by_id(self, analysis_id: str) -> AnalysisVO | None:
analysis = self.session.get(Analysis, analysis_id)
if not analysis:
return None
return AnalysisVO(**row_to_dict(analysis))
def save(self, analysis: AnalysisVO) -> AnalysisVO:
new_analysis = Analysis(
**analysis.model_dump()
)
self.session.add(new_analysis)
self.session.flush()
self.session.refresh(new_analysis)
analysis.id = new_analysis.id
return analysis
def update(self, analysis_vo: AnalysisVO) -> AnalysisVO | None:
analysis = self.session.get(Analysis, analysis_vo.id)
if not analysis:
return None
# AnalysisVO의 데이터를 SQLModel 객체에 업데이트
analysis_data = analysis_vo.model_dump(exclude_unset=True)
analysis.updated_at = datetime.now()
analysis.sqlmodel_update(analysis_data)
self.session.flush()
return AnalysisVO(**row_to_dict(analysis))

View File

@ -1,108 +1,137 @@
from typing import Annotated
from fastapi import APIRouter, Depends, BackgroundTasks, HTTPException, status
from analysis.interface.dto import (
AnalysisSessionResponse,
TradingAnalysisRequest,
AnalysisResultResponse
)
from utils.auth import get_current_member, CurrentMember
from dependency_injector.wiring import inject, Provide
from analysis.application.analysis_service import AnalysisService
from utils.containers import Container
router = APIRouter(prefix="/analysis", tags=["analysis"])
@router.get("/", response_model=list[AnalysisSessionResponse])
@inject
def get_analysis_list_for_member(
current_member: Annotated[CurrentMember, Depends(get_current_member)],
analysis_service: Annotated[AnalysisService, Depends(Provide[Container.analysis_service])]
):
"""
현재 로그인한 사용자의 모든 분석 세션 목록을 조회합니다.
"""
analyses = analysis_service.get_analysis_list(current_member.id)
return [
AnalysisSessionResponse(
id=analysis.id,
ticker=analysis.ticker,
status=analysis.status
) for analysis in analyses
]
@router.post("/start", status_code=201, response_model=AnalysisSessionResponse)
@inject
def start_analysis_session(
request: TradingAnalysisRequest,
current_member: Annotated[CurrentMember, Depends(get_current_member)],
analysis_service: Annotated[AnalysisService, Depends(Provide[Container.analysis_service])],
background_tasks: BackgroundTasks
):
"""
새로운 분석 세션을 시작합니다.
"""
try:
new_analysis = analysis_service.create_analysis(current_member.id, request, background_tasks)
return AnalysisSessionResponse(
id=new_analysis.id,
ticker=new_analysis.ticker,
status=new_analysis.status
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to start analysis: {str(e)}"
)
@router.get("/{analysis_id}", response_model=AnalysisResultResponse)
@inject
def get_analysis_result(
analysis_id: str,
current_member: Annotated[CurrentMember, Depends(get_current_member)],
analysis_service: Annotated[AnalysisService, Depends(Provide[Container.analysis_service])]
):
"""
특정 분석 세션의 결과를 조회합니다.
"""
analysis = analysis_service.get_analysis_by_id(analysis_id, current_member.id)
return AnalysisResultResponse(
id=analysis.id,
ticker=analysis.ticker,
analysis_date=analysis.analysis_date,
status=analysis.status,
market_report=analysis.market_report,
sentiment_report=analysis.sentiment_report,
news_report=analysis.news_report,
fundamentals_report=analysis.fundamentals_report,
investment_debate_state=analysis.investment_debate_state,
trader_investment_plan=analysis.trader_investment_plan,
risk_debate_state=analysis.risk_debate_state,
final_trade_decision=analysis.final_trade_decision,
final_report=analysis.final_report,
created_at=analysis.created_at.isoformat(),
completed_at=analysis.completed_at.isoformat() if analysis.completed_at else None,
error_message=analysis.error_message
)
@router.get("/{analysis_id}/status")
@inject
def get_analysis_status(
analysis_id: str,
current_member: Annotated[CurrentMember, Depends(get_current_member)],
analysis_service: Annotated[AnalysisService, Depends(Provide[Container.analysis_service])]
):
"""
분석 진행 상황을 조회합니다.
"""
analysis = analysis_service.get_analysis_by_id(analysis_id, current_member.id)
return {
"analysis_id": analysis.id,
"status": analysis.status,
"ticker": analysis.ticker,
"analysis_date": analysis.analysis_date,
"created_at": analysis.created_at.isoformat(),
"updated_at": analysis.updated_at.isoformat(),
"error_message": analysis.error_message
}
from typing import Annotated
from fastapi import APIRouter, Depends, BackgroundTasks, HTTPException, status, WebSocket, WebSocketDisconnect
from analysis.interface.dto import (
AnalysisSessionResponse,
TradingAnalysisRequest,
AnalysisResultResponse
)
from utils.auth import get_current_member, CurrentMember
from dependency_injector.wiring import inject, Provide
from analysis.application.analysis_service import AnalysisService
from utils.containers import Container
from analysis.application.websocket_manager import WebSocketManager
router = APIRouter(prefix="/analysis", tags=["analysis"])
@router.get("/", response_model=list[AnalysisSessionResponse])
@inject
def get_analysis_list_for_member(
current_member: Annotated[CurrentMember, Depends(get_current_member)],
analysis_service: Annotated[AnalysisService, Depends(Provide[Container.analysis_service])]
):
"""
현재 로그인한 사용자의 모든 분석 세션 목록을 조회합니다.
"""
analyses = analysis_service.get_analysis_list(current_member.id)
return [
AnalysisSessionResponse(
id=analysis.id,
ticker=analysis.ticker,
status=analysis.status
) for analysis in analyses
]
@router.post("/start", status_code=201, response_model=AnalysisSessionResponse)
@inject
def start_analysis_session(
request: TradingAnalysisRequest,
current_member: Annotated[CurrentMember, Depends(get_current_member)],
analysis_service: Annotated[AnalysisService, Depends(Provide[Container.analysis_service])],
background_tasks: BackgroundTasks
):
"""
새로운 분석 세션을 시작합니다.
"""
try:
new_analysis = analysis_service.create_analysis(current_member.id, request, background_tasks)
return AnalysisSessionResponse(
id=new_analysis.id,
ticker=new_analysis.ticker,
status=new_analysis.status
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to start analysis: {str(e)}"
)
@router.get("/{analysis_id}", response_model=AnalysisResultResponse)
@inject
def get_analysis_result(
analysis_id: str,
current_member: Annotated[CurrentMember, Depends(get_current_member)],
analysis_service: Annotated[AnalysisService, Depends(Provide[Container.analysis_service])]
):
"""
특정 분석 세션의 결과를 조회합니다.
"""
analysis = analysis_service.get_analysis_by_id(analysis_id, current_member.id)
return AnalysisResultResponse(
id=analysis.id,
ticker=analysis.ticker,
analysis_date=analysis.analysis_date.isoformat() if hasattr(analysis.analysis_date, 'isoformat') else str(analysis.analysis_date),
status=analysis.status,
market_report=analysis.market_report,
sentiment_report=analysis.sentiment_report,
news_report=analysis.news_report,
fundamentals_report=analysis.fundamentals_report,
investment_debate_state=analysis.investment_debate_state,
trader_investment_plan=analysis.trader_investment_plan,
risk_debate_state=analysis.risk_debate_state,
final_trade_decision=analysis.final_trade_decision,
final_report=analysis.final_report,
created_at=analysis.created_at.isoformat(),
completed_at=analysis.completed_at.isoformat() if analysis.completed_at else None,
error_message=analysis.error_message
)
@router.get("/{analysis_id}/status")
@inject
def get_analysis_status(
analysis_id: str,
current_member: Annotated[CurrentMember, Depends(get_current_member)],
analysis_service: Annotated[AnalysisService, Depends(Provide[Container.analysis_service])]
):
"""
분석 진행 상황을 조회합니다.
"""
analysis = analysis_service.get_analysis_by_id(analysis_id, current_member.id)
return {
"analysis_id": analysis.id,
"status": analysis.status,
"ticker": analysis.ticker,
"analysis_date": analysis.analysis_date,
"created_at": analysis.created_at.isoformat(),
"updated_at": analysis.updated_at.isoformat(),
"error_message": analysis.error_message
}
@router.websocket("/ws")
@inject
async def websocket_endpoint(
websocket: WebSocket,
current_member: Annotated[CurrentMember, Depends(get_current_member)],
websocket_manager: Annotated[WebSocketManager, Depends(Provide[Container.websocket_manager])]
):
"""
WebSocket endpoint for real-time analysis updates
"""
try:
# Connect the websocket
await websocket_manager.connect(websocket, current_member.id)
try:
# Keep connection alive
while True:
# Wait for messages from client (like ping/pong)
data = await websocket.receive_text()
# Echo back for heartbeat
if data == "ping":
await websocket.send_text("pong")
except WebSocketDisconnect:
websocket_manager.disconnect(websocket, current_member.id)
except Exception as e:
await websocket.close(code=1011, reason=str(e))

View File

@ -1,52 +1,52 @@
from pydantic import BaseModel
from datetime import date
from typing import List
from analysis.infra.db_models.analysis import AnalysisStatus
from enum import Enum
class AnalystType(str, Enum):
MARKET = "market"
SOCIAL = "social"
NEWS = "news"
FUNDAMENTALS = "fundamentals"
class TradingAnalysisRequest(BaseModel):
ticker: str
analysis_date: str
analysts: List[AnalystType]
research_depth: int = 3
llm_provider: str = "openai"
backend_url: str = "https://api.openai.com/v1"
shallow_thinker: str = "gpt-4o-mini"
deep_thinker: str = "gpt-4o"
class AnalysisSessionResponse(BaseModel):
id : str
ticker : str
status : AnalysisStatus
class AnalysisProgressUpdate(BaseModel):
analysis_id: str
current_agent: str
status: str
progress_percentage: float
current_report_section: str | None = None
message: str | None = None
class AnalysisResultResponse(BaseModel):
id: str
ticker: str
analysis_date: str
status: AnalysisStatus
market_report: str | None = None
sentiment_report: str | None = None
news_report: str | None = None
fundamentals_report: str | None = None
investment_debate_state: dict | None = None
trader_investment_plan: str | None = None
risk_debate_state: dict | None = None
final_trade_decision: str | None = None
final_report: str | None = None
created_at: str
completed_at: str | None = None
from pydantic import BaseModel
from datetime import date
from typing import List
from analysis.infra.db_models.analysis import AnalysisStatus
from enum import Enum
class AnalystType(str, Enum):
MARKET = "market"
SOCIAL = "social"
NEWS = "news"
FUNDAMENTALS = "fundamentals"
class TradingAnalysisRequest(BaseModel):
ticker: str = "NVDA"
analysis_date: str = "2025-07-07"
analysts: List[AnalystType] = [AnalystType.MARKET, AnalystType.SOCIAL, AnalystType.NEWS, AnalystType.FUNDAMENTALS]
research_depth: int = 3
llm_provider: str = "openai"
backend_url: str = "https://api.openai.com/v1"
shallow_thinker: str = "gpt-4o-mini"
deep_thinker: str = "gpt-4o-mini"
class AnalysisSessionResponse(BaseModel):
id : str
ticker : str
status : AnalysisStatus
class AnalysisProgressUpdate(BaseModel):
analysis_id: str
current_agent: str
status: str
progress_percentage: float
current_report_section: str | None = None
message: str | None = None
class AnalysisResultResponse(BaseModel):
id: str
ticker: str
analysis_date: str
status: AnalysisStatus
market_report: str | None = None
sentiment_report: str | None = None
news_report: str | None = None
fundamentals_report: str | None = None
investment_debate_state: dict | None = None
trader_investment_plan: str | None = None
risk_debate_state: dict | None = None
final_trade_decision: str | None = None
final_report: str | None = None
created_at: str
completed_at: str | None = None
error_message: str | None = None

View File

@ -1,20 +1,20 @@
from functools import lru_cache
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
model_config = SettingsConfigDict(
env_file=".env",
env_file_encoding="utf-8",
)
# MySQL 데이터베이스 설정
DB_HOST: str
DB_PORT: int
DB_USER: str
DB_PASSWORD: str
DB_NAME: str
SECRET_KEY: str
@lru_cache
def get_settings():
from functools import lru_cache
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
model_config = SettingsConfigDict(
env_file=".env",
env_file_encoding="utf-8",
)
# MySQL 데이터베이스 설정
DB_HOST: str
DB_PORT: int
DB_USER: str
DB_PASSWORD: str
DB_NAME: str
SECRET_KEY: str
@lru_cache
def get_settings():
return Settings()

View File

@ -1,20 +1,30 @@
from fastapi import FastAPI
from utils.database import create_db_and_tables
from utils.containers import Container
from analysis.interface.controller.analysis_controller import router as analysis_router
from member.interface.controller.member_controller import router as member_router
app = FastAPI()
app.container = Container()
app.include_router(analysis_router)
app.include_router(member_router)
@app.on_event("startup")
def startup_db_client():
from fastapi import FastAPI
from utils.database import create_db_and_tables
from utils.containers import Container
from analysis.interface.controller.analysis_controller import router as analysis_router
from member.interface.controller.member_controller import router as member_router
import logging
# 로깅 설정
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(), # 콘솔 출력
]
)
app = FastAPI()
app.container = Container()
app.include_router(analysis_router)
app.include_router(member_router)
@app.on_event("startup")
def startup_db_client():
create_db_and_tables()

View File

@ -1,97 +1,97 @@
from sqlmodel import Session
from utils.crypto import Crypto
from member.domain.repository.member_repo import IMemberRepository
from utils.auth import Role
from member.domain.member import Member as MemberVO
from fastapi import HTTPException, status
from datetime import datetime
from utils.auth import create_access_token
from ulid import ULID
from analysis.domain.analysis import Analysis as AnalysisVO
class MemberService:
def __init__(
self,
member_repo: IMemberRepository,
crypto: Crypto,
db_session: Session,
ulid: ULID
):
self.member_repo = member_repo
self.crypto = crypto
self.db_session = db_session
self.ulid = ulid
def create_member(
self,
name: str,
email: str,
password: str,
role: Role
):
try:
if self.member_repo.find_by_email(email):
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Email already exists")
except Exception as e:
self.db_session.rollback()
raise e
now = datetime.now()
member_vo = MemberVO(
id=self.ulid.generate(),
name=name,
email=email,
password=self.crypto.encrypt(password),
created_at=now,
updated_at=now,
role=role
)
saved_member = self.member_repo.save(member_vo)
self.db_session.commit()
return saved_member
def get_members(
self,
page: int,
items_per_page: int
)->tuple[int, list[MemberVO]] :
return self.member_repo.get_members(page, items_per_page)
def get_member(
self,
id: str
)->MemberVO | None:
member = self.member_repo.find_by_id(id)
if not member:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Member not found")
return member
def login(
self,
email: str,
password: str
):
member = self.member_repo.find_by_email(email)
if not member:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials")
if not self.crypto.verify(password, member.password):
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials")
access_token = create_access_token(
payload={"member_id": member.id, "role": member.role},
role=member.role,
)
return access_token
def get_analysis_sessions_by_member(
self,
member_id: str
)->list[AnalysisVO]:
analysis_sessions = self.member_repo.find_analysis_sessions_by_member(member_id)
return analysis_sessions
from sqlmodel import Session
from utils.crypto import Crypto
from member.domain.repository.member_repo import IMemberRepository
from utils.auth import Role
from member.domain.member import Member as MemberVO
from fastapi import HTTPException, status
from datetime import datetime
from utils.auth import create_access_token
from ulid import ULID
from analysis.domain.analysis import Analysis as AnalysisVO
class MemberService:
def __init__(
self,
member_repo: IMemberRepository,
crypto: Crypto,
session: Session,
ulid: ULID
):
self.member_repo = member_repo
self.crypto = crypto
self.db_session = session
self.ulid = ulid
def create_member(
self,
name: str,
email: str,
password: str,
role: Role
):
try:
if self.member_repo.find_by_email(email):
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Email already exists")
except Exception as e:
self.db_session.rollback()
raise e
now = datetime.now()
member_vo = MemberVO(
id=self.ulid.generate(),
name=name,
email=email,
password=self.crypto.encrypt(password),
created_at=now,
updated_at=now,
role=role
)
saved_member = self.member_repo.save(member_vo)
self.db_session.commit()
return saved_member
def get_members(
self,
page: int,
items_per_page: int
)->tuple[int, list[MemberVO]] :
return self.member_repo.get_members(page, items_per_page)
def get_member(
self,
id: str
)->MemberVO | None:
member = self.member_repo.find_by_id(id)
if not member:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Member not found")
return member
def login(
self,
email: str,
password: str
):
member = self.member_repo.find_by_email(email)
if not member:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials")
if not self.crypto.verify(password, member.password):
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials")
access_token = create_access_token(
payload={"member_id": member.id, "role": member.role},
role=member.role,
)
return access_token
def get_analysis_sessions_by_member(
self,
member_id: str
)->list[AnalysisVO]:
analysis_sessions = self.member_repo.find_analysis_sessions_by_member(member_id)
return analysis_sessions

View File

@ -1,12 +1,12 @@
from pydantic import BaseModel
from utils.auth import Role
from datetime import datetime
class Member(BaseModel):
id: str | None = None
name: str
email: str
password: str
role: Role
created_at: datetime
from pydantic import BaseModel
from utils.auth import Role
from datetime import datetime
class Member(BaseModel):
id: str | None = None
name: str
email: str
password: str
role: Role
created_at: datetime
updated_at: datetime

View File

@ -1,4 +1,4 @@
from abc import ABC, abstractmethod
class IMemberRepository(ABC):
from abc import ABC, abstractmethod
class IMemberRepository(ABC):
pass

View File

@ -1,24 +1,24 @@
from abc import ABC, abstractmethod
from member.domain.member import Member as MemberVO
from analysis.domain.analysis import Analysis as AnalysisVO
class IMemberRepository(ABC):
@abstractmethod
def find_by_email(self, email: str) -> MemberVO | None:
raise NotImplementedError()
@abstractmethod
def save(self, member: MemberVO) -> MemberVO:
raise NotImplementedError()
@abstractmethod
def find_by_id(self, id: str) -> MemberVO | None:
raise NotImplementedError()
@abstractmethod
def get_members(self, page: int, items_per_page: int) -> tuple[int, list[MemberVO]]:
raise NotImplementedError()
@abstractmethod
def find_analysis_sessions_by_member(self, member_id: str) -> list[AnalysisVO]:
from abc import ABC, abstractmethod
from member.domain.member import Member as MemberVO
from analysis.domain.analysis import Analysis as AnalysisVO
class IMemberRepository(ABC):
@abstractmethod
def find_by_email(self, email: str) -> MemberVO | None:
raise NotImplementedError()
@abstractmethod
def save(self, member: MemberVO) -> MemberVO:
raise NotImplementedError()
@abstractmethod
def find_by_id(self, id: str) -> MemberVO | None:
raise NotImplementedError()
@abstractmethod
def get_members(self, page: int, items_per_page: int) -> tuple[int, list[MemberVO]]:
raise NotImplementedError()
@abstractmethod
def find_analysis_sessions_by_member(self, member_id: str) -> list[AnalysisVO]:
raise NotImplementedError()

View File

@ -1,66 +1,66 @@
from member.domain.repository import IMemberRepository
from sqlmodel import Session, select
from member.domain.member import Member as MemberVO
from member.infra.db_models.member import Member
from utils.db_utils import row_to_dict
from sqlalchemy import func
class MemberRepository(IMemberRepository):
def __init__(self, session: Session):
self.session = session
def find_by_email(self, email: str) -> MemberVO | None:
query = select(Member).where(Member.email == email)
member = self.session.exec(query).first()
if not member:
return None
return MemberVO(**row_to_dict(member))
def save(self, member: MemberVO) -> MemberVO:
new_member = Member(
id=member.id,
email=member.email,
name=member.name,
password=member.password,
role=member.role,
created_at=member.created_at,
updated_at=member.updated_at
)
self.session.add(new_member)
self.session.flush()
self.session.refresh(new_member)
member.id = new_member.id
return member
def get_members(self, page: int, items_per_page: int) -> tuple[int, list[MemberVO]]:
offset = (page - 1) * items_per_page
total_count_query = select(func.count(Member.id))
total_count = self.session.exec(total_count_query).one()
if total_count == 0:
return 0, []
query = (
select(Member)
.order_by(Member.created_at.desc())
.offset(offset)
.limit(items_per_page)
)
members = self.session.exec(query).all()
return total_count, [MemberVO(**row_to_dict(member)) for member in members]
def find_by_id(self, id: str) -> MemberVO | None:
query = select(Member).where(Member.id == id)
member = self.session.exec(query).first()
if not member:
return None
return MemberVO(**row_to_dict(member))
from member.domain.repository import IMemberRepository
from sqlmodel import Session, select
from member.domain.member import Member as MemberVO
from member.infra.db_models.member import Member
from utils.db_utils import row_to_dict
from sqlalchemy import func
class MemberRepository(IMemberRepository):
def __init__(self, session: Session):
self.session = session
def find_by_email(self, email: str) -> MemberVO | None:
query = select(Member).where(Member.email == email)
member = self.session.exec(query).first()
if not member:
return None
return MemberVO(**row_to_dict(member))
def save(self, member: MemberVO) -> MemberVO:
new_member = Member(
id=member.id,
email=member.email,
name=member.name,
password=member.password,
role=member.role,
created_at=member.created_at,
updated_at=member.updated_at
)
self.session.add(new_member)
self.session.flush()
self.session.refresh(new_member)
member.id = new_member.id
return member
def get_members(self, page: int, items_per_page: int) -> tuple[int, list[MemberVO]]:
offset = (page - 1) * items_per_page
total_count_query = select(func.count(Member.id))
total_count = self.session.exec(total_count_query).one()
if total_count == 0:
return 0, []
query = (
select(Member)
.order_by(Member.created_at.desc())
.offset(offset)
.limit(items_per_page)
)
members = self.session.exec(query).all()
return total_count, [MemberVO(**row_to_dict(member)) for member in members]
def find_by_id(self, id: str) -> MemberVO | None:
query = select(Member).where(Member.id == id)
member = self.session.exec(query).first()
if not member:
return None
return MemberVO(**row_to_dict(member))

View File

@ -1,78 +1,71 @@
from fastapi import APIRouter, status, Depends,HTTPException
from member.interface.dto import CreateUserBody, MemberResponse
from member.application.member_service import MemberService
from typing import Annotated
from utils.containers import Container
from dependency_injector.wiring import inject, Provide
from fastapi.security import OAuth2PasswordRequestForm
from utils.auth import get_current_member, CurrentMember, get_admin_member
router = APIRouter(prefix="/members", tags=["members"])
@router.post("", status_code=status.HTTP_201_CREATED, response_model=MemberResponse)
@inject
async def create_user(
member: CreateUserBody,
member_service: MemberService = Depends(Provide[Container.member_service])
):
created_member = member_service.create_member(
member.name,
member.email,
member.password,
member.role
)
return created_member
@router.post("/login")
@inject
def login(
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
member_service: MemberService = Depends(Provide[Container.member_service])
):
access_token = member_service.login(
email=form_data.username,
password=form_data.password
)
return {
"access_token" : access_token,
"token_type" : "Bearer"
}
@router.get("/me", response_model=dict)
def get_current_user_info(
current_user: CurrentMember = Depends(get_current_member)
):
"""
현재 로그인한 사용자 정보를 조회합니다.
엔드포인트는 JWT 토큰이 필요하며, Swagger UI에서 Authorize 버튼을 활성화합니다.
"""
return {
"user_id": current_user.id,
"role": current_user.role,
"message": "Successfully authenticated"
}
@router.get("/{member_id}", response_model=MemberResponse)
@inject
def get_member(
member_id: str,
current_member: Annotated[CurrentMember | None, Depends(get_current_member)] = None,
member_service: Annotated[MemberService | None, Depends(Provide[Container.member_service])] = None
):
member = member_service.get_member(member_id)
if not member:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Member not found")
return member
# @router.get("/analysis-sessions", response_model=list[AnalysisSessionResponse])
# @inject
# def get_member_analysis_sessions(
# current_member: Annotated[CurrentMember | None, Depends(get_current_member)] = None,
# member_service: Annotated[MemberService | None, Depends(Provide[Container.member_service])] = None
# ):
# result = member_service.get_analysis_sessions_by_member(current_member.id)
# return result
from fastapi import APIRouter, status, Depends,HTTPException
from member.interface.dto import CreateUserBody, MemberResponse
from member.application.member_service import MemberService
from typing import Annotated
from utils.containers import Container
from dependency_injector.wiring import inject, Provide
from fastapi.security import OAuth2PasswordRequestForm
from utils.auth import get_current_member, CurrentMember, get_admin_member
from analysis.interface.dto import AnalysisSessionResponse
from analysis.application.analysis_service import AnalysisService
router = APIRouter(prefix="/members", tags=["members"])
@router.post("", status_code=status.HTTP_201_CREATED, response_model=MemberResponse)
@inject
async def create_user(
member: CreateUserBody,
member_service: MemberService = Depends(Provide[Container.member_service])
):
created_member = member_service.create_member(
member.name,
member.email,
member.password,
member.role
)
return created_member
@router.post("/login")
@inject
def login(
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
member_service: MemberService = Depends(Provide[Container.member_service])
):
access_token = member_service.login(
email=form_data.username,
password=form_data.password
)
return {
"access_token" : access_token,
"token_type" : "Bearer"
}
@router.get("/me", response_model=dict)
def get_current_user_info(
current_user: CurrentMember = Depends(get_current_member)
):
"""
현재 로그인한 사용자 정보를 조회합니다.
엔드포인트는 JWT 토큰이 필요하며, Swagger UI에서 Authorize 버튼을 활성화합니다.
"""
return {
"user_id": current_user.id,
"role": current_user.role,
"message": "Successfully authenticated"
}
@router.get("/{member_id}", response_model=MemberResponse)
@inject
def get_member(
member_id: str,
current_member: Annotated[CurrentMember | None, Depends(get_current_member)] = None,
member_service: Annotated[MemberService | None, Depends(Provide[Container.member_service])] = None
):
member = member_service.get_member(member_id)
if not member:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Member not found")
return member

View File

@ -1,18 +1,18 @@
from typing import Annotated
from pydantic import BaseModel, Field, EmailStr
from utils.auth import Role
from datetime import datetime
class CreateUserBody(BaseModel):
name : Annotated[str, Field(min_length=1, max_length=32)]
email : Annotated[EmailStr, Field(max_length=32)]
password : Annotated[str, Field(max_length=32)]
role : Annotated[Role, Field(default=Role.USER)]
class MemberResponse(BaseModel):
id : str
name : str | None = None
email : str
created_at : datetime
updated_at : datetime
role : Role
from typing import Annotated
from pydantic import BaseModel, Field, EmailStr
from utils.auth import Role
from datetime import datetime
class CreateUserBody(BaseModel):
name : Annotated[str, Field(min_length=1, max_length=32)]
email : Annotated[EmailStr, Field(max_length=32)]
password : Annotated[str, Field(max_length=32)]
role : Annotated[Role, Field(default=Role.USER)]
class MemberResponse(BaseModel):
id : str
name : str | None = None
email : str
created_at : datetime
updated_at : datetime
role : Role

View File

@ -1,69 +1,69 @@
from datetime import datetime, timedelta
from fastapi import HTTPException, status, Depends
from jose import jwt, JWTError
from fastapi.security import OAuth2PasswordBearer
import os
from dotenv import load_dotenv
from enum import StrEnum
from pydantic import BaseModel
from typing import Annotated
from config import get_settings
settings = get_settings()
SECRET_KEY = settings.SECRET_KEY
ALGORITHM = "HS256"
class Role(StrEnum):
ADMIN = "ADMIN"
USER = "USER"
class CurrentMember(BaseModel):
id : str
role : Role
def __str__(self):
return f"{self.id}({self.role})"
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/members/login")
def create_access_token(
payload: dict,
role: Role,
expires_delta: timedelta = timedelta(hours=6)
):
expire = datetime.utcnow() + expires_delta
payload.update({"exp": expire, "role": role})
encoded_jwt = jwt.encode(payload, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
def decode_access_token(token: str):
try:
return jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
except JWTError:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token")
# ✅ 수정된 부분: Annotated 올바른 사용법
def get_current_member(token: Annotated[str, Depends(oauth2_scheme)]):
payload = decode_access_token(token)
member_id = payload.get("member_id")
role = payload.get("role")
if not member_id or not role:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Invalid token")
return CurrentMember(id=member_id, role=Role(role))
def get_admin_member(token: Annotated[str, Depends(oauth2_scheme)]):
payload = decode_access_token(token)
member_id = payload.get("member_id")
role = payload.get("role")
if not role or role != Role.ADMIN:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Invalid token")
from datetime import datetime, timedelta
from fastapi import HTTPException, status, Depends
from jose import jwt, JWTError
from fastapi.security import OAuth2PasswordBearer
import os
from dotenv import load_dotenv
from enum import StrEnum
from pydantic import BaseModel
from typing import Annotated
from config import get_settings
settings = get_settings()
SECRET_KEY = settings.SECRET_KEY
ALGORITHM = "HS256"
class Role(StrEnum):
ADMIN = "ADMIN"
USER = "USER"
class CurrentMember(BaseModel):
id : str
role : Role
def __str__(self):
return f"{self.id}({self.role})"
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/members/login")
def create_access_token(
payload: dict,
role: Role,
expires_delta: timedelta = timedelta(hours=6)
):
expire = datetime.utcnow() + expires_delta
payload.update({"exp": expire, "role": role})
encoded_jwt = jwt.encode(payload, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
def decode_access_token(token: str):
try:
return jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
except JWTError:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token")
# ✅ 수정된 부분: Annotated 올바른 사용법
def get_current_member(token: Annotated[str, Depends(oauth2_scheme)]):
payload = decode_access_token(token)
member_id = payload.get("member_id")
role = payload.get("role")
if not member_id or not role:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Invalid token")
return CurrentMember(id=member_id, role=Role(role))
def get_admin_member(token: Annotated[str, Depends(oauth2_scheme)]):
payload = decode_access_token(token)
member_id = payload.get("member_id")
role = payload.get("role")
if not role or role != Role.ADMIN:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Invalid token")
return CurrentMember(id=member_id, role=Role(role))

View File

@ -1,43 +1,49 @@
from dependency_injector import containers, providers
from utils.database import get_session
from utils.crypto import Crypto
from member.infra.repository.member_repo import MemberRepository
from member.application.member_service import MemberService
from analysis.application.analysis_service import AnalysisService
from analysis.infra.repository.analysis_repo import AnalysisRepository
from ulid import ULID
class Container(containers.DeclarativeContainer):
wiring_config = containers.WiringConfiguration(
packages=["member", "analysis"]
)
db_session = providers.Resource(get_session)
crypto = providers.Factory(Crypto)
ulid = providers.Factory(ULID)
member_repo = providers.Factory(
MemberRepository,
session=db_session
)
member_service = providers.Factory(
MemberService,
member_repo=member_repo,
crypto=crypto,
db_session=db_session,
ulid=ulid
)
analysis_repo = providers.Factory(
AnalysisRepository,
session=db_session
)
analysis_service = providers.Factory(
AnalysisService,
analysis_repo=analysis_repo,
db_session=db_session,
ulid=ulid
)
from dependency_injector import containers, providers
from utils.database import get_session
from utils.crypto import Crypto
from member.infra.repository.member_repo import MemberRepository
from member.application.member_service import MemberService
from analysis.application.analysis_service import AnalysisService
from analysis.infra.repository.analysis_repo import AnalysisRepository
from analysis.application.websocket_manager import WebSocketManager
from ulid import ULID
class Container(containers.DeclarativeContainer):
wiring_config = containers.WiringConfiguration(
packages=["member", "analysis"]
)
session = providers.Resource(get_session)
crypto = providers.Factory(Crypto)
ulid = providers.Factory(ULID)
member_repo = providers.Factory(
MemberRepository,
session=session
)
member_service = providers.Factory(
MemberService,
member_repo=member_repo,
crypto=crypto,
session=session,
ulid=ulid
)
analysis_repo = providers.Factory(
AnalysisRepository,
session=session
)
websocket_manager = providers.Singleton(
WebSocketManager
)
analysis_service = providers.Factory(
AnalysisService,
analysis_repo=analysis_repo,
session=session,
ulid=ulid,
websocket_manager=websocket_manager
)

View File

@ -1,12 +1,12 @@
from passlib.context import CryptContext
class Crypto:
def __init__(self):
self.pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def encrypt(self, secret):
return self.pwd_context.hash(secret)
def verify(self, secret, hash):
return self.pwd_context.verify(secret, hash)
from passlib.context import CryptContext
class Crypto:
def __init__(self):
self.pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def encrypt(self, secret):
return self.pwd_context.hash(secret)
def verify(self, secret, hash):
return self.pwd_context.verify(secret, hash)

View File

@ -1,32 +1,32 @@
import os
from pathlib import Path
from sqlmodel import SQLModel, create_engine, Session
from config.config import get_settings
from member.infra.db_models.member import Member
from analysis.infra.db_models.analysis import Analysis
settings = get_settings()
BASE_DIR = Path(__file__).resolve().parent.parent
# MySQL 데이터베이스 URL 구성
DATABASE_URL = f"mysql+pymysql://{settings.DB_USER}:{settings.DB_PASSWORD}@{settings.DB_HOST}:{settings.DB_PORT}/{settings.DB_NAME}?charset=utf8mb4"
# MySQL 엔진 생성
engine = create_engine(
DATABASE_URL,
echo=True
)
def get_session():
with Session(engine) as session:
yield session
def create_db_and_tables():
# 테이블 생성
# SQLModel.metadata.drop_all(engine)
SQLModel.metadata.create_all(engine)
if __name__ == "__main__":
create_db_and_tables()
import os
from pathlib import Path
from sqlmodel import SQLModel, create_engine, Session
from config.config import get_settings
from member.infra.db_models.member import Member
from analysis.infra.db_models.analysis import Analysis
settings = get_settings()
BASE_DIR = Path(__file__).resolve().parent.parent
# MySQL 데이터베이스 URL 구성
DATABASE_URL = f"mysql+pymysql://{settings.DB_USER}:{settings.DB_PASSWORD}@{settings.DB_HOST}:{settings.DB_PORT}/{settings.DB_NAME}?charset=utf8mb4"
# MySQL 엔진 생성
engine = create_engine(
DATABASE_URL,
echo=True
)
def get_session():
with Session(engine) as session:
yield session
def create_db_and_tables():
# 테이블 생성
# SQLModel.metadata.drop_all(engine)
SQLModel.metadata.create_all(engine)
if __name__ == "__main__":
create_db_and_tables()
print(DATABASE_URL)

View File

@ -1,4 +1,4 @@
from sqlalchemy import inspect
def row_to_dict(row)->dict:
from sqlalchemy import inspect
def row_to_dict(row)->dict:
return {key : getattr(row, key) for key in inspect(row).attrs.keys()}

View File

@ -1,59 +1,59 @@
version: '3.8'
services:
mysql:
image: mysql:8.0
container_name: tradingagents_mysql
restart: unless-stopped
environment:
MYSQL_ROOT_PASSWORD: ${DB_PASSWORD:-password}
MYSQL_DATABASE: ${DB_NAME:-tradingagents_db}
MYSQL_USER: ${DB_USER:-tradinguser}
MYSQL_PASSWORD: ${DB_PASSWORD:-password}
ports:
- "3306:3306"
volumes:
- /home/hskim/mysql_data:/var/lib/mysql
- /home/hskim/docker/mysql/init:/docker-entrypoint-initdb.d
command: --default-authentication-plugin=mysql_native_password --character-set-server=utf8mb4 --collation-server=utf8mb4_unicode_ci
networks:
- tradingagents_network
redis:
image: redis:7-alpine
container_name: tradingagents_redis
restart: unless-stopped
ports:
- "6379:6379"
volumes:
- redis_data:/data
command: redis-server --appendonly yes
networks:
- tradingagents_network
# 개발용 phpMyAdmin (선택사항)
# phpmyadmin:
# image: phpmyadmin/phpmyadmin
# container_name: tradingagents_phpmyadmin
# restart: unless-stopped
# environment:
# PMA_HOST: mysql
# PMA_PORT: 3306
# PMA_USER: root
# PMA_PASSWORD: ${DB_PASSWORD:-password}
# ports:
# - "8080:80"
# depends_on:
# - mysql
# networks:
# - tradingagents_network
volumes:
mysql_data:
driver: local
redis_data:
driver: local
networks:
tradingagents_network:
version: '3.8'
services:
mysql:
image: mysql:8.0
container_name: tradingagents_mysql
restart: unless-stopped
environment:
MYSQL_ROOT_PASSWORD: ${DB_PASSWORD:-password}
MYSQL_DATABASE: ${DB_NAME:-tradingagents_db}
MYSQL_USER: ${DB_USER:-tradinguser}
MYSQL_PASSWORD: ${DB_PASSWORD:-password}
ports:
- "3306:3306"
volumes:
- /home/hskim/mysql_data:/var/lib/mysql
- /home/hskim/docker/mysql/init:/docker-entrypoint-initdb.d
command: --default-authentication-plugin=mysql_native_password --character-set-server=utf8mb4 --collation-server=utf8mb4_unicode_ci
networks:
- tradingagents_network
redis:
image: redis:7-alpine
container_name: tradingagents_redis
restart: unless-stopped
ports:
- "6379:6379"
volumes:
- redis_data:/data
command: redis-server --appendonly yes
networks:
- tradingagents_network
# 개발용 phpMyAdmin (선택사항)
# phpmyadmin:
# image: phpmyadmin/phpmyadmin
# container_name: tradingagents_phpmyadmin
# restart: unless-stopped
# environment:
# PMA_HOST: mysql
# PMA_PORT: 3306
# PMA_USER: root
# PMA_PASSWORD: ${DB_PASSWORD:-password}
# ports:
# - "8080:80"
# depends_on:
# - mysql
# networks:
# - tradingagents_network
volumes:
mysql_data:
driver: local
redis_data:
driver: local
networks:
tradingagents_network:
driver: bridge

53212
frontend/package-lock.json generated

File diff suppressed because it is too large Load Diff

View File

@ -1,48 +1,48 @@
{
"name": "tradingagents-web-frontend",
"version": "0.1.0",
"private": true,
"dependencies": {
"@ant-design/icons": "^5.2.6",
"@testing-library/jest-dom": "^5.16.4",
"@testing-library/react": "^13.3.0",
"@testing-library/user-event": "^13.5.0",
"antd": "^5.10.0",
"axios": "^1.5.0",
"dayjs": "^1.11.9",
"react": "^18.2.0",
"react-dom": "^18.2.0",
"react-markdown": "^8.0.7",
"react-router-dom": "^6.4.0",
"react-scripts": "5.0.1",
"recharts": "^2.8.0",
"remark-gfm": "^4.0.1",
"styled-components": "^6.0.8",
"websocket": "^1.0.34"
},
"scripts": {
"start": "react-scripts start",
"build": "react-scripts build",
"test": "react-scripts test",
"eject": "react-scripts eject"
},
"eslintConfig": {
"extends": [
"react-app",
"react-app/jest"
]
},
"browserslist": {
"production": [
">0.2%",
"not dead",
"not op_mini all"
],
"development": [
"last 1 chrome version",
"last 1 firefox version",
"last 1 safari version"
]
},
"proxy": "http://localhost:8000"
}
{
"name": "tradingagents-web-frontend",
"version": "0.1.0",
"private": true,
"dependencies": {
"@ant-design/icons": "^5.2.6",
"@testing-library/jest-dom": "^5.16.4",
"@testing-library/react": "^13.3.0",
"@testing-library/user-event": "^13.5.0",
"antd": "^5.10.0",
"axios": "^1.5.0",
"dayjs": "^1.11.9",
"react": "^18.2.0",
"react-dom": "^18.2.0",
"react-markdown": "^8.0.7",
"react-router-dom": "^6.4.0",
"react-scripts": "5.0.1",
"recharts": "^2.8.0",
"remark-gfm": "^4.0.1",
"styled-components": "^6.0.8",
"websocket": "^1.0.34"
},
"scripts": {
"start": "react-scripts start",
"build": "react-scripts build",
"test": "react-scripts test",
"eject": "react-scripts eject"
},
"eslintConfig": {
"extends": [
"react-app",
"react-app/jest"
]
},
"browserslist": {
"production": [
">0.2%",
"not dead",
"not op_mini all"
],
"development": [
"last 1 chrome version",
"last 1 firefox version",
"last 1 safari version"
]
},
"proxy": "http://localhost:8000"
}

View File

@ -1,20 +1,20 @@
<!DOCTYPE html>
<html lang="ko">
<head>
<meta charset="utf-8" />
<link rel="icon" href="%PUBLIC_URL%/favicon.ico" />
<meta name="viewport" content="width=device-width, initial-scale=1" />
<meta name="theme-color" content="#000000" />
<meta
name="description"
content="TradingAgents - Multi-Agents LLM Financial Trading Framework"
/>
<link rel="apple-touch-icon" href="%PUBLIC_URL%/logo192.png" />
<link rel="manifest" href="%PUBLIC_URL%/manifest.json" />
<title>TradingAgents - AI 거래 분석 플랫폼</title>
</head>
<body>
<noscript>JavaScript를 활성화해야 이 앱을 실행할 수 있습니다.</noscript>
<div id="root"></div>
</body>
<!DOCTYPE html>
<html lang="ko">
<head>
<meta charset="utf-8" />
<link rel="icon" href="%PUBLIC_URL%/favicon.ico" />
<meta name="viewport" content="width=device-width, initial-scale=1" />
<meta name="theme-color" content="#000000" />
<meta
name="description"
content="TradingAgents - Multi-Agents LLM Financial Trading Framework"
/>
<link rel="apple-touch-icon" href="%PUBLIC_URL%/logo192.png" />
<link rel="manifest" href="%PUBLIC_URL%/manifest.json" />
<title>TradingAgents - AI 거래 분석 플랫폼</title>
</head>
<body>
<noscript>JavaScript를 활성화해야 이 앱을 실행할 수 있습니다.</noscript>
<div id="root"></div>
</body>
</html>

42
main.py
View File

@ -1,21 +1,21 @@
from tradingagents.graph.trading_graph import TradingAgentsGraph
from tradingagents.default_config import DEFAULT_CONFIG
# Create a custom config
config = DEFAULT_CONFIG.copy()
config["llm_provider"] = "google" # Use a different model
config["backend_url"] = "https://generativelanguage.googleapis.com/v1" # Use a different backend
config["deep_think_llm"] = "gemini-2.5-pro" # Use a different model
config["quick_think_llm"] = "gemini-2.5-flash-lite-preview-06-17" # Use a different model
config["max_debate_rounds"] = 1 # Increase debate rounds
config["online_tools"] = True # Increase debate rounds
# Initialize with custom config
ta = TradingAgentsGraph(debug=True, config=config)
# forward propagate
_, decision = ta.propagate("NVDA", "2024-05-10")
print(decision)
# Memorize mistakes and reflect
# ta.reflect_and_remember(1000) # parameter is the position returns
from tradingagents.graph.trading_graph import TradingAgentsGraph
from tradingagents.default_config import DEFAULT_CONFIG
# Create a custom config
config = DEFAULT_CONFIG.copy()
config["llm_provider"] = "google" # Use a different model
config["backend_url"] = "https://generativelanguage.googleapis.com/v1" # Use a different backend
config["deep_think_llm"] = "gemini-2.5-pro" # Use a different model
config["quick_think_llm"] = "gemini-2.5-flash-lite-preview-06-17" # Use a different model
config["max_debate_rounds"] = 1 # Increase debate rounds
config["online_tools"] = True # Increase debate rounds
# Initialize with custom config
ta = TradingAgentsGraph(debug=True, config=config)
# forward propagate
_, decision = ta.propagate("NVDA", "2024-05-10")
print(decision)
# Memorize mistakes and reflect
# ta.reflect_and_remember(1000) # parameter is the position returns

View File

@ -1,34 +1,34 @@
[project]
name = "tradingagents"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.10"
dependencies = [
"akshare>=1.16.98",
"backtrader>=1.9.78.123",
"chainlit>=2.5.5",
"chromadb>=1.0.12",
"eodhd>=1.0.32",
"feedparser>=6.0.11",
"finnhub-python>=2.4.23",
"langchain-anthropic>=0.3.15",
"langchain-experimental>=0.3.4",
"langchain-google-genai>=2.1.5",
"langchain-openai>=0.3.23",
"langgraph>=0.4.8",
"pandas>=2.3.0",
"parsel>=1.10.0",
"praw>=7.8.1",
"pytz>=2025.2",
"questionary>=2.1.0",
"redis>=6.2.0",
"requests>=2.32.4",
"rich>=14.0.0",
"setuptools>=80.9.0",
"stockstats>=0.6.5",
"tqdm>=4.67.1",
"tushare>=1.4.21",
"typing-extensions>=4.14.0",
"yfinance>=0.2.63",
]
[project]
name = "tradingagents"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.10"
dependencies = [
"akshare>=1.16.98",
"backtrader>=1.9.78.123",
"chainlit>=2.5.5",
"chromadb>=1.0.12",
"eodhd>=1.0.32",
"feedparser>=6.0.11",
"finnhub-python>=2.4.23",
"langchain-anthropic>=0.3.15",
"langchain-experimental>=0.3.4",
"langchain-google-genai>=2.1.5",
"langchain-openai>=0.3.23",
"langgraph>=0.4.8",
"pandas>=2.3.0",
"parsel>=1.10.0",
"praw>=7.8.1",
"pytz>=2025.2",
"questionary>=2.1.0",
"redis>=6.2.0",
"requests>=2.32.4",
"rich>=14.0.0",
"setuptools>=80.9.0",
"stockstats>=0.6.5",
"tqdm>=4.67.1",
"tushare>=1.4.21",
"typing-extensions>=4.14.0",
"yfinance>=0.2.63",
]

View File

@ -1,60 +1,60 @@
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
import time
import json
def create_news_analyst(llm, toolkit):
def news_analyst_node(state):
current_date = state["trade_date"]
ticker = state["company_of_interest"]
if toolkit.config["online_tools"]:
tools = [toolkit.get_global_news, toolkit.get_google_news]
else:
tools = [
toolkit.get_finnhub_news,
toolkit.get_reddit_news,
toolkit.get_google_news,
]
system_message = (
"**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)\n\nYou are a news researcher tasked with analyzing recent news and trends over the past week. Please write a comprehensive report of the current state of the world that is relevant for trading and macroeconomics. Look at news from EODHD, and finnhub to be comprehensive. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."
+ """ Make sure to append a Makrdown table at the end of the report to organize key points in the report, organized and easy to read."""
)
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"You are a helpful AI assistant, collaborating with other assistants."
" Use the provided tools to progress towards answering the question."
" If you are unable to fully answer, that's OK; another assistant with different tools"
" will help where you left off. Execute what you can to make progress."
" If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable,"
" prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop."
" You have access to the following tools: {tool_names}.\n{system_message}"
"For your reference, the current date is {current_date}. We are looking at the company {ticker}",
),
MessagesPlaceholder(variable_name="messages"),
]
)
prompt = prompt.partial(system_message=system_message)
prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools]))
prompt = prompt.partial(current_date=current_date)
prompt = prompt.partial(ticker=ticker)
chain = prompt | llm.bind_tools(tools)
result = chain.invoke(state["messages"])
report = ""
if len(result.tool_calls) == 0:
report = result.content
return {
"messages": [result],
"news_report": report,
}
return news_analyst_node
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
import time
import json
def create_news_analyst(llm, toolkit):
def news_analyst_node(state):
current_date = state["trade_date"]
ticker = state["company_of_interest"]
if toolkit.config["online_tools"]:
tools = [toolkit.get_global_news, toolkit.get_google_news]
else:
tools = [
toolkit.get_finnhub_news,
toolkit.get_reddit_news,
toolkit.get_google_news,
]
system_message = (
"**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)\n\nYou are a news researcher tasked with analyzing recent news and trends over the past week. Please write a comprehensive report of the current state of the world that is relevant for trading and macroeconomics. Look at news from EODHD, and finnhub to be comprehensive. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."
+ """ Make sure to append a Makrdown table at the end of the report to organize key points in the report, organized and easy to read."""
)
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"You are a helpful AI assistant, collaborating with other assistants."
" Use the provided tools to progress towards answering the question."
" If you are unable to fully answer, that's OK; another assistant with different tools"
" will help where you left off. Execute what you can to make progress."
" If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable,"
" prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop."
" You have access to the following tools: {tool_names}.\n{system_message}"
"For your reference, the current date is {current_date}. We are looking at the company {ticker}",
),
MessagesPlaceholder(variable_name="messages"),
]
)
prompt = prompt.partial(system_message=system_message)
prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools]))
prompt = prompt.partial(current_date=current_date)
prompt = prompt.partial(ticker=ticker)
chain = prompt | llm.bind_tools(tools)
result = chain.invoke(state["messages"])
report = ""
if len(result.tool_calls) == 0:
report = result.content
return {
"messages": [result],
"news_report": report,
}
return news_analyst_node

View File

@ -1,60 +1,60 @@
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
import time
import json
def create_social_media_analyst(llm, toolkit):
def social_media_analyst_node(state):
current_date = state["trade_date"]
ticker = state["company_of_interest"]
company_name = state["company_of_interest"]
if toolkit.config["online_tools"]:
tools = [toolkit.get_stock_news]
else:
tools = [
toolkit.get_reddit_stock_info,
]
system_message = (
"**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)\n\nYou are a social media and company specific news researcher/analyst tasked with analyzing social media posts, recent company news, and public sentiment for a specific company over the past week. You will be given a company's name your objective is to write a comprehensive long report detailing your analysis, insights, and implications for traders and investors on this company's current state after looking at social media and what people are saying about that company, analyzing sentiment data of what people feel each day about the company, and looking at recent company news. Try to look at all sources possible from social media to sentiment to news. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."
+ """ Make sure to append a Makrdown table at the end of the report to organize key points in the report, organized and easy to read.""",
)
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"You are a helpful AI assistant, collaborating with other assistants."
" Use the provided tools to progress towards answering the question."
" If you are unable to fully answer, that's OK; another assistant with different tools"
" will help where you left off. Execute what you can to make progress."
" If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable,"
" prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop."
" You have access to the following tools: {tool_names}.\n{system_message}"
"For your reference, the current date is {current_date}. The current company we want to analyze is {ticker}",
),
MessagesPlaceholder(variable_name="messages"),
]
)
prompt = prompt.partial(system_message=system_message)
prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools]))
prompt = prompt.partial(current_date=current_date)
prompt = prompt.partial(ticker=ticker)
chain = prompt | llm.bind_tools(tools)
result = chain.invoke(state["messages"])
report = ""
if len(result.tool_calls) == 0:
report = result.content
return {
"messages": [result],
"sentiment_report": report,
}
return social_media_analyst_node
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
import time
import json
def create_social_media_analyst(llm, toolkit):
def social_media_analyst_node(state):
current_date = state["trade_date"]
ticker = state["company_of_interest"]
company_name = state["company_of_interest"]
if toolkit.config["online_tools"]:
tools = [toolkit.get_stock_news]
else:
tools = [
toolkit.get_reddit_stock_info,
]
system_message = (
"**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)\n\nYou are a social media and company specific news researcher/analyst tasked with analyzing social media posts, recent company news, and public sentiment for a specific company over the past week. You will be given a company's name your objective is to write a comprehensive long report detailing your analysis, insights, and implications for traders and investors on this company's current state after looking at social media and what people are saying about that company, analyzing sentiment data of what people feel each day about the company, and looking at recent company news. Try to look at all sources possible from social media to sentiment to news. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."
+ """ Make sure to append a Makrdown table at the end of the report to organize key points in the report, organized and easy to read.""",
)
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"You are a helpful AI assistant, collaborating with other assistants."
" Use the provided tools to progress towards answering the question."
" If you are unable to fully answer, that's OK; another assistant with different tools"
" will help where you left off. Execute what you can to make progress."
" If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable,"
" prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop."
" You have access to the following tools: {tool_names}.\n{system_message}"
"For your reference, the current date is {current_date}. The current company we want to analyze is {ticker}",
),
MessagesPlaceholder(variable_name="messages"),
]
)
prompt = prompt.partial(system_message=system_message)
prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools]))
prompt = prompt.partial(current_date=current_date)
prompt = prompt.partial(ticker=ticker)
chain = prompt | llm.bind_tools(tools)
result = chain.invoke(state["messages"])
report = ""
if len(result.tool_calls) == 0:
report = result.content
return {
"messages": [result],
"sentiment_report": report,
}
return social_media_analyst_node

View File

@ -1,57 +1,57 @@
import time
import json
def create_research_manager(llm, memory):
def research_manager_node(state) -> dict:
history = state["investment_debate_state"].get("history", "")
market_research_report = state["market_report"]
sentiment_report = state["sentiment_report"]
news_report = state["news_report"]
fundamentals_report = state["fundamentals_report"]
investment_debate_state = state["investment_debate_state"]
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
past_memories = memory.get_memories(curr_situation, n_matches=2)
past_memory_str = ""
for i, rec in enumerate(past_memories, 1):
past_memory_str += rec["recommendation"] + "\n\n"
prompt = f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)
As the portfolio manager and debate facilitator, your role is to critically evaluate this round of debate and make a definitive decision: align with the bear analyst, the bull analyst, or choose Hold only if it is strongly justified based on the arguments presented.
Summarize the key points from both sides concisely, focusing on the most compelling evidence or reasoning. Your recommendationBuy, Sell, or Holdmust be clear and actionable. Avoid defaulting to Hold simply because both sides have valid points; commit to a stance grounded in the debate's strongest arguments.
Additionally, develop a detailed investment plan for the trader. This should include:
Your Recommendation: A decisive stance supported by the most convincing arguments.
Rationale: An explanation of why these arguments lead to your conclusion.
Strategic Actions: Concrete steps for implementing the recommendation.
Take into account your past mistakes on similar situations. Use these insights to refine your decision-making and ensure you are learning and improving. Present your analysis conversationally, as if speaking naturally, without special formatting.
Here are your past reflections on mistakes:
\"{past_memory_str}\"
Here is the debate:
Debate History:
{history}"""
response = llm.invoke(prompt)
new_investment_debate_state = {
"judge_decision": response.content,
"history": investment_debate_state.get("history", ""),
"bear_history": investment_debate_state.get("bear_history", ""),
"bull_history": investment_debate_state.get("bull_history", ""),
"current_response": response.content,
"count": investment_debate_state["count"],
}
return {
"investment_debate_state": new_investment_debate_state,
"investment_plan": response.content,
}
return research_manager_node
import time
import json
def create_research_manager(llm, memory):
def research_manager_node(state) -> dict:
history = state["investment_debate_state"].get("history", "")
market_research_report = state["market_report"]
sentiment_report = state["sentiment_report"]
news_report = state["news_report"]
fundamentals_report = state["fundamentals_report"]
investment_debate_state = state["investment_debate_state"]
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
past_memories = memory.get_memories(curr_situation, n_matches=2)
past_memory_str = ""
for i, rec in enumerate(past_memories, 1):
past_memory_str += rec["recommendation"] + "\n\n"
prompt = f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)
As the portfolio manager and debate facilitator, your role is to critically evaluate this round of debate and make a definitive decision: align with the bear analyst, the bull analyst, or choose Hold only if it is strongly justified based on the arguments presented.
Summarize the key points from both sides concisely, focusing on the most compelling evidence or reasoning. Your recommendationBuy, Sell, or Holdmust be clear and actionable. Avoid defaulting to Hold simply because both sides have valid points; commit to a stance grounded in the debate's strongest arguments.
Additionally, develop a detailed investment plan for the trader. This should include:
Your Recommendation: A decisive stance supported by the most convincing arguments.
Rationale: An explanation of why these arguments lead to your conclusion.
Strategic Actions: Concrete steps for implementing the recommendation.
Take into account your past mistakes on similar situations. Use these insights to refine your decision-making and ensure you are learning and improving. Present your analysis conversationally, as if speaking naturally, without special formatting.
Here are your past reflections on mistakes:
\"{past_memory_str}\"
Here is the debate:
Debate History:
{history}"""
response = llm.invoke(prompt)
new_investment_debate_state = {
"judge_decision": response.content,
"history": investment_debate_state.get("history", ""),
"bear_history": investment_debate_state.get("bear_history", ""),
"bull_history": investment_debate_state.get("bull_history", ""),
"current_response": response.content,
"count": investment_debate_state["count"],
}
return {
"investment_debate_state": new_investment_debate_state,
"investment_plan": response.content,
}
return research_manager_node

View File

@ -1,68 +1,68 @@
import time
import json
def create_risk_manager(llm, memory):
def risk_manager_node(state) -> dict:
company_name = state["company_of_interest"]
history = state["risk_debate_state"]["history"]
risk_debate_state = state["risk_debate_state"]
market_research_report = state["market_report"]
news_report = state["news_report"]
fundamentals_report = state["news_report"]
sentiment_report = state["sentiment_report"]
trader_plan = state["investment_plan"]
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
past_memories = memory.get_memories(curr_situation, n_matches=2)
past_memory_str = ""
for i, rec in enumerate(past_memories, 1):
past_memory_str += rec["recommendation"] + "\n\n"
prompt = f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)
As the Risk Management Judge and Debate Facilitator, your goal is to evaluate the debate between three risk analystsRisky, Neutral, and Safe/Conservativeand determine the best course of action for the trader. Your decision must result in a clear recommendation: Buy, Sell, or Hold. Choose Hold only if strongly justified by specific arguments, not as a fallback when all sides seem valid. Strive for clarity and decisiveness.
Guidelines for Decision-Making:
1. **Summarize Key Arguments**: Extract the strongest points from each analyst, focusing on relevance to the context.
2. **Provide Rationale**: Support your recommendation with direct quotes and counterarguments from the debate.
3. **Refine the Trader's Plan**: Start with the trader's original plan, **{trader_plan}**, and adjust it based on the analysts' insights.
4. **Learn from Past Mistakes**: Use lessons from **{past_memory_str}** to address prior misjudgments and improve the decision you are making now to make sure you don't make a wrong BUY/SELL/HOLD call that loses money.
Deliverables:
- A clear and actionable recommendation: Buy, Sell, or Hold.
- Detailed reasoning anchored in the debate and past reflections.
---
**Analysts Debate History:**
{history}
---
Focus on actionable insights and continuous improvement. Build on past lessons, critically evaluate all perspectives, and ensure each decision advances better outcomes."""
response = llm.invoke(prompt)
new_risk_debate_state = {
"judge_decision": response.content,
"history": risk_debate_state["history"],
"risky_history": risk_debate_state["risky_history"],
"safe_history": risk_debate_state["safe_history"],
"neutral_history": risk_debate_state["neutral_history"],
"latest_speaker": "Judge",
"current_risky_response": risk_debate_state["current_risky_response"],
"current_safe_response": risk_debate_state["current_safe_response"],
"current_neutral_response": risk_debate_state["current_neutral_response"],
"count": risk_debate_state["count"],
}
return {
"risk_debate_state": new_risk_debate_state,
"final_trade_decision": response.content,
}
return risk_manager_node
import time
import json
def create_risk_manager(llm, memory):
def risk_manager_node(state) -> dict:
company_name = state["company_of_interest"]
history = state["risk_debate_state"]["history"]
risk_debate_state = state["risk_debate_state"]
market_research_report = state["market_report"]
news_report = state["news_report"]
fundamentals_report = state["news_report"]
sentiment_report = state["sentiment_report"]
trader_plan = state["investment_plan"]
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
past_memories = memory.get_memories(curr_situation, n_matches=2)
past_memory_str = ""
for i, rec in enumerate(past_memories, 1):
past_memory_str += rec["recommendation"] + "\n\n"
prompt = f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)
As the Risk Management Judge and Debate Facilitator, your goal is to evaluate the debate between three risk analystsRisky, Neutral, and Safe/Conservativeand determine the best course of action for the trader. Your decision must result in a clear recommendation: Buy, Sell, or Hold. Choose Hold only if strongly justified by specific arguments, not as a fallback when all sides seem valid. Strive for clarity and decisiveness.
Guidelines for Decision-Making:
1. **Summarize Key Arguments**: Extract the strongest points from each analyst, focusing on relevance to the context.
2. **Provide Rationale**: Support your recommendation with direct quotes and counterarguments from the debate.
3. **Refine the Trader's Plan**: Start with the trader's original plan, **{trader_plan}**, and adjust it based on the analysts' insights.
4. **Learn from Past Mistakes**: Use lessons from **{past_memory_str}** to address prior misjudgments and improve the decision you are making now to make sure you don't make a wrong BUY/SELL/HOLD call that loses money.
Deliverables:
- A clear and actionable recommendation: Buy, Sell, or Hold.
- Detailed reasoning anchored in the debate and past reflections.
---
**Analysts Debate History:**
{history}
---
Focus on actionable insights and continuous improvement. Build on past lessons, critically evaluate all perspectives, and ensure each decision advances better outcomes."""
response = llm.invoke(prompt)
new_risk_debate_state = {
"judge_decision": response.content,
"history": risk_debate_state["history"],
"risky_history": risk_debate_state["risky_history"],
"safe_history": risk_debate_state["safe_history"],
"neutral_history": risk_debate_state["neutral_history"],
"latest_speaker": "Judge",
"current_risky_response": risk_debate_state["current_risky_response"],
"current_safe_response": risk_debate_state["current_safe_response"],
"current_neutral_response": risk_debate_state["current_neutral_response"],
"count": risk_debate_state["count"],
}
return {
"risk_debate_state": new_risk_debate_state,
"final_trade_decision": response.content,
}
return risk_manager_node

View File

@ -1,63 +1,63 @@
from langchain_core.messages import AIMessage
import time
import json
def create_bear_researcher(llm, memory):
def bear_node(state) -> dict:
investment_debate_state = state["investment_debate_state"]
history = investment_debate_state.get("history", "")
bear_history = investment_debate_state.get("bear_history", "")
current_response = investment_debate_state.get("current_response", "")
market_research_report = state["market_report"]
sentiment_report = state["sentiment_report"]
news_report = state["news_report"]
fundamentals_report = state["fundamentals_report"]
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
past_memories = memory.get_memories(curr_situation, n_matches=2)
past_memory_str = ""
for i, rec in enumerate(past_memories, 1):
past_memory_str += rec["recommendation"] + "\n\n"
prompt = f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)
You are a Bear Analyst making the case against investing in the stock. Your goal is to present a well-reasoned argument emphasizing risks, challenges, and negative indicators. Leverage the provided research and data to highlight potential downsides and counter bullish arguments effectively.
Key points to focus on:
- Risks and Challenges: Highlight factors like market saturation, financial instability, or macroeconomic threats that could hinder the stock's performance.
- Competitive Weaknesses: Emphasize vulnerabilities such as weaker market positioning, declining innovation, or threats from competitors.
- Negative Indicators: Use evidence from financial data, market trends, or recent adverse news to support your position.
- Bull Counterpoints: Critically analyze the bull argument with specific data and sound reasoning, exposing weaknesses or over-optimistic assumptions.
- Engagement: Present your argument in a conversational style, directly engaging with the bull analyst's points and debating effectively rather than simply listing facts.
Resources available:
Market research report: {market_research_report}
Social media sentiment report: {sentiment_report}
Latest world affairs news: {news_report}
Company fundamentals report: {fundamentals_report}
Conversation history of the debate: {history}
Last bull argument: {current_response}
Reflections from similar situations and lessons learned: {past_memory_str}
Use this information to deliver a compelling bear argument, refute the bull's claims, and engage in a dynamic debate that demonstrates the risks and weaknesses of investing in the stock. You must also address reflections and learn from lessons and mistakes you made in the past.
"""
response = llm.invoke(prompt)
argument = f"Bear Analyst: {response.content}"
new_investment_debate_state = {
"history": history + "\n" + argument,
"bear_history": bear_history + "\n" + argument,
"bull_history": investment_debate_state.get("bull_history", ""),
"current_response": argument,
"count": investment_debate_state["count"] + 1,
}
return {"investment_debate_state": new_investment_debate_state}
return bear_node
from langchain_core.messages import AIMessage
import time
import json
def create_bear_researcher(llm, memory):
def bear_node(state) -> dict:
investment_debate_state = state["investment_debate_state"]
history = investment_debate_state.get("history", "")
bear_history = investment_debate_state.get("bear_history", "")
current_response = investment_debate_state.get("current_response", "")
market_research_report = state["market_report"]
sentiment_report = state["sentiment_report"]
news_report = state["news_report"]
fundamentals_report = state["fundamentals_report"]
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
past_memories = memory.get_memories(curr_situation, n_matches=2)
past_memory_str = ""
for i, rec in enumerate(past_memories, 1):
past_memory_str += rec["recommendation"] + "\n\n"
prompt = f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)
You are a Bear Analyst making the case against investing in the stock. Your goal is to present a well-reasoned argument emphasizing risks, challenges, and negative indicators. Leverage the provided research and data to highlight potential downsides and counter bullish arguments effectively.
Key points to focus on:
- Risks and Challenges: Highlight factors like market saturation, financial instability, or macroeconomic threats that could hinder the stock's performance.
- Competitive Weaknesses: Emphasize vulnerabilities such as weaker market positioning, declining innovation, or threats from competitors.
- Negative Indicators: Use evidence from financial data, market trends, or recent adverse news to support your position.
- Bull Counterpoints: Critically analyze the bull argument with specific data and sound reasoning, exposing weaknesses or over-optimistic assumptions.
- Engagement: Present your argument in a conversational style, directly engaging with the bull analyst's points and debating effectively rather than simply listing facts.
Resources available:
Market research report: {market_research_report}
Social media sentiment report: {sentiment_report}
Latest world affairs news: {news_report}
Company fundamentals report: {fundamentals_report}
Conversation history of the debate: {history}
Last bull argument: {current_response}
Reflections from similar situations and lessons learned: {past_memory_str}
Use this information to deliver a compelling bear argument, refute the bull's claims, and engage in a dynamic debate that demonstrates the risks and weaknesses of investing in the stock. You must also address reflections and learn from lessons and mistakes you made in the past.
"""
response = llm.invoke(prompt)
argument = f"Bear Analyst: {response.content}"
new_investment_debate_state = {
"history": history + "\n" + argument,
"bear_history": bear_history + "\n" + argument,
"bull_history": investment_debate_state.get("bull_history", ""),
"current_response": argument,
"count": investment_debate_state["count"] + 1,
}
return {"investment_debate_state": new_investment_debate_state}
return bear_node

View File

@ -1,61 +1,61 @@
from langchain_core.messages import AIMessage
import time
import json
def create_bull_researcher(llm, memory):
def bull_node(state) -> dict:
investment_debate_state = state["investment_debate_state"]
history = investment_debate_state.get("history", "")
bull_history = investment_debate_state.get("bull_history", "")
current_response = investment_debate_state.get("current_response", "")
market_research_report = state["market_report"]
sentiment_report = state["sentiment_report"]
news_report = state["news_report"]
fundamentals_report = state["fundamentals_report"]
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
past_memories = memory.get_memories(curr_situation, n_matches=2)
past_memory_str = ""
for i, rec in enumerate(past_memories, 1):
past_memory_str += rec["recommendation"] + "\n\n"
prompt = f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)
You are a Bull Analyst advocating for investing in the stock. Your task is to build a strong, evidence-based case emphasizing growth potential, competitive advantages, and positive market indicators. Leverage the provided research and data to address concerns and counter bearish arguments effectively.
Key points to focus on:
- Growth Potential: Highlight the company's market opportunities, revenue projections, and scalability.
- Competitive Advantages: Emphasize factors like unique products, strong branding, or dominant market positioning.
- Positive Indicators: Use financial health, industry trends, and recent positive news as evidence.
- Bear Counterpoints: Critically analyze the bear argument with specific data and sound reasoning, addressing concerns thoroughly and showing why the bull perspective holds stronger merit.
- Engagement: Present your argument in a conversational style, engaging directly with the bear analyst's points and debating effectively rather than just listing data.
Resources available:
Market research report: {market_research_report}
Social media sentiment report: {sentiment_report}
Latest world affairs news: {news_report}
Company fundamentals report: {fundamentals_report}
Conversation history of the debate: {history}
Last bear argument: {current_response}
Reflections from similar situations and lessons learned: {past_memory_str}
Use this information to deliver a compelling bull argument, refute the bear's concerns, and engage in a dynamic debate that demonstrates the strengths of the bull position. You must also address reflections and learn from lessons and mistakes you made in the past.
"""
response = llm.invoke(prompt)
argument = f"Bull Analyst: {response.content}"
new_investment_debate_state = {
"history": history + "\n" + argument,
"bull_history": bull_history + "\n" + argument,
"bear_history": investment_debate_state.get("bear_history", ""),
"current_response": argument,
"count": investment_debate_state["count"] + 1,
}
return {"investment_debate_state": new_investment_debate_state}
return bull_node
from langchain_core.messages import AIMessage
import time
import json
def create_bull_researcher(llm, memory):
def bull_node(state) -> dict:
investment_debate_state = state["investment_debate_state"]
history = investment_debate_state.get("history", "")
bull_history = investment_debate_state.get("bull_history", "")
current_response = investment_debate_state.get("current_response", "")
market_research_report = state["market_report"]
sentiment_report = state["sentiment_report"]
news_report = state["news_report"]
fundamentals_report = state["fundamentals_report"]
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
past_memories = memory.get_memories(curr_situation, n_matches=2)
past_memory_str = ""
for i, rec in enumerate(past_memories, 1):
past_memory_str += rec["recommendation"] + "\n\n"
prompt = f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)
You are a Bull Analyst advocating for investing in the stock. Your task is to build a strong, evidence-based case emphasizing growth potential, competitive advantages, and positive market indicators. Leverage the provided research and data to address concerns and counter bearish arguments effectively.
Key points to focus on:
- Growth Potential: Highlight the company's market opportunities, revenue projections, and scalability.
- Competitive Advantages: Emphasize factors like unique products, strong branding, or dominant market positioning.
- Positive Indicators: Use financial health, industry trends, and recent positive news as evidence.
- Bear Counterpoints: Critically analyze the bear argument with specific data and sound reasoning, addressing concerns thoroughly and showing why the bull perspective holds stronger merit.
- Engagement: Present your argument in a conversational style, engaging directly with the bear analyst's points and debating effectively rather than just listing data.
Resources available:
Market research report: {market_research_report}
Social media sentiment report: {sentiment_report}
Latest world affairs news: {news_report}
Company fundamentals report: {fundamentals_report}
Conversation history of the debate: {history}
Last bear argument: {current_response}
Reflections from similar situations and lessons learned: {past_memory_str}
Use this information to deliver a compelling bull argument, refute the bear's concerns, and engage in a dynamic debate that demonstrates the strengths of the bull position. You must also address reflections and learn from lessons and mistakes you made in the past.
"""
response = llm.invoke(prompt)
argument = f"Bull Analyst: {response.content}"
new_investment_debate_state = {
"history": history + "\n" + argument,
"bull_history": bull_history + "\n" + argument,
"bear_history": investment_debate_state.get("bear_history", ""),
"current_response": argument,
"count": investment_debate_state["count"] + 1,
}
return {"investment_debate_state": new_investment_debate_state}
return bull_node

View File

@ -1,57 +1,57 @@
import time
import json
def create_risky_debator(llm):
def risky_node(state) -> dict:
risk_debate_state = state["risk_debate_state"]
history = risk_debate_state.get("history", "")
risky_history = risk_debate_state.get("risky_history", "")
current_safe_response = risk_debate_state.get("current_safe_response", "")
current_neutral_response = risk_debate_state.get("current_neutral_response", "")
market_research_report = state["market_report"]
sentiment_report = state["sentiment_report"]
news_report = state["news_report"]
fundamentals_report = state["fundamentals_report"]
trader_decision = state["trader_investment_plan"]
prompt = f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)
As the Risky Risk Analyst, your role is to actively champion high-reward, high-risk opportunities, emphasizing bold strategies and competitive advantages. When evaluating the trader's decision or plan, focus intently on the potential upside, growth potential, and innovative benefits—even when these come with elevated risk. Use the provided market data and sentiment analysis to strengthen your arguments and challenge the opposing views. Specifically, respond directly to each point made by the conservative and neutral analysts, countering with data-driven rebuttals and persuasive reasoning. Highlight where their caution might miss critical opportunities or where their assumptions may be overly conservative. Here is the trader's decision:
{trader_decision}
Your task is to create a compelling case for the trader's decision by questioning and critiquing the conservative and neutral stances to demonstrate why your high-reward perspective offers the best path forward. Incorporate insights from the following sources into your arguments:
Market Research Report: {market_research_report}
Social Media Sentiment Report: {sentiment_report}
Latest World Affairs Report: {news_report}
Company Fundamentals Report: {fundamentals_report}
Here is the current conversation history: {history} Here are the last arguments from the conservative analyst: {current_safe_response} Here are the last arguments from the neutral analyst: {current_neutral_response}. If there are no responses from the other viewpoints, do not halluncinate and just present your point.
Engage actively by addressing any specific concerns raised, refuting the weaknesses in their logic, and asserting the benefits of risk-taking to outpace market norms. Maintain a focus on debating and persuading, not just presenting data. Challenge each counterpoint to underscore why a high-risk approach is optimal. Output conversationally as if you are speaking without any special formatting."""
response = llm.invoke(prompt)
argument = f"Risky Analyst: {response.content}"
new_risk_debate_state = {
"history": history + "\n" + argument,
"risky_history": risky_history + "\n" + argument,
"safe_history": risk_debate_state.get("safe_history", ""),
"neutral_history": risk_debate_state.get("neutral_history", ""),
"latest_speaker": "Risky",
"current_risky_response": argument,
"current_safe_response": risk_debate_state.get("current_safe_response", ""),
"current_neutral_response": risk_debate_state.get(
"current_neutral_response", ""
),
"count": risk_debate_state["count"] + 1,
}
return {"risk_debate_state": new_risk_debate_state}
return risky_node
import time
import json
def create_risky_debator(llm):
def risky_node(state) -> dict:
risk_debate_state = state["risk_debate_state"]
history = risk_debate_state.get("history", "")
risky_history = risk_debate_state.get("risky_history", "")
current_safe_response = risk_debate_state.get("current_safe_response", "")
current_neutral_response = risk_debate_state.get("current_neutral_response", "")
market_research_report = state["market_report"]
sentiment_report = state["sentiment_report"]
news_report = state["news_report"]
fundamentals_report = state["fundamentals_report"]
trader_decision = state["trader_investment_plan"]
prompt = f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)
As the Risky Risk Analyst, your role is to actively champion high-reward, high-risk opportunities, emphasizing bold strategies and competitive advantages. When evaluating the trader's decision or plan, focus intently on the potential upside, growth potential, and innovative benefits—even when these come with elevated risk. Use the provided market data and sentiment analysis to strengthen your arguments and challenge the opposing views. Specifically, respond directly to each point made by the conservative and neutral analysts, countering with data-driven rebuttals and persuasive reasoning. Highlight where their caution might miss critical opportunities or where their assumptions may be overly conservative. Here is the trader's decision:
{trader_decision}
Your task is to create a compelling case for the trader's decision by questioning and critiquing the conservative and neutral stances to demonstrate why your high-reward perspective offers the best path forward. Incorporate insights from the following sources into your arguments:
Market Research Report: {market_research_report}
Social Media Sentiment Report: {sentiment_report}
Latest World Affairs Report: {news_report}
Company Fundamentals Report: {fundamentals_report}
Here is the current conversation history: {history} Here are the last arguments from the conservative analyst: {current_safe_response} Here are the last arguments from the neutral analyst: {current_neutral_response}. If there are no responses from the other viewpoints, do not halluncinate and just present your point.
Engage actively by addressing any specific concerns raised, refuting the weaknesses in their logic, and asserting the benefits of risk-taking to outpace market norms. Maintain a focus on debating and persuading, not just presenting data. Challenge each counterpoint to underscore why a high-risk approach is optimal. Output conversationally as if you are speaking without any special formatting."""
response = llm.invoke(prompt)
argument = f"Risky Analyst: {response.content}"
new_risk_debate_state = {
"history": history + "\n" + argument,
"risky_history": risky_history + "\n" + argument,
"safe_history": risk_debate_state.get("safe_history", ""),
"neutral_history": risk_debate_state.get("neutral_history", ""),
"latest_speaker": "Risky",
"current_risky_response": argument,
"current_safe_response": risk_debate_state.get("current_safe_response", ""),
"current_neutral_response": risk_debate_state.get(
"current_neutral_response", ""
),
"count": risk_debate_state["count"] + 1,
}
return {"risk_debate_state": new_risk_debate_state}
return risky_node

View File

@ -1,60 +1,60 @@
from langchain_core.messages import AIMessage
import time
import json
def create_safe_debator(llm):
def safe_node(state) -> dict:
risk_debate_state = state["risk_debate_state"]
history = risk_debate_state.get("history", "")
safe_history = risk_debate_state.get("safe_history", "")
current_risky_response = risk_debate_state.get("current_risky_response", "")
current_neutral_response = risk_debate_state.get("current_neutral_response", "")
market_research_report = state["market_report"]
sentiment_report = state["sentiment_report"]
news_report = state["news_report"]
fundamentals_report = state["fundamentals_report"]
trader_decision = state["trader_investment_plan"]
prompt = f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)
As the Safe/Conservative Risk Analyst, your primary objective is to protect assets, minimize volatility, and ensure steady, reliable growth. You prioritize stability, security, and risk mitigation, carefully assessing potential losses, economic downturns, and market volatility. When evaluating the trader's decision or plan, critically examine high-risk elements, pointing out where the decision may expose the firm to undue risk and where more cautious alternatives could secure long-term gains. Here is the trader's decision:
{trader_decision}
Your task is to actively counter the arguments of the Risky and Neutral Analysts, highlighting where their views may overlook potential threats or fail to prioritize sustainability. Respond directly to their points, drawing from the following data sources to build a convincing case for a low-risk approach adjustment to the trader's decision:
Market Research Report: {market_research_report}
Social Media Sentiment Report: {sentiment_report}
Latest World Affairs Report: {news_report}
Company Fundamentals Report: {fundamentals_report}
Here is the current conversation history: {history} Here is the last response from the risky analyst: {current_risky_response} Here is the last response from the neutral analyst: {current_neutral_response}. If there are no responses from the other viewpoints, do not halluncinate and just present your point.
Engage by questioning their optimism and emphasizing the potential downsides they may have overlooked. Address each of their counterpoints to showcase why a conservative stance is ultimately the safest path for the firm's assets. Focus on debating and critiquing their arguments to demonstrate the strength of a low-risk strategy over their approaches. Output conversationally as if you are speaking without any special formatting."""
response = llm.invoke(prompt)
argument = f"Safe Analyst: {response.content}"
new_risk_debate_state = {
"history": history + "\n" + argument,
"risky_history": risk_debate_state.get("risky_history", ""),
"safe_history": safe_history + "\n" + argument,
"neutral_history": risk_debate_state.get("neutral_history", ""),
"latest_speaker": "Safe",
"current_risky_response": risk_debate_state.get(
"current_risky_response", ""
),
"current_safe_response": argument,
"current_neutral_response": risk_debate_state.get(
"current_neutral_response", ""
),
"count": risk_debate_state["count"] + 1,
}
return {"risk_debate_state": new_risk_debate_state}
return safe_node
from langchain_core.messages import AIMessage
import time
import json
def create_safe_debator(llm):
def safe_node(state) -> dict:
risk_debate_state = state["risk_debate_state"]
history = risk_debate_state.get("history", "")
safe_history = risk_debate_state.get("safe_history", "")
current_risky_response = risk_debate_state.get("current_risky_response", "")
current_neutral_response = risk_debate_state.get("current_neutral_response", "")
market_research_report = state["market_report"]
sentiment_report = state["sentiment_report"]
news_report = state["news_report"]
fundamentals_report = state["fundamentals_report"]
trader_decision = state["trader_investment_plan"]
prompt = f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)
As the Safe/Conservative Risk Analyst, your primary objective is to protect assets, minimize volatility, and ensure steady, reliable growth. You prioritize stability, security, and risk mitigation, carefully assessing potential losses, economic downturns, and market volatility. When evaluating the trader's decision or plan, critically examine high-risk elements, pointing out where the decision may expose the firm to undue risk and where more cautious alternatives could secure long-term gains. Here is the trader's decision:
{trader_decision}
Your task is to actively counter the arguments of the Risky and Neutral Analysts, highlighting where their views may overlook potential threats or fail to prioritize sustainability. Respond directly to their points, drawing from the following data sources to build a convincing case for a low-risk approach adjustment to the trader's decision:
Market Research Report: {market_research_report}
Social Media Sentiment Report: {sentiment_report}
Latest World Affairs Report: {news_report}
Company Fundamentals Report: {fundamentals_report}
Here is the current conversation history: {history} Here is the last response from the risky analyst: {current_risky_response} Here is the last response from the neutral analyst: {current_neutral_response}. If there are no responses from the other viewpoints, do not halluncinate and just present your point.
Engage by questioning their optimism and emphasizing the potential downsides they may have overlooked. Address each of their counterpoints to showcase why a conservative stance is ultimately the safest path for the firm's assets. Focus on debating and critiquing their arguments to demonstrate the strength of a low-risk strategy over their approaches. Output conversationally as if you are speaking without any special formatting."""
response = llm.invoke(prompt)
argument = f"Safe Analyst: {response.content}"
new_risk_debate_state = {
"history": history + "\n" + argument,
"risky_history": risk_debate_state.get("risky_history", ""),
"safe_history": safe_history + "\n" + argument,
"neutral_history": risk_debate_state.get("neutral_history", ""),
"latest_speaker": "Safe",
"current_risky_response": risk_debate_state.get(
"current_risky_response", ""
),
"current_safe_response": argument,
"current_neutral_response": risk_debate_state.get(
"current_neutral_response", ""
),
"count": risk_debate_state["count"] + 1,
}
return {"risk_debate_state": new_risk_debate_state}
return safe_node

View File

@ -1,57 +1,57 @@
import time
import json
def create_neutral_debator(llm):
def neutral_node(state) -> dict:
risk_debate_state = state["risk_debate_state"]
history = risk_debate_state.get("history", "")
neutral_history = risk_debate_state.get("neutral_history", "")
current_risky_response = risk_debate_state.get("current_risky_response", "")
current_safe_response = risk_debate_state.get("current_safe_response", "")
market_research_report = state["market_report"]
sentiment_report = state["sentiment_report"]
news_report = state["news_report"]
fundamentals_report = state["fundamentals_report"]
trader_decision = state["trader_investment_plan"]
prompt = f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)
As the Neutral Risk Analyst, your role is to provide a balanced perspective, weighing both the potential benefits and risks of the trader's decision or plan. You prioritize a well-rounded approach, evaluating the upsides and downsides while factoring in broader market trends, potential economic shifts, and diversification strategies.Here is the trader's decision:
{trader_decision}
Your task is to challenge both the Risky and Safe Analysts, pointing out where each perspective may be overly optimistic or overly cautious. Use insights from the following data sources to support a moderate, sustainable strategy to adjust the trader's decision:
Market Research Report: {market_research_report}
Social Media Sentiment Report: {sentiment_report}
Latest World Affairs Report: {news_report}
Company Fundamentals Report: {fundamentals_report}
Here is the current conversation history: {history} Here is the last response from the risky analyst: {current_risky_response} Here is the last response from the safe analyst: {current_safe_response}. If there are no responses from the other viewpoints, do not halluncinate and just present your point.
Engage actively by analyzing both sides critically, addressing weaknesses in the risky and conservative arguments to advocate for a more balanced approach. Challenge each of their points to illustrate why a moderate risk strategy might offer the best of both worlds, providing growth potential while safeguarding against extreme volatility. Focus on debating rather than simply presenting data, aiming to show that a balanced view can lead to the most reliable outcomes. Output conversationally as if you are speaking without any special formatting."""
response = llm.invoke(prompt)
argument = f"Neutral Analyst: {response.content}"
new_risk_debate_state = {
"history": history + "\n" + argument,
"risky_history": risk_debate_state.get("risky_history", ""),
"safe_history": risk_debate_state.get("safe_history", ""),
"neutral_history": neutral_history + "\n" + argument,
"latest_speaker": "Neutral",
"current_risky_response": risk_debate_state.get(
"current_risky_response", ""
),
"current_safe_response": risk_debate_state.get("current_safe_response", ""),
"current_neutral_response": argument,
"count": risk_debate_state["count"] + 1,
}
return {"risk_debate_state": new_risk_debate_state}
return neutral_node
import time
import json
def create_neutral_debator(llm):
def neutral_node(state) -> dict:
risk_debate_state = state["risk_debate_state"]
history = risk_debate_state.get("history", "")
neutral_history = risk_debate_state.get("neutral_history", "")
current_risky_response = risk_debate_state.get("current_risky_response", "")
current_safe_response = risk_debate_state.get("current_safe_response", "")
market_research_report = state["market_report"]
sentiment_report = state["sentiment_report"]
news_report = state["news_report"]
fundamentals_report = state["fundamentals_report"]
trader_decision = state["trader_investment_plan"]
prompt = f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)
As the Neutral Risk Analyst, your role is to provide a balanced perspective, weighing both the potential benefits and risks of the trader's decision or plan. You prioritize a well-rounded approach, evaluating the upsides and downsides while factoring in broader market trends, potential economic shifts, and diversification strategies.Here is the trader's decision:
{trader_decision}
Your task is to challenge both the Risky and Safe Analysts, pointing out where each perspective may be overly optimistic or overly cautious. Use insights from the following data sources to support a moderate, sustainable strategy to adjust the trader's decision:
Market Research Report: {market_research_report}
Social Media Sentiment Report: {sentiment_report}
Latest World Affairs Report: {news_report}
Company Fundamentals Report: {fundamentals_report}
Here is the current conversation history: {history} Here is the last response from the risky analyst: {current_risky_response} Here is the last response from the safe analyst: {current_safe_response}. If there are no responses from the other viewpoints, do not halluncinate and just present your point.
Engage actively by analyzing both sides critically, addressing weaknesses in the risky and conservative arguments to advocate for a more balanced approach. Challenge each of their points to illustrate why a moderate risk strategy might offer the best of both worlds, providing growth potential while safeguarding against extreme volatility. Focus on debating rather than simply presenting data, aiming to show that a balanced view can lead to the most reliable outcomes. Output conversationally as if you are speaking without any special formatting."""
response = llm.invoke(prompt)
argument = f"Neutral Analyst: {response.content}"
new_risk_debate_state = {
"history": history + "\n" + argument,
"risky_history": risk_debate_state.get("risky_history", ""),
"safe_history": risk_debate_state.get("safe_history", ""),
"neutral_history": neutral_history + "\n" + argument,
"latest_speaker": "Neutral",
"current_risky_response": risk_debate_state.get(
"current_risky_response", ""
),
"current_safe_response": risk_debate_state.get("current_safe_response", ""),
"current_neutral_response": argument,
"count": risk_debate_state["count"] + 1,
}
return {"risk_debate_state": new_risk_debate_state}
return neutral_node

View File

@ -1,45 +1,45 @@
import functools
import time
import json
def create_trader(llm, memory):
def trader_node(state, name):
company_name = state["company_of_interest"]
investment_plan = state["investment_plan"]
market_research_report = state["market_report"]
sentiment_report = state["sentiment_report"]
news_report = state["news_report"]
fundamentals_report = state["fundamentals_report"]
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
past_memories = memory.get_memories(curr_situation, n_matches=2)
past_memory_str = ""
for i, rec in enumerate(past_memories, 1):
past_memory_str += rec["recommendation"] + "\n\n"
context = {
"role": "user",
"content": f"Based on a comprehensive analysis by a team of analysts, here is an investment plan tailored for {company_name}. This plan incorporates insights from current technical market trends, macroeconomic indicators, and social media sentiment. Use this plan as a foundation for evaluating your next trading decision.\n\nProposed Investment Plan: {investment_plan}\n\nLeverage these insights to make an informed and strategic decision.",
}
messages = [
{
"role": "system",
"content": f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)
You are a trading agent analyzing market data to make investment decisions. Based on your analysis, provide a specific recommendation to buy, sell, or hold. End with a firm decision and always conclude your response with 'FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL**' to confirm your recommendation. Do not forget to utilize lessons from past decisions to learn from your mistakes. Here is some reflections from similar situatiosn you traded in and the lessons learned: {past_memory_str}""",
},
context,
]
result = llm.invoke(messages)
return {
"messages": [result],
"trader_investment_plan": result.content,
"sender": name,
}
return functools.partial(trader_node, name="Trader")
import functools
import time
import json
def create_trader(llm, memory):
def trader_node(state, name):
company_name = state["company_of_interest"]
investment_plan = state["investment_plan"]
market_research_report = state["market_report"]
sentiment_report = state["sentiment_report"]
news_report = state["news_report"]
fundamentals_report = state["fundamentals_report"]
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
past_memories = memory.get_memories(curr_situation, n_matches=2)
past_memory_str = ""
for i, rec in enumerate(past_memories, 1):
past_memory_str += rec["recommendation"] + "\n\n"
context = {
"role": "user",
"content": f"Based on a comprehensive analysis by a team of analysts, here is an investment plan tailored for {company_name}. This plan incorporates insights from current technical market trends, macroeconomic indicators, and social media sentiment. Use this plan as a foundation for evaluating your next trading decision.\n\nProposed Investment Plan: {investment_plan}\n\nLeverage these insights to make an informed and strategic decision.",
}
messages = [
{
"role": "system",
"content": f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)
You are a trading agent analyzing market data to make investment decisions. Based on your analysis, provide a specific recommendation to buy, sell, or hold. End with a firm decision and always conclude your response with 'FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL**' to confirm your recommendation. Do not forget to utilize lessons from past decisions to learn from your mistakes. Here is some reflections from similar situatiosn you traded in and the lessons learned: {past_memory_str}""",
},
context,
]
result = llm.invoke(messages)
return {
"messages": [result],
"trader_investment_plan": result.content,
"sender": name,
}
return functools.partial(trader_node, name="Trader")

View File

@ -1,20 +1,20 @@
from .embedding_providers import (
EmbeddingProvider,
OpenAIEmbeddingProvider,
GeminiEmbeddingProvider,
OllamaEmbeddingProvider
)
from typing import Any
class EmbeddingProviderFactory:
@staticmethod
def create_provider(config : dict[str, Any])->EmbeddingProvider:
backend_url = config["backend_url"]
if "generativelanguage.googleapis.com" in backend_url:
return GeminiEmbeddingProvider(backend_url)
elif "localhost:11434" in backend_url:
return OllamaEmbeddingProvider(backend_url)
else:
return OpenAIEmbeddingProvider(backend_url)
from .embedding_providers import (
EmbeddingProvider,
OpenAIEmbeddingProvider,
GeminiEmbeddingProvider,
OllamaEmbeddingProvider
)
from typing import Any
class EmbeddingProviderFactory:
@staticmethod
def create_provider(config : dict[str, Any])->EmbeddingProvider:
backend_url = config["backend_url"]
if "generativelanguage.googleapis.com" in backend_url:
return GeminiEmbeddingProvider(backend_url)
elif "localhost:11434" in backend_url:
return OllamaEmbeddingProvider(backend_url)
else:
return OpenAIEmbeddingProvider(backend_url)

View File

@ -1,66 +1,66 @@
from abc import ABC, abstractmethod
from openai import OpenAI
from google import genai
class EmbeddingProvider(ABC):
@abstractmethod
def get_embedding(self, text: str)->list[float]:
pass
@property
@abstractmethod
def model_name(self)->str:
pass
class OpenAIEmbeddingProvider(EmbeddingProvider):
def __init__(self, backend_url: str, embedding_model: str = "text-embedding-3-small"):
self.client = OpenAI(base_url=backend_url)
self._embedding_model = embedding_model
def get_embedding(self, text: str)->list[float]:
response = self.client.embeddings.create(
model=self._embedding_model,
input=text
)
return response.data[0].embedding
@property
def model_name(self)->str:
return self._embedding_model
class GeminiEmbeddingProvider(EmbeddingProvider):
def __init__(self, backend_url: str, embedding_model: str = "gemini-embedding-exp-03-07"):
self.client = genai.Client()
self._embedding_model = embedding_model
def get_embedding(self, text: str)->list[float]:
response = self.client.models.embed_content(
model=self._embedding_model,
contents=text
)
return response.embeddings[0].values
@property
def model_name(self)->str:
return self._embedding_model
class OllamaEmbeddingProvider(EmbeddingProvider):
def __init__(self, backend_url: str, embedding_model: str = "nomic-embed-text"):
self.client = OpenAI(base_url=backend_url)
self._embedding_model = embedding_model
def get_embedding(self, text: str)->list[float]:
response = self.client.embeddings.create(
model=self._embedding_model,
input=text
)
return response.data[0].embedding
@property
def model_name(self)->str:
return self._embedding_model
from abc import ABC, abstractmethod
from openai import OpenAI
from google import genai
class EmbeddingProvider(ABC):
@abstractmethod
def get_embedding(self, text: str)->list[float]:
pass
@property
@abstractmethod
def model_name(self)->str:
pass
class OpenAIEmbeddingProvider(EmbeddingProvider):
def __init__(self, backend_url: str, embedding_model: str = "text-embedding-3-small"):
self.client = OpenAI(base_url=backend_url)
self._embedding_model = embedding_model
def get_embedding(self, text: str)->list[float]:
response = self.client.embeddings.create(
model=self._embedding_model,
input=text
)
return response.data[0].embedding
@property
def model_name(self)->str:
return self._embedding_model
class GeminiEmbeddingProvider(EmbeddingProvider):
def __init__(self, backend_url: str, embedding_model: str = "gemini-embedding-exp-03-07"):
self.client = genai.Client()
self._embedding_model = embedding_model
def get_embedding(self, text: str)->list[float]:
response = self.client.models.embed_content(
model=self._embedding_model,
contents=text
)
return response.embeddings[0].values
@property
def model_name(self)->str:
return self._embedding_model
class OllamaEmbeddingProvider(EmbeddingProvider):
def __init__(self, backend_url: str, embedding_model: str = "nomic-embed-text"):
self.client = OpenAI(base_url=backend_url)
self._embedding_model = embedding_model
def get_embedding(self, text: str)->list[float]:
response = self.client.embeddings.create(
model=self._embedding_model,
input=text
)
return response.data[0].embedding
@property
def model_name(self)->str:
return self._embedding_model

View File

@ -1,112 +1,112 @@
import chromadb
from chromadb.config import Settings
from openai import OpenAI
import os
from .embedding_provider_factory import EmbeddingProviderFactory
from google import genai
class FinancialSituationMemory:
def __init__(self, name, config):
self.config = config
self.backend_url = config["backend_url"]
self.embedding_provider = EmbeddingProviderFactory.create_provider(config)
self.chroma_client = chromadb.Client(Settings(allow_reset=True))
self.situation_collection = self.chroma_client.create_collection(name=name)
def get_embedding(self, text):
"""Get embedding for a text using the appropriate API"""
return self.embedding_provider.get_embedding(text)
def add_situations(self, situations_and_advice):
"""Add financial situations and their corresponding advice. Parameter is a list of tuples (situation, rec)"""
situations = []
advice = []
ids = []
embeddings = []
offset = self.situation_collection.count()
for i, (situation, recommendation) in enumerate(situations_and_advice):
situations.append(situation)
advice.append(recommendation)
ids.append(str(offset + i))
embeddings.append(self.get_embedding(situation))
self.situation_collection.add(
documents=situations,
metadatas=[{"recommendation": rec} for rec in advice],
embeddings=embeddings,
ids=ids,
)
def get_memories(self, current_situation, n_matches=1):
"""Find matching recommendations using embeddings"""
query_embedding = self.get_embedding(current_situation)
results = self.situation_collection.query(
query_embeddings=[query_embedding],
n_results=n_matches,
include=["metadatas", "documents", "distances"],
)
matched_results = []
for i in range(len(results["documents"][0])):
matched_results.append(
{
"matched_situation": results["documents"][0][i],
"recommendation": results["metadatas"][0][i]["recommendation"],
"similarity_score": 1 - results["distances"][0][i],
}
)
return matched_results
if __name__ == "__main__":
# Example usage
matcher = FinancialSituationMemory()
# Example data
example_data = [
(
"High inflation rate with rising interest rates and declining consumer spending",
"Consider defensive sectors like consumer staples and utilities. Review fixed-income portfolio duration.",
),
(
"Tech sector showing high volatility with increasing institutional selling pressure",
"Reduce exposure to high-growth tech stocks. Look for value opportunities in established tech companies with strong cash flows.",
),
(
"Strong dollar affecting emerging markets with increasing forex volatility",
"Hedge currency exposure in international positions. Consider reducing allocation to emerging market debt.",
),
(
"Market showing signs of sector rotation with rising yields",
"Rebalance portfolio to maintain target allocations. Consider increasing exposure to sectors benefiting from higher rates.",
),
]
# Add the example situations and recommendations
matcher.add_situations(example_data)
# Example query
current_situation = """
Market showing increased volatility in tech sector, with institutional investors
reducing positions and rising interest rates affecting growth stock valuations
"""
try:
recommendations = matcher.get_memories(current_situation, n_matches=2)
for i, rec in enumerate(recommendations, 1):
print(f"\nMatch {i}:")
print(f"Similarity Score: {rec['similarity_score']:.2f}")
print(f"Matched Situation: {rec['matched_situation']}")
print(f"Recommendation: {rec['recommendation']}")
except Exception as e:
print(f"Error during recommendation: {str(e)}")
import chromadb
from chromadb.config import Settings
from openai import OpenAI
import os
from .embedding_provider_factory import EmbeddingProviderFactory
from google import genai
class FinancialSituationMemory:
def __init__(self, name, config):
self.config = config
self.backend_url = config["backend_url"]
self.embedding_provider = EmbeddingProviderFactory.create_provider(config)
self.chroma_client = chromadb.Client(Settings(allow_reset=True))
self.situation_collection = self.chroma_client.create_collection(name=name)
def get_embedding(self, text):
"""Get embedding for a text using the appropriate API"""
return self.embedding_provider.get_embedding(text)
def add_situations(self, situations_and_advice):
"""Add financial situations and their corresponding advice. Parameter is a list of tuples (situation, rec)"""
situations = []
advice = []
ids = []
embeddings = []
offset = self.situation_collection.count()
for i, (situation, recommendation) in enumerate(situations_and_advice):
situations.append(situation)
advice.append(recommendation)
ids.append(str(offset + i))
embeddings.append(self.get_embedding(situation))
self.situation_collection.add(
documents=situations,
metadatas=[{"recommendation": rec} for rec in advice],
embeddings=embeddings,
ids=ids,
)
def get_memories(self, current_situation, n_matches=1):
"""Find matching recommendations using embeddings"""
query_embedding = self.get_embedding(current_situation)
results = self.situation_collection.query(
query_embeddings=[query_embedding],
n_results=n_matches,
include=["metadatas", "documents", "distances"],
)
matched_results = []
for i in range(len(results["documents"][0])):
matched_results.append(
{
"matched_situation": results["documents"][0][i],
"recommendation": results["metadatas"][0][i]["recommendation"],
"similarity_score": 1 - results["distances"][0][i],
}
)
return matched_results
if __name__ == "__main__":
# Example usage
matcher = FinancialSituationMemory()
# Example data
example_data = [
(
"High inflation rate with rising interest rates and declining consumer spending",
"Consider defensive sectors like consumer staples and utilities. Review fixed-income portfolio duration.",
),
(
"Tech sector showing high volatility with increasing institutional selling pressure",
"Reduce exposure to high-growth tech stocks. Look for value opportunities in established tech companies with strong cash flows.",
),
(
"Strong dollar affecting emerging markets with increasing forex volatility",
"Hedge currency exposure in international positions. Consider reducing allocation to emerging market debt.",
),
(
"Market showing signs of sector rotation with rising yields",
"Rebalance portfolio to maintain target allocations. Consider increasing exposure to sectors benefiting from higher rates.",
),
]
# Add the example situations and recommendations
matcher.add_situations(example_data)
# Example query
current_situation = """
Market showing increased volatility in tech sector, with institutional investors
reducing positions and rising interest rates affecting growth stock valuations
"""
try:
recommendations = matcher.get_memories(current_situation, n_matches=2)
for i, rec in enumerate(recommendations, 1):
print(f"\nMatch {i}:")
print(f"Similarity Score: {rec['similarity_score']:.2f}")
print(f"Matched Situation: {rec['matched_situation']}")
print(f"Recommendation: {rec['recommendation']}")
except Exception as e:
print(f"Error during recommendation: {str(e)}")

View File

@ -1,76 +1,76 @@
from google import genai
from google.genai.types import Tool, GenerateContentConfig, GoogleSearch
from openai import OpenAI
from abc import ABC, abstractmethod
class SearchProvider(ABC):
@abstractmethod
def search(self, query: str, ticker: str, curr_date: str) -> str:
pass
class GoogleSearchProvider(SearchProvider):
def __init__(self, model: str):
self.client = genai.Client()
self.model = model
def search(self, query: str) -> str:
google_search_tool = Tool(
google_search=GoogleSearch()
)
response = self.client.models.generate_content(
model=self.model,
contents=query,
config=GenerateContentConfig(
tools=[google_search_tool],
response_modalities=["TEXT"]
)
)
result_text = ""
for part in response.candidates[0].content.parts:
if hasattr(part, 'text'):
result_text += part.text
return result_text
class OpenAISearchProvider(SearchProvider):
def __init__(self, model: str, backend_url: str):
self.client = OpenAI(base_url=backend_url)
self.model = model
def search(self, query: str) -> str:
response = self.client.responses.create(
model=self.model,
input=[
{
"role": "system",
"content": [
{
"type": "input_text",
"text": query
}
],
}
],
text={"format": {"type": "text"}},
reasoning={},
tools=[
{
"type": "web_search_preview",
"user_location": {"type": "approximate"},
"search_context_size": "low",
}
],
temperature=1,
max_output_tokens=4096,
top_p=1,
store=True,
)
from google import genai
from google.genai.types import Tool, GenerateContentConfig, GoogleSearch
from openai import OpenAI
from abc import ABC, abstractmethod
class SearchProvider(ABC):
@abstractmethod
def search(self, query: str, ticker: str, curr_date: str) -> str:
pass
class GoogleSearchProvider(SearchProvider):
def __init__(self, model: str):
self.client = genai.Client()
self.model = model
def search(self, query: str) -> str:
google_search_tool = Tool(
google_search=GoogleSearch()
)
response = self.client.models.generate_content(
model=self.model,
contents=query,
config=GenerateContentConfig(
tools=[google_search_tool],
response_modalities=["TEXT"]
)
)
result_text = ""
for part in response.candidates[0].content.parts:
if hasattr(part, 'text'):
result_text += part.text
return result_text
class OpenAISearchProvider(SearchProvider):
def __init__(self, model: str, backend_url: str):
self.client = OpenAI(base_url=backend_url)
self.model = model
def search(self, query: str) -> str:
response = self.client.responses.create(
model=self.model,
input=[
{
"role": "system",
"content": [
{
"type": "input_text",
"text": query
}
],
}
],
text={"format": {"type": "text"}},
reasoning={},
tools=[
{
"type": "web_search_preview",
"user_location": {"type": "approximate"},
"search_context_size": "low",
}
],
temperature=1,
max_output_tokens=4096,
top_p=1,
store=True,
)
return response.output[1].content[0].text

View File

@ -1,47 +1,47 @@
from .search_provider import (
SearchProvider,
GoogleSearchProvider,
OpenAISearchProvider
)
import hashlib
import json
class SearchProviderFactory:
_cache = {} # 클래스 레벨 캐시
@staticmethod
def create_provider(config: dict[str, any]) -> SearchProvider:
"""
Create a SearchProvider with caching to avoid creating new instances.
Uses config hash as cache key for efficient reuse.
"""
# Create cache key from relevant config values
cache_key_data = {
"backend_url": config["backend_url"],
"model": config["quick_think_llm"]
}
cache_key = hashlib.md5(json.dumps(cache_key_data, sort_keys=True).encode()).hexdigest()
# Return cached instance if exists
if cache_key in SearchProviderFactory._cache:
return SearchProviderFactory._cache[cache_key]
# Create new instance
backend_url = config["backend_url"]
model = config["quick_think_llm"]
if "generativelanguage.googleapis.com" in backend_url:
provider = GoogleSearchProvider(model)
else:
provider = OpenAISearchProvider(model, backend_url)
# Cache and return
SearchProviderFactory._cache[cache_key] = provider
return provider
@staticmethod
def clear_cache():
"""Clear the provider cache (useful for testing or config changes)."""
SearchProviderFactory._cache.clear()
from .search_provider import (
SearchProvider,
GoogleSearchProvider,
OpenAISearchProvider
)
import hashlib
import json
class SearchProviderFactory:
_cache = {} # 클래스 레벨 캐시
@staticmethod
def create_provider(config: dict[str, any]) -> SearchProvider:
"""
Create a SearchProvider with caching to avoid creating new instances.
Uses config hash as cache key for efficient reuse.
"""
# Create cache key from relevant config values
cache_key_data = {
"backend_url": config["backend_url"],
"model": config["quick_think_llm"]
}
cache_key = hashlib.md5(json.dumps(cache_key_data, sort_keys=True).encode()).hexdigest()
# Return cached instance if exists
if cache_key in SearchProviderFactory._cache:
return SearchProviderFactory._cache[cache_key]
# Create new instance
backend_url = config["backend_url"]
model = config["quick_think_llm"]
if "generativelanguage.googleapis.com" in backend_url:
provider = GoogleSearchProvider(model)
else:
provider = OpenAISearchProvider(model, backend_url)
# Cache and return
SearchProviderFactory._cache[cache_key] = provider
return provider
@staticmethod
def clear_cache():
"""Clear the provider cache (useful for testing or config changes)."""
SearchProviderFactory._cache.clear()

View File

@ -1,31 +1,31 @@
# TradingAgents/graph/signal_processing.py
from langchain_openai import ChatOpenAI
class SignalProcessor:
"""Processes trading signals to extract actionable decisions."""
def __init__(self, quick_thinking_llm: ChatOpenAI):
"""Initialize with an LLM for processing."""
self.quick_thinking_llm = quick_thinking_llm
def process_signal(self, full_signal: str) -> str:
"""
Process a full trading signal to extract the core decision.
Args:
full_signal: Complete trading signal text
Returns:
Extracted decision (BUY, SELL, or HOLD)
"""
messages = [
(
"system",
"**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)\n\nYou are an efficient assistant designed to analyze paragraphs or financial reports provided by a group of analysts. Your task is to extract the investment decision: SELL, BUY, or HOLD. Provide only the extracted decision (SELL, BUY, or HOLD) as your output, without adding any additional text or information.",
),
("human", full_signal),
]
return self.quick_thinking_llm.invoke(messages).content
# TradingAgents/graph/signal_processing.py
from langchain_openai import ChatOpenAI
class SignalProcessor:
"""Processes trading signals to extract actionable decisions."""
def __init__(self, quick_thinking_llm: ChatOpenAI):
"""Initialize with an LLM for processing."""
self.quick_thinking_llm = quick_thinking_llm
def process_signal(self, full_signal: str) -> str:
"""
Process a full trading signal to extract the core decision.
Args:
full_signal: Complete trading signal text
Returns:
Extracted decision (BUY, SELL, or HOLD)
"""
messages = [
(
"system",
"**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)\n\nYou are an efficient assistant designed to analyze paragraphs or financial reports provided by a group of analysts. Your task is to extract the investment decision: SELL, BUY, or HOLD. Provide only the extracted decision (SELL, BUY, or HOLD) as your output, without adding any additional text or information.",
),
("human", full_signal),
]
return self.quick_thinking_llm.invoke(messages).content

10810
uv.lock

File diff suppressed because it is too large Load Diff