dev/web
This commit is contained in:
parent
fbd96e9c18
commit
ab1b0120c2
|
|
@ -1,10 +1,11 @@
|
||||||
env/
|
env/
|
||||||
__pycache__/
|
__pycache__/
|
||||||
.DS_Store
|
.DS_Store
|
||||||
*.csv
|
*.csv
|
||||||
src/
|
src/
|
||||||
eval_results/
|
eval_results/
|
||||||
eval_data/
|
eval_data/
|
||||||
*.egg-info/
|
*.egg-info/
|
||||||
results/
|
results/
|
||||||
.env
|
.env
|
||||||
|
tradingagents/dataflows/data_cache/
|
||||||
|
|
@ -1 +1 @@
|
||||||
3.10
|
3.10
|
||||||
|
|
|
||||||
|
|
@ -1,15 +1,15 @@
|
||||||
{
|
{
|
||||||
// Use IntelliSense to learn about possible attributes.
|
// Use IntelliSense to learn about possible attributes.
|
||||||
// Hover to view descriptions of existing attributes.
|
// Hover to view descriptions of existing attributes.
|
||||||
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||||
"version": "0.2.0",
|
"version": "0.2.0",
|
||||||
"configurations": [
|
"configurations": [
|
||||||
{
|
{
|
||||||
"name": "Python Debugger: main.py",
|
"name": "Python Debugger: main.py",
|
||||||
"type": "debugpy",
|
"type": "debugpy",
|
||||||
"request": "launch",
|
"request": "launch",
|
||||||
"program": "${workspaceFolder}/main.py",
|
"program": "${workspaceFolder}/main.py",
|
||||||
"console": "integratedTerminal"
|
"console": "integratedTerminal"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
430
README.md
430
README.md
|
|
@ -1,215 +1,215 @@
|
||||||
<p align="center">
|
<p align="center">
|
||||||
<img src="assets/TauricResearch.png" style="width: 60%; height: auto;">
|
<img src="assets/TauricResearch.png" style="width: 60%; height: auto;">
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
<div align="center" style="line-height: 1;">
|
<div align="center" style="line-height: 1;">
|
||||||
<a href="https://arxiv.org/abs/2412.20138" target="_blank"><img alt="arXiv" src="https://img.shields.io/badge/arXiv-2412.20138-B31B1B?logo=arxiv"/></a>
|
<a href="https://arxiv.org/abs/2412.20138" target="_blank"><img alt="arXiv" src="https://img.shields.io/badge/arXiv-2412.20138-B31B1B?logo=arxiv"/></a>
|
||||||
<a href="https://discord.com/invite/hk9PGKShPK" target="_blank"><img alt="Discord" src="https://img.shields.io/badge/Discord-TradingResearch-7289da?logo=discord&logoColor=white&color=7289da"/></a>
|
<a href="https://discord.com/invite/hk9PGKShPK" target="_blank"><img alt="Discord" src="https://img.shields.io/badge/Discord-TradingResearch-7289da?logo=discord&logoColor=white&color=7289da"/></a>
|
||||||
<a href="./assets/wechat.png" target="_blank"><img alt="WeChat" src="https://img.shields.io/badge/WeChat-TauricResearch-brightgreen?logo=wechat&logoColor=white"/></a>
|
<a href="./assets/wechat.png" target="_blank"><img alt="WeChat" src="https://img.shields.io/badge/WeChat-TauricResearch-brightgreen?logo=wechat&logoColor=white"/></a>
|
||||||
<a href="https://x.com/TauricResearch" target="_blank"><img alt="X Follow" src="https://img.shields.io/badge/X-TauricResearch-white?logo=x&logoColor=white"/></a>
|
<a href="https://x.com/TauricResearch" target="_blank"><img alt="X Follow" src="https://img.shields.io/badge/X-TauricResearch-white?logo=x&logoColor=white"/></a>
|
||||||
<br>
|
<br>
|
||||||
<a href="https://github.com/TauricResearch/" target="_blank"><img alt="Community" src="https://img.shields.io/badge/Join_GitHub_Community-TauricResearch-14C290?logo=discourse"/></a>
|
<a href="https://github.com/TauricResearch/" target="_blank"><img alt="Community" src="https://img.shields.io/badge/Join_GitHub_Community-TauricResearch-14C290?logo=discourse"/></a>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div align="center">
|
<div align="center">
|
||||||
<!-- Keep these links. Translations will automatically update with the README. -->
|
<!-- Keep these links. Translations will automatically update with the README. -->
|
||||||
<a href="https://www.readme-i18n.com/TauricResearch/TradingAgents?lang=de">Deutsch</a> |
|
<a href="https://www.readme-i18n.com/TauricResearch/TradingAgents?lang=de">Deutsch</a> |
|
||||||
<a href="https://www.readme-i18n.com/TauricResearch/TradingAgents?lang=es">Español</a> |
|
<a href="https://www.readme-i18n.com/TauricResearch/TradingAgents?lang=es">Español</a> |
|
||||||
<a href="https://www.readme-i18n.com/TauricResearch/TradingAgents?lang=fr">français</a> |
|
<a href="https://www.readme-i18n.com/TauricResearch/TradingAgents?lang=fr">français</a> |
|
||||||
<a href="https://www.readme-i18n.com/TauricResearch/TradingAgents?lang=ja">日本語</a> |
|
<a href="https://www.readme-i18n.com/TauricResearch/TradingAgents?lang=ja">日本語</a> |
|
||||||
<a href="https://www.readme-i18n.com/TauricResearch/TradingAgents?lang=ko">한국어</a> |
|
<a href="https://www.readme-i18n.com/TauricResearch/TradingAgents?lang=ko">한국어</a> |
|
||||||
<a href="https://www.readme-i18n.com/TauricResearch/TradingAgents?lang=pt">Português</a> |
|
<a href="https://www.readme-i18n.com/TauricResearch/TradingAgents?lang=pt">Português</a> |
|
||||||
<a href="https://www.readme-i18n.com/TauricResearch/TradingAgents?lang=ru">Русский</a> |
|
<a href="https://www.readme-i18n.com/TauricResearch/TradingAgents?lang=ru">Русский</a> |
|
||||||
<a href="https://www.readme-i18n.com/TauricResearch/TradingAgents?lang=zh">中文</a>
|
<a href="https://www.readme-i18n.com/TauricResearch/TradingAgents?lang=zh">中文</a>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
# TradingAgents: Multi-Agents LLM Financial Trading Framework
|
# TradingAgents: Multi-Agents LLM Financial Trading Framework
|
||||||
|
|
||||||
> 🎉 **TradingAgents** officially released! We have received numerous inquiries about the work, and we would like to express our thanks for the enthusiasm in our community.
|
> 🎉 **TradingAgents** officially released! We have received numerous inquiries about the work, and we would like to express our thanks for the enthusiasm in our community.
|
||||||
>
|
>
|
||||||
> So we decided to fully open-source the framework. Looking forward to building impactful projects with you!
|
> So we decided to fully open-source the framework. Looking forward to building impactful projects with you!
|
||||||
|
|
||||||
<div align="center">
|
<div align="center">
|
||||||
<a href="https://www.star-history.com/#TauricResearch/TradingAgents&Date">
|
<a href="https://www.star-history.com/#TauricResearch/TradingAgents&Date">
|
||||||
<picture>
|
<picture>
|
||||||
<source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=TauricResearch/TradingAgents&type=Date&theme=dark" />
|
<source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=TauricResearch/TradingAgents&type=Date&theme=dark" />
|
||||||
<source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=TauricResearch/TradingAgents&type=Date" />
|
<source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=TauricResearch/TradingAgents&type=Date" />
|
||||||
<img alt="TradingAgents Star History" src="https://api.star-history.com/svg?repos=TauricResearch/TradingAgents&type=Date" style="width: 80%; height: auto;" />
|
<img alt="TradingAgents Star History" src="https://api.star-history.com/svg?repos=TauricResearch/TradingAgents&type=Date" style="width: 80%; height: auto;" />
|
||||||
</picture>
|
</picture>
|
||||||
</a>
|
</a>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div align="center">
|
<div align="center">
|
||||||
|
|
||||||
🚀 [TradingAgents](#tradingagents-framework) | ⚡ [Installation & CLI](#installation-and-cli) | 🎬 [Demo](https://www.youtube.com/watch?v=90gr5lwjIho) | 📦 [Package Usage](#tradingagents-package) | 🤝 [Contributing](#contributing) | 📄 [Citation](#citation)
|
🚀 [TradingAgents](#tradingagents-framework) | ⚡ [Installation & CLI](#installation-and-cli) | 🎬 [Demo](https://www.youtube.com/watch?v=90gr5lwjIho) | 📦 [Package Usage](#tradingagents-package) | 🤝 [Contributing](#contributing) | 📄 [Citation](#citation)
|
||||||
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
## TradingAgents Framework
|
## TradingAgents Framework
|
||||||
|
|
||||||
TradingAgents is a multi-agent trading framework that mirrors the dynamics of real-world trading firms. By deploying specialized LLM-powered agents: from fundamental analysts, sentiment experts, and technical analysts, to trader, risk management team, the platform collaboratively evaluates market conditions and informs trading decisions. Moreover, these agents engage in dynamic discussions to pinpoint the optimal strategy.
|
TradingAgents is a multi-agent trading framework that mirrors the dynamics of real-world trading firms. By deploying specialized LLM-powered agents: from fundamental analysts, sentiment experts, and technical analysts, to trader, risk management team, the platform collaboratively evaluates market conditions and informs trading decisions. Moreover, these agents engage in dynamic discussions to pinpoint the optimal strategy.
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
<img src="assets/schema.png" style="width: 100%; height: auto;">
|
<img src="assets/schema.png" style="width: 100%; height: auto;">
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
> TradingAgents framework is designed for research purposes. Trading performance may vary based on many factors, including the chosen backbone language models, model temperature, trading periods, the quality of data, and other non-deterministic factors. [It is not intended as financial, investment, or trading advice.](https://tauric.ai/disclaimer/)
|
> TradingAgents framework is designed for research purposes. Trading performance may vary based on many factors, including the chosen backbone language models, model temperature, trading periods, the quality of data, and other non-deterministic factors. [It is not intended as financial, investment, or trading advice.](https://tauric.ai/disclaimer/)
|
||||||
|
|
||||||
Our framework decomposes complex trading tasks into specialized roles. This ensures the system achieves a robust, scalable approach to market analysis and decision-making.
|
Our framework decomposes complex trading tasks into specialized roles. This ensures the system achieves a robust, scalable approach to market analysis and decision-making.
|
||||||
|
|
||||||
### Analyst Team
|
### Analyst Team
|
||||||
- Fundamentals Analyst: Evaluates company financials and performance metrics, identifying intrinsic values and potential red flags.
|
- Fundamentals Analyst: Evaluates company financials and performance metrics, identifying intrinsic values and potential red flags.
|
||||||
- Sentiment Analyst: Analyzes social media and public sentiment using sentiment scoring algorithms to gauge short-term market mood.
|
- Sentiment Analyst: Analyzes social media and public sentiment using sentiment scoring algorithms to gauge short-term market mood.
|
||||||
- News Analyst: Monitors global news and macroeconomic indicators, interpreting the impact of events on market conditions.
|
- News Analyst: Monitors global news and macroeconomic indicators, interpreting the impact of events on market conditions.
|
||||||
- Technical Analyst: Utilizes technical indicators (like MACD and RSI) to detect trading patterns and forecast price movements.
|
- Technical Analyst: Utilizes technical indicators (like MACD and RSI) to detect trading patterns and forecast price movements.
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
<img src="assets/analyst.png" width="100%" style="display: inline-block; margin: 0 2%;">
|
<img src="assets/analyst.png" width="100%" style="display: inline-block; margin: 0 2%;">
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
### Researcher Team
|
### Researcher Team
|
||||||
- Comprises both bullish and bearish researchers who critically assess the insights provided by the Analyst Team. Through structured debates, they balance potential gains against inherent risks.
|
- Comprises both bullish and bearish researchers who critically assess the insights provided by the Analyst Team. Through structured debates, they balance potential gains against inherent risks.
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
<img src="assets/researcher.png" width="70%" style="display: inline-block; margin: 0 2%;">
|
<img src="assets/researcher.png" width="70%" style="display: inline-block; margin: 0 2%;">
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
### Trader Agent
|
### Trader Agent
|
||||||
- Composes reports from the analysts and researchers to make informed trading decisions. It determines the timing and magnitude of trades based on comprehensive market insights.
|
- Composes reports from the analysts and researchers to make informed trading decisions. It determines the timing and magnitude of trades based on comprehensive market insights.
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
<img src="assets/trader.png" width="70%" style="display: inline-block; margin: 0 2%;">
|
<img src="assets/trader.png" width="70%" style="display: inline-block; margin: 0 2%;">
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
### Risk Management and Portfolio Manager
|
### Risk Management and Portfolio Manager
|
||||||
- Continuously evaluates portfolio risk by assessing market volatility, liquidity, and other risk factors. The risk management team evaluates and adjusts trading strategies, providing assessment reports to the Portfolio Manager for final decision.
|
- Continuously evaluates portfolio risk by assessing market volatility, liquidity, and other risk factors. The risk management team evaluates and adjusts trading strategies, providing assessment reports to the Portfolio Manager for final decision.
|
||||||
- The Portfolio Manager approves/rejects the transaction proposal. If approved, the order will be sent to the simulated exchange and executed.
|
- The Portfolio Manager approves/rejects the transaction proposal. If approved, the order will be sent to the simulated exchange and executed.
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
<img src="assets/risk.png" width="70%" style="display: inline-block; margin: 0 2%;">
|
<img src="assets/risk.png" width="70%" style="display: inline-block; margin: 0 2%;">
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
## Installation and CLI
|
## Installation and CLI
|
||||||
|
|
||||||
### Installation
|
### Installation
|
||||||
|
|
||||||
Clone TradingAgents:
|
Clone TradingAgents:
|
||||||
```bash
|
```bash
|
||||||
git clone https://github.com/TauricResearch/TradingAgents.git
|
git clone https://github.com/TauricResearch/TradingAgents.git
|
||||||
cd TradingAgents
|
cd TradingAgents
|
||||||
```
|
```
|
||||||
|
|
||||||
Create a virtual environment in any of your favorite environment managers:
|
Create a virtual environment in any of your favorite environment managers:
|
||||||
```bash
|
```bash
|
||||||
conda create -n tradingagents python=3.13
|
conda create -n tradingagents python=3.13
|
||||||
conda activate tradingagents
|
conda activate tradingagents
|
||||||
```
|
```
|
||||||
|
|
||||||
Install dependencies:
|
Install dependencies:
|
||||||
```bash
|
```bash
|
||||||
pip install -r requirements.txt
|
pip install -r requirements.txt
|
||||||
```
|
```
|
||||||
|
|
||||||
### Required APIs
|
### Required APIs
|
||||||
|
|
||||||
You will also need the FinnHub API for financial data. All of our code is implemented with the free tier.
|
You will also need the FinnHub API for financial data. All of our code is implemented with the free tier.
|
||||||
```bash
|
```bash
|
||||||
export FINNHUB_API_KEY=$YOUR_FINNHUB_API_KEY
|
export FINNHUB_API_KEY=$YOUR_FINNHUB_API_KEY
|
||||||
```
|
```
|
||||||
|
|
||||||
You will need the OpenAI API or GEMINI API for all the agents.
|
You will need the OpenAI API or GEMINI API for all the agents.
|
||||||
```bash
|
```bash
|
||||||
export OPENAI_API_KEY=$YOUR_OPENAI_API_KEY
|
export OPENAI_API_KEY=$YOUR_OPENAI_API_KEY
|
||||||
export GEMINI_API_KEY=$YOUR_GEMINI_API_KEY
|
export GEMINI_API_KEY=$YOUR_GEMINI_API_KEY
|
||||||
export GOOGLE_API_KEY=$YOUR_GEMINI_API_KEY
|
export GOOGLE_API_KEY=$YOUR_GEMINI_API_KEY
|
||||||
```
|
```
|
||||||
|
|
||||||
### CLI Usage
|
### CLI Usage
|
||||||
|
|
||||||
You can also try out the CLI directly by running:
|
You can also try out the CLI directly by running:
|
||||||
```bash
|
```bash
|
||||||
python -m cli.main
|
python -m cli.main
|
||||||
```
|
```
|
||||||
You will see a screen where you can select your desired tickers, date, LLMs, research depth, etc.
|
You will see a screen where you can select your desired tickers, date, LLMs, research depth, etc.
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
<img src="assets/cli/cli_init.png" width="100%" style="display: inline-block; margin: 0 2%;">
|
<img src="assets/cli/cli_init.png" width="100%" style="display: inline-block; margin: 0 2%;">
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
An interface will appear showing results as they load, letting you track the agent's progress as it runs.
|
An interface will appear showing results as they load, letting you track the agent's progress as it runs.
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
<img src="assets/cli/cli_news.png" width="100%" style="display: inline-block; margin: 0 2%;">
|
<img src="assets/cli/cli_news.png" width="100%" style="display: inline-block; margin: 0 2%;">
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
<img src="assets/cli/cli_transaction.png" width="100%" style="display: inline-block; margin: 0 2%;">
|
<img src="assets/cli/cli_transaction.png" width="100%" style="display: inline-block; margin: 0 2%;">
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
## TradingAgents Package
|
## TradingAgents Package
|
||||||
|
|
||||||
### Implementation Details
|
### Implementation Details
|
||||||
|
|
||||||
We built TradingAgents with LangGraph to ensure flexibility and modularity. We utilize `o1-preview` and `gpt-4o` as our deep thinking and fast thinking LLMs for our experiments. However, for testing purposes, we recommend you use `o4-mini` and `gpt-4.1-mini` to save on costs as our framework makes **lots of** API calls.
|
We built TradingAgents with LangGraph to ensure flexibility and modularity. We utilize `o1-preview` and `gpt-4o` as our deep thinking and fast thinking LLMs for our experiments. However, for testing purposes, we recommend you use `o4-mini` and `gpt-4.1-mini` to save on costs as our framework makes **lots of** API calls.
|
||||||
|
|
||||||
### Python Usage
|
### Python Usage
|
||||||
|
|
||||||
To use TradingAgents inside your code, you can import the `tradingagents` module and initialize a `TradingAgentsGraph()` object. The `.propagate()` function will return a decision. You can run `main.py`, here's also a quick example:
|
To use TradingAgents inside your code, you can import the `tradingagents` module and initialize a `TradingAgentsGraph()` object. The `.propagate()` function will return a decision. You can run `main.py`, here's also a quick example:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||||
from tradingagents.default_config import DEFAULT_CONFIG
|
from tradingagents.default_config import DEFAULT_CONFIG
|
||||||
|
|
||||||
ta = TradingAgentsGraph(debug=True, config=DEFAULT_CONFIG.copy())
|
ta = TradingAgentsGraph(debug=True, config=DEFAULT_CONFIG.copy())
|
||||||
|
|
||||||
# forward propagate
|
# forward propagate
|
||||||
_, decision = ta.propagate("NVDA", "2024-05-10")
|
_, decision = ta.propagate("NVDA", "2024-05-10")
|
||||||
print(decision)
|
print(decision)
|
||||||
```
|
```
|
||||||
|
|
||||||
You can also adjust the default configuration to set your own choice of LLMs, debate rounds, etc.
|
You can also adjust the default configuration to set your own choice of LLMs, debate rounds, etc.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||||
from tradingagents.default_config import DEFAULT_CONFIG
|
from tradingagents.default_config import DEFAULT_CONFIG
|
||||||
|
|
||||||
# Create a custom config
|
# Create a custom config
|
||||||
config = DEFAULT_CONFIG.copy()
|
config = DEFAULT_CONFIG.copy()
|
||||||
config["deep_think_llm"] = "gpt-4.1-nano" # Use a different model
|
config["deep_think_llm"] = "gpt-4.1-nano" # Use a different model
|
||||||
config["quick_think_llm"] = "gpt-4.1-nano" # Use a different model
|
config["quick_think_llm"] = "gpt-4.1-nano" # Use a different model
|
||||||
config["max_debate_rounds"] = 1 # Increase debate rounds
|
config["max_debate_rounds"] = 1 # Increase debate rounds
|
||||||
config["online_tools"] = True # Use online tools or cached data
|
config["online_tools"] = True # Use online tools or cached data
|
||||||
|
|
||||||
# Initialize with custom config
|
# Initialize with custom config
|
||||||
ta = TradingAgentsGraph(debug=True, config=config)
|
ta = TradingAgentsGraph(debug=True, config=config)
|
||||||
|
|
||||||
# forward propagate
|
# forward propagate
|
||||||
_, decision = ta.propagate("NVDA", "2024-05-10")
|
_, decision = ta.propagate("NVDA", "2024-05-10")
|
||||||
print(decision)
|
print(decision)
|
||||||
```
|
```
|
||||||
|
|
||||||
> For `online_tools`, we recommend enabling them for experimentation, as they provide access to real-time data. The agents' offline tools rely on cached data from our **Tauric TradingDB**, a curated dataset we use for backtesting. We're currently in the process of refining this dataset, and we plan to release it soon alongside our upcoming projects. Stay tuned!
|
> For `online_tools`, we recommend enabling them for experimentation, as they provide access to real-time data. The agents' offline tools rely on cached data from our **Tauric TradingDB**, a curated dataset we use for backtesting. We're currently in the process of refining this dataset, and we plan to release it soon alongside our upcoming projects. Stay tuned!
|
||||||
|
|
||||||
You can view the full list of configurations in `tradingagents/default_config.py`.
|
You can view the full list of configurations in `tradingagents/default_config.py`.
|
||||||
|
|
||||||
## Contributing
|
## Contributing
|
||||||
|
|
||||||
We welcome contributions from the community! Whether it's fixing a bug, improving documentation, or suggesting a new feature, your input helps make this project better. If you are interested in this line of research, please consider joining our open-source financial AI research community [Tauric Research](https://tauric.ai/).
|
We welcome contributions from the community! Whether it's fixing a bug, improving documentation, or suggesting a new feature, your input helps make this project better. If you are interested in this line of research, please consider joining our open-source financial AI research community [Tauric Research](https://tauric.ai/).
|
||||||
|
|
||||||
## Citation
|
## Citation
|
||||||
|
|
||||||
Please reference our work if you find *TradingAgents* provides you with some help :)
|
Please reference our work if you find *TradingAgents* provides you with some help :)
|
||||||
|
|
||||||
```
|
```
|
||||||
@misc{xiao2025tradingagentsmultiagentsllmfinancial,
|
@misc{xiao2025tradingagentsmultiagentsllmfinancial,
|
||||||
title={TradingAgents: Multi-Agents LLM Financial Trading Framework},
|
title={TradingAgents: Multi-Agents LLM Financial Trading Framework},
|
||||||
author={Yijia Xiao and Edward Sun and Di Luo and Wei Wang},
|
author={Yijia Xiao and Edward Sun and Di Luo and Wei Wang},
|
||||||
year={2025},
|
year={2025},
|
||||||
eprint={2412.20138},
|
eprint={2412.20138},
|
||||||
archivePrefix={arXiv},
|
archivePrefix={arXiv},
|
||||||
primaryClass={q-fin.TR},
|
primaryClass={q-fin.TR},
|
||||||
url={https://arxiv.org/abs/2412.20138},
|
url={https://arxiv.org/abs/2412.20138},
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
|
||||||
134
app/api/deps.py
134
app/api/deps.py
|
|
@ -1,67 +1,67 @@
|
||||||
from typing import Generator, Optional
|
from typing import Generator, Optional
|
||||||
from fastapi import Depends, HTTPException, status
|
from fastapi import Depends, HTTPException, status
|
||||||
from fastapi.security import OAuth2PasswordBearer
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
from jose import jwt, JWTError
|
from jose import jwt, JWTError
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlmodel import Session
|
from sqlmodel import Session
|
||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.infrastructure.database import get_db
|
from app.infrastructure.database import get_db
|
||||||
from app.domain.models import User
|
from app.domain.models import User
|
||||||
from app.infrastructure.repositories.user import UserRepository
|
from app.infrastructure.repositories.user import UserRepository
|
||||||
from app.core.services.trading_analysis import TradingAnalysisService
|
from app.core.services.trading_analysis import TradingAnalysisService
|
||||||
|
|
||||||
reusable_oauth2 = OAuth2PasswordBearer(tokenUrl=f"{settings.API_V1_STR}/login/access-token")
|
reusable_oauth2 = OAuth2PasswordBearer(tokenUrl=f"{settings.API_V1_STR}/login/access-token")
|
||||||
|
|
||||||
class TokenData(BaseModel):
|
class TokenData(BaseModel):
|
||||||
username: Optional[str] = None
|
username: Optional[str] = None
|
||||||
|
|
||||||
def get_user_repository(db: Session = Depends(get_db)) -> UserRepository:
|
def get_user_repository(db: Session = Depends(get_db)) -> UserRepository:
|
||||||
return UserRepository(db)
|
return UserRepository(db)
|
||||||
|
|
||||||
def get_user_from_token(token: str, db: Session) -> Optional[User]:
|
def get_user_from_token(token: str, db: Session) -> Optional[User]:
|
||||||
try:
|
try:
|
||||||
payload = jwt.decode(
|
payload = jwt.decode(
|
||||||
token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
|
token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
|
||||||
)
|
)
|
||||||
token_data = TokenData(username=payload.get("sub"))
|
token_data = TokenData(username=payload.get("sub"))
|
||||||
except JWTError:
|
except JWTError:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
user_repo = UserRepository(db)
|
user_repo = UserRepository(db)
|
||||||
user = user_repo.get_by_email(email=token_data.username)
|
user = user_repo.get_by_email(email=token_data.username)
|
||||||
return user
|
return user
|
||||||
|
|
||||||
def get_current_user(
|
def get_current_user(
|
||||||
db: Session = Depends(get_db), token: str = Depends(reusable_oauth2)
|
db: Session = Depends(get_db), token: str = Depends(reusable_oauth2)
|
||||||
) -> User:
|
) -> User:
|
||||||
user = get_user_from_token(token=token, db=db)
|
user = get_user_from_token(token=token, db=db)
|
||||||
if not user:
|
if not user:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
detail="Could not validate credentials",
|
detail="Could not validate credentials",
|
||||||
headers={"WWW-Authenticate": "Bearer"},
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
)
|
)
|
||||||
return user
|
return user
|
||||||
|
|
||||||
def get_current_active_user(
|
def get_current_active_user(
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
) -> User:
|
) -> User:
|
||||||
if not current_user.is_active:
|
if not current_user.is_active:
|
||||||
raise HTTPException(status_code=400, detail="Inactive user")
|
raise HTTPException(status_code=400, detail="Inactive user")
|
||||||
return current_user
|
return current_user
|
||||||
|
|
||||||
def get_current_active_superuser(
|
def get_current_active_superuser(
|
||||||
current_user: User = Depends(get_current_active_user),
|
current_user: User = Depends(get_current_active_user),
|
||||||
) -> User:
|
) -> User:
|
||||||
if not current_user.is_superuser:
|
if not current_user.is_superuser:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=403, detail="The user doesn't have enough privileges"
|
status_code=403, detail="The user doesn't have enough privileges"
|
||||||
)
|
)
|
||||||
return current_user
|
return current_user
|
||||||
|
|
||||||
def get_analysis_service(
|
def get_analysis_service(
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
user: User = Depends(get_current_active_user)
|
user: User = Depends(get_current_active_user)
|
||||||
) -> TradingAnalysisService:
|
) -> TradingAnalysisService:
|
||||||
return TradingAnalysisService(user=user, db=db)
|
return TradingAnalysisService(user=user, db=db)
|
||||||
|
|
|
||||||
|
|
@ -1,94 +1,94 @@
|
||||||
from typing import Any, List
|
from typing import Any, List
|
||||||
from fastapi import APIRouter, Depends, HTTPException, WebSocket, WebSocketDisconnect, BackgroundTasks
|
from fastapi import APIRouter, Depends, HTTPException, WebSocket, WebSocketDisconnect, BackgroundTasks
|
||||||
from app.api import deps
|
from app.api import deps
|
||||||
from app.core.schemas.analysis import AnalysisSession, AnalysisSessionCreate
|
from app.core.schemas.analysis import AnalysisSession, AnalysisSessionCreate
|
||||||
from app.domain.models import User as UserModel
|
from app.domain.models import User as UserModel
|
||||||
from app.core.services.trading_analysis import TradingAnalysisService
|
from app.core.services.trading_analysis import TradingAnalysisService
|
||||||
from app.core.websocket_manager import WebSocketManager
|
from app.core.websocket_manager import WebSocketManager
|
||||||
from sqlmodel import Session
|
from sqlmodel import Session
|
||||||
from cli.utils import SHALLOW_AGENT_OPTIONS, DEEP_AGENT_OPTIONS, BASE_URLS
|
from cli.utils import SHALLOW_AGENT_OPTIONS, DEEP_AGENT_OPTIONS, BASE_URLS
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
manager = WebSocketManager()
|
manager = WebSocketManager()
|
||||||
|
|
||||||
@router.post("/start", response_model=AnalysisSession)
|
@router.post("/start", response_model=AnalysisSession)
|
||||||
def start_analysis(
|
def start_analysis(
|
||||||
*,
|
*,
|
||||||
analysis_in: AnalysisSessionCreate,
|
analysis_in: AnalysisSessionCreate,
|
||||||
background_tasks: BackgroundTasks,
|
background_tasks: BackgroundTasks,
|
||||||
service: TradingAnalysisService = Depends(deps.get_analysis_service),
|
service: TradingAnalysisService = Depends(deps.get_analysis_service),
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Start a new analysis session.
|
Start a new analysis session.
|
||||||
"""
|
"""
|
||||||
session = service.create_session(analysis_in=analysis_in)
|
session = service.create_session(analysis_in=analysis_in)
|
||||||
background_tasks.add_task(service.run_analysis, session_id=session.id)
|
background_tasks.add_task(service.run_analysis, session_id=session.id)
|
||||||
return session
|
return session
|
||||||
|
|
||||||
@router.get("/history", response_model=List[AnalysisSession])
|
@router.get("/history", response_model=List[AnalysisSession])
|
||||||
def get_analysis_history(
|
def get_analysis_history(
|
||||||
service: TradingAnalysisService = Depends(deps.get_analysis_service),
|
service: TradingAnalysisService = Depends(deps.get_analysis_service),
|
||||||
skip: int = 0,
|
skip: int = 0,
|
||||||
limit: int = 100,
|
limit: int = 100,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Get analysis history for the current user.
|
Get analysis history for the current user.
|
||||||
"""
|
"""
|
||||||
return service.get_user_sessions(skip=skip, limit=limit)
|
return service.get_user_sessions(skip=skip, limit=limit)
|
||||||
|
|
||||||
@router.get("/options")
|
@router.get("/options")
|
||||||
def get_analysis_options():
|
def get_analysis_options():
|
||||||
"""
|
"""
|
||||||
Get available options for analysis.
|
Get available options for analysis.
|
||||||
"""
|
"""
|
||||||
return {
|
return {
|
||||||
'analysts': [
|
'analysts': [
|
||||||
{'value': 'market', 'label': 'Market Analyst'},
|
{'value': 'market', 'label': 'Market Analyst'},
|
||||||
{'value': 'social', 'label': 'Social Analyst'},
|
{'value': 'social', 'label': 'Social Analyst'},
|
||||||
{'value': 'news', 'label': 'News Analyst'},
|
{'value': 'news', 'label': 'News Analyst'},
|
||||||
{'value': 'fundamentals', 'label': 'Fundamentals Analyst'},
|
{'value': 'fundamentals', 'label': 'Fundamentals Analyst'},
|
||||||
],
|
],
|
||||||
'research_depths': [
|
'research_depths': [
|
||||||
{'value': 1, 'label': 'Shallow'},
|
{'value': 1, 'label': 'Shallow'},
|
||||||
{'value': 3, 'label': 'Medium'},
|
{'value': 3, 'label': 'Medium'},
|
||||||
{'value': 5, 'label': 'Deep'},
|
{'value': 5, 'label': 'Deep'},
|
||||||
],
|
],
|
||||||
'llm_providers': [{'name': p[0], 'url': p[1]} for p in BASE_URLS],
|
'llm_providers': [{'name': p[0], 'url': p[1]} for p in BASE_URLS],
|
||||||
'shallow_thinkers': SHALLOW_AGENT_OPTIONS,
|
'shallow_thinkers': SHALLOW_AGENT_OPTIONS,
|
||||||
'deep_thinkers': DEEP_AGENT_OPTIONS,
|
'deep_thinkers': DEEP_AGENT_OPTIONS,
|
||||||
}
|
}
|
||||||
|
|
||||||
@router.get("/{session_id}", response_model=AnalysisSession)
|
@router.get("/{session_id}", response_model=AnalysisSession)
|
||||||
def get_analysis_session(
|
def get_analysis_session(
|
||||||
session_id: int,
|
session_id: int,
|
||||||
service: TradingAnalysisService = Depends(deps.get_analysis_service),
|
service: TradingAnalysisService = Depends(deps.get_analysis_service),
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Get a specific analysis session by ID.
|
Get a specific analysis session by ID.
|
||||||
"""
|
"""
|
||||||
session = service.get_session(session_id=session_id)
|
session = service.get_session(session_id=session_id)
|
||||||
if not session:
|
if not session:
|
||||||
raise HTTPException(status_code=404, detail="Analysis session not found")
|
raise HTTPException(status_code=404, detail="Analysis session not found")
|
||||||
return session
|
return session
|
||||||
|
|
||||||
@router.websocket("/ws")
|
@router.websocket("/ws")
|
||||||
async def websocket_endpoint(
|
async def websocket_endpoint(
|
||||||
websocket: WebSocket,
|
websocket: WebSocket,
|
||||||
token: str,
|
token: str,
|
||||||
db: Session = Depends(deps.get_db)
|
db: Session = Depends(deps.get_db)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
WebSocket endpoint for real-time analysis updates.
|
WebSocket endpoint for real-time analysis updates.
|
||||||
"""
|
"""
|
||||||
user = deps.get_user_from_token(token=token, db=db)
|
user = deps.get_user_from_token(token=token, db=db)
|
||||||
if not user or not user.is_active:
|
if not user or not user.is_active:
|
||||||
await websocket.close(code=1008)
|
await websocket.close(code=1008)
|
||||||
return
|
return
|
||||||
|
|
||||||
await manager.connect(user.id, websocket)
|
await manager.connect(user.id, websocket)
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
# Keep the connection alive
|
# Keep the connection alive
|
||||||
await websocket.receive_text()
|
await websocket.receive_text()
|
||||||
except WebSocketDisconnect:
|
except WebSocketDisconnect:
|
||||||
manager.disconnect(user.id, websocket)
|
manager.disconnect(user.id, websocket)
|
||||||
|
|
@ -1,35 +1,35 @@
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
from fastapi.security import OAuth2PasswordRequestForm
|
from fastapi.security import OAuth2PasswordRequestForm
|
||||||
from sqlmodel import Session
|
from sqlmodel import Session
|
||||||
|
|
||||||
from app.api import deps
|
from app.api import deps
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.core.schemas.token import Token
|
from app.core.schemas.token import Token
|
||||||
from app.core import security
|
from app.core import security
|
||||||
from app.infrastructure.repositories.user import UserRepository
|
from app.infrastructure.repositories.user import UserRepository
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
@router.post("/login/access-token", response_model=Token)
|
@router.post("/login/access-token", response_model=Token)
|
||||||
def login_access_token(
|
def login_access_token(
|
||||||
db: Session = Depends(deps.get_db), form_data: OAuth2PasswordRequestForm = Depends()
|
db: Session = Depends(deps.get_db), form_data: OAuth2PasswordRequestForm = Depends()
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
OAuth2 compatible token login, get an access token for future requests
|
OAuth2 compatible token login, get an access token for future requests
|
||||||
"""
|
"""
|
||||||
user_repo = UserRepository(db)
|
user_repo = UserRepository(db)
|
||||||
user = user_repo.get_by_email(email=form_data.username)
|
user = user_repo.get_by_email(email=form_data.username)
|
||||||
|
|
||||||
if not user or not security.verify_password(form_data.password, user.hashed_password):
|
if not user or not security.verify_password(form_data.password, user.hashed_password):
|
||||||
raise HTTPException(status_code=400, detail="Incorrect email or password")
|
raise HTTPException(status_code=400, detail="Incorrect email or password")
|
||||||
elif not user.is_active:
|
elif not user.is_active:
|
||||||
raise HTTPException(status_code=400, detail="Inactive user")
|
raise HTTPException(status_code=400, detail="Inactive user")
|
||||||
|
|
||||||
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||||
return {
|
return {
|
||||||
"access_token": security.create_access_token(
|
"access_token": security.create_access_token(
|
||||||
user.email, expires_delta=access_token_expires
|
user.email, expires_delta=access_token_expires
|
||||||
),
|
),
|
||||||
"token_type": "bearer",
|
"token_type": "bearer",
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,89 +1,89 @@
|
||||||
from typing import Any, List
|
from typing import Any, List
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
from app.api import deps
|
from app.api import deps
|
||||||
from app.core.schemas.user import User, UserCreate, UserUpdate
|
from app.core.schemas.user import User, UserCreate, UserUpdate
|
||||||
from app.domain.models import User as UserModel
|
from app.domain.models import User as UserModel
|
||||||
from app.domain.repositories import IUserRepository
|
from app.domain.repositories import IUserRepository
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
@router.get("/", response_model=List[User])
|
@router.get("/", response_model=List[User])
|
||||||
def read_users(
|
def read_users(
|
||||||
repo: IUserRepository = Depends(deps.get_user_repository),
|
repo: IUserRepository = Depends(deps.get_user_repository),
|
||||||
skip: int = 0,
|
skip: int = 0,
|
||||||
limit: int = 100,
|
limit: int = 100,
|
||||||
current_user: UserModel = Depends(deps.get_current_active_superuser),
|
current_user: UserModel = Depends(deps.get_current_active_superuser),
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Retrieve users.
|
Retrieve users.
|
||||||
"""
|
"""
|
||||||
users = repo.get_multi(skip=skip, limit=limit)
|
users = repo.get_multi(skip=skip, limit=limit)
|
||||||
return users
|
return users
|
||||||
|
|
||||||
@router.post("/", response_model=User)
|
@router.post("/", response_model=User)
|
||||||
def create_user(
|
def create_user(
|
||||||
*,
|
*,
|
||||||
repo: IUserRepository = Depends(deps.get_user_repository),
|
repo: IUserRepository = Depends(deps.get_user_repository),
|
||||||
user_in: UserCreate,
|
user_in: UserCreate,
|
||||||
current_user: UserModel = Depends(deps.get_current_active_superuser),
|
current_user: UserModel = Depends(deps.get_current_active_superuser),
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Create new user.
|
Create new user.
|
||||||
"""
|
"""
|
||||||
user = repo.get_by_email(email=user_in.email)
|
user = repo.get_by_email(email=user_in.email)
|
||||||
if user:
|
if user:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
detail="The user with this username already exists in the system.",
|
detail="The user with this username already exists in the system.",
|
||||||
)
|
)
|
||||||
user = repo.create(obj_in=user_in)
|
user = repo.create(obj_in=user_in)
|
||||||
return user
|
return user
|
||||||
|
|
||||||
@router.get("/me", response_model=User)
|
@router.get("/me", response_model=User)
|
||||||
def read_user_me(
|
def read_user_me(
|
||||||
current_user: UserModel = Depends(deps.get_current_active_user),
|
current_user: UserModel = Depends(deps.get_current_active_user),
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Get current user.
|
Get current user.
|
||||||
"""
|
"""
|
||||||
return current_user
|
return current_user
|
||||||
|
|
||||||
@router.get("/{user_id}", response_model=User)
|
@router.get("/{user_id}", response_model=User)
|
||||||
def read_user_by_id(
|
def read_user_by_id(
|
||||||
user_id: int,
|
user_id: int,
|
||||||
repo: IUserRepository = Depends(deps.get_user_repository),
|
repo: IUserRepository = Depends(deps.get_user_repository),
|
||||||
current_user: UserModel = Depends(deps.get_current_active_user),
|
current_user: UserModel = Depends(deps.get_current_active_user),
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Get a specific user by id.
|
Get a specific user by id.
|
||||||
"""
|
"""
|
||||||
user = repo.get(id=user_id)
|
user = repo.get(id=user_id)
|
||||||
if not user:
|
if not user:
|
||||||
raise HTTPException(status_code=404, detail="User not found")
|
raise HTTPException(status_code=404, detail="User not found")
|
||||||
if user == current_user:
|
if user == current_user:
|
||||||
return user
|
return user
|
||||||
if not repo.is_superuser(user=current_user):
|
if not repo.is_superuser(user=current_user):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=403, detail="The user doesn't have enough privileges"
|
status_code=403, detail="The user doesn't have enough privileges"
|
||||||
)
|
)
|
||||||
return user
|
return user
|
||||||
|
|
||||||
@router.put("/{user_id}", response_model=User)
|
@router.put("/{user_id}", response_model=User)
|
||||||
def update_user(
|
def update_user(
|
||||||
*,
|
*,
|
||||||
repo: IUserRepository = Depends(deps.get_user_repository),
|
repo: IUserRepository = Depends(deps.get_user_repository),
|
||||||
user_id: int,
|
user_id: int,
|
||||||
user_in: UserUpdate,
|
user_in: UserUpdate,
|
||||||
current_user: UserModel = Depends(deps.get_current_active_superuser),
|
current_user: UserModel = Depends(deps.get_current_active_superuser),
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Update a user.
|
Update a user.
|
||||||
"""
|
"""
|
||||||
user = repo.get(id=user_id)
|
user = repo.get(id=user_id)
|
||||||
if not user:
|
if not user:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=404,
|
status_code=404,
|
||||||
detail="The user with this username does not exist in the system",
|
detail="The user with this username does not exist in the system",
|
||||||
)
|
)
|
||||||
user = repo.update(db_obj=user, obj_in=user_in)
|
user = repo.update(db_obj=user, obj_in=user_in)
|
||||||
return user
|
return user
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
from app.api.endpoints import login, users, analysis
|
from app.api.endpoints import login, users, analysis
|
||||||
|
|
||||||
api_router = APIRouter()
|
api_router = APIRouter()
|
||||||
api_router.include_router(login.router, tags=["login"])
|
api_router.include_router(login.router, tags=["login"])
|
||||||
api_router.include_router(users.router, prefix="/users", tags=["users"])
|
api_router.include_router(users.router, prefix="/users", tags=["users"])
|
||||||
api_router.include_router(analysis.router, prefix="/analysis", tags=["analysis"])
|
api_router.include_router(analysis.router, prefix="/analysis", tags=["analysis"])
|
||||||
|
|
|
||||||
|
|
@ -1,26 +1,26 @@
|
||||||
import os
|
import os
|
||||||
from pydantic import BaseSettings
|
from pydantic import BaseSettings
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
PROJECT_NAME: str = "TradingAgents Backend"
|
PROJECT_NAME: str = "TradingAgents Backend"
|
||||||
API_V1_STR: str = "/api/v1"
|
API_V1_STR: str = "/api/v1"
|
||||||
|
|
||||||
# Security
|
# Security
|
||||||
SECRET_KEY: str = os.getenv("SECRET_KEY", "a_very_secret_key")
|
SECRET_KEY: str = os.getenv("SECRET_KEY", "a_very_secret_key")
|
||||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = 60 * 24 * 8 # 8 days
|
ACCESS_TOKEN_EXPIRE_MINUTES: int = 60 * 24 * 8 # 8 days
|
||||||
ALGORITHM: str = "HS256"
|
ALGORITHM: str = "HS256"
|
||||||
|
|
||||||
# Database
|
# Database
|
||||||
DATABASE_URL: str = os.getenv("DATABASE_URL", "sqlite:///./tradingagents.db")
|
DATABASE_URL: str = os.getenv("DATABASE_URL", "sqlite:///./tradingagents.db")
|
||||||
|
|
||||||
# OpenAI
|
# OpenAI
|
||||||
OPENAI_API_KEY: str = os.getenv("OPENAI_API_KEY", "")
|
OPENAI_API_KEY: str = os.getenv("OPENAI_API_KEY", "")
|
||||||
|
|
||||||
# CORS
|
# CORS
|
||||||
CORS_ALLOWED_ORIGINS: List[str] = os.getenv("CORS_ALLOWED_ORIGINS", "http://localhost:3000,http://127.0.0.1:3000").split(',')
|
CORS_ALLOWED_ORIGINS: List[str] = os.getenv("CORS_ALLOWED_ORIGINS", "http://localhost:3000,http://127.0.0.1:3000").split(',')
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
case_sensitive = True
|
case_sensitive = True
|
||||||
|
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from .user import User, UserCreate, UserUpdate
|
from .user import User, UserCreate, UserUpdate
|
||||||
from .token import Token, TokenPayload
|
from .token import Token, TokenPayload
|
||||||
from .profile import Profile, ProfileCreate, ProfileUpdate
|
from .profile import Profile, ProfileCreate, ProfileUpdate
|
||||||
from .analysis import AnalysisSession, AnalysisSessionCreate, AnalysisSessionUpdate
|
from .analysis import AnalysisSession, AnalysisSessionCreate, AnalysisSessionUpdate
|
||||||
|
|
|
||||||
|
|
@ -1,38 +1,38 @@
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
from datetime import date, datetime
|
from datetime import date, datetime
|
||||||
from app.domain.models import AnalysisStatus
|
from app.domain.models import AnalysisStatus
|
||||||
|
|
||||||
class AnalysisSessionBase(BaseModel):
|
class AnalysisSessionBase(BaseModel):
|
||||||
ticker: str
|
ticker: str
|
||||||
analysts_selected: List[str]
|
analysts_selected: List[str]
|
||||||
research_depth: int
|
research_depth: int
|
||||||
llm_provider: str
|
llm_provider: str
|
||||||
backend_url: str
|
backend_url: str
|
||||||
shallow_thinker: str
|
shallow_thinker: str
|
||||||
deep_thinker: str
|
deep_thinker: str
|
||||||
|
|
||||||
class AnalysisSessionCreate(AnalysisSessionBase):
|
class AnalysisSessionCreate(AnalysisSessionBase):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
class AnalysisSessionUpdate(BaseModel):
|
class AnalysisSessionUpdate(BaseModel):
|
||||||
status: Optional[AnalysisStatus] = None
|
status: Optional[AnalysisStatus] = None
|
||||||
final_report: Optional[str] = None
|
final_report: Optional[str] = None
|
||||||
error_message: Optional[str] = None
|
error_message: Optional[str] = None
|
||||||
|
|
||||||
class AnalysisSessionInDBBase(AnalysisSessionBase):
|
class AnalysisSessionInDBBase(AnalysisSessionBase):
|
||||||
id: int
|
id: int
|
||||||
user_id: int
|
user_id: int
|
||||||
analysis_date: date
|
analysis_date: date
|
||||||
status: AnalysisStatus
|
status: AnalysisStatus
|
||||||
final_report: Optional[str] = None
|
final_report: Optional[str] = None
|
||||||
error_message: Optional[str] = None
|
error_message: Optional[str] = None
|
||||||
created_at: datetime
|
created_at: datetime
|
||||||
started_at: Optional[datetime] = None
|
started_at: Optional[datetime] = None
|
||||||
completed_at: Optional[datetime] = None
|
completed_at: Optional[datetime] = None
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
orm_mode = True
|
orm_mode = True
|
||||||
|
|
||||||
class AnalysisSession(AnalysisSessionInDBBase):
|
class AnalysisSession(AnalysisSessionInDBBase):
|
||||||
pass
|
pass
|
||||||
|
|
@ -1,20 +1,20 @@
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
class ProfileBase(BaseModel):
|
class ProfileBase(BaseModel):
|
||||||
default_ticker: str = "SPY"
|
default_ticker: str = "SPY"
|
||||||
preferred_research_depth: int = 3
|
preferred_research_depth: int = 3
|
||||||
preferred_shallow_thinker: str = "gpt-4o-mini"
|
preferred_shallow_thinker: str = "gpt-4o-mini"
|
||||||
preferred_deep_thinker: str = "gpt-4o"
|
preferred_deep_thinker: str = "gpt-4o"
|
||||||
|
|
||||||
class ProfileCreate(ProfileBase):
|
class ProfileCreate(ProfileBase):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
class ProfileUpdate(ProfileBase):
|
class ProfileUpdate(ProfileBase):
|
||||||
openai_api_key: Optional[str] = None
|
openai_api_key: Optional[str] = None
|
||||||
|
|
||||||
class Profile(ProfileBase):
|
class Profile(ProfileBase):
|
||||||
has_openai_api_key: bool
|
has_openai_api_key: bool
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
orm_mode = True
|
orm_mode = True
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,9 @@
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
class Token(BaseModel):
|
class Token(BaseModel):
|
||||||
access_token: str
|
access_token: str
|
||||||
token_type: str
|
token_type: str
|
||||||
|
|
||||||
class TokenPayload(BaseModel):
|
class TokenPayload(BaseModel):
|
||||||
sub: Optional[int] = None
|
sub: Optional[int] = None
|
||||||
|
|
|
||||||
|
|
@ -1,28 +1,28 @@
|
||||||
from pydantic import BaseModel, EmailStr
|
from pydantic import BaseModel, EmailStr
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
class UserBase(BaseModel):
|
class UserBase(BaseModel):
|
||||||
email: EmailStr
|
email: EmailStr
|
||||||
username: str
|
username: str
|
||||||
first_name: Optional[str] = None
|
first_name: Optional[str] = None
|
||||||
last_name: Optional[str] = None
|
last_name: Optional[str] = None
|
||||||
|
|
||||||
class UserCreate(UserBase):
|
class UserCreate(UserBase):
|
||||||
password: str
|
password: str
|
||||||
|
|
||||||
class UserUpdate(UserBase):
|
class UserUpdate(UserBase):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
class UserInDBBase(UserBase):
|
class UserInDBBase(UserBase):
|
||||||
id: int
|
id: int
|
||||||
is_active: bool
|
is_active: bool
|
||||||
is_superuser: bool
|
is_superuser: bool
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
orm_mode = True
|
orm_mode = True
|
||||||
|
|
||||||
class User(UserInDBBase):
|
class User(UserInDBBase):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
class UserInDB(UserInDBBase):
|
class UserInDB(UserInDBBase):
|
||||||
hashed_password: str
|
hashed_password: str
|
||||||
|
|
|
||||||
|
|
@ -1,23 +1,23 @@
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import Any, Union, Optional
|
from typing import Any, Union, Optional
|
||||||
from jose import jwt
|
from jose import jwt
|
||||||
from passlib.context import CryptContext
|
from passlib.context import CryptContext
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
|
|
||||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||||
|
|
||||||
def create_access_token(subject: Union[str, Any], expires_delta: Optional[timedelta] = None) -> str:
|
def create_access_token(subject: Union[str, Any], expires_delta: Optional[timedelta] = None) -> str:
|
||||||
if expires_delta:
|
if expires_delta:
|
||||||
expire = datetime.utcnow() + expires_delta
|
expire = datetime.utcnow() + expires_delta
|
||||||
else:
|
else:
|
||||||
expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||||
|
|
||||||
to_encode = {"exp": expire, "sub": str(subject)}
|
to_encode = {"exp": expire, "sub": str(subject)}
|
||||||
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
||||||
return encoded_jwt
|
return encoded_jwt
|
||||||
|
|
||||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||||
return pwd_context.verify(plain_password, hashed_password)
|
return pwd_context.verify(plain_password, hashed_password)
|
||||||
|
|
||||||
def get_password_hash(password: str) -> str:
|
def get_password_hash(password: str) -> str:
|
||||||
return pwd_context.hash(password)
|
return pwd_context.hash(password)
|
||||||
|
|
|
||||||
|
|
@ -1,128 +1,128 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import datetime
|
import datetime
|
||||||
import json
|
import json
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
from sqlmodel import Session, select
|
from sqlmodel import Session, select
|
||||||
from app.domain.models import User, AnalysisSession, AnalysisStatus
|
from app.domain.models import User, AnalysisSession, AnalysisStatus
|
||||||
from app.core.schemas.analysis import AnalysisSessionCreate
|
from app.core.schemas.analysis import AnalysisSessionCreate
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from cli.models import AnalystType
|
from cli.models import AnalystType
|
||||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||||
from tradingagents.default_config import DEFAULT_CONFIG
|
from tradingagents.default_config import DEFAULT_CONFIG
|
||||||
from app.api.deps import get_db
|
from app.api.deps import get_db
|
||||||
from app.core.websocket_manager import WebSocketManager
|
from app.core.websocket_manager import WebSocketManager
|
||||||
|
|
||||||
class TradingAnalysisService:
|
class TradingAnalysisService:
|
||||||
def __init__(self, user: User, db: Session):
|
def __init__(self, user: User, db: Session):
|
||||||
self.user = user
|
self.user = user
|
||||||
self.db = db
|
self.db = db
|
||||||
self.websocket_manager = WebSocketManager()
|
self.websocket_manager = WebSocketManager()
|
||||||
|
|
||||||
async def run_analysis(self, session_id: int):
|
async def run_analysis(self, session_id: int):
|
||||||
"""분석 실행"""
|
"""분석 실행"""
|
||||||
session = self.get_session(session_id=session_id)
|
session = self.get_session(session_id=session_id)
|
||||||
if not session:
|
if not session:
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
session.status = AnalysisStatus.RUNNING
|
session.status = AnalysisStatus.RUNNING
|
||||||
session.started_at = datetime.datetime.utcnow()
|
session.started_at = datetime.datetime.utcnow()
|
||||||
self.db.add(session)
|
self.db.add(session)
|
||||||
self.db.commit()
|
self.db.commit()
|
||||||
self.db.refresh(session)
|
self.db.refresh(session)
|
||||||
|
|
||||||
await self.websocket_manager.send_to_user(
|
await self.websocket_manager.send_to_user(
|
||||||
self.user.id,
|
self.user.id,
|
||||||
{
|
{
|
||||||
'type': 'analysis_started',
|
'type': 'analysis_started',
|
||||||
'session_id': session.id,
|
'session_id': session.id,
|
||||||
'message': '분석을 시작합니다...'
|
'message': '분석을 시작합니다...'
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Prepare config for TradingAgentsGraph
|
# Prepare config for TradingAgentsGraph
|
||||||
config = DEFAULT_CONFIG.copy()
|
config = DEFAULT_CONFIG.copy()
|
||||||
config.update({
|
config.update({
|
||||||
'openai_api_key': settings.OPENAI_API_KEY,
|
'openai_api_key': settings.OPENAI_API_KEY,
|
||||||
'llm_provider': session.llm_provider,
|
'llm_provider': session.llm_provider,
|
||||||
'backend_url': session.backend_url,
|
'backend_url': session.backend_url,
|
||||||
'shallow_thinking_model': session.shallow_thinker,
|
'shallow_thinking_model': session.shallow_thinker,
|
||||||
'deep_thinking_model': session.deep_thinker,
|
'deep_thinking_model': session.deep_thinker,
|
||||||
})
|
})
|
||||||
|
|
||||||
# Progress callback for websocket
|
# Progress callback for websocket
|
||||||
async def progress_callback(message_type: str, content: str, agent: str = None, step: int = 0, total: int = 0):
|
async def progress_callback(message_type: str, content: str, agent: str = None, step: int = 0, total: int = 0):
|
||||||
progress_percent = int((step / total) * 99) if total > 0 else 0
|
progress_percent = int((step / total) * 99) if total > 0 else 0
|
||||||
await self.websocket_manager.send_to_user(self.user.id, {
|
await self.websocket_manager.send_to_user(self.user.id, {
|
||||||
'type': 'analysis_progress',
|
'type': 'analysis_progress',
|
||||||
'session_id': session.id,
|
'session_id': session.id,
|
||||||
'message_type': message_type,
|
'message_type': message_type,
|
||||||
'content': content,
|
'content': content,
|
||||||
'agent': agent,
|
'agent': agent,
|
||||||
'progress': progress_percent,
|
'progress': progress_percent,
|
||||||
})
|
})
|
||||||
|
|
||||||
trading_graph = TradingAgentsGraph(
|
trading_graph = TradingAgentsGraph(
|
||||||
config=config,
|
config=config,
|
||||||
selected_analysts=session.analysts_selected,
|
selected_analysts=session.analysts_selected,
|
||||||
)
|
)
|
||||||
|
|
||||||
input_data = {
|
input_data = {
|
||||||
'company_of_interest': session.ticker,
|
'company_of_interest': session.ticker,
|
||||||
'trade_date': session.analysis_date.strftime('%Y-%m-%d'),
|
'trade_date': session.analysis_date.strftime('%Y-%m-%d'),
|
||||||
}
|
}
|
||||||
|
|
||||||
final_state, result = await asyncio.to_thread(
|
final_state, result = await asyncio.to_thread(
|
||||||
trading_graph.propagate,
|
trading_graph.propagate,
|
||||||
input_data['company_of_interest'],
|
input_data['company_of_interest'],
|
||||||
input_data['trade_date']
|
input_data['trade_date']
|
||||||
)
|
)
|
||||||
|
|
||||||
session.status = AnalysisStatus.COMPLETED
|
session.status = AnalysisStatus.COMPLETED
|
||||||
session.completed_at = datetime.datetime.utcnow()
|
session.completed_at = datetime.datetime.utcnow()
|
||||||
session.final_report = json.dumps(final_state) # Store full state as JSON
|
session.final_report = json.dumps(final_state) # Store full state as JSON
|
||||||
self.db.add(session)
|
self.db.add(session)
|
||||||
self.db.commit()
|
self.db.commit()
|
||||||
|
|
||||||
await self.websocket_manager.send_to_user(
|
await self.websocket_manager.send_to_user(
|
||||||
self.user.id,
|
self.user.id,
|
||||||
{
|
{
|
||||||
'type': 'analysis_completed',
|
'type': 'analysis_completed',
|
||||||
'session_id': session.id,
|
'session_id': session.id,
|
||||||
'message': '분석이 완료되었습니다.',
|
'message': '분석이 완료되었습니다.',
|
||||||
'result': result
|
'result': result
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
session.status = AnalysisStatus.FAILED
|
session.status = AnalysisStatus.FAILED
|
||||||
session.error_message = str(e)
|
session.error_message = str(e)
|
||||||
self.db.add(session)
|
self.db.add(session)
|
||||||
self.db.commit()
|
self.db.commit()
|
||||||
await self.websocket_manager.send_to_user(
|
await self.websocket_manager.send_to_user(
|
||||||
self.user.id,
|
self.user.id,
|
||||||
{
|
{
|
||||||
'type': 'analysis_failed',
|
'type': 'analysis_failed',
|
||||||
'session_id': session.id,
|
'session_id': session.id,
|
||||||
'message': f'분석 중 오류가 발생했습니다: {str(e)}'
|
'message': f'분석 중 오류가 발생했습니다: {str(e)}'
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_session(self, *, analysis_in: AnalysisSessionCreate) -> AnalysisSession:
|
def create_session(self, *, analysis_in: AnalysisSessionCreate) -> AnalysisSession:
|
||||||
session = AnalysisSession(
|
session = AnalysisSession(
|
||||||
**analysis_in.dict(),
|
**analysis_in.dict(),
|
||||||
user_id=self.user.id,
|
user_id=self.user.id,
|
||||||
analysis_date=datetime.date.today()
|
analysis_date=datetime.date.today()
|
||||||
)
|
)
|
||||||
self.db.add(session)
|
self.db.add(session)
|
||||||
self.db.commit()
|
self.db.commit()
|
||||||
self.db.refresh(session)
|
self.db.refresh(session)
|
||||||
return session
|
return session
|
||||||
|
|
||||||
def get_session(self, *, session_id: int) -> Optional[AnalysisSession]:
|
def get_session(self, *, session_id: int) -> Optional[AnalysisSession]:
|
||||||
statement = select(AnalysisSession).where(AnalysisSession.id == session_id, AnalysisSession.user_id == self.user.id)
|
statement = select(AnalysisSession).where(AnalysisSession.id == session_id, AnalysisSession.user_id == self.user.id)
|
||||||
return self.db.exec(statement).first()
|
return self.db.exec(statement).first()
|
||||||
|
|
||||||
def get_user_sessions(self, *, skip: int = 0, limit: int = 100) -> List[AnalysisSession]:
|
def get_user_sessions(self, *, skip: int = 0, limit: int = 100) -> List[AnalysisSession]:
|
||||||
statement = select(AnalysisSession).where(AnalysisSession.user_id == self.user.id).order_by(AnalysisSession.created_at.desc()).offset(skip).limit(limit)
|
statement = select(AnalysisSession).where(AnalysisSession.user_id == self.user.id).order_by(AnalysisSession.created_at.desc()).offset(skip).limit(limit)
|
||||||
return self.db.exec(statement).all()
|
return self.db.exec(statement).all()
|
||||||
|
|
|
||||||
|
|
@ -1,23 +1,23 @@
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
from fastapi import WebSocket
|
from fastapi import WebSocket
|
||||||
|
|
||||||
class WebSocketManager:
|
class WebSocketManager:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.active_connections: Dict[int, List[WebSocket]] = {}
|
self.active_connections: Dict[int, List[WebSocket]] = {}
|
||||||
|
|
||||||
async def connect(self, user_id: int, websocket: WebSocket):
|
async def connect(self, user_id: int, websocket: WebSocket):
|
||||||
await websocket.accept()
|
await websocket.accept()
|
||||||
if user_id not in self.active_connections:
|
if user_id not in self.active_connections:
|
||||||
self.active_connections[user_id] = []
|
self.active_connections[user_id] = []
|
||||||
self.active_connections[user_id].append(websocket)
|
self.active_connections[user_id].append(websocket)
|
||||||
|
|
||||||
def disconnect(self, user_id: int, websocket: WebSocket):
|
def disconnect(self, user_id: int, websocket: WebSocket):
|
||||||
if user_id in self.active_connections:
|
if user_id in self.active_connections:
|
||||||
self.active_connections[user_id].remove(websocket)
|
self.active_connections[user_id].remove(websocket)
|
||||||
if not self.active_connections[user_id]:
|
if not self.active_connections[user_id]:
|
||||||
del self.active_connections[user_id]
|
del self.active_connections[user_id]
|
||||||
|
|
||||||
async def send_to_user(self, user_id: int, message: dict):
|
async def send_to_user(self, user_id: int, message: dict):
|
||||||
if user_id in self.active_connections:
|
if user_id in self.active_connections:
|
||||||
for connection in self.active_connections[user_id]:
|
for connection in self.active_connections[user_id]:
|
||||||
await connection.send_json(message)
|
await connection.send_json(message)
|
||||||
|
|
|
||||||
|
|
@ -1,56 +1,56 @@
|
||||||
from datetime import date, datetime
|
from datetime import date, datetime
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
from sqlmodel import Field, SQLModel, JSON, Column
|
from sqlmodel import Field, SQLModel, JSON, Column
|
||||||
import enum
|
import enum
|
||||||
|
|
||||||
|
|
||||||
class User(SQLModel, table=True):
|
class User(SQLModel, table=True):
|
||||||
id: Optional[int] = Field(default=None, primary_key=True)
|
id: Optional[int] = Field(default=None, primary_key=True)
|
||||||
email: str = Field(unique=True, index=True)
|
email: str = Field(unique=True, index=True)
|
||||||
username: str = Field(unique=True, index=True)
|
username: str = Field(unique=True, index=True)
|
||||||
hashed_password: str
|
hashed_password: str
|
||||||
first_name: Optional[str] = None
|
first_name: Optional[str] = None
|
||||||
last_name: Optional[str] = None
|
last_name: Optional[str] = None
|
||||||
is_active: bool = Field(default=True)
|
is_active: bool = Field(default=True)
|
||||||
is_superuser: bool = Field(default=False)
|
is_superuser: bool = Field(default=False)
|
||||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||||
updated_at: datetime = Field(default_factory=datetime.utcnow, sa_column_kwargs={"onupdate": datetime.utcnow})
|
updated_at: datetime = Field(default_factory=datetime.utcnow, sa_column_kwargs={"onupdate": datetime.utcnow})
|
||||||
|
|
||||||
|
|
||||||
class UserProfile(SQLModel, table=True):
|
class UserProfile(SQLModel, table=True):
|
||||||
id: Optional[int] = Field(default=None, primary_key=True)
|
id: Optional[int] = Field(default=None, primary_key=True)
|
||||||
user_id: int = Field(foreign_key="user.id", unique=True)
|
user_id: int = Field(foreign_key="user.id", unique=True)
|
||||||
encrypted_openai_api_key: Optional[str] = None
|
encrypted_openai_api_key: Optional[str] = None
|
||||||
default_ticker: str = Field(default="SPY")
|
default_ticker: str = Field(default="SPY")
|
||||||
preferred_research_depth: int = Field(default=3)
|
preferred_research_depth: int = Field(default=3)
|
||||||
preferred_shallow_thinker: str = Field(default="gpt-4o-mini")
|
preferred_shallow_thinker: str = Field(default="gpt-4o-mini")
|
||||||
preferred_deep_thinker: str = Field(default="gpt-4o")
|
preferred_deep_thinker: str = Field(default="gpt-4o")
|
||||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||||
updated_at: datetime = Field(default_factory=datetime.utcnow, sa_column_kwargs={"onupdate": datetime.utcnow})
|
updated_at: datetime = Field(default_factory=datetime.utcnow, sa_column_kwargs={"onupdate": datetime.utcnow})
|
||||||
|
|
||||||
|
|
||||||
class AnalysisStatus(str, enum.Enum):
|
class AnalysisStatus(str, enum.Enum):
|
||||||
PENDING = "pending"
|
PENDING = "pending"
|
||||||
RUNNING = "running"
|
RUNNING = "running"
|
||||||
COMPLETED = "completed"
|
COMPLETED = "completed"
|
||||||
FAILED = "failed"
|
FAILED = "failed"
|
||||||
CANCELLED = "cancelled"
|
CANCELLED = "cancelled"
|
||||||
|
|
||||||
|
|
||||||
class AnalysisSession(SQLModel, table=True):
|
class AnalysisSession(SQLModel, table=True):
|
||||||
id: Optional[int] = Field(default=None, primary_key=True)
|
id: Optional[int] = Field(default=None, primary_key=True)
|
||||||
user_id: int = Field(foreign_key="user.id")
|
user_id: int = Field(foreign_key="user.id")
|
||||||
ticker: str
|
ticker: str
|
||||||
analysis_date: date
|
analysis_date: date
|
||||||
analysts_selected: List[str] = Field(sa_column=Column(JSON))
|
analysts_selected: List[str] = Field(sa_column=Column(JSON))
|
||||||
research_depth: int
|
research_depth: int
|
||||||
llm_provider: str
|
llm_provider: str
|
||||||
backend_url: str
|
backend_url: str
|
||||||
shallow_thinker: str
|
shallow_thinker: str
|
||||||
deep_thinker: str
|
deep_thinker: str
|
||||||
status: AnalysisStatus = Field(default=AnalysisStatus.PENDING)
|
status: AnalysisStatus = Field(default=AnalysisStatus.PENDING)
|
||||||
final_report: Optional[str] = None
|
final_report: Optional[str] = None
|
||||||
error_message: Optional[str] = None
|
error_message: Optional[str] = None
|
||||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||||
started_at: Optional[datetime] = None
|
started_at: Optional[datetime] = None
|
||||||
completed_at: Optional[datetime] = None
|
completed_at: Optional[datetime] = None
|
||||||
|
|
@ -1,48 +1,48 @@
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Generic, TypeVar, Optional, List
|
from typing import Generic, TypeVar, Optional, List
|
||||||
from sqlmodel import SQLModel
|
from sqlmodel import SQLModel
|
||||||
from app.core.schemas.user import UserCreate, UserUpdate
|
from app.core.schemas.user import UserCreate, UserUpdate
|
||||||
from app.domain.models import User
|
from app.domain.models import User
|
||||||
|
|
||||||
ModelType = TypeVar("ModelType", bound=SQLModel)
|
ModelType = TypeVar("ModelType", bound=SQLModel)
|
||||||
CreateSchemaType = TypeVar("CreateSchemaType", bound=SQLModel)
|
CreateSchemaType = TypeVar("CreateSchemaType", bound=SQLModel)
|
||||||
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=SQLModel)
|
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=SQLModel)
|
||||||
|
|
||||||
class IRepository(Generic[ModelType], ABC):
|
class IRepository(Generic[ModelType], ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get(self, id: int) -> Optional[ModelType]:
|
def get(self, id: int) -> Optional[ModelType]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_multi(self, *, skip: int = 0, limit: int = 100) -> List[ModelType]:
|
def get_multi(self, *, skip: int = 0, limit: int = 100) -> List[ModelType]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def create(self, *, obj_in: CreateSchemaType) -> ModelType:
|
def create(self, *, obj_in: CreateSchemaType) -> ModelType:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def update(self, *, db_obj: ModelType, obj_in: UpdateSchemaType) -> ModelType:
|
def update(self, *, db_obj: ModelType, obj_in: UpdateSchemaType) -> ModelType:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def remove(self, *, id: int) -> ModelType:
|
def remove(self, *, id: int) -> ModelType:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class IUserRepository(IRepository[User], ABC):
|
class IUserRepository(IRepository[User], ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_by_email(self, *, email: str) -> Optional[User]:
|
def get_by_email(self, *, email: str) -> Optional[User]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def create(self, *, obj_in: UserCreate) -> User:
|
def create(self, *, obj_in: UserCreate) -> User:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def update(self, *, db_obj: User, obj_in: UserUpdate) -> User:
|
def update(self, *, db_obj: User, obj_in: UserUpdate) -> User:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def is_superuser(self, *, user: User) -> bool:
|
def is_superuser(self, *, user: User) -> bool:
|
||||||
pass
|
pass
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,9 @@
|
||||||
from sqlmodel import create_engine, Session
|
from sqlmodel import create_engine, Session
|
||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
|
|
||||||
engine = create_engine(settings.DATABASE_URL, echo=True, connect_args={"check_same_thread": False})
|
engine = create_engine(settings.DATABASE_URL, echo=True, connect_args={"check_same_thread": False})
|
||||||
|
|
||||||
def get_db():
|
def get_db():
|
||||||
with Session(engine) as session:
|
with Session(engine) as session:
|
||||||
yield session
|
yield session
|
||||||
|
|
|
||||||
|
|
@ -1,53 +1,53 @@
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from sqlmodel import Session, select
|
from sqlmodel import Session, select
|
||||||
from app.domain.models import User
|
from app.domain.models import User
|
||||||
from app.core.schemas.user import UserCreate, UserUpdate
|
from app.core.schemas.user import UserCreate, UserUpdate
|
||||||
from app.domain.repositories import IUserRepository
|
from app.domain.repositories import IUserRepository
|
||||||
from app.core.security import get_password_hash
|
from app.core.security import get_password_hash
|
||||||
|
|
||||||
class UserRepository(IUserRepository):
|
class UserRepository(IUserRepository):
|
||||||
def __init__(self, db: Session):
|
def __init__(self, db: Session):
|
||||||
self.db = db
|
self.db = db
|
||||||
|
|
||||||
def get(self, id: int) -> Optional[User]:
|
def get(self, id: int) -> Optional[User]:
|
||||||
return self.db.get(User, id)
|
return self.db.get(User, id)
|
||||||
|
|
||||||
def get_by_email(self, *, email: str) -> Optional[User]:
|
def get_by_email(self, *, email: str) -> Optional[User]:
|
||||||
statement = select(User).where(User.email == email)
|
statement = select(User).where(User.email == email)
|
||||||
return self.db.exec(statement).first()
|
return self.db.exec(statement).first()
|
||||||
|
|
||||||
def get_multi(self, *, skip: int = 0, limit: int = 100) -> list[User]:
|
def get_multi(self, *, skip: int = 0, limit: int = 100) -> list[User]:
|
||||||
statement = select(User).offset(skip).limit(limit)
|
statement = select(User).offset(skip).limit(limit)
|
||||||
return self.db.exec(statement).all()
|
return self.db.exec(statement).all()
|
||||||
|
|
||||||
def create(self, *, obj_in: UserCreate) -> User:
|
def create(self, *, obj_in: UserCreate) -> User:
|
||||||
db_obj = User(
|
db_obj = User(
|
||||||
email=obj_in.email,
|
email=obj_in.email,
|
||||||
username=obj_in.username,
|
username=obj_in.username,
|
||||||
hashed_password=get_password_hash(obj_in.password),
|
hashed_password=get_password_hash(obj_in.password),
|
||||||
first_name=obj_in.first_name,
|
first_name=obj_in.first_name,
|
||||||
last_name=obj_in.last_name,
|
last_name=obj_in.last_name,
|
||||||
)
|
)
|
||||||
self.db.add(db_obj)
|
self.db.add(db_obj)
|
||||||
self.db.commit()
|
self.db.commit()
|
||||||
self.db.refresh(db_obj)
|
self.db.refresh(db_obj)
|
||||||
return db_obj
|
return db_obj
|
||||||
|
|
||||||
def update(self, *, db_obj: User, obj_in: UserUpdate) -> User:
|
def update(self, *, db_obj: User, obj_in: UserUpdate) -> User:
|
||||||
update_data = obj_in.dict(exclude_unset=True)
|
update_data = obj_in.dict(exclude_unset=True)
|
||||||
for field, value in update_data.items():
|
for field, value in update_data.items():
|
||||||
setattr(db_obj, field, value)
|
setattr(db_obj, field, value)
|
||||||
|
|
||||||
self.db.add(db_obj)
|
self.db.add(db_obj)
|
||||||
self.db.commit()
|
self.db.commit()
|
||||||
self.db.refresh(db_obj)
|
self.db.refresh(db_obj)
|
||||||
return db_obj
|
return db_obj
|
||||||
|
|
||||||
def remove(self, *, id: int) -> User:
|
def remove(self, *, id: int) -> User:
|
||||||
db_obj = self.db.get(User, id)
|
db_obj = self.db.get(User, id)
|
||||||
self.db.delete(db_obj)
|
self.db.delete(db_obj)
|
||||||
self.db.commit()
|
self.db.commit()
|
||||||
return db_obj
|
return db_obj
|
||||||
|
|
||||||
def is_superuser(self, *, user: User) -> bool:
|
def is_superuser(self, *, user: User) -> bool:
|
||||||
return user.is_superuser
|
return user.is_superuser
|
||||||
|
|
|
||||||
70
app/main.py
70
app/main.py
|
|
@ -1,36 +1,36 @@
|
||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from starlette.middleware.cors import CORSMiddleware
|
from starlette.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
# Add project root to path to allow importing tradingagents
|
# Add project root to path to allow importing tradingagents
|
||||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")))
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")))
|
||||||
|
|
||||||
from app.api.router import api_router
|
from app.api.router import api_router
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.infrastructure.database import engine
|
from app.infrastructure.database import engine
|
||||||
from sqlmodel import SQLModel
|
from sqlmodel import SQLModel
|
||||||
|
|
||||||
def create_tables():
|
def create_tables():
|
||||||
SQLModel.metadata.create_all(engine)
|
SQLModel.metadata.create_all(engine)
|
||||||
|
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
title=settings.PROJECT_NAME,
|
title=settings.PROJECT_NAME,
|
||||||
openapi_url=f"{settings.API_V1_STR}/openapi.json"
|
openapi_url=f"{settings.API_V1_STR}/openapi.json"
|
||||||
)
|
)
|
||||||
|
|
||||||
@app.on_event("startup")
|
@app.on_event("startup")
|
||||||
def on_startup():
|
def on_startup():
|
||||||
create_tables()
|
create_tables()
|
||||||
|
|
||||||
# Set all CORS enabled origins
|
# Set all CORS enabled origins
|
||||||
if settings.CORS_ALLOWED_ORIGINS:
|
if settings.CORS_ALLOWED_ORIGINS:
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
allow_origins=[str(origin) for origin in settings.CORS_ALLOWED_ORIGINS],
|
allow_origins=[str(origin) for origin in settings.CORS_ALLOWED_ORIGINS],
|
||||||
allow_credentials=True,
|
allow_credentials=True,
|
||||||
allow_methods=["*"],
|
allow_methods=["*"],
|
||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
app.include_router(api_router, prefix=settings.API_V1_STR)
|
app.include_router(api_router, prefix=settings.API_V1_STR)
|
||||||
|
|
@ -1,2 +1,2 @@
|
||||||
.env
|
.env
|
||||||
wallet/
|
wallet/
|
||||||
|
|
@ -1,247 +1,299 @@
|
||||||
from sqlmodel import Session
|
import sys
|
||||||
from analysis.domain.repository.analysis_repo import IAnalysisRepository
|
import os
|
||||||
from ulid import ULID
|
sys.path.append(os.path.join(os.path.dirname(__file__), '../../..'))
|
||||||
from analysis.domain.analysis import Analysis as AnalysisVO
|
|
||||||
from analysis.interface.dto import TradingAnalysisRequest, AnalysisProgressUpdate
|
import logging
|
||||||
from fastapi import HTTPException, status, BackgroundTasks
|
from sqlmodel import Session
|
||||||
import asyncio
|
from analysis.domain.repository.analysis_repo import IAnalysisRepository
|
||||||
from datetime import datetime
|
from ulid import ULID
|
||||||
|
from analysis.domain.analysis import Analysis as AnalysisVO
|
||||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
from analysis.interface.dto import TradingAnalysisRequest, AnalysisProgressUpdate
|
||||||
from tradingagents.default_config import DEFAULT_CONFIG
|
from fastapi import HTTPException, status, BackgroundTasks
|
||||||
|
import asyncio
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
class AnalysisService:
|
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||||
def __init__(
|
from tradingagents.default_config import DEFAULT_CONFIG
|
||||||
self,
|
from analysis.application.websocket_manager import WebSocketManager
|
||||||
analysis_repo: IAnalysisRepository,
|
from analysis.infra.db_models.analysis import AnalysisStatus
|
||||||
session: Session,
|
|
||||||
ulid: ULID
|
logger = logging.getLogger(__name__)
|
||||||
):
|
|
||||||
self.analysis_repo = analysis_repo
|
class AnalysisService:
|
||||||
self.session = session
|
def __init__(
|
||||||
self.ulid = ulid
|
self,
|
||||||
|
analysis_repo: IAnalysisRepository,
|
||||||
def get_analysis_list(
|
session: Session,
|
||||||
self,
|
ulid: ULID,
|
||||||
member_id: str
|
websocket_manager: WebSocketManager
|
||||||
) -> list[AnalysisVO]:
|
):
|
||||||
analyses = self.analysis_repo.find_by_member_id(member_id)
|
self.analysis_repo = analysis_repo
|
||||||
if not analyses:
|
self.session = session
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Analysis not found")
|
self.ulid = ulid
|
||||||
return analyses
|
self.websocket_manager = websocket_manager
|
||||||
|
|
||||||
def get_analysis_by_id(
|
def get_analysis_list(
|
||||||
self,
|
self,
|
||||||
analysis_id: str,
|
member_id: str
|
||||||
member_id: str
|
) -> list[AnalysisVO]:
|
||||||
) -> AnalysisVO:
|
analyses = self.analysis_repo.find_by_member_id(member_id)
|
||||||
analysis = self.analysis_repo.find_by_id(analysis_id)
|
if not analyses:
|
||||||
if not analysis:
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Analysis not found")
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Analysis not found")
|
return analyses
|
||||||
|
|
||||||
if analysis.member_id != member_id:
|
def get_analysis_by_id(
|
||||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Access denied")
|
self,
|
||||||
|
analysis_id: str,
|
||||||
return analysis
|
member_id: str
|
||||||
|
) -> AnalysisVO:
|
||||||
def create_analysis(
|
analysis = self.analysis_repo.find_by_id(analysis_id)
|
||||||
self,
|
if not analysis:
|
||||||
member_id: str,
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Analysis not found")
|
||||||
request: TradingAnalysisRequest,
|
|
||||||
background_tasks: BackgroundTasks
|
if analysis.member_id != member_id:
|
||||||
) -> AnalysisVO:
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Access denied")
|
||||||
# 분석 요청 생성
|
|
||||||
analysis_id = self.ulid.generate()
|
return analysis
|
||||||
now = datetime.now()
|
|
||||||
|
def get_analysis_sessions_by_member(
|
||||||
analysis_vo = AnalysisVO(
|
self,
|
||||||
id=analysis_id,
|
member_id: str
|
||||||
member_id=member_id,
|
) -> list[AnalysisVO]:
|
||||||
ticker=request.ticker,
|
analyses = self.analysis_repo.find_by_member_id(member_id)
|
||||||
analysis_date=request.analysis_date,
|
if not analyses:
|
||||||
analysts_selected=[analyst.value for analyst in request.analysts],
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Analysis not found")
|
||||||
research_depth=request.research_depth,
|
return analyses
|
||||||
llm_provider=request.llm_provider,
|
|
||||||
backend_url=request.backend_url,
|
def create_analysis(
|
||||||
shallow_thinker=request.shallow_thinker,
|
self,
|
||||||
deep_thinker=request.deep_thinker,
|
member_id: str,
|
||||||
status="pending",
|
request: TradingAnalysisRequest,
|
||||||
created_at=now,
|
background_tasks: BackgroundTasks
|
||||||
updated_at=now
|
) -> AnalysisVO:
|
||||||
)
|
# 분석 요청 생성
|
||||||
|
analysis_id = self.ulid.generate()
|
||||||
saved_analysis = self.analysis_repo.save(analysis_vo)
|
now = datetime.now()
|
||||||
self.session.commit()
|
|
||||||
|
analysis_vo = AnalysisVO(
|
||||||
# 백그라운드에서 분석 실행
|
id=analysis_id,
|
||||||
background_tasks.add_task(self._run_analysis, saved_analysis.id)
|
member_id=member_id,
|
||||||
|
ticker=request.ticker,
|
||||||
return saved_analysis
|
analysis_date=request.analysis_date,
|
||||||
|
analysts_selected=[analyst.value for analyst in request.analysts],
|
||||||
async def _run_analysis(self, analysis_id: str):
|
research_depth=request.research_depth,
|
||||||
"""백그라운드에서 실제 분석을 실행하는 메서드"""
|
llm_provider=request.llm_provider,
|
||||||
try:
|
backend_url=request.backend_url,
|
||||||
# 분석 상태를 RUNNING으로 변경
|
shallow_thinker=request.shallow_thinker,
|
||||||
analysis = self.analysis_repo.find_by_id(analysis_id)
|
deep_thinker=request.deep_thinker,
|
||||||
if analysis:
|
status=AnalysisStatus.PENDING,
|
||||||
analysis.status = "running"
|
created_at=now,
|
||||||
analysis.updated_at = datetime.now()
|
updated_at=now
|
||||||
self.analysis_repo.update(analysis)
|
)
|
||||||
self.session.commit()
|
|
||||||
|
saved_analysis = self.analysis_repo.save(analysis_vo)
|
||||||
# 분석 정보 조회
|
if not saved_analysis:
|
||||||
analysis = self.analysis_repo.find_by_id(analysis_id)
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to save analysis")
|
||||||
if not analysis:
|
|
||||||
return
|
self.session.commit()
|
||||||
|
|
||||||
# TradingAgentsGraph 설정 및 실행
|
# Register analysis with websocket manager
|
||||||
config = self._create_config(analysis)
|
self.websocket_manager.register_analysis(saved_analysis.id, member_id)
|
||||||
|
|
||||||
# 분석 실행 (실제 구현)
|
# 백그라운드에서 분석 실행
|
||||||
await self._execute_trading_analysis(analysis_id, analysis, config)
|
background_tasks.add_task(self._run_analysis, saved_analysis.id)
|
||||||
|
|
||||||
# 분석 완료 상태로 변경
|
return saved_analysis
|
||||||
analysis = self.analysis_repo.find_by_id(analysis_id)
|
|
||||||
if analysis:
|
async def _run_analysis(self, analysis_id: str):
|
||||||
analysis.status = "completed"
|
"""백그라운드에서 실제 분석을 실행하는 메서드"""
|
||||||
analysis.completed_at = datetime.now()
|
try:
|
||||||
analysis.updated_at = datetime.now()
|
analysis = AnalysisVO(
|
||||||
self.analysis_repo.update(analysis)
|
id=analysis_id,
|
||||||
self.session.commit()
|
status=AnalysisStatus.RUNNING,
|
||||||
|
updated_at=datetime.now()
|
||||||
except Exception as e:
|
)
|
||||||
# 에러 발생 시 실패 상태로 변경
|
|
||||||
analysis = self.analysis_repo.find_by_id(analysis_id)
|
analysis = self.analysis_repo.update(analysis)
|
||||||
if analysis:
|
if not analysis:
|
||||||
analysis.status = "failed"
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Analysis not found")
|
||||||
analysis.error_message = str(e)
|
|
||||||
analysis.completed_at = datetime.now()
|
await self.websocket_manager.send_analysis_update(
|
||||||
analysis.updated_at = datetime.now()
|
analysis_id=analysis_id,
|
||||||
self.analysis_repo.update(analysis)
|
update_type="status_changed",
|
||||||
self.session.commit()
|
data={"status": "running", "message": "Analysis started"}
|
||||||
|
)
|
||||||
def _create_config(self, analysis: AnalysisVO) -> dict:
|
|
||||||
"""분석 설정을 생성하는 메서드"""
|
|
||||||
config = DEFAULT_CONFIG.copy() if DEFAULT_CONFIG else {}
|
|
||||||
config.update({
|
# TradingAgentsGraph 설정 및 실행
|
||||||
"max_debate_rounds": analysis.research_depth,
|
if analysis:
|
||||||
"max_risk_discuss_rounds": analysis.research_depth,
|
config = self._create_config(analysis)
|
||||||
"quick_think_llm": analysis.shallow_thinker,
|
|
||||||
"deep_think_llm": analysis.deep_thinker,
|
# 분석 실행 (실제 구현)
|
||||||
"backend_url": analysis.backend_url,
|
await self._execute_trading_analysis(analysis_id, analysis, config)
|
||||||
"llm_provider": analysis.llm_provider.lower(),
|
|
||||||
})
|
# 완료 상태로 업데이트
|
||||||
return config
|
completed_analysis = AnalysisVO(
|
||||||
|
id=analysis_id,
|
||||||
async def _execute_trading_analysis(self, analysis_id: str, analysis: AnalysisVO, config: dict):
|
status=AnalysisStatus.COMPLETED,
|
||||||
"""실제 TradingAgentsGraph를 실행하는 메서드"""
|
completed_at=datetime.now(),
|
||||||
try:
|
updated_at=datetime.now()
|
||||||
# TradingAgentsGraph 초기화
|
)
|
||||||
graph = TradingAgentsGraph(
|
self.analysis_repo.update(completed_analysis)
|
||||||
analysis.analysts_selected,
|
self.session.commit()
|
||||||
config=config,
|
|
||||||
debug=True
|
|
||||||
)
|
except Exception as e:
|
||||||
|
now = datetime.now()
|
||||||
# 초기 상태 생성
|
updates = AnalysisVO(
|
||||||
init_agent_state = graph.propagator.create_initial_state(
|
status=AnalysisStatus.FAILED,
|
||||||
analysis.ticker,
|
error_message=str(e),
|
||||||
analysis.analysis_date
|
completed_at = now,
|
||||||
)
|
updated_at = now
|
||||||
args = graph.propagator.get_graph_args()
|
)
|
||||||
|
|
||||||
# 분석 실행 및 결과 처리
|
self.analysis_repo.update(updates)
|
||||||
trace = []
|
self.session.commit()
|
||||||
async for chunk in graph.graph.astream(init_agent_state, **args):
|
|
||||||
trace.append(chunk)
|
|
||||||
|
def _create_config(self, analysis: AnalysisVO) -> dict:
|
||||||
# 실시간으로 분석 결과 업데이트
|
"""분석 설정을 생성하는 메서드"""
|
||||||
await self._process_analysis_chunk(analysis_id, chunk)
|
config = {}
|
||||||
|
config.update({
|
||||||
# 최종 결과 처리
|
"max_debate_rounds": analysis.research_depth,
|
||||||
if trace:
|
"max_risk_discuss_rounds": analysis.research_depth,
|
||||||
final_state = trace[-1]
|
"quick_think_llm": analysis.shallow_thinker,
|
||||||
final_decision = graph.process_signal(final_state.get("final_trade_decision", ""))
|
"deep_think_llm": analysis.deep_thinker,
|
||||||
|
"backend_url": analysis.backend_url,
|
||||||
# 최종 보고서 생성
|
"llm_provider": analysis.llm_provider.lower(),
|
||||||
final_report = self._generate_final_report(final_state)
|
})
|
||||||
|
return config
|
||||||
# 최종 결과 저장
|
|
||||||
self.analysis_repo.update(analysis_id, {
|
async def _execute_trading_analysis(self, analysis_id: str, analysis: AnalysisVO, config: dict):
|
||||||
"final_trade_decision": final_decision,
|
"""실제 TradingAgentsGraph를 실행하는 메서드"""
|
||||||
"final_report": final_report
|
try:
|
||||||
})
|
logger.info(f"Starting trading analysis for {analysis_id} with ticker {analysis.ticker}")
|
||||||
self.session.commit()
|
logger.info(f"Analysts selected: {analysis.analysts_selected}")
|
||||||
|
logger.info(f"Config: {config}")
|
||||||
except Exception as e:
|
|
||||||
raise Exception(f"Analysis execution failed: {str(e)}")
|
# TradingAgentsGraph 초기화
|
||||||
|
graph = TradingAgentsGraph(
|
||||||
async def _process_analysis_chunk(self, analysis_id: str, chunk: dict):
|
analysis.analysts_selected,
|
||||||
"""분석 중간 결과를 처리하고 저장하는 메서드"""
|
config=config,
|
||||||
updates = {}
|
debug=True
|
||||||
|
)
|
||||||
# 개별 분석가 보고서 업데이트
|
logger.info("TradingAgentsGraph initialized successfully")
|
||||||
if "market_report" in chunk and chunk["market_report"]:
|
|
||||||
updates["market_report"] = chunk["market_report"]
|
# 초기 상태 생성
|
||||||
|
init_agent_state = graph.propagator.create_initial_state(
|
||||||
if "sentiment_report" in chunk and chunk["sentiment_report"]:
|
analysis.ticker,
|
||||||
updates["sentiment_report"] = chunk["sentiment_report"]
|
analysis.analysis_date
|
||||||
|
)
|
||||||
if "news_report" in chunk and chunk["news_report"]:
|
args = graph.propagator.get_graph_args()
|
||||||
updates["news_report"] = chunk["news_report"]
|
|
||||||
|
# 분석 실행 및 결과 처리
|
||||||
if "fundamentals_report" in chunk and chunk["fundamentals_report"]:
|
logger.info("Starting graph execution...")
|
||||||
updates["fundamentals_report"] = chunk["fundamentals_report"]
|
trace = []
|
||||||
|
chunk_count = 0
|
||||||
# 팀별 의사결정 과정 업데이트
|
async for chunk in graph.graph.astream(init_agent_state, **args):
|
||||||
if "investment_debate_state" in chunk and chunk["investment_debate_state"]:
|
chunk_count += 1
|
||||||
updates["investment_debate_state"] = chunk["investment_debate_state"]
|
logger.info(f"Processing chunk {chunk_count}: {list(chunk.keys()) if chunk else 'Empty chunk'}")
|
||||||
|
trace.append(chunk)
|
||||||
if "trader_investment_plan" in chunk and chunk["trader_investment_plan"]:
|
|
||||||
updates["trader_investment_plan"] = chunk["trader_investment_plan"]
|
# 실시간으로 분석 결과 업데이트
|
||||||
|
await self._process_analysis_chunk(analysis_id, chunk)
|
||||||
if "risk_debate_state" in chunk and chunk["risk_debate_state"]:
|
|
||||||
updates["risk_debate_state"] = chunk["risk_debate_state"]
|
# 최종 결과 처리
|
||||||
|
if trace:
|
||||||
# 업데이트가 있는 경우 저장
|
final_state = trace[-1]
|
||||||
if updates:
|
final_decision = graph.process_signal(final_state.get("final_trade_decision", ""))
|
||||||
self.analysis_repo.update(analysis_id, updates)
|
|
||||||
self.session.commit()
|
# 최종 보고서 생성
|
||||||
|
final_report = self._generate_final_report(final_state)
|
||||||
def _generate_final_report(self, final_state: dict) -> str:
|
analysis.final_trade_decision = final_decision
|
||||||
"""최종 통합 보고서를 생성하는 메서드"""
|
analysis.final_report = final_report
|
||||||
report_parts = []
|
|
||||||
|
# 최종 결과 저장
|
||||||
# Analyst Team Reports
|
updates = AnalysisVO(
|
||||||
if any(final_state.get(section) for section in ["market_report", "sentiment_report", "news_report", "fundamentals_report"]):
|
id=analysis_id,
|
||||||
report_parts.append("## Analyst Team Reports")
|
final_trade_decision=final_decision,
|
||||||
|
final_report=final_report
|
||||||
if final_state.get("market_report"):
|
)
|
||||||
report_parts.append(f"### Market Analysis\n{final_state['market_report']}")
|
self.analysis_repo.update(updates)
|
||||||
if final_state.get("sentiment_report"):
|
|
||||||
report_parts.append(f"### Social Sentiment\n{final_state['sentiment_report']}")
|
self.session.commit()
|
||||||
if final_state.get("news_report"):
|
|
||||||
report_parts.append(f"### News Analysis\n{final_state['news_report']}")
|
except Exception as e:
|
||||||
if final_state.get("fundamentals_report"):
|
raise Exception(f"Analysis execution failed: {str(e)}")
|
||||||
report_parts.append(f"### Fundamentals Analysis\n{final_state['fundamentals_report']}")
|
|
||||||
|
async def _process_analysis_chunk(self, analysis_id: str, chunk: dict):
|
||||||
# Research Team Reports
|
"""분석 중간 결과를 처리하고 저장하는 메서드"""
|
||||||
if final_state.get("investment_debate_state"):
|
updates = {}
|
||||||
report_parts.append("## Research Team Decision")
|
|
||||||
debate_state = final_state["investment_debate_state"]
|
# 개별 분석가 보고서 업데이트
|
||||||
if debate_state.get("judge_decision"):
|
if "market_report" in chunk and chunk["market_report"]:
|
||||||
report_parts.append(f"{debate_state['judge_decision']}")
|
updates["market_report"] = chunk["market_report"]
|
||||||
|
|
||||||
# Trading Team Reports
|
if "sentiment_report" in chunk and chunk["sentiment_report"]:
|
||||||
if final_state.get("trader_investment_plan"):
|
updates["sentiment_report"] = chunk["sentiment_report"]
|
||||||
report_parts.append("## Trading Team Plan")
|
|
||||||
report_parts.append(f"{final_state['trader_investment_plan']}")
|
if "news_report" in chunk and chunk["news_report"]:
|
||||||
|
updates["news_report"] = chunk["news_report"]
|
||||||
# Portfolio Management Decision
|
|
||||||
if final_state.get("risk_debate_state") and final_state["risk_debate_state"].get("judge_decision"):
|
if "fundamentals_report" in chunk and chunk["fundamentals_report"]:
|
||||||
report_parts.append("## Portfolio Management Decision")
|
updates["fundamentals_report"] = chunk["fundamentals_report"]
|
||||||
report_parts.append(f"{final_state['risk_debate_state']['judge_decision']}")
|
|
||||||
|
# 팀별 의사결정 과정 업데이트
|
||||||
return "\n\n".join(report_parts) if report_parts else "No analysis results available."
|
if "investment_debate_state" in chunk and chunk["investment_debate_state"]:
|
||||||
|
updates["investment_debate_state"] = chunk["investment_debate_state"]
|
||||||
|
|
||||||
|
if "trader_investment_plan" in chunk and chunk["trader_investment_plan"]:
|
||||||
|
updates["trader_investment_plan"] = chunk["trader_investment_plan"]
|
||||||
|
|
||||||
|
if "risk_debate_state" in chunk and chunk["risk_debate_state"]:
|
||||||
|
updates["risk_debate_state"] = chunk["risk_debate_state"]
|
||||||
|
|
||||||
|
# 업데이트가 있는 경우 저장
|
||||||
|
if updates:
|
||||||
|
# analysis_id를 포함한 AnalysisVO 객체 생성
|
||||||
|
updates["id"] = analysis_id
|
||||||
|
updates_vo = AnalysisVO(**updates)
|
||||||
|
self.analysis_repo.update(updates_vo)
|
||||||
|
self.session.commit()
|
||||||
|
|
||||||
|
def _generate_final_report(self, final_state: dict) -> str:
|
||||||
|
"""최종 통합 보고서를 생성하는 메서드"""
|
||||||
|
report_parts = []
|
||||||
|
|
||||||
|
# Analyst Team Reports
|
||||||
|
if any(final_state.get(section) for section in ["market_report", "sentiment_report", "news_report", "fundamentals_report"]):
|
||||||
|
report_parts.append("## Analyst Team Reports")
|
||||||
|
|
||||||
|
if final_state.get("market_report"):
|
||||||
|
report_parts.append(f"### Market Analysis\n{final_state['market_report']}")
|
||||||
|
if final_state.get("sentiment_report"):
|
||||||
|
report_parts.append(f"### Social Sentiment\n{final_state['sentiment_report']}")
|
||||||
|
if final_state.get("news_report"):
|
||||||
|
report_parts.append(f"### News Analysis\n{final_state['news_report']}")
|
||||||
|
if final_state.get("fundamentals_report"):
|
||||||
|
report_parts.append(f"### Fundamentals Analysis\n{final_state['fundamentals_report']}")
|
||||||
|
|
||||||
|
# Research Team Reports
|
||||||
|
if final_state.get("investment_debate_state"):
|
||||||
|
report_parts.append("## Research Team Decision")
|
||||||
|
debate_state = final_state["investment_debate_state"]
|
||||||
|
if debate_state.get("judge_decision"):
|
||||||
|
report_parts.append(f"{debate_state['judge_decision']}")
|
||||||
|
|
||||||
|
# Trading Team Reports
|
||||||
|
if final_state.get("trader_investment_plan"):
|
||||||
|
report_parts.append("## Trading Team Plan")
|
||||||
|
report_parts.append(f"{final_state['trader_investment_plan']}")
|
||||||
|
|
||||||
|
# Portfolio Management Decision
|
||||||
|
if final_state.get("risk_debate_state") and final_state["risk_debate_state"].get("judge_decision"):
|
||||||
|
report_parts.append("## Portfolio Management Decision")
|
||||||
|
report_parts.append(f"{final_state['risk_debate_state']['judge_decision']}")
|
||||||
|
|
||||||
|
return "\n\n".join(report_parts) if report_parts else "No analysis results available."
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,65 @@
|
||||||
|
from typing import Dict, Set
|
||||||
|
from fastapi import WebSocket
|
||||||
|
import json
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
|
class WebSocketManager:
|
||||||
|
def __init__(self):
|
||||||
|
# Store active connections by member_id
|
||||||
|
self.active_connections: Dict[str, Set[WebSocket]] = {}
|
||||||
|
# Store analysis_id to member_id mapping
|
||||||
|
self.analysis_member_map: Dict[str, str] = {}
|
||||||
|
|
||||||
|
async def connect(self, websocket: WebSocket, member_id: str):
|
||||||
|
await websocket.accept()
|
||||||
|
if member_id not in self.active_connections:
|
||||||
|
self.active_connections[member_id] = set()
|
||||||
|
self.active_connections[member_id].add(websocket)
|
||||||
|
|
||||||
|
def disconnect(self, websocket: WebSocket, member_id: str):
|
||||||
|
if member_id in self.active_connections:
|
||||||
|
self.active_connections[member_id].discard(websocket)
|
||||||
|
if not self.active_connections[member_id]:
|
||||||
|
del self.active_connections[member_id]
|
||||||
|
|
||||||
|
def register_analysis(self, analysis_id: str, member_id: str):
|
||||||
|
"""Register which member owns which analysis"""
|
||||||
|
self.analysis_member_map[analysis_id] = member_id
|
||||||
|
|
||||||
|
async def send_analysis_update(self, analysis_id: str, update_type: str, data: dict):
|
||||||
|
"""Send analysis update to the member who owns the analysis"""
|
||||||
|
member_id = self.analysis_member_map.get(analysis_id)
|
||||||
|
if not member_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
message = {
|
||||||
|
"type": "analysis_update",
|
||||||
|
"analysis_id": analysis_id,
|
||||||
|
"update_type": update_type,
|
||||||
|
"data": data,
|
||||||
|
"timestamp": datetime.now().isoformat()
|
||||||
|
}
|
||||||
|
|
||||||
|
await self.send_to_member(member_id, message)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
async def send_to_member(self, member_id: str, message: dict|str):
|
||||||
|
"""Send message to all connections of a specific member"""
|
||||||
|
if member_id not in self.active_connections:
|
||||||
|
return
|
||||||
|
|
||||||
|
dead_connections = set()
|
||||||
|
for connection in self.active_connections[member_id]:
|
||||||
|
try:
|
||||||
|
if isinstance(message, dict):
|
||||||
|
await connection.send_json(message)
|
||||||
|
else:
|
||||||
|
await connection.send_text(message)
|
||||||
|
except Exception:
|
||||||
|
dead_connections.add(connection)
|
||||||
|
|
||||||
|
# Clean up dead connections
|
||||||
|
for connection in dead_connections:
|
||||||
|
self.disconnect(connection, member_id)
|
||||||
|
|
@ -1,19 +1,20 @@
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, field_validator
|
||||||
from datetime import datetime
|
from datetime import datetime, date
|
||||||
from typing import List, Dict
|
from typing import List, Dict, Union
|
||||||
|
from analysis.infra.db_models.analysis import AnalysisStatus
|
||||||
|
|
||||||
class Analysis(BaseModel):
|
class Analysis(BaseModel):
|
||||||
id: str | None = None
|
id: str
|
||||||
member_id: str
|
member_id: str | None = None
|
||||||
ticker: str
|
ticker: str | None = None
|
||||||
analysis_date: str
|
analysis_date: date | None = None
|
||||||
analysts_selected: List[str] = []
|
analysts_selected: list[str] = []
|
||||||
research_depth: int = 3
|
research_depth: int = 3
|
||||||
llm_provider: str = "openai"
|
llm_provider: str = "openai"
|
||||||
backend_url: str = "https://api.openai.com/v1"
|
backend_url: str = "https://api.openai.com/v1"
|
||||||
shallow_thinker: str = "gpt-4o-mini"
|
shallow_thinker: str = "gpt-4o"
|
||||||
deep_thinker: str = "gpt-4o"
|
deep_thinker: str = "o3"
|
||||||
status: str
|
status: AnalysisStatus = AnalysisStatus.PENDING
|
||||||
|
|
||||||
# 개별 분석가 리포트들
|
# 개별 분석가 리포트들
|
||||||
market_report: str | None = None
|
market_report: str | None = None
|
||||||
|
|
@ -33,5 +34,5 @@ class Analysis(BaseModel):
|
||||||
# 실행 결과 정보
|
# 실행 결과 정보
|
||||||
error_message: str | None = None
|
error_message: str | None = None
|
||||||
completed_at: datetime | None = None
|
completed_at: datetime | None = None
|
||||||
created_at: datetime
|
created_at: datetime | None = None
|
||||||
updated_at: datetime
|
updated_at: datetime | None = None
|
||||||
|
|
@ -1,20 +1,20 @@
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from analysis.domain.analysis import Analysis as AnalysisVO
|
from analysis.domain.analysis import Analysis as AnalysisVO
|
||||||
from analysis.interface.dto import TradingAnalysisRequest
|
from analysis.interface.dto import TradingAnalysisRequest
|
||||||
|
|
||||||
class IAnalysisRepository(ABC):
|
class IAnalysisRepository(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def find_by_member_id(self, member_id: str) -> list[AnalysisVO] | None:
|
def find_by_member_id(self, member_id: str) -> list[AnalysisVO] | None:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def find_by_id(self, analysis_id: str) -> AnalysisVO | None:
|
def find_by_id(self, analysis_id: str) -> AnalysisVO | None:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def update(self, analysis: AnalysisVO) -> AnalysisVO | None:
|
def update(self, analysis: AnalysisVO) -> AnalysisVO | None:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def save(self, analysis: AnalysisVO) -> AnalysisVO:
|
def save(self, analysis: AnalysisVO) -> AnalysisVO:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
|
||||||
|
|
@ -1,57 +1,57 @@
|
||||||
from datetime import datetime,date
|
from datetime import datetime,date
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
from sqlmodel import SQLModel, Field, JSON, Relationship
|
from sqlmodel import SQLModel, Field, JSON, Relationship
|
||||||
import enum
|
import enum
|
||||||
from sqlalchemy import Column
|
from sqlalchemy import Column, Text
|
||||||
|
|
||||||
# TYPE_CHECKING을 사용해서 circular import 방지
|
# TYPE_CHECKING을 사용해서 circular import 방지
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from member.infra.db_models.member import Member
|
from member.infra.db_models.member import Member
|
||||||
|
|
||||||
class AnalysisStatus(str, enum.Enum):
|
class AnalysisStatus(str, enum.Enum):
|
||||||
PENDING = "pending"
|
PENDING = "pending"
|
||||||
RUNNING = "running"
|
RUNNING = "running"
|
||||||
COMPLETED = "completed"
|
COMPLETED = "completed"
|
||||||
FAILED = "failed"
|
FAILED = "failed"
|
||||||
CANCELLED = "cancelled"
|
CANCELLED = "cancelled"
|
||||||
|
|
||||||
|
|
||||||
class Analysis(SQLModel, table=True):
|
class Analysis(SQLModel, table=True):
|
||||||
__tablename__ = "analyses"
|
__tablename__ = "analyses"
|
||||||
id: str = Field(default=None, max_length=36, primary_key=True)
|
id: str = Field(default=None, max_length=36, primary_key=True)
|
||||||
|
|
||||||
# 기본 분석 설정 정보
|
# 기본 분석 설정 정보
|
||||||
ticker: str
|
ticker: str
|
||||||
analysis_date: date
|
analysis_date: date
|
||||||
analysts_selected: list[str] = Field(sa_column=Column(JSON))
|
analysts_selected: list[str] = Field(sa_column=Column(JSON))
|
||||||
research_depth: int
|
research_depth: int
|
||||||
llm_provider: str
|
llm_provider: str
|
||||||
backend_url: str
|
backend_url: str
|
||||||
shallow_thinker: str
|
shallow_thinker: str
|
||||||
deep_thinker: str
|
deep_thinker: str
|
||||||
status: AnalysisStatus = Field(default=AnalysisStatus.PENDING)
|
status: AnalysisStatus = Field(default=AnalysisStatus.PENDING)
|
||||||
|
|
||||||
# 개별 분석가 리포트들
|
# 개별 분석가 리포트들
|
||||||
market_report: str | None = Field(default=None, description="Market Analyst 리포트")
|
market_report: str | None = Field(default=None, sa_column=Column(Text), description="Market Analyst 리포트")
|
||||||
sentiment_report: str | None = Field(default=None, description="Social Analyst 리포트")
|
sentiment_report: str | None = Field(default=None, sa_column=Column(Text), description="Social Analyst 리포트")
|
||||||
news_report: str | None = Field(default=None, description="News Analyst 리포트")
|
news_report: str | None = Field(default=None, sa_column=Column(Text), description="News Analyst 리포트")
|
||||||
fundamentals_report: str | None = Field(default=None, description="Fundamentals Analyst 리포트")
|
fundamentals_report: str | None = Field(default=None, sa_column=Column(Text), description="Fundamentals Analyst 리포트")
|
||||||
|
|
||||||
# 팀별 의사결정 과정
|
# 팀별 의사결정 과정
|
||||||
investment_debate_state: dict | None = Field(default=None, sa_column=Column(JSON), description="Research Team 토론 과정")
|
investment_debate_state: dict | None = Field(default=None, sa_column=Column(JSON), description="Research Team 토론 과정")
|
||||||
trader_investment_plan: str | None = Field(default=None, description="Trading Team 계획")
|
trader_investment_plan: str | None = Field(default=None, sa_column=Column(Text), description="Trading Team 계획")
|
||||||
risk_debate_state: dict | None = Field(default=None, sa_column=Column(JSON), description="Risk Management Team 토론 과정")
|
risk_debate_state: dict | None = Field(default=None, sa_column=Column(JSON), description="Risk Management Team 토론 과정")
|
||||||
|
|
||||||
# 최종 결과물
|
# 최종 결과물
|
||||||
final_trade_decision: str | None = Field(default=None, description="최종 거래 결정")
|
final_trade_decision: str | None = Field(default=None, sa_column=Column(Text), description="최종 거래 결정")
|
||||||
final_report: str | None = Field(default=None, description="전체 통합 리포트")
|
final_report: str | None = Field(default=None, sa_column=Column(Text), description="전체 통합 리포트")
|
||||||
|
|
||||||
# 실행 결과 정보
|
# 실행 결과 정보
|
||||||
error_message: str | None = None
|
error_message: str | None = Field(default=None, sa_column=Column(Text))
|
||||||
completed_at: datetime | None = None
|
completed_at: datetime | None = None
|
||||||
created_at : datetime = Field(nullable=False)
|
created_at : datetime = Field(nullable=False)
|
||||||
updated_at : datetime = Field(nullable=False)
|
updated_at : datetime = Field(nullable=False)
|
||||||
|
|
||||||
# Foreign Key와 Relationship 설정
|
# Foreign Key와 Relationship 설정
|
||||||
member_id: str = Field(foreign_key="members.id")
|
member_id: str = Field(foreign_key="members.id")
|
||||||
member: "Member" = Relationship(back_populates="analyses")
|
member: "Member" = Relationship(back_populates="analyses")
|
||||||
|
|
@ -1,80 +1,55 @@
|
||||||
from analysis.domain.repository.analysis_repo import IAnalysisRepository
|
from analysis.domain.repository.analysis_repo import IAnalysisRepository
|
||||||
from sqlmodel import Session, select
|
from sqlmodel import Session, select
|
||||||
from analysis.domain.analysis import Analysis as AnalysisVO
|
from analysis.domain.analysis import Analysis as AnalysisVO
|
||||||
from analysis.infra.db_models.analysis import Analysis, AnalysisStatus
|
from analysis.infra.db_models.analysis import Analysis, AnalysisStatus
|
||||||
from analysis.interface.dto import TradingAnalysisRequest
|
from analysis.interface.dto import TradingAnalysisRequest
|
||||||
from utils.db_utils import row_to_dict
|
from utils.db_utils import row_to_dict
|
||||||
from sqlalchemy.orm import selectinload
|
from sqlalchemy.orm import selectinload
|
||||||
from datetime import datetime, date
|
from datetime import datetime, date
|
||||||
|
|
||||||
class AnalysisRepository(IAnalysisRepository):
|
class AnalysisRepository(IAnalysisRepository):
|
||||||
def __init__(self, session: Session):
|
def __init__(self, session: Session):
|
||||||
self.session = session
|
self.session = session
|
||||||
|
|
||||||
def find_by_member_id(self, member_id: str) -> list[AnalysisVO] | None:
|
def find_by_member_id(self, member_id: str) -> list[AnalysisVO] | None:
|
||||||
query = select(Analysis).where(Analysis.member_id == member_id)
|
query = select(Analysis).where(Analysis.member_id == member_id)
|
||||||
analyses = self.session.exec(query).all()
|
analyses = self.session.exec(query).all()
|
||||||
|
|
||||||
if not analyses:
|
if not analyses:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return [AnalysisVO(**row_to_dict(analysis)) for analysis in analyses]
|
return [AnalysisVO(**row_to_dict(analysis)) for analysis in analyses]
|
||||||
|
|
||||||
def find_by_id(self, analysis_id: str) -> AnalysisVO | None:
|
def find_by_id(self, analysis_id: str) -> AnalysisVO | None:
|
||||||
analysis = self.session.get(Analysis, analysis_id)
|
analysis = self.session.get(Analysis, analysis_id)
|
||||||
if not analysis:
|
if not analysis:
|
||||||
return None
|
return None
|
||||||
return AnalysisVO(**row_to_dict(analysis))
|
return AnalysisVO(**row_to_dict(analysis))
|
||||||
|
|
||||||
def save(self, analysis: AnalysisVO) -> AnalysisVO:
|
def save(self, analysis: AnalysisVO) -> AnalysisVO:
|
||||||
new_analysis = Analysis(
|
new_analysis = Analysis(
|
||||||
id=analysis.id,
|
**analysis.model_dump()
|
||||||
member_id=analysis.member_id,
|
)
|
||||||
ticker=analysis.ticker,
|
|
||||||
analysis_date=date.fromisoformat(analysis.analysis_date),
|
self.session.add(new_analysis)
|
||||||
analysts_selected=analysis.analysts_selected,
|
self.session.flush()
|
||||||
research_depth=analysis.research_depth,
|
self.session.refresh(new_analysis)
|
||||||
llm_provider=analysis.llm_provider,
|
|
||||||
backend_url=analysis.backend_url,
|
analysis.id = new_analysis.id
|
||||||
shallow_thinker=analysis.shallow_thinker,
|
return analysis
|
||||||
deep_thinker=analysis.deep_thinker,
|
|
||||||
status=analysis.status,
|
def update(self, analysis_vo: AnalysisVO) -> AnalysisVO | None:
|
||||||
market_report=analysis.market_report,
|
analysis = self.session.get(Analysis, analysis_vo.id)
|
||||||
sentiment_report=analysis.sentiment_report,
|
if not analysis:
|
||||||
news_report=analysis.news_report,
|
return None
|
||||||
fundamentals_report=analysis.fundamentals_report,
|
|
||||||
investment_debate_state=analysis.investment_debate_state,
|
# AnalysisVO의 데이터를 SQLModel 객체에 업데이트
|
||||||
trader_investment_plan=analysis.trader_investment_plan,
|
analysis_data = analysis_vo.model_dump(exclude_unset=True)
|
||||||
risk_debate_state=analysis.risk_debate_state,
|
|
||||||
final_trade_decision=analysis.final_trade_decision,
|
analysis.updated_at = datetime.now()
|
||||||
final_report=analysis.final_report,
|
analysis.sqlmodel_update(analysis_data)
|
||||||
error_message=analysis.error_message,
|
|
||||||
completed_at=analysis.completed_at,
|
self.session.flush()
|
||||||
created_at=analysis.created_at,
|
|
||||||
updated_at=analysis.updated_at
|
|
||||||
)
|
return AnalysisVO(**row_to_dict(analysis))
|
||||||
|
|
||||||
self.session.add(new_analysis)
|
|
||||||
self.session.flush()
|
|
||||||
self.session.refresh(new_analysis)
|
|
||||||
|
|
||||||
analysis.id = new_analysis.id
|
|
||||||
return analysis
|
|
||||||
|
|
||||||
def update(self, analysis_vo: AnalysisVO) -> AnalysisVO | None:
|
|
||||||
analysis = self.session.get(Analysis, analysis_vo.id)
|
|
||||||
if not analysis:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# AnalysisVO의 데이터를 SQLModel 객체에 업데이트
|
|
||||||
vo_data = analysis_vo.sqlmodel_dump(exclude_unset=True)
|
|
||||||
for key, value in vo_data.items():
|
|
||||||
if hasattr(analysis, key) and key != 'id': # id는 변경하지 않음
|
|
||||||
setattr(analysis, key, value)
|
|
||||||
|
|
||||||
analysis.updated_at = datetime.now()
|
|
||||||
self.session.add(analysis)
|
|
||||||
self.session.flush()
|
|
||||||
self.session.refresh(analysis)
|
|
||||||
|
|
||||||
return AnalysisVO(**row_to_dict(analysis))
|
|
||||||
|
|
|
||||||
|
|
@ -1,108 +1,137 @@
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
from fastapi import APIRouter, Depends, BackgroundTasks, HTTPException, status
|
from fastapi import APIRouter, Depends, BackgroundTasks, HTTPException, status, WebSocket, WebSocketDisconnect
|
||||||
from analysis.interface.dto import (
|
from analysis.interface.dto import (
|
||||||
AnalysisSessionResponse,
|
AnalysisSessionResponse,
|
||||||
TradingAnalysisRequest,
|
TradingAnalysisRequest,
|
||||||
AnalysisResultResponse
|
AnalysisResultResponse
|
||||||
)
|
)
|
||||||
from utils.auth import get_current_member, CurrentMember
|
from utils.auth import get_current_member, CurrentMember
|
||||||
from dependency_injector.wiring import inject, Provide
|
from dependency_injector.wiring import inject, Provide
|
||||||
from analysis.application.analysis_service import AnalysisService
|
from analysis.application.analysis_service import AnalysisService
|
||||||
from utils.containers import Container
|
from utils.containers import Container
|
||||||
|
from analysis.application.websocket_manager import WebSocketManager
|
||||||
router = APIRouter(prefix="/analysis", tags=["analysis"])
|
|
||||||
|
router = APIRouter(prefix="/analysis", tags=["analysis"])
|
||||||
@router.get("/", response_model=list[AnalysisSessionResponse])
|
|
||||||
@inject
|
@router.get("/", response_model=list[AnalysisSessionResponse])
|
||||||
def get_analysis_list_for_member(
|
@inject
|
||||||
current_member: Annotated[CurrentMember, Depends(get_current_member)],
|
def get_analysis_list_for_member(
|
||||||
analysis_service: Annotated[AnalysisService, Depends(Provide[Container.analysis_service])]
|
current_member: Annotated[CurrentMember, Depends(get_current_member)],
|
||||||
):
|
analysis_service: Annotated[AnalysisService, Depends(Provide[Container.analysis_service])]
|
||||||
"""
|
):
|
||||||
현재 로그인한 사용자의 모든 분석 세션 목록을 조회합니다.
|
"""
|
||||||
"""
|
현재 로그인한 사용자의 모든 분석 세션 목록을 조회합니다.
|
||||||
analyses = analysis_service.get_analysis_list(current_member.id)
|
"""
|
||||||
return [
|
analyses = analysis_service.get_analysis_list(current_member.id)
|
||||||
AnalysisSessionResponse(
|
return [
|
||||||
id=analysis.id,
|
AnalysisSessionResponse(
|
||||||
ticker=analysis.ticker,
|
id=analysis.id,
|
||||||
status=analysis.status
|
ticker=analysis.ticker,
|
||||||
) for analysis in analyses
|
status=analysis.status
|
||||||
]
|
) for analysis in analyses
|
||||||
|
]
|
||||||
@router.post("/start", status_code=201, response_model=AnalysisSessionResponse)
|
|
||||||
@inject
|
@router.post("/start", status_code=201, response_model=AnalysisSessionResponse)
|
||||||
def start_analysis_session(
|
@inject
|
||||||
request: TradingAnalysisRequest,
|
def start_analysis_session(
|
||||||
current_member: Annotated[CurrentMember, Depends(get_current_member)],
|
request: TradingAnalysisRequest,
|
||||||
analysis_service: Annotated[AnalysisService, Depends(Provide[Container.analysis_service])],
|
current_member: Annotated[CurrentMember, Depends(get_current_member)],
|
||||||
background_tasks: BackgroundTasks
|
analysis_service: Annotated[AnalysisService, Depends(Provide[Container.analysis_service])],
|
||||||
):
|
background_tasks: BackgroundTasks
|
||||||
"""
|
):
|
||||||
새로운 분석 세션을 시작합니다.
|
"""
|
||||||
"""
|
새로운 분석 세션을 시작합니다.
|
||||||
try:
|
|
||||||
new_analysis = analysis_service.create_analysis(current_member.id, request, background_tasks)
|
"""
|
||||||
return AnalysisSessionResponse(
|
try:
|
||||||
id=new_analysis.id,
|
new_analysis = analysis_service.create_analysis(current_member.id, request, background_tasks)
|
||||||
ticker=new_analysis.ticker,
|
return AnalysisSessionResponse(
|
||||||
status=new_analysis.status
|
id=new_analysis.id,
|
||||||
)
|
ticker=new_analysis.ticker,
|
||||||
except Exception as e:
|
status=new_analysis.status
|
||||||
raise HTTPException(
|
)
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
except Exception as e:
|
||||||
detail=f"Failed to start analysis: {str(e)}"
|
raise HTTPException(
|
||||||
)
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"Failed to start analysis: {str(e)}"
|
||||||
@router.get("/{analysis_id}", response_model=AnalysisResultResponse)
|
)
|
||||||
@inject
|
|
||||||
def get_analysis_result(
|
@router.get("/{analysis_id}", response_model=AnalysisResultResponse)
|
||||||
analysis_id: str,
|
@inject
|
||||||
current_member: Annotated[CurrentMember, Depends(get_current_member)],
|
def get_analysis_result(
|
||||||
analysis_service: Annotated[AnalysisService, Depends(Provide[Container.analysis_service])]
|
analysis_id: str,
|
||||||
):
|
current_member: Annotated[CurrentMember, Depends(get_current_member)],
|
||||||
"""
|
analysis_service: Annotated[AnalysisService, Depends(Provide[Container.analysis_service])]
|
||||||
특정 분석 세션의 결과를 조회합니다.
|
):
|
||||||
"""
|
"""
|
||||||
analysis = analysis_service.get_analysis_by_id(analysis_id, current_member.id)
|
특정 분석 세션의 결과를 조회합니다.
|
||||||
|
"""
|
||||||
return AnalysisResultResponse(
|
analysis = analysis_service.get_analysis_by_id(analysis_id, current_member.id)
|
||||||
id=analysis.id,
|
|
||||||
ticker=analysis.ticker,
|
return AnalysisResultResponse(
|
||||||
analysis_date=analysis.analysis_date,
|
id=analysis.id,
|
||||||
status=analysis.status,
|
ticker=analysis.ticker,
|
||||||
market_report=analysis.market_report,
|
analysis_date=analysis.analysis_date.isoformat() if hasattr(analysis.analysis_date, 'isoformat') else str(analysis.analysis_date),
|
||||||
sentiment_report=analysis.sentiment_report,
|
status=analysis.status,
|
||||||
news_report=analysis.news_report,
|
market_report=analysis.market_report,
|
||||||
fundamentals_report=analysis.fundamentals_report,
|
sentiment_report=analysis.sentiment_report,
|
||||||
investment_debate_state=analysis.investment_debate_state,
|
news_report=analysis.news_report,
|
||||||
trader_investment_plan=analysis.trader_investment_plan,
|
fundamentals_report=analysis.fundamentals_report,
|
||||||
risk_debate_state=analysis.risk_debate_state,
|
investment_debate_state=analysis.investment_debate_state,
|
||||||
final_trade_decision=analysis.final_trade_decision,
|
trader_investment_plan=analysis.trader_investment_plan,
|
||||||
final_report=analysis.final_report,
|
risk_debate_state=analysis.risk_debate_state,
|
||||||
created_at=analysis.created_at.isoformat(),
|
final_trade_decision=analysis.final_trade_decision,
|
||||||
completed_at=analysis.completed_at.isoformat() if analysis.completed_at else None,
|
final_report=analysis.final_report,
|
||||||
error_message=analysis.error_message
|
created_at=analysis.created_at.isoformat(),
|
||||||
)
|
completed_at=analysis.completed_at.isoformat() if analysis.completed_at else None,
|
||||||
|
error_message=analysis.error_message
|
||||||
@router.get("/{analysis_id}/status")
|
)
|
||||||
@inject
|
|
||||||
def get_analysis_status(
|
@router.get("/{analysis_id}/status")
|
||||||
analysis_id: str,
|
@inject
|
||||||
current_member: Annotated[CurrentMember, Depends(get_current_member)],
|
def get_analysis_status(
|
||||||
analysis_service: Annotated[AnalysisService, Depends(Provide[Container.analysis_service])]
|
analysis_id: str,
|
||||||
):
|
current_member: Annotated[CurrentMember, Depends(get_current_member)],
|
||||||
"""
|
analysis_service: Annotated[AnalysisService, Depends(Provide[Container.analysis_service])]
|
||||||
분석 진행 상황을 조회합니다.
|
):
|
||||||
"""
|
"""
|
||||||
analysis = analysis_service.get_analysis_by_id(analysis_id, current_member.id)
|
분석 진행 상황을 조회합니다.
|
||||||
|
"""
|
||||||
return {
|
analysis = analysis_service.get_analysis_by_id(analysis_id, current_member.id)
|
||||||
"analysis_id": analysis.id,
|
|
||||||
"status": analysis.status,
|
return {
|
||||||
"ticker": analysis.ticker,
|
"analysis_id": analysis.id,
|
||||||
"analysis_date": analysis.analysis_date,
|
"status": analysis.status,
|
||||||
"created_at": analysis.created_at.isoformat(),
|
"ticker": analysis.ticker,
|
||||||
"updated_at": analysis.updated_at.isoformat(),
|
"analysis_date": analysis.analysis_date,
|
||||||
"error_message": analysis.error_message
|
"created_at": analysis.created_at.isoformat(),
|
||||||
}
|
"updated_at": analysis.updated_at.isoformat(),
|
||||||
|
"error_message": analysis.error_message
|
||||||
|
}
|
||||||
|
|
||||||
|
@router.websocket("/ws")
|
||||||
|
@inject
|
||||||
|
async def websocket_endpoint(
|
||||||
|
websocket: WebSocket,
|
||||||
|
current_member: Annotated[CurrentMember, Depends(get_current_member)],
|
||||||
|
websocket_manager: Annotated[WebSocketManager, Depends(Provide[Container.websocket_manager])]
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
WebSocket endpoint for real-time analysis updates
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Connect the websocket
|
||||||
|
await websocket_manager.connect(websocket, current_member.id)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Keep connection alive
|
||||||
|
while True:
|
||||||
|
# Wait for messages from client (like ping/pong)
|
||||||
|
data = await websocket.receive_text()
|
||||||
|
# Echo back for heartbeat
|
||||||
|
if data == "ping":
|
||||||
|
await websocket.send_text("pong")
|
||||||
|
except WebSocketDisconnect:
|
||||||
|
websocket_manager.disconnect(websocket, current_member.id)
|
||||||
|
except Exception as e:
|
||||||
|
await websocket.close(code=1011, reason=str(e))
|
||||||
|
|
|
||||||
|
|
@ -1,52 +1,52 @@
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from datetime import date
|
from datetime import date
|
||||||
from typing import List
|
from typing import List
|
||||||
from analysis.infra.db_models.analysis import AnalysisStatus
|
from analysis.infra.db_models.analysis import AnalysisStatus
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
class AnalystType(str, Enum):
|
class AnalystType(str, Enum):
|
||||||
MARKET = "market"
|
MARKET = "market"
|
||||||
SOCIAL = "social"
|
SOCIAL = "social"
|
||||||
NEWS = "news"
|
NEWS = "news"
|
||||||
FUNDAMENTALS = "fundamentals"
|
FUNDAMENTALS = "fundamentals"
|
||||||
|
|
||||||
class TradingAnalysisRequest(BaseModel):
|
class TradingAnalysisRequest(BaseModel):
|
||||||
ticker: str
|
ticker: str = "NVDA"
|
||||||
analysis_date: str
|
analysis_date: str = "2025-07-07"
|
||||||
analysts: List[AnalystType]
|
analysts: List[AnalystType] = [AnalystType.MARKET, AnalystType.SOCIAL, AnalystType.NEWS, AnalystType.FUNDAMENTALS]
|
||||||
research_depth: int = 3
|
research_depth: int = 3
|
||||||
llm_provider: str = "openai"
|
llm_provider: str = "openai"
|
||||||
backend_url: str = "https://api.openai.com/v1"
|
backend_url: str = "https://api.openai.com/v1"
|
||||||
shallow_thinker: str = "gpt-4o-mini"
|
shallow_thinker: str = "gpt-4o-mini"
|
||||||
deep_thinker: str = "gpt-4o"
|
deep_thinker: str = "gpt-4o-mini"
|
||||||
|
|
||||||
class AnalysisSessionResponse(BaseModel):
|
class AnalysisSessionResponse(BaseModel):
|
||||||
id : str
|
id : str
|
||||||
ticker : str
|
ticker : str
|
||||||
status : AnalysisStatus
|
status : AnalysisStatus
|
||||||
|
|
||||||
class AnalysisProgressUpdate(BaseModel):
|
class AnalysisProgressUpdate(BaseModel):
|
||||||
analysis_id: str
|
analysis_id: str
|
||||||
current_agent: str
|
current_agent: str
|
||||||
status: str
|
status: str
|
||||||
progress_percentage: float
|
progress_percentage: float
|
||||||
current_report_section: str | None = None
|
current_report_section: str | None = None
|
||||||
message: str | None = None
|
message: str | None = None
|
||||||
|
|
||||||
class AnalysisResultResponse(BaseModel):
|
class AnalysisResultResponse(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
ticker: str
|
ticker: str
|
||||||
analysis_date: str
|
analysis_date: str
|
||||||
status: AnalysisStatus
|
status: AnalysisStatus
|
||||||
market_report: str | None = None
|
market_report: str | None = None
|
||||||
sentiment_report: str | None = None
|
sentiment_report: str | None = None
|
||||||
news_report: str | None = None
|
news_report: str | None = None
|
||||||
fundamentals_report: str | None = None
|
fundamentals_report: str | None = None
|
||||||
investment_debate_state: dict | None = None
|
investment_debate_state: dict | None = None
|
||||||
trader_investment_plan: str | None = None
|
trader_investment_plan: str | None = None
|
||||||
risk_debate_state: dict | None = None
|
risk_debate_state: dict | None = None
|
||||||
final_trade_decision: str | None = None
|
final_trade_decision: str | None = None
|
||||||
final_report: str | None = None
|
final_report: str | None = None
|
||||||
created_at: str
|
created_at: str
|
||||||
completed_at: str | None = None
|
completed_at: str | None = None
|
||||||
error_message: str | None = None
|
error_message: str | None = None
|
||||||
|
|
@ -1,20 +1,20 @@
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
model_config = SettingsConfigDict(
|
model_config = SettingsConfigDict(
|
||||||
env_file=".env",
|
env_file=".env",
|
||||||
env_file_encoding="utf-8",
|
env_file_encoding="utf-8",
|
||||||
)
|
)
|
||||||
|
|
||||||
# MySQL 데이터베이스 설정
|
# MySQL 데이터베이스 설정
|
||||||
DB_HOST: str
|
DB_HOST: str
|
||||||
DB_PORT: int
|
DB_PORT: int
|
||||||
DB_USER: str
|
DB_USER: str
|
||||||
DB_PASSWORD: str
|
DB_PASSWORD: str
|
||||||
DB_NAME: str
|
DB_NAME: str
|
||||||
SECRET_KEY: str
|
SECRET_KEY: str
|
||||||
|
|
||||||
@lru_cache
|
@lru_cache
|
||||||
def get_settings():
|
def get_settings():
|
||||||
return Settings()
|
return Settings()
|
||||||
|
|
@ -1,20 +1,30 @@
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from utils.database import create_db_and_tables
|
from utils.database import create_db_and_tables
|
||||||
from utils.containers import Container
|
from utils.containers import Container
|
||||||
|
|
||||||
|
|
||||||
from analysis.interface.controller.analysis_controller import router as analysis_router
|
from analysis.interface.controller.analysis_controller import router as analysis_router
|
||||||
from member.interface.controller.member_controller import router as member_router
|
from member.interface.controller.member_controller import router as member_router
|
||||||
|
import logging
|
||||||
|
|
||||||
|
# 로깅 설정
|
||||||
app = FastAPI()
|
logging.basicConfig(
|
||||||
app.container = Container()
|
level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||||
app.include_router(analysis_router)
|
handlers=[
|
||||||
app.include_router(member_router)
|
logging.StreamHandler(), # 콘솔 출력
|
||||||
|
]
|
||||||
|
)
|
||||||
@app.on_event("startup")
|
|
||||||
def startup_db_client():
|
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
app.container = Container()
|
||||||
|
|
||||||
|
app.include_router(analysis_router)
|
||||||
|
app.include_router(member_router)
|
||||||
|
|
||||||
|
|
||||||
|
@app.on_event("startup")
|
||||||
|
def startup_db_client():
|
||||||
create_db_and_tables()
|
create_db_and_tables()
|
||||||
|
|
@ -1,97 +1,97 @@
|
||||||
from sqlmodel import Session
|
from sqlmodel import Session
|
||||||
from utils.crypto import Crypto
|
from utils.crypto import Crypto
|
||||||
from member.domain.repository.member_repo import IMemberRepository
|
from member.domain.repository.member_repo import IMemberRepository
|
||||||
from utils.auth import Role
|
from utils.auth import Role
|
||||||
from member.domain.member import Member as MemberVO
|
from member.domain.member import Member as MemberVO
|
||||||
from fastapi import HTTPException, status
|
from fastapi import HTTPException, status
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from utils.auth import create_access_token
|
from utils.auth import create_access_token
|
||||||
from ulid import ULID
|
from ulid import ULID
|
||||||
from analysis.domain.analysis import Analysis as AnalysisVO
|
from analysis.domain.analysis import Analysis as AnalysisVO
|
||||||
|
|
||||||
class MemberService:
|
class MemberService:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
member_repo: IMemberRepository,
|
member_repo: IMemberRepository,
|
||||||
crypto: Crypto,
|
crypto: Crypto,
|
||||||
db_session: Session,
|
session: Session,
|
||||||
ulid: ULID
|
ulid: ULID
|
||||||
):
|
):
|
||||||
self.member_repo = member_repo
|
self.member_repo = member_repo
|
||||||
self.crypto = crypto
|
self.crypto = crypto
|
||||||
self.db_session = db_session
|
self.db_session = session
|
||||||
self.ulid = ulid
|
self.ulid = ulid
|
||||||
|
|
||||||
def create_member(
|
def create_member(
|
||||||
self,
|
self,
|
||||||
name: str,
|
name: str,
|
||||||
email: str,
|
email: str,
|
||||||
password: str,
|
password: str,
|
||||||
role: Role
|
role: Role
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
if self.member_repo.find_by_email(email):
|
if self.member_repo.find_by_email(email):
|
||||||
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Email already exists")
|
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Email already exists")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.db_session.rollback()
|
self.db_session.rollback()
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
now = datetime.now()
|
now = datetime.now()
|
||||||
member_vo = MemberVO(
|
member_vo = MemberVO(
|
||||||
id=self.ulid.generate(),
|
id=self.ulid.generate(),
|
||||||
name=name,
|
name=name,
|
||||||
email=email,
|
email=email,
|
||||||
password=self.crypto.encrypt(password),
|
password=self.crypto.encrypt(password),
|
||||||
created_at=now,
|
created_at=now,
|
||||||
updated_at=now,
|
updated_at=now,
|
||||||
role=role
|
role=role
|
||||||
)
|
)
|
||||||
|
|
||||||
saved_member = self.member_repo.save(member_vo)
|
saved_member = self.member_repo.save(member_vo)
|
||||||
self.db_session.commit()
|
self.db_session.commit()
|
||||||
|
|
||||||
return saved_member
|
return saved_member
|
||||||
|
|
||||||
|
|
||||||
def get_members(
|
def get_members(
|
||||||
self,
|
self,
|
||||||
page: int,
|
page: int,
|
||||||
items_per_page: int
|
items_per_page: int
|
||||||
)->tuple[int, list[MemberVO]] :
|
)->tuple[int, list[MemberVO]] :
|
||||||
return self.member_repo.get_members(page, items_per_page)
|
return self.member_repo.get_members(page, items_per_page)
|
||||||
|
|
||||||
def get_member(
|
def get_member(
|
||||||
self,
|
self,
|
||||||
id: str
|
id: str
|
||||||
)->MemberVO | None:
|
)->MemberVO | None:
|
||||||
member = self.member_repo.find_by_id(id)
|
member = self.member_repo.find_by_id(id)
|
||||||
if not member:
|
if not member:
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Member not found")
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Member not found")
|
||||||
return member
|
return member
|
||||||
|
|
||||||
def login(
|
def login(
|
||||||
self,
|
self,
|
||||||
email: str,
|
email: str,
|
||||||
password: str
|
password: str
|
||||||
):
|
):
|
||||||
member = self.member_repo.find_by_email(email)
|
member = self.member_repo.find_by_email(email)
|
||||||
if not member:
|
if not member:
|
||||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials")
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials")
|
||||||
|
|
||||||
if not self.crypto.verify(password, member.password):
|
if not self.crypto.verify(password, member.password):
|
||||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials")
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials")
|
||||||
|
|
||||||
access_token = create_access_token(
|
access_token = create_access_token(
|
||||||
payload={"member_id": member.id, "role": member.role},
|
payload={"member_id": member.id, "role": member.role},
|
||||||
role=member.role,
|
role=member.role,
|
||||||
)
|
)
|
||||||
|
|
||||||
return access_token
|
return access_token
|
||||||
|
|
||||||
def get_analysis_sessions_by_member(
|
def get_analysis_sessions_by_member(
|
||||||
self,
|
self,
|
||||||
member_id: str
|
member_id: str
|
||||||
)->list[AnalysisVO]:
|
)->list[AnalysisVO]:
|
||||||
analysis_sessions = self.member_repo.find_analysis_sessions_by_member(member_id)
|
analysis_sessions = self.member_repo.find_analysis_sessions_by_member(member_id)
|
||||||
return analysis_sessions
|
return analysis_sessions
|
||||||
|
|
||||||
|
|
@ -1,12 +1,12 @@
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from utils.auth import Role
|
from utils.auth import Role
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
class Member(BaseModel):
|
class Member(BaseModel):
|
||||||
id: str | None = None
|
id: str | None = None
|
||||||
name: str
|
name: str
|
||||||
email: str
|
email: str
|
||||||
password: str
|
password: str
|
||||||
role: Role
|
role: Role
|
||||||
created_at: datetime
|
created_at: datetime
|
||||||
updated_at: datetime
|
updated_at: datetime
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
class IMemberRepository(ABC):
|
class IMemberRepository(ABC):
|
||||||
pass
|
pass
|
||||||
|
|
@ -1,24 +1,24 @@
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from member.domain.member import Member as MemberVO
|
from member.domain.member import Member as MemberVO
|
||||||
from analysis.domain.analysis import Analysis as AnalysisVO
|
from analysis.domain.analysis import Analysis as AnalysisVO
|
||||||
|
|
||||||
class IMemberRepository(ABC):
|
class IMemberRepository(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def find_by_email(self, email: str) -> MemberVO | None:
|
def find_by_email(self, email: str) -> MemberVO | None:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def save(self, member: MemberVO) -> MemberVO:
|
def save(self, member: MemberVO) -> MemberVO:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def find_by_id(self, id: str) -> MemberVO | None:
|
def find_by_id(self, id: str) -> MemberVO | None:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_members(self, page: int, items_per_page: int) -> tuple[int, list[MemberVO]]:
|
def get_members(self, page: int, items_per_page: int) -> tuple[int, list[MemberVO]]:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def find_analysis_sessions_by_member(self, member_id: str) -> list[AnalysisVO]:
|
def find_analysis_sessions_by_member(self, member_id: str) -> list[AnalysisVO]:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
@ -1,66 +1,66 @@
|
||||||
from member.domain.repository import IMemberRepository
|
from member.domain.repository import IMemberRepository
|
||||||
from sqlmodel import Session, select
|
from sqlmodel import Session, select
|
||||||
from member.domain.member import Member as MemberVO
|
from member.domain.member import Member as MemberVO
|
||||||
from member.infra.db_models.member import Member
|
from member.infra.db_models.member import Member
|
||||||
from utils.db_utils import row_to_dict
|
from utils.db_utils import row_to_dict
|
||||||
from sqlalchemy import func
|
from sqlalchemy import func
|
||||||
|
|
||||||
class MemberRepository(IMemberRepository):
|
class MemberRepository(IMemberRepository):
|
||||||
def __init__(self, session: Session):
|
def __init__(self, session: Session):
|
||||||
self.session = session
|
self.session = session
|
||||||
|
|
||||||
def find_by_email(self, email: str) -> MemberVO | None:
|
def find_by_email(self, email: str) -> MemberVO | None:
|
||||||
query = select(Member).where(Member.email == email)
|
query = select(Member).where(Member.email == email)
|
||||||
member = self.session.exec(query).first()
|
member = self.session.exec(query).first()
|
||||||
|
|
||||||
if not member:
|
if not member:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return MemberVO(**row_to_dict(member))
|
return MemberVO(**row_to_dict(member))
|
||||||
|
|
||||||
def save(self, member: MemberVO) -> MemberVO:
|
def save(self, member: MemberVO) -> MemberVO:
|
||||||
new_member = Member(
|
new_member = Member(
|
||||||
id=member.id,
|
id=member.id,
|
||||||
email=member.email,
|
email=member.email,
|
||||||
name=member.name,
|
name=member.name,
|
||||||
password=member.password,
|
password=member.password,
|
||||||
role=member.role,
|
role=member.role,
|
||||||
created_at=member.created_at,
|
created_at=member.created_at,
|
||||||
updated_at=member.updated_at
|
updated_at=member.updated_at
|
||||||
)
|
)
|
||||||
|
|
||||||
self.session.add(new_member)
|
self.session.add(new_member)
|
||||||
self.session.flush()
|
self.session.flush()
|
||||||
self.session.refresh(new_member)
|
self.session.refresh(new_member)
|
||||||
|
|
||||||
member.id = new_member.id
|
member.id = new_member.id
|
||||||
return member
|
return member
|
||||||
|
|
||||||
|
|
||||||
def get_members(self, page: int, items_per_page: int) -> tuple[int, list[MemberVO]]:
|
def get_members(self, page: int, items_per_page: int) -> tuple[int, list[MemberVO]]:
|
||||||
offset = (page - 1) * items_per_page
|
offset = (page - 1) * items_per_page
|
||||||
total_count_query = select(func.count(Member.id))
|
total_count_query = select(func.count(Member.id))
|
||||||
total_count = self.session.exec(total_count_query).one()
|
total_count = self.session.exec(total_count_query).one()
|
||||||
|
|
||||||
if total_count == 0:
|
if total_count == 0:
|
||||||
return 0, []
|
return 0, []
|
||||||
|
|
||||||
query = (
|
query = (
|
||||||
select(Member)
|
select(Member)
|
||||||
.order_by(Member.created_at.desc())
|
.order_by(Member.created_at.desc())
|
||||||
.offset(offset)
|
.offset(offset)
|
||||||
.limit(items_per_page)
|
.limit(items_per_page)
|
||||||
)
|
)
|
||||||
|
|
||||||
members = self.session.exec(query).all()
|
members = self.session.exec(query).all()
|
||||||
|
|
||||||
return total_count, [MemberVO(**row_to_dict(member)) for member in members]
|
return total_count, [MemberVO(**row_to_dict(member)) for member in members]
|
||||||
|
|
||||||
def find_by_id(self, id: str) -> MemberVO | None:
|
def find_by_id(self, id: str) -> MemberVO | None:
|
||||||
query = select(Member).where(Member.id == id)
|
query = select(Member).where(Member.id == id)
|
||||||
member = self.session.exec(query).first()
|
member = self.session.exec(query).first()
|
||||||
|
|
||||||
if not member:
|
if not member:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return MemberVO(**row_to_dict(member))
|
return MemberVO(**row_to_dict(member))
|
||||||
|
|
|
||||||
|
|
@ -1,78 +1,71 @@
|
||||||
from fastapi import APIRouter, status, Depends,HTTPException
|
from fastapi import APIRouter, status, Depends,HTTPException
|
||||||
from member.interface.dto import CreateUserBody, MemberResponse
|
from member.interface.dto import CreateUserBody, MemberResponse
|
||||||
from member.application.member_service import MemberService
|
from member.application.member_service import MemberService
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
from utils.containers import Container
|
from utils.containers import Container
|
||||||
from dependency_injector.wiring import inject, Provide
|
from dependency_injector.wiring import inject, Provide
|
||||||
from fastapi.security import OAuth2PasswordRequestForm
|
from fastapi.security import OAuth2PasswordRequestForm
|
||||||
from utils.auth import get_current_member, CurrentMember, get_admin_member
|
from utils.auth import get_current_member, CurrentMember, get_admin_member
|
||||||
|
from analysis.interface.dto import AnalysisSessionResponse
|
||||||
router = APIRouter(prefix="/members", tags=["members"])
|
from analysis.application.analysis_service import AnalysisService
|
||||||
|
|
||||||
@router.post("", status_code=status.HTTP_201_CREATED, response_model=MemberResponse)
|
router = APIRouter(prefix="/members", tags=["members"])
|
||||||
@inject
|
|
||||||
async def create_user(
|
@router.post("", status_code=status.HTTP_201_CREATED, response_model=MemberResponse)
|
||||||
member: CreateUserBody,
|
@inject
|
||||||
member_service: MemberService = Depends(Provide[Container.member_service])
|
async def create_user(
|
||||||
):
|
member: CreateUserBody,
|
||||||
created_member = member_service.create_member(
|
member_service: MemberService = Depends(Provide[Container.member_service])
|
||||||
member.name,
|
):
|
||||||
member.email,
|
created_member = member_service.create_member(
|
||||||
member.password,
|
member.name,
|
||||||
member.role
|
member.email,
|
||||||
)
|
member.password,
|
||||||
|
member.role
|
||||||
return created_member
|
)
|
||||||
|
|
||||||
@router.post("/login")
|
return created_member
|
||||||
@inject
|
|
||||||
def login(
|
@router.post("/login")
|
||||||
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
@inject
|
||||||
member_service: MemberService = Depends(Provide[Container.member_service])
|
def login(
|
||||||
):
|
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
||||||
access_token = member_service.login(
|
member_service: MemberService = Depends(Provide[Container.member_service])
|
||||||
email=form_data.username,
|
):
|
||||||
password=form_data.password
|
access_token = member_service.login(
|
||||||
)
|
email=form_data.username,
|
||||||
|
password=form_data.password
|
||||||
return {
|
)
|
||||||
"access_token" : access_token,
|
|
||||||
"token_type" : "Bearer"
|
return {
|
||||||
}
|
"access_token" : access_token,
|
||||||
|
"token_type" : "Bearer"
|
||||||
@router.get("/me", response_model=dict)
|
}
|
||||||
def get_current_user_info(
|
|
||||||
current_user: CurrentMember = Depends(get_current_member)
|
@router.get("/me", response_model=dict)
|
||||||
):
|
def get_current_user_info(
|
||||||
"""
|
current_user: CurrentMember = Depends(get_current_member)
|
||||||
현재 로그인한 사용자 정보를 조회합니다.
|
):
|
||||||
이 엔드포인트는 JWT 토큰이 필요하며, Swagger UI에서 Authorize 버튼을 활성화합니다.
|
"""
|
||||||
"""
|
현재 로그인한 사용자 정보를 조회합니다.
|
||||||
return {
|
이 엔드포인트는 JWT 토큰이 필요하며, Swagger UI에서 Authorize 버튼을 활성화합니다.
|
||||||
"user_id": current_user.id,
|
"""
|
||||||
"role": current_user.role,
|
return {
|
||||||
"message": "Successfully authenticated"
|
"user_id": current_user.id,
|
||||||
}
|
"role": current_user.role,
|
||||||
|
"message": "Successfully authenticated"
|
||||||
@router.get("/{member_id}", response_model=MemberResponse)
|
}
|
||||||
@inject
|
|
||||||
def get_member(
|
@router.get("/{member_id}", response_model=MemberResponse)
|
||||||
member_id: str,
|
@inject
|
||||||
current_member: Annotated[CurrentMember | None, Depends(get_current_member)] = None,
|
def get_member(
|
||||||
member_service: Annotated[MemberService | None, Depends(Provide[Container.member_service])] = None
|
member_id: str,
|
||||||
):
|
current_member: Annotated[CurrentMember | None, Depends(get_current_member)] = None,
|
||||||
|
member_service: Annotated[MemberService | None, Depends(Provide[Container.member_service])] = None
|
||||||
member = member_service.get_member(member_id)
|
):
|
||||||
if not member:
|
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Member not found")
|
member = member_service.get_member(member_id)
|
||||||
return member
|
if not member:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Member not found")
|
||||||
# @router.get("/analysis-sessions", response_model=list[AnalysisSessionResponse])
|
return member
|
||||||
# @inject
|
|
||||||
# def get_member_analysis_sessions(
|
|
||||||
# current_member: Annotated[CurrentMember | None, Depends(get_current_member)] = None,
|
|
||||||
# member_service: Annotated[MemberService | None, Depends(Provide[Container.member_service])] = None
|
|
||||||
# ):
|
|
||||||
|
|
||||||
# result = member_service.get_analysis_sessions_by_member(current_member.id)
|
|
||||||
# return result
|
|
||||||
|
|
|
||||||
|
|
@ -1,18 +1,18 @@
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
from pydantic import BaseModel, Field, EmailStr
|
from pydantic import BaseModel, Field, EmailStr
|
||||||
from utils.auth import Role
|
from utils.auth import Role
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
class CreateUserBody(BaseModel):
|
class CreateUserBody(BaseModel):
|
||||||
name : Annotated[str, Field(min_length=1, max_length=32)]
|
name : Annotated[str, Field(min_length=1, max_length=32)]
|
||||||
email : Annotated[EmailStr, Field(max_length=32)]
|
email : Annotated[EmailStr, Field(max_length=32)]
|
||||||
password : Annotated[str, Field(max_length=32)]
|
password : Annotated[str, Field(max_length=32)]
|
||||||
role : Annotated[Role, Field(default=Role.USER)]
|
role : Annotated[Role, Field(default=Role.USER)]
|
||||||
|
|
||||||
class MemberResponse(BaseModel):
|
class MemberResponse(BaseModel):
|
||||||
id : str
|
id : str
|
||||||
name : str | None = None
|
name : str | None = None
|
||||||
email : str
|
email : str
|
||||||
created_at : datetime
|
created_at : datetime
|
||||||
updated_at : datetime
|
updated_at : datetime
|
||||||
role : Role
|
role : Role
|
||||||
|
|
|
||||||
|
|
@ -1,69 +1,69 @@
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from fastapi import HTTPException, status, Depends
|
from fastapi import HTTPException, status, Depends
|
||||||
from jose import jwt, JWTError
|
from jose import jwt, JWTError
|
||||||
from fastapi.security import OAuth2PasswordBearer
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
import os
|
import os
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
from config import get_settings
|
from config import get_settings
|
||||||
|
|
||||||
settings = get_settings()
|
settings = get_settings()
|
||||||
|
|
||||||
SECRET_KEY = settings.SECRET_KEY
|
SECRET_KEY = settings.SECRET_KEY
|
||||||
ALGORITHM = "HS256"
|
ALGORITHM = "HS256"
|
||||||
|
|
||||||
class Role(StrEnum):
|
class Role(StrEnum):
|
||||||
ADMIN = "ADMIN"
|
ADMIN = "ADMIN"
|
||||||
USER = "USER"
|
USER = "USER"
|
||||||
|
|
||||||
class CurrentMember(BaseModel):
|
class CurrentMember(BaseModel):
|
||||||
id : str
|
id : str
|
||||||
role : Role
|
role : Role
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return f"{self.id}({self.role})"
|
return f"{self.id}({self.role})"
|
||||||
|
|
||||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/members/login")
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/members/login")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def create_access_token(
|
def create_access_token(
|
||||||
payload: dict,
|
payload: dict,
|
||||||
role: Role,
|
role: Role,
|
||||||
expires_delta: timedelta = timedelta(hours=6)
|
expires_delta: timedelta = timedelta(hours=6)
|
||||||
):
|
):
|
||||||
expire = datetime.utcnow() + expires_delta
|
expire = datetime.utcnow() + expires_delta
|
||||||
payload.update({"exp": expire, "role": role})
|
payload.update({"exp": expire, "role": role})
|
||||||
encoded_jwt = jwt.encode(payload, SECRET_KEY, algorithm=ALGORITHM)
|
encoded_jwt = jwt.encode(payload, SECRET_KEY, algorithm=ALGORITHM)
|
||||||
|
|
||||||
return encoded_jwt
|
return encoded_jwt
|
||||||
|
|
||||||
def decode_access_token(token: str):
|
def decode_access_token(token: str):
|
||||||
try:
|
try:
|
||||||
return jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
return jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||||
except JWTError:
|
except JWTError:
|
||||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token")
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token")
|
||||||
|
|
||||||
|
|
||||||
# ✅ 수정된 부분: Annotated 올바른 사용법
|
# ✅ 수정된 부분: Annotated 올바른 사용법
|
||||||
def get_current_member(token: Annotated[str, Depends(oauth2_scheme)]):
|
def get_current_member(token: Annotated[str, Depends(oauth2_scheme)]):
|
||||||
payload = decode_access_token(token)
|
payload = decode_access_token(token)
|
||||||
member_id = payload.get("member_id")
|
member_id = payload.get("member_id")
|
||||||
role = payload.get("role")
|
role = payload.get("role")
|
||||||
if not member_id or not role:
|
if not member_id or not role:
|
||||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Invalid token")
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Invalid token")
|
||||||
|
|
||||||
return CurrentMember(id=member_id, role=Role(role))
|
return CurrentMember(id=member_id, role=Role(role))
|
||||||
|
|
||||||
def get_admin_member(token: Annotated[str, Depends(oauth2_scheme)]):
|
def get_admin_member(token: Annotated[str, Depends(oauth2_scheme)]):
|
||||||
payload = decode_access_token(token)
|
payload = decode_access_token(token)
|
||||||
member_id = payload.get("member_id")
|
member_id = payload.get("member_id")
|
||||||
role = payload.get("role")
|
role = payload.get("role")
|
||||||
|
|
||||||
if not role or role != Role.ADMIN:
|
if not role or role != Role.ADMIN:
|
||||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Invalid token")
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Invalid token")
|
||||||
|
|
||||||
return CurrentMember(id=member_id, role=Role(role))
|
return CurrentMember(id=member_id, role=Role(role))
|
||||||
|
|
@ -1,43 +1,49 @@
|
||||||
from dependency_injector import containers, providers
|
from dependency_injector import containers, providers
|
||||||
from utils.database import get_session
|
from utils.database import get_session
|
||||||
from utils.crypto import Crypto
|
from utils.crypto import Crypto
|
||||||
from member.infra.repository.member_repo import MemberRepository
|
from member.infra.repository.member_repo import MemberRepository
|
||||||
from member.application.member_service import MemberService
|
from member.application.member_service import MemberService
|
||||||
from analysis.application.analysis_service import AnalysisService
|
from analysis.application.analysis_service import AnalysisService
|
||||||
from analysis.infra.repository.analysis_repo import AnalysisRepository
|
from analysis.infra.repository.analysis_repo import AnalysisRepository
|
||||||
from ulid import ULID
|
from analysis.application.websocket_manager import WebSocketManager
|
||||||
|
from ulid import ULID
|
||||||
class Container(containers.DeclarativeContainer):
|
|
||||||
wiring_config = containers.WiringConfiguration(
|
class Container(containers.DeclarativeContainer):
|
||||||
packages=["member", "analysis"]
|
wiring_config = containers.WiringConfiguration(
|
||||||
)
|
packages=["member", "analysis"]
|
||||||
|
)
|
||||||
db_session = providers.Resource(get_session)
|
|
||||||
crypto = providers.Factory(Crypto)
|
session = providers.Resource(get_session)
|
||||||
ulid = providers.Factory(ULID)
|
crypto = providers.Factory(Crypto)
|
||||||
|
ulid = providers.Factory(ULID)
|
||||||
member_repo = providers.Factory(
|
|
||||||
MemberRepository,
|
member_repo = providers.Factory(
|
||||||
session=db_session
|
MemberRepository,
|
||||||
)
|
session=session
|
||||||
|
)
|
||||||
member_service = providers.Factory(
|
|
||||||
MemberService,
|
member_service = providers.Factory(
|
||||||
member_repo=member_repo,
|
MemberService,
|
||||||
crypto=crypto,
|
member_repo=member_repo,
|
||||||
db_session=db_session,
|
crypto=crypto,
|
||||||
ulid=ulid
|
session=session,
|
||||||
)
|
ulid=ulid
|
||||||
|
)
|
||||||
analysis_repo = providers.Factory(
|
|
||||||
AnalysisRepository,
|
analysis_repo = providers.Factory(
|
||||||
session=db_session
|
AnalysisRepository,
|
||||||
)
|
session=session
|
||||||
|
)
|
||||||
analysis_service = providers.Factory(
|
|
||||||
AnalysisService,
|
websocket_manager = providers.Singleton(
|
||||||
analysis_repo=analysis_repo,
|
WebSocketManager
|
||||||
db_session=db_session,
|
)
|
||||||
ulid=ulid
|
|
||||||
)
|
analysis_service = providers.Factory(
|
||||||
|
AnalysisService,
|
||||||
|
analysis_repo=analysis_repo,
|
||||||
|
session=session,
|
||||||
|
ulid=ulid,
|
||||||
|
websocket_manager=websocket_manager
|
||||||
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,12 @@
|
||||||
from passlib.context import CryptContext
|
from passlib.context import CryptContext
|
||||||
|
|
||||||
class Crypto:
|
class Crypto:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
self.pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||||
|
|
||||||
def encrypt(self, secret):
|
def encrypt(self, secret):
|
||||||
return self.pwd_context.hash(secret)
|
return self.pwd_context.hash(secret)
|
||||||
|
|
||||||
def verify(self, secret, hash):
|
def verify(self, secret, hash):
|
||||||
return self.pwd_context.verify(secret, hash)
|
return self.pwd_context.verify(secret, hash)
|
||||||
|
|
||||||
|
|
@ -1,32 +1,32 @@
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from sqlmodel import SQLModel, create_engine, Session
|
from sqlmodel import SQLModel, create_engine, Session
|
||||||
from config.config import get_settings
|
from config.config import get_settings
|
||||||
from member.infra.db_models.member import Member
|
from member.infra.db_models.member import Member
|
||||||
from analysis.infra.db_models.analysis import Analysis
|
from analysis.infra.db_models.analysis import Analysis
|
||||||
|
|
||||||
settings = get_settings()
|
settings = get_settings()
|
||||||
|
|
||||||
BASE_DIR = Path(__file__).resolve().parent.parent
|
BASE_DIR = Path(__file__).resolve().parent.parent
|
||||||
|
|
||||||
# MySQL 데이터베이스 URL 구성
|
# MySQL 데이터베이스 URL 구성
|
||||||
DATABASE_URL = f"mysql+pymysql://{settings.DB_USER}:{settings.DB_PASSWORD}@{settings.DB_HOST}:{settings.DB_PORT}/{settings.DB_NAME}?charset=utf8mb4"
|
DATABASE_URL = f"mysql+pymysql://{settings.DB_USER}:{settings.DB_PASSWORD}@{settings.DB_HOST}:{settings.DB_PORT}/{settings.DB_NAME}?charset=utf8mb4"
|
||||||
|
|
||||||
# MySQL 엔진 생성
|
# MySQL 엔진 생성
|
||||||
engine = create_engine(
|
engine = create_engine(
|
||||||
DATABASE_URL,
|
DATABASE_URL,
|
||||||
echo=True
|
echo=True
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_session():
|
def get_session():
|
||||||
with Session(engine) as session:
|
with Session(engine) as session:
|
||||||
yield session
|
yield session
|
||||||
|
|
||||||
def create_db_and_tables():
|
def create_db_and_tables():
|
||||||
# 테이블 생성
|
# 테이블 생성
|
||||||
# SQLModel.metadata.drop_all(engine)
|
# SQLModel.metadata.drop_all(engine)
|
||||||
SQLModel.metadata.create_all(engine)
|
SQLModel.metadata.create_all(engine)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
create_db_and_tables()
|
create_db_and_tables()
|
||||||
print(DATABASE_URL)
|
print(DATABASE_URL)
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from sqlalchemy import inspect
|
from sqlalchemy import inspect
|
||||||
|
|
||||||
def row_to_dict(row)->dict:
|
def row_to_dict(row)->dict:
|
||||||
return {key : getattr(row, key) for key in inspect(row).attrs.keys()}
|
return {key : getattr(row, key) for key in inspect(row).attrs.keys()}
|
||||||
|
|
@ -1,59 +1,59 @@
|
||||||
version: '3.8'
|
version: '3.8'
|
||||||
|
|
||||||
services:
|
services:
|
||||||
mysql:
|
mysql:
|
||||||
image: mysql:8.0
|
image: mysql:8.0
|
||||||
container_name: tradingagents_mysql
|
container_name: tradingagents_mysql
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
environment:
|
environment:
|
||||||
MYSQL_ROOT_PASSWORD: ${DB_PASSWORD:-password}
|
MYSQL_ROOT_PASSWORD: ${DB_PASSWORD:-password}
|
||||||
MYSQL_DATABASE: ${DB_NAME:-tradingagents_db}
|
MYSQL_DATABASE: ${DB_NAME:-tradingagents_db}
|
||||||
MYSQL_USER: ${DB_USER:-tradinguser}
|
MYSQL_USER: ${DB_USER:-tradinguser}
|
||||||
MYSQL_PASSWORD: ${DB_PASSWORD:-password}
|
MYSQL_PASSWORD: ${DB_PASSWORD:-password}
|
||||||
ports:
|
ports:
|
||||||
- "3306:3306"
|
- "3306:3306"
|
||||||
volumes:
|
volumes:
|
||||||
- /home/hskim/mysql_data:/var/lib/mysql
|
- /home/hskim/mysql_data:/var/lib/mysql
|
||||||
- /home/hskim/docker/mysql/init:/docker-entrypoint-initdb.d
|
- /home/hskim/docker/mysql/init:/docker-entrypoint-initdb.d
|
||||||
command: --default-authentication-plugin=mysql_native_password --character-set-server=utf8mb4 --collation-server=utf8mb4_unicode_ci
|
command: --default-authentication-plugin=mysql_native_password --character-set-server=utf8mb4 --collation-server=utf8mb4_unicode_ci
|
||||||
networks:
|
networks:
|
||||||
- tradingagents_network
|
- tradingagents_network
|
||||||
|
|
||||||
redis:
|
redis:
|
||||||
image: redis:7-alpine
|
image: redis:7-alpine
|
||||||
container_name: tradingagents_redis
|
container_name: tradingagents_redis
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
ports:
|
ports:
|
||||||
- "6379:6379"
|
- "6379:6379"
|
||||||
volumes:
|
volumes:
|
||||||
- redis_data:/data
|
- redis_data:/data
|
||||||
command: redis-server --appendonly yes
|
command: redis-server --appendonly yes
|
||||||
networks:
|
networks:
|
||||||
- tradingagents_network
|
- tradingagents_network
|
||||||
|
|
||||||
# 개발용 phpMyAdmin (선택사항)
|
# 개발용 phpMyAdmin (선택사항)
|
||||||
# phpmyadmin:
|
# phpmyadmin:
|
||||||
# image: phpmyadmin/phpmyadmin
|
# image: phpmyadmin/phpmyadmin
|
||||||
# container_name: tradingagents_phpmyadmin
|
# container_name: tradingagents_phpmyadmin
|
||||||
# restart: unless-stopped
|
# restart: unless-stopped
|
||||||
# environment:
|
# environment:
|
||||||
# PMA_HOST: mysql
|
# PMA_HOST: mysql
|
||||||
# PMA_PORT: 3306
|
# PMA_PORT: 3306
|
||||||
# PMA_USER: root
|
# PMA_USER: root
|
||||||
# PMA_PASSWORD: ${DB_PASSWORD:-password}
|
# PMA_PASSWORD: ${DB_PASSWORD:-password}
|
||||||
# ports:
|
# ports:
|
||||||
# - "8080:80"
|
# - "8080:80"
|
||||||
# depends_on:
|
# depends_on:
|
||||||
# - mysql
|
# - mysql
|
||||||
# networks:
|
# networks:
|
||||||
# - tradingagents_network
|
# - tradingagents_network
|
||||||
|
|
||||||
volumes:
|
volumes:
|
||||||
mysql_data:
|
mysql_data:
|
||||||
driver: local
|
driver: local
|
||||||
redis_data:
|
redis_data:
|
||||||
driver: local
|
driver: local
|
||||||
|
|
||||||
networks:
|
networks:
|
||||||
tradingagents_network:
|
tradingagents_network:
|
||||||
driver: bridge
|
driver: bridge
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,48 +1,48 @@
|
||||||
{
|
{
|
||||||
"name": "tradingagents-web-frontend",
|
"name": "tradingagents-web-frontend",
|
||||||
"version": "0.1.0",
|
"version": "0.1.0",
|
||||||
"private": true,
|
"private": true,
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"@ant-design/icons": "^5.2.6",
|
"@ant-design/icons": "^5.2.6",
|
||||||
"@testing-library/jest-dom": "^5.16.4",
|
"@testing-library/jest-dom": "^5.16.4",
|
||||||
"@testing-library/react": "^13.3.0",
|
"@testing-library/react": "^13.3.0",
|
||||||
"@testing-library/user-event": "^13.5.0",
|
"@testing-library/user-event": "^13.5.0",
|
||||||
"antd": "^5.10.0",
|
"antd": "^5.10.0",
|
||||||
"axios": "^1.5.0",
|
"axios": "^1.5.0",
|
||||||
"dayjs": "^1.11.9",
|
"dayjs": "^1.11.9",
|
||||||
"react": "^18.2.0",
|
"react": "^18.2.0",
|
||||||
"react-dom": "^18.2.0",
|
"react-dom": "^18.2.0",
|
||||||
"react-markdown": "^8.0.7",
|
"react-markdown": "^8.0.7",
|
||||||
"react-router-dom": "^6.4.0",
|
"react-router-dom": "^6.4.0",
|
||||||
"react-scripts": "5.0.1",
|
"react-scripts": "5.0.1",
|
||||||
"recharts": "^2.8.0",
|
"recharts": "^2.8.0",
|
||||||
"remark-gfm": "^4.0.1",
|
"remark-gfm": "^4.0.1",
|
||||||
"styled-components": "^6.0.8",
|
"styled-components": "^6.0.8",
|
||||||
"websocket": "^1.0.34"
|
"websocket": "^1.0.34"
|
||||||
},
|
},
|
||||||
"scripts": {
|
"scripts": {
|
||||||
"start": "react-scripts start",
|
"start": "react-scripts start",
|
||||||
"build": "react-scripts build",
|
"build": "react-scripts build",
|
||||||
"test": "react-scripts test",
|
"test": "react-scripts test",
|
||||||
"eject": "react-scripts eject"
|
"eject": "react-scripts eject"
|
||||||
},
|
},
|
||||||
"eslintConfig": {
|
"eslintConfig": {
|
||||||
"extends": [
|
"extends": [
|
||||||
"react-app",
|
"react-app",
|
||||||
"react-app/jest"
|
"react-app/jest"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"browserslist": {
|
"browserslist": {
|
||||||
"production": [
|
"production": [
|
||||||
">0.2%",
|
">0.2%",
|
||||||
"not dead",
|
"not dead",
|
||||||
"not op_mini all"
|
"not op_mini all"
|
||||||
],
|
],
|
||||||
"development": [
|
"development": [
|
||||||
"last 1 chrome version",
|
"last 1 chrome version",
|
||||||
"last 1 firefox version",
|
"last 1 firefox version",
|
||||||
"last 1 safari version"
|
"last 1 safari version"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"proxy": "http://localhost:8000"
|
"proxy": "http://localhost:8000"
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,20 +1,20 @@
|
||||||
<!DOCTYPE html>
|
<!DOCTYPE html>
|
||||||
<html lang="ko">
|
<html lang="ko">
|
||||||
<head>
|
<head>
|
||||||
<meta charset="utf-8" />
|
<meta charset="utf-8" />
|
||||||
<link rel="icon" href="%PUBLIC_URL%/favicon.ico" />
|
<link rel="icon" href="%PUBLIC_URL%/favicon.ico" />
|
||||||
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
||||||
<meta name="theme-color" content="#000000" />
|
<meta name="theme-color" content="#000000" />
|
||||||
<meta
|
<meta
|
||||||
name="description"
|
name="description"
|
||||||
content="TradingAgents - Multi-Agents LLM Financial Trading Framework"
|
content="TradingAgents - Multi-Agents LLM Financial Trading Framework"
|
||||||
/>
|
/>
|
||||||
<link rel="apple-touch-icon" href="%PUBLIC_URL%/logo192.png" />
|
<link rel="apple-touch-icon" href="%PUBLIC_URL%/logo192.png" />
|
||||||
<link rel="manifest" href="%PUBLIC_URL%/manifest.json" />
|
<link rel="manifest" href="%PUBLIC_URL%/manifest.json" />
|
||||||
<title>TradingAgents - AI 거래 분석 플랫폼</title>
|
<title>TradingAgents - AI 거래 분석 플랫폼</title>
|
||||||
</head>
|
</head>
|
||||||
<body>
|
<body>
|
||||||
<noscript>JavaScript를 활성화해야 이 앱을 실행할 수 있습니다.</noscript>
|
<noscript>JavaScript를 활성화해야 이 앱을 실행할 수 있습니다.</noscript>
|
||||||
<div id="root"></div>
|
<div id="root"></div>
|
||||||
</body>
|
</body>
|
||||||
</html>
|
</html>
|
||||||
42
main.py
42
main.py
|
|
@ -1,21 +1,21 @@
|
||||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||||
from tradingagents.default_config import DEFAULT_CONFIG
|
from tradingagents.default_config import DEFAULT_CONFIG
|
||||||
|
|
||||||
# Create a custom config
|
# Create a custom config
|
||||||
config = DEFAULT_CONFIG.copy()
|
config = DEFAULT_CONFIG.copy()
|
||||||
config["llm_provider"] = "google" # Use a different model
|
config["llm_provider"] = "google" # Use a different model
|
||||||
config["backend_url"] = "https://generativelanguage.googleapis.com/v1" # Use a different backend
|
config["backend_url"] = "https://generativelanguage.googleapis.com/v1" # Use a different backend
|
||||||
config["deep_think_llm"] = "gemini-2.5-pro" # Use a different model
|
config["deep_think_llm"] = "gemini-2.5-pro" # Use a different model
|
||||||
config["quick_think_llm"] = "gemini-2.5-flash-lite-preview-06-17" # Use a different model
|
config["quick_think_llm"] = "gemini-2.5-flash-lite-preview-06-17" # Use a different model
|
||||||
config["max_debate_rounds"] = 1 # Increase debate rounds
|
config["max_debate_rounds"] = 1 # Increase debate rounds
|
||||||
config["online_tools"] = True # Increase debate rounds
|
config["online_tools"] = True # Increase debate rounds
|
||||||
|
|
||||||
# Initialize with custom config
|
# Initialize with custom config
|
||||||
ta = TradingAgentsGraph(debug=True, config=config)
|
ta = TradingAgentsGraph(debug=True, config=config)
|
||||||
|
|
||||||
# forward propagate
|
# forward propagate
|
||||||
_, decision = ta.propagate("NVDA", "2024-05-10")
|
_, decision = ta.propagate("NVDA", "2024-05-10")
|
||||||
print(decision)
|
print(decision)
|
||||||
|
|
||||||
# Memorize mistakes and reflect
|
# Memorize mistakes and reflect
|
||||||
# ta.reflect_and_remember(1000) # parameter is the position returns
|
# ta.reflect_and_remember(1000) # parameter is the position returns
|
||||||
|
|
|
||||||
|
|
@ -1,34 +1,34 @@
|
||||||
[project]
|
[project]
|
||||||
name = "tradingagents"
|
name = "tradingagents"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
description = "Add your description here"
|
description = "Add your description here"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.10"
|
requires-python = ">=3.10"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"akshare>=1.16.98",
|
"akshare>=1.16.98",
|
||||||
"backtrader>=1.9.78.123",
|
"backtrader>=1.9.78.123",
|
||||||
"chainlit>=2.5.5",
|
"chainlit>=2.5.5",
|
||||||
"chromadb>=1.0.12",
|
"chromadb>=1.0.12",
|
||||||
"eodhd>=1.0.32",
|
"eodhd>=1.0.32",
|
||||||
"feedparser>=6.0.11",
|
"feedparser>=6.0.11",
|
||||||
"finnhub-python>=2.4.23",
|
"finnhub-python>=2.4.23",
|
||||||
"langchain-anthropic>=0.3.15",
|
"langchain-anthropic>=0.3.15",
|
||||||
"langchain-experimental>=0.3.4",
|
"langchain-experimental>=0.3.4",
|
||||||
"langchain-google-genai>=2.1.5",
|
"langchain-google-genai>=2.1.5",
|
||||||
"langchain-openai>=0.3.23",
|
"langchain-openai>=0.3.23",
|
||||||
"langgraph>=0.4.8",
|
"langgraph>=0.4.8",
|
||||||
"pandas>=2.3.0",
|
"pandas>=2.3.0",
|
||||||
"parsel>=1.10.0",
|
"parsel>=1.10.0",
|
||||||
"praw>=7.8.1",
|
"praw>=7.8.1",
|
||||||
"pytz>=2025.2",
|
"pytz>=2025.2",
|
||||||
"questionary>=2.1.0",
|
"questionary>=2.1.0",
|
||||||
"redis>=6.2.0",
|
"redis>=6.2.0",
|
||||||
"requests>=2.32.4",
|
"requests>=2.32.4",
|
||||||
"rich>=14.0.0",
|
"rich>=14.0.0",
|
||||||
"setuptools>=80.9.0",
|
"setuptools>=80.9.0",
|
||||||
"stockstats>=0.6.5",
|
"stockstats>=0.6.5",
|
||||||
"tqdm>=4.67.1",
|
"tqdm>=4.67.1",
|
||||||
"tushare>=1.4.21",
|
"tushare>=1.4.21",
|
||||||
"typing-extensions>=4.14.0",
|
"typing-extensions>=4.14.0",
|
||||||
"yfinance>=0.2.63",
|
"yfinance>=0.2.63",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -1,60 +1,60 @@
|
||||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||||
import time
|
import time
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
|
||||||
def create_news_analyst(llm, toolkit):
|
def create_news_analyst(llm, toolkit):
|
||||||
def news_analyst_node(state):
|
def news_analyst_node(state):
|
||||||
current_date = state["trade_date"]
|
current_date = state["trade_date"]
|
||||||
ticker = state["company_of_interest"]
|
ticker = state["company_of_interest"]
|
||||||
|
|
||||||
if toolkit.config["online_tools"]:
|
if toolkit.config["online_tools"]:
|
||||||
tools = [toolkit.get_global_news, toolkit.get_google_news]
|
tools = [toolkit.get_global_news, toolkit.get_google_news]
|
||||||
else:
|
else:
|
||||||
tools = [
|
tools = [
|
||||||
toolkit.get_finnhub_news,
|
toolkit.get_finnhub_news,
|
||||||
toolkit.get_reddit_news,
|
toolkit.get_reddit_news,
|
||||||
toolkit.get_google_news,
|
toolkit.get_google_news,
|
||||||
]
|
]
|
||||||
|
|
||||||
system_message = (
|
system_message = (
|
||||||
"**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)\n\nYou are a news researcher tasked with analyzing recent news and trends over the past week. Please write a comprehensive report of the current state of the world that is relevant for trading and macroeconomics. Look at news from EODHD, and finnhub to be comprehensive. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."
|
"**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)\n\nYou are a news researcher tasked with analyzing recent news and trends over the past week. Please write a comprehensive report of the current state of the world that is relevant for trading and macroeconomics. Look at news from EODHD, and finnhub to be comprehensive. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."
|
||||||
+ """ Make sure to append a Makrdown table at the end of the report to organize key points in the report, organized and easy to read."""
|
+ """ Make sure to append a Makrdown table at the end of the report to organize key points in the report, organized and easy to read."""
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt = ChatPromptTemplate.from_messages(
|
prompt = ChatPromptTemplate.from_messages(
|
||||||
[
|
[
|
||||||
(
|
(
|
||||||
"system",
|
"system",
|
||||||
"You are a helpful AI assistant, collaborating with other assistants."
|
"You are a helpful AI assistant, collaborating with other assistants."
|
||||||
" Use the provided tools to progress towards answering the question."
|
" Use the provided tools to progress towards answering the question."
|
||||||
" If you are unable to fully answer, that's OK; another assistant with different tools"
|
" If you are unable to fully answer, that's OK; another assistant with different tools"
|
||||||
" will help where you left off. Execute what you can to make progress."
|
" will help where you left off. Execute what you can to make progress."
|
||||||
" If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable,"
|
" If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable,"
|
||||||
" prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop."
|
" prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop."
|
||||||
" You have access to the following tools: {tool_names}.\n{system_message}"
|
" You have access to the following tools: {tool_names}.\n{system_message}"
|
||||||
"For your reference, the current date is {current_date}. We are looking at the company {ticker}",
|
"For your reference, the current date is {current_date}. We are looking at the company {ticker}",
|
||||||
),
|
),
|
||||||
MessagesPlaceholder(variable_name="messages"),
|
MessagesPlaceholder(variable_name="messages"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt = prompt.partial(system_message=system_message)
|
prompt = prompt.partial(system_message=system_message)
|
||||||
prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools]))
|
prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools]))
|
||||||
prompt = prompt.partial(current_date=current_date)
|
prompt = prompt.partial(current_date=current_date)
|
||||||
prompt = prompt.partial(ticker=ticker)
|
prompt = prompt.partial(ticker=ticker)
|
||||||
|
|
||||||
chain = prompt | llm.bind_tools(tools)
|
chain = prompt | llm.bind_tools(tools)
|
||||||
result = chain.invoke(state["messages"])
|
result = chain.invoke(state["messages"])
|
||||||
|
|
||||||
report = ""
|
report = ""
|
||||||
|
|
||||||
if len(result.tool_calls) == 0:
|
if len(result.tool_calls) == 0:
|
||||||
report = result.content
|
report = result.content
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"messages": [result],
|
"messages": [result],
|
||||||
"news_report": report,
|
"news_report": report,
|
||||||
}
|
}
|
||||||
|
|
||||||
return news_analyst_node
|
return news_analyst_node
|
||||||
|
|
|
||||||
|
|
@ -1,60 +1,60 @@
|
||||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||||
import time
|
import time
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
|
||||||
def create_social_media_analyst(llm, toolkit):
|
def create_social_media_analyst(llm, toolkit):
|
||||||
def social_media_analyst_node(state):
|
def social_media_analyst_node(state):
|
||||||
current_date = state["trade_date"]
|
current_date = state["trade_date"]
|
||||||
ticker = state["company_of_interest"]
|
ticker = state["company_of_interest"]
|
||||||
company_name = state["company_of_interest"]
|
company_name = state["company_of_interest"]
|
||||||
|
|
||||||
if toolkit.config["online_tools"]:
|
if toolkit.config["online_tools"]:
|
||||||
tools = [toolkit.get_stock_news]
|
tools = [toolkit.get_stock_news]
|
||||||
else:
|
else:
|
||||||
tools = [
|
tools = [
|
||||||
toolkit.get_reddit_stock_info,
|
toolkit.get_reddit_stock_info,
|
||||||
]
|
]
|
||||||
|
|
||||||
system_message = (
|
system_message = (
|
||||||
"**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)\n\nYou are a social media and company specific news researcher/analyst tasked with analyzing social media posts, recent company news, and public sentiment for a specific company over the past week. You will be given a company's name your objective is to write a comprehensive long report detailing your analysis, insights, and implications for traders and investors on this company's current state after looking at social media and what people are saying about that company, analyzing sentiment data of what people feel each day about the company, and looking at recent company news. Try to look at all sources possible from social media to sentiment to news. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."
|
"**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)\n\nYou are a social media and company specific news researcher/analyst tasked with analyzing social media posts, recent company news, and public sentiment for a specific company over the past week. You will be given a company's name your objective is to write a comprehensive long report detailing your analysis, insights, and implications for traders and investors on this company's current state after looking at social media and what people are saying about that company, analyzing sentiment data of what people feel each day about the company, and looking at recent company news. Try to look at all sources possible from social media to sentiment to news. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."
|
||||||
+ """ Make sure to append a Makrdown table at the end of the report to organize key points in the report, organized and easy to read.""",
|
+ """ Make sure to append a Makrdown table at the end of the report to organize key points in the report, organized and easy to read.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt = ChatPromptTemplate.from_messages(
|
prompt = ChatPromptTemplate.from_messages(
|
||||||
[
|
[
|
||||||
(
|
(
|
||||||
"system",
|
"system",
|
||||||
"You are a helpful AI assistant, collaborating with other assistants."
|
"You are a helpful AI assistant, collaborating with other assistants."
|
||||||
" Use the provided tools to progress towards answering the question."
|
" Use the provided tools to progress towards answering the question."
|
||||||
" If you are unable to fully answer, that's OK; another assistant with different tools"
|
" If you are unable to fully answer, that's OK; another assistant with different tools"
|
||||||
" will help where you left off. Execute what you can to make progress."
|
" will help where you left off. Execute what you can to make progress."
|
||||||
" If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable,"
|
" If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable,"
|
||||||
" prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop."
|
" prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop."
|
||||||
" You have access to the following tools: {tool_names}.\n{system_message}"
|
" You have access to the following tools: {tool_names}.\n{system_message}"
|
||||||
"For your reference, the current date is {current_date}. The current company we want to analyze is {ticker}",
|
"For your reference, the current date is {current_date}. The current company we want to analyze is {ticker}",
|
||||||
),
|
),
|
||||||
MessagesPlaceholder(variable_name="messages"),
|
MessagesPlaceholder(variable_name="messages"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt = prompt.partial(system_message=system_message)
|
prompt = prompt.partial(system_message=system_message)
|
||||||
prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools]))
|
prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools]))
|
||||||
prompt = prompt.partial(current_date=current_date)
|
prompt = prompt.partial(current_date=current_date)
|
||||||
prompt = prompt.partial(ticker=ticker)
|
prompt = prompt.partial(ticker=ticker)
|
||||||
|
|
||||||
chain = prompt | llm.bind_tools(tools)
|
chain = prompt | llm.bind_tools(tools)
|
||||||
|
|
||||||
result = chain.invoke(state["messages"])
|
result = chain.invoke(state["messages"])
|
||||||
|
|
||||||
report = ""
|
report = ""
|
||||||
|
|
||||||
if len(result.tool_calls) == 0:
|
if len(result.tool_calls) == 0:
|
||||||
report = result.content
|
report = result.content
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"messages": [result],
|
"messages": [result],
|
||||||
"sentiment_report": report,
|
"sentiment_report": report,
|
||||||
}
|
}
|
||||||
|
|
||||||
return social_media_analyst_node
|
return social_media_analyst_node
|
||||||
|
|
|
||||||
|
|
@ -1,57 +1,57 @@
|
||||||
import time
|
import time
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
|
||||||
def create_research_manager(llm, memory):
|
def create_research_manager(llm, memory):
|
||||||
def research_manager_node(state) -> dict:
|
def research_manager_node(state) -> dict:
|
||||||
history = state["investment_debate_state"].get("history", "")
|
history = state["investment_debate_state"].get("history", "")
|
||||||
market_research_report = state["market_report"]
|
market_research_report = state["market_report"]
|
||||||
sentiment_report = state["sentiment_report"]
|
sentiment_report = state["sentiment_report"]
|
||||||
news_report = state["news_report"]
|
news_report = state["news_report"]
|
||||||
fundamentals_report = state["fundamentals_report"]
|
fundamentals_report = state["fundamentals_report"]
|
||||||
|
|
||||||
investment_debate_state = state["investment_debate_state"]
|
investment_debate_state = state["investment_debate_state"]
|
||||||
|
|
||||||
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
|
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
|
||||||
past_memories = memory.get_memories(curr_situation, n_matches=2)
|
past_memories = memory.get_memories(curr_situation, n_matches=2)
|
||||||
|
|
||||||
past_memory_str = ""
|
past_memory_str = ""
|
||||||
for i, rec in enumerate(past_memories, 1):
|
for i, rec in enumerate(past_memories, 1):
|
||||||
past_memory_str += rec["recommendation"] + "\n\n"
|
past_memory_str += rec["recommendation"] + "\n\n"
|
||||||
|
|
||||||
prompt = f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)
|
prompt = f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)
|
||||||
|
|
||||||
As the portfolio manager and debate facilitator, your role is to critically evaluate this round of debate and make a definitive decision: align with the bear analyst, the bull analyst, or choose Hold only if it is strongly justified based on the arguments presented.
|
As the portfolio manager and debate facilitator, your role is to critically evaluate this round of debate and make a definitive decision: align with the bear analyst, the bull analyst, or choose Hold only if it is strongly justified based on the arguments presented.
|
||||||
|
|
||||||
Summarize the key points from both sides concisely, focusing on the most compelling evidence or reasoning. Your recommendation—Buy, Sell, or Hold—must be clear and actionable. Avoid defaulting to Hold simply because both sides have valid points; commit to a stance grounded in the debate's strongest arguments.
|
Summarize the key points from both sides concisely, focusing on the most compelling evidence or reasoning. Your recommendation—Buy, Sell, or Hold—must be clear and actionable. Avoid defaulting to Hold simply because both sides have valid points; commit to a stance grounded in the debate's strongest arguments.
|
||||||
|
|
||||||
Additionally, develop a detailed investment plan for the trader. This should include:
|
Additionally, develop a detailed investment plan for the trader. This should include:
|
||||||
|
|
||||||
Your Recommendation: A decisive stance supported by the most convincing arguments.
|
Your Recommendation: A decisive stance supported by the most convincing arguments.
|
||||||
Rationale: An explanation of why these arguments lead to your conclusion.
|
Rationale: An explanation of why these arguments lead to your conclusion.
|
||||||
Strategic Actions: Concrete steps for implementing the recommendation.
|
Strategic Actions: Concrete steps for implementing the recommendation.
|
||||||
Take into account your past mistakes on similar situations. Use these insights to refine your decision-making and ensure you are learning and improving. Present your analysis conversationally, as if speaking naturally, without special formatting.
|
Take into account your past mistakes on similar situations. Use these insights to refine your decision-making and ensure you are learning and improving. Present your analysis conversationally, as if speaking naturally, without special formatting.
|
||||||
|
|
||||||
Here are your past reflections on mistakes:
|
Here are your past reflections on mistakes:
|
||||||
\"{past_memory_str}\"
|
\"{past_memory_str}\"
|
||||||
|
|
||||||
Here is the debate:
|
Here is the debate:
|
||||||
Debate History:
|
Debate History:
|
||||||
{history}"""
|
{history}"""
|
||||||
response = llm.invoke(prompt)
|
response = llm.invoke(prompt)
|
||||||
|
|
||||||
new_investment_debate_state = {
|
new_investment_debate_state = {
|
||||||
"judge_decision": response.content,
|
"judge_decision": response.content,
|
||||||
"history": investment_debate_state.get("history", ""),
|
"history": investment_debate_state.get("history", ""),
|
||||||
"bear_history": investment_debate_state.get("bear_history", ""),
|
"bear_history": investment_debate_state.get("bear_history", ""),
|
||||||
"bull_history": investment_debate_state.get("bull_history", ""),
|
"bull_history": investment_debate_state.get("bull_history", ""),
|
||||||
"current_response": response.content,
|
"current_response": response.content,
|
||||||
"count": investment_debate_state["count"],
|
"count": investment_debate_state["count"],
|
||||||
}
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"investment_debate_state": new_investment_debate_state,
|
"investment_debate_state": new_investment_debate_state,
|
||||||
"investment_plan": response.content,
|
"investment_plan": response.content,
|
||||||
}
|
}
|
||||||
|
|
||||||
return research_manager_node
|
return research_manager_node
|
||||||
|
|
|
||||||
|
|
@ -1,68 +1,68 @@
|
||||||
import time
|
import time
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
|
||||||
def create_risk_manager(llm, memory):
|
def create_risk_manager(llm, memory):
|
||||||
def risk_manager_node(state) -> dict:
|
def risk_manager_node(state) -> dict:
|
||||||
|
|
||||||
company_name = state["company_of_interest"]
|
company_name = state["company_of_interest"]
|
||||||
|
|
||||||
history = state["risk_debate_state"]["history"]
|
history = state["risk_debate_state"]["history"]
|
||||||
risk_debate_state = state["risk_debate_state"]
|
risk_debate_state = state["risk_debate_state"]
|
||||||
market_research_report = state["market_report"]
|
market_research_report = state["market_report"]
|
||||||
news_report = state["news_report"]
|
news_report = state["news_report"]
|
||||||
fundamentals_report = state["news_report"]
|
fundamentals_report = state["news_report"]
|
||||||
sentiment_report = state["sentiment_report"]
|
sentiment_report = state["sentiment_report"]
|
||||||
trader_plan = state["investment_plan"]
|
trader_plan = state["investment_plan"]
|
||||||
|
|
||||||
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
|
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
|
||||||
past_memories = memory.get_memories(curr_situation, n_matches=2)
|
past_memories = memory.get_memories(curr_situation, n_matches=2)
|
||||||
|
|
||||||
past_memory_str = ""
|
past_memory_str = ""
|
||||||
for i, rec in enumerate(past_memories, 1):
|
for i, rec in enumerate(past_memories, 1):
|
||||||
past_memory_str += rec["recommendation"] + "\n\n"
|
past_memory_str += rec["recommendation"] + "\n\n"
|
||||||
|
|
||||||
prompt = f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)
|
prompt = f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)
|
||||||
|
|
||||||
As the Risk Management Judge and Debate Facilitator, your goal is to evaluate the debate between three risk analysts—Risky, Neutral, and Safe/Conservative—and determine the best course of action for the trader. Your decision must result in a clear recommendation: Buy, Sell, or Hold. Choose Hold only if strongly justified by specific arguments, not as a fallback when all sides seem valid. Strive for clarity and decisiveness.
|
As the Risk Management Judge and Debate Facilitator, your goal is to evaluate the debate between three risk analysts—Risky, Neutral, and Safe/Conservative—and determine the best course of action for the trader. Your decision must result in a clear recommendation: Buy, Sell, or Hold. Choose Hold only if strongly justified by specific arguments, not as a fallback when all sides seem valid. Strive for clarity and decisiveness.
|
||||||
|
|
||||||
Guidelines for Decision-Making:
|
Guidelines for Decision-Making:
|
||||||
1. **Summarize Key Arguments**: Extract the strongest points from each analyst, focusing on relevance to the context.
|
1. **Summarize Key Arguments**: Extract the strongest points from each analyst, focusing on relevance to the context.
|
||||||
2. **Provide Rationale**: Support your recommendation with direct quotes and counterarguments from the debate.
|
2. **Provide Rationale**: Support your recommendation with direct quotes and counterarguments from the debate.
|
||||||
3. **Refine the Trader's Plan**: Start with the trader's original plan, **{trader_plan}**, and adjust it based on the analysts' insights.
|
3. **Refine the Trader's Plan**: Start with the trader's original plan, **{trader_plan}**, and adjust it based on the analysts' insights.
|
||||||
4. **Learn from Past Mistakes**: Use lessons from **{past_memory_str}** to address prior misjudgments and improve the decision you are making now to make sure you don't make a wrong BUY/SELL/HOLD call that loses money.
|
4. **Learn from Past Mistakes**: Use lessons from **{past_memory_str}** to address prior misjudgments and improve the decision you are making now to make sure you don't make a wrong BUY/SELL/HOLD call that loses money.
|
||||||
|
|
||||||
Deliverables:
|
Deliverables:
|
||||||
- A clear and actionable recommendation: Buy, Sell, or Hold.
|
- A clear and actionable recommendation: Buy, Sell, or Hold.
|
||||||
- Detailed reasoning anchored in the debate and past reflections.
|
- Detailed reasoning anchored in the debate and past reflections.
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
**Analysts Debate History:**
|
**Analysts Debate History:**
|
||||||
{history}
|
{history}
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
Focus on actionable insights and continuous improvement. Build on past lessons, critically evaluate all perspectives, and ensure each decision advances better outcomes."""
|
Focus on actionable insights and continuous improvement. Build on past lessons, critically evaluate all perspectives, and ensure each decision advances better outcomes."""
|
||||||
|
|
||||||
response = llm.invoke(prompt)
|
response = llm.invoke(prompt)
|
||||||
|
|
||||||
new_risk_debate_state = {
|
new_risk_debate_state = {
|
||||||
"judge_decision": response.content,
|
"judge_decision": response.content,
|
||||||
"history": risk_debate_state["history"],
|
"history": risk_debate_state["history"],
|
||||||
"risky_history": risk_debate_state["risky_history"],
|
"risky_history": risk_debate_state["risky_history"],
|
||||||
"safe_history": risk_debate_state["safe_history"],
|
"safe_history": risk_debate_state["safe_history"],
|
||||||
"neutral_history": risk_debate_state["neutral_history"],
|
"neutral_history": risk_debate_state["neutral_history"],
|
||||||
"latest_speaker": "Judge",
|
"latest_speaker": "Judge",
|
||||||
"current_risky_response": risk_debate_state["current_risky_response"],
|
"current_risky_response": risk_debate_state["current_risky_response"],
|
||||||
"current_safe_response": risk_debate_state["current_safe_response"],
|
"current_safe_response": risk_debate_state["current_safe_response"],
|
||||||
"current_neutral_response": risk_debate_state["current_neutral_response"],
|
"current_neutral_response": risk_debate_state["current_neutral_response"],
|
||||||
"count": risk_debate_state["count"],
|
"count": risk_debate_state["count"],
|
||||||
}
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"risk_debate_state": new_risk_debate_state,
|
"risk_debate_state": new_risk_debate_state,
|
||||||
"final_trade_decision": response.content,
|
"final_trade_decision": response.content,
|
||||||
}
|
}
|
||||||
|
|
||||||
return risk_manager_node
|
return risk_manager_node
|
||||||
|
|
|
||||||
|
|
@ -1,63 +1,63 @@
|
||||||
from langchain_core.messages import AIMessage
|
from langchain_core.messages import AIMessage
|
||||||
import time
|
import time
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
|
||||||
def create_bear_researcher(llm, memory):
|
def create_bear_researcher(llm, memory):
|
||||||
def bear_node(state) -> dict:
|
def bear_node(state) -> dict:
|
||||||
investment_debate_state = state["investment_debate_state"]
|
investment_debate_state = state["investment_debate_state"]
|
||||||
history = investment_debate_state.get("history", "")
|
history = investment_debate_state.get("history", "")
|
||||||
bear_history = investment_debate_state.get("bear_history", "")
|
bear_history = investment_debate_state.get("bear_history", "")
|
||||||
|
|
||||||
current_response = investment_debate_state.get("current_response", "")
|
current_response = investment_debate_state.get("current_response", "")
|
||||||
market_research_report = state["market_report"]
|
market_research_report = state["market_report"]
|
||||||
sentiment_report = state["sentiment_report"]
|
sentiment_report = state["sentiment_report"]
|
||||||
news_report = state["news_report"]
|
news_report = state["news_report"]
|
||||||
fundamentals_report = state["fundamentals_report"]
|
fundamentals_report = state["fundamentals_report"]
|
||||||
|
|
||||||
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
|
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
|
||||||
past_memories = memory.get_memories(curr_situation, n_matches=2)
|
past_memories = memory.get_memories(curr_situation, n_matches=2)
|
||||||
|
|
||||||
past_memory_str = ""
|
past_memory_str = ""
|
||||||
for i, rec in enumerate(past_memories, 1):
|
for i, rec in enumerate(past_memories, 1):
|
||||||
past_memory_str += rec["recommendation"] + "\n\n"
|
past_memory_str += rec["recommendation"] + "\n\n"
|
||||||
|
|
||||||
prompt = f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)
|
prompt = f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)
|
||||||
|
|
||||||
You are a Bear Analyst making the case against investing in the stock. Your goal is to present a well-reasoned argument emphasizing risks, challenges, and negative indicators. Leverage the provided research and data to highlight potential downsides and counter bullish arguments effectively.
|
You are a Bear Analyst making the case against investing in the stock. Your goal is to present a well-reasoned argument emphasizing risks, challenges, and negative indicators. Leverage the provided research and data to highlight potential downsides and counter bullish arguments effectively.
|
||||||
|
|
||||||
Key points to focus on:
|
Key points to focus on:
|
||||||
|
|
||||||
- Risks and Challenges: Highlight factors like market saturation, financial instability, or macroeconomic threats that could hinder the stock's performance.
|
- Risks and Challenges: Highlight factors like market saturation, financial instability, or macroeconomic threats that could hinder the stock's performance.
|
||||||
- Competitive Weaknesses: Emphasize vulnerabilities such as weaker market positioning, declining innovation, or threats from competitors.
|
- Competitive Weaknesses: Emphasize vulnerabilities such as weaker market positioning, declining innovation, or threats from competitors.
|
||||||
- Negative Indicators: Use evidence from financial data, market trends, or recent adverse news to support your position.
|
- Negative Indicators: Use evidence from financial data, market trends, or recent adverse news to support your position.
|
||||||
- Bull Counterpoints: Critically analyze the bull argument with specific data and sound reasoning, exposing weaknesses or over-optimistic assumptions.
|
- Bull Counterpoints: Critically analyze the bull argument with specific data and sound reasoning, exposing weaknesses or over-optimistic assumptions.
|
||||||
- Engagement: Present your argument in a conversational style, directly engaging with the bull analyst's points and debating effectively rather than simply listing facts.
|
- Engagement: Present your argument in a conversational style, directly engaging with the bull analyst's points and debating effectively rather than simply listing facts.
|
||||||
|
|
||||||
Resources available:
|
Resources available:
|
||||||
|
|
||||||
Market research report: {market_research_report}
|
Market research report: {market_research_report}
|
||||||
Social media sentiment report: {sentiment_report}
|
Social media sentiment report: {sentiment_report}
|
||||||
Latest world affairs news: {news_report}
|
Latest world affairs news: {news_report}
|
||||||
Company fundamentals report: {fundamentals_report}
|
Company fundamentals report: {fundamentals_report}
|
||||||
Conversation history of the debate: {history}
|
Conversation history of the debate: {history}
|
||||||
Last bull argument: {current_response}
|
Last bull argument: {current_response}
|
||||||
Reflections from similar situations and lessons learned: {past_memory_str}
|
Reflections from similar situations and lessons learned: {past_memory_str}
|
||||||
Use this information to deliver a compelling bear argument, refute the bull's claims, and engage in a dynamic debate that demonstrates the risks and weaknesses of investing in the stock. You must also address reflections and learn from lessons and mistakes you made in the past.
|
Use this information to deliver a compelling bear argument, refute the bull's claims, and engage in a dynamic debate that demonstrates the risks and weaknesses of investing in the stock. You must also address reflections and learn from lessons and mistakes you made in the past.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
response = llm.invoke(prompt)
|
response = llm.invoke(prompt)
|
||||||
|
|
||||||
argument = f"Bear Analyst: {response.content}"
|
argument = f"Bear Analyst: {response.content}"
|
||||||
|
|
||||||
new_investment_debate_state = {
|
new_investment_debate_state = {
|
||||||
"history": history + "\n" + argument,
|
"history": history + "\n" + argument,
|
||||||
"bear_history": bear_history + "\n" + argument,
|
"bear_history": bear_history + "\n" + argument,
|
||||||
"bull_history": investment_debate_state.get("bull_history", ""),
|
"bull_history": investment_debate_state.get("bull_history", ""),
|
||||||
"current_response": argument,
|
"current_response": argument,
|
||||||
"count": investment_debate_state["count"] + 1,
|
"count": investment_debate_state["count"] + 1,
|
||||||
}
|
}
|
||||||
|
|
||||||
return {"investment_debate_state": new_investment_debate_state}
|
return {"investment_debate_state": new_investment_debate_state}
|
||||||
|
|
||||||
return bear_node
|
return bear_node
|
||||||
|
|
|
||||||
|
|
@ -1,61 +1,61 @@
|
||||||
from langchain_core.messages import AIMessage
|
from langchain_core.messages import AIMessage
|
||||||
import time
|
import time
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
|
||||||
def create_bull_researcher(llm, memory):
|
def create_bull_researcher(llm, memory):
|
||||||
def bull_node(state) -> dict:
|
def bull_node(state) -> dict:
|
||||||
investment_debate_state = state["investment_debate_state"]
|
investment_debate_state = state["investment_debate_state"]
|
||||||
history = investment_debate_state.get("history", "")
|
history = investment_debate_state.get("history", "")
|
||||||
bull_history = investment_debate_state.get("bull_history", "")
|
bull_history = investment_debate_state.get("bull_history", "")
|
||||||
|
|
||||||
current_response = investment_debate_state.get("current_response", "")
|
current_response = investment_debate_state.get("current_response", "")
|
||||||
market_research_report = state["market_report"]
|
market_research_report = state["market_report"]
|
||||||
sentiment_report = state["sentiment_report"]
|
sentiment_report = state["sentiment_report"]
|
||||||
news_report = state["news_report"]
|
news_report = state["news_report"]
|
||||||
fundamentals_report = state["fundamentals_report"]
|
fundamentals_report = state["fundamentals_report"]
|
||||||
|
|
||||||
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
|
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
|
||||||
past_memories = memory.get_memories(curr_situation, n_matches=2)
|
past_memories = memory.get_memories(curr_situation, n_matches=2)
|
||||||
|
|
||||||
past_memory_str = ""
|
past_memory_str = ""
|
||||||
for i, rec in enumerate(past_memories, 1):
|
for i, rec in enumerate(past_memories, 1):
|
||||||
past_memory_str += rec["recommendation"] + "\n\n"
|
past_memory_str += rec["recommendation"] + "\n\n"
|
||||||
|
|
||||||
prompt = f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)
|
prompt = f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)
|
||||||
|
|
||||||
You are a Bull Analyst advocating for investing in the stock. Your task is to build a strong, evidence-based case emphasizing growth potential, competitive advantages, and positive market indicators. Leverage the provided research and data to address concerns and counter bearish arguments effectively.
|
You are a Bull Analyst advocating for investing in the stock. Your task is to build a strong, evidence-based case emphasizing growth potential, competitive advantages, and positive market indicators. Leverage the provided research and data to address concerns and counter bearish arguments effectively.
|
||||||
|
|
||||||
Key points to focus on:
|
Key points to focus on:
|
||||||
- Growth Potential: Highlight the company's market opportunities, revenue projections, and scalability.
|
- Growth Potential: Highlight the company's market opportunities, revenue projections, and scalability.
|
||||||
- Competitive Advantages: Emphasize factors like unique products, strong branding, or dominant market positioning.
|
- Competitive Advantages: Emphasize factors like unique products, strong branding, or dominant market positioning.
|
||||||
- Positive Indicators: Use financial health, industry trends, and recent positive news as evidence.
|
- Positive Indicators: Use financial health, industry trends, and recent positive news as evidence.
|
||||||
- Bear Counterpoints: Critically analyze the bear argument with specific data and sound reasoning, addressing concerns thoroughly and showing why the bull perspective holds stronger merit.
|
- Bear Counterpoints: Critically analyze the bear argument with specific data and sound reasoning, addressing concerns thoroughly and showing why the bull perspective holds stronger merit.
|
||||||
- Engagement: Present your argument in a conversational style, engaging directly with the bear analyst's points and debating effectively rather than just listing data.
|
- Engagement: Present your argument in a conversational style, engaging directly with the bear analyst's points and debating effectively rather than just listing data.
|
||||||
|
|
||||||
Resources available:
|
Resources available:
|
||||||
Market research report: {market_research_report}
|
Market research report: {market_research_report}
|
||||||
Social media sentiment report: {sentiment_report}
|
Social media sentiment report: {sentiment_report}
|
||||||
Latest world affairs news: {news_report}
|
Latest world affairs news: {news_report}
|
||||||
Company fundamentals report: {fundamentals_report}
|
Company fundamentals report: {fundamentals_report}
|
||||||
Conversation history of the debate: {history}
|
Conversation history of the debate: {history}
|
||||||
Last bear argument: {current_response}
|
Last bear argument: {current_response}
|
||||||
Reflections from similar situations and lessons learned: {past_memory_str}
|
Reflections from similar situations and lessons learned: {past_memory_str}
|
||||||
Use this information to deliver a compelling bull argument, refute the bear's concerns, and engage in a dynamic debate that demonstrates the strengths of the bull position. You must also address reflections and learn from lessons and mistakes you made in the past.
|
Use this information to deliver a compelling bull argument, refute the bear's concerns, and engage in a dynamic debate that demonstrates the strengths of the bull position. You must also address reflections and learn from lessons and mistakes you made in the past.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
response = llm.invoke(prompt)
|
response = llm.invoke(prompt)
|
||||||
|
|
||||||
argument = f"Bull Analyst: {response.content}"
|
argument = f"Bull Analyst: {response.content}"
|
||||||
|
|
||||||
new_investment_debate_state = {
|
new_investment_debate_state = {
|
||||||
"history": history + "\n" + argument,
|
"history": history + "\n" + argument,
|
||||||
"bull_history": bull_history + "\n" + argument,
|
"bull_history": bull_history + "\n" + argument,
|
||||||
"bear_history": investment_debate_state.get("bear_history", ""),
|
"bear_history": investment_debate_state.get("bear_history", ""),
|
||||||
"current_response": argument,
|
"current_response": argument,
|
||||||
"count": investment_debate_state["count"] + 1,
|
"count": investment_debate_state["count"] + 1,
|
||||||
}
|
}
|
||||||
|
|
||||||
return {"investment_debate_state": new_investment_debate_state}
|
return {"investment_debate_state": new_investment_debate_state}
|
||||||
|
|
||||||
return bull_node
|
return bull_node
|
||||||
|
|
|
||||||
|
|
@ -1,57 +1,57 @@
|
||||||
import time
|
import time
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
|
||||||
def create_risky_debator(llm):
|
def create_risky_debator(llm):
|
||||||
def risky_node(state) -> dict:
|
def risky_node(state) -> dict:
|
||||||
risk_debate_state = state["risk_debate_state"]
|
risk_debate_state = state["risk_debate_state"]
|
||||||
history = risk_debate_state.get("history", "")
|
history = risk_debate_state.get("history", "")
|
||||||
risky_history = risk_debate_state.get("risky_history", "")
|
risky_history = risk_debate_state.get("risky_history", "")
|
||||||
|
|
||||||
current_safe_response = risk_debate_state.get("current_safe_response", "")
|
current_safe_response = risk_debate_state.get("current_safe_response", "")
|
||||||
current_neutral_response = risk_debate_state.get("current_neutral_response", "")
|
current_neutral_response = risk_debate_state.get("current_neutral_response", "")
|
||||||
|
|
||||||
market_research_report = state["market_report"]
|
market_research_report = state["market_report"]
|
||||||
sentiment_report = state["sentiment_report"]
|
sentiment_report = state["sentiment_report"]
|
||||||
news_report = state["news_report"]
|
news_report = state["news_report"]
|
||||||
fundamentals_report = state["fundamentals_report"]
|
fundamentals_report = state["fundamentals_report"]
|
||||||
|
|
||||||
trader_decision = state["trader_investment_plan"]
|
trader_decision = state["trader_investment_plan"]
|
||||||
|
|
||||||
prompt = f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)
|
prompt = f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)
|
||||||
|
|
||||||
As the Risky Risk Analyst, your role is to actively champion high-reward, high-risk opportunities, emphasizing bold strategies and competitive advantages. When evaluating the trader's decision or plan, focus intently on the potential upside, growth potential, and innovative benefits—even when these come with elevated risk. Use the provided market data and sentiment analysis to strengthen your arguments and challenge the opposing views. Specifically, respond directly to each point made by the conservative and neutral analysts, countering with data-driven rebuttals and persuasive reasoning. Highlight where their caution might miss critical opportunities or where their assumptions may be overly conservative. Here is the trader's decision:
|
As the Risky Risk Analyst, your role is to actively champion high-reward, high-risk opportunities, emphasizing bold strategies and competitive advantages. When evaluating the trader's decision or plan, focus intently on the potential upside, growth potential, and innovative benefits—even when these come with elevated risk. Use the provided market data and sentiment analysis to strengthen your arguments and challenge the opposing views. Specifically, respond directly to each point made by the conservative and neutral analysts, countering with data-driven rebuttals and persuasive reasoning. Highlight where their caution might miss critical opportunities or where their assumptions may be overly conservative. Here is the trader's decision:
|
||||||
|
|
||||||
{trader_decision}
|
{trader_decision}
|
||||||
|
|
||||||
Your task is to create a compelling case for the trader's decision by questioning and critiquing the conservative and neutral stances to demonstrate why your high-reward perspective offers the best path forward. Incorporate insights from the following sources into your arguments:
|
Your task is to create a compelling case for the trader's decision by questioning and critiquing the conservative and neutral stances to demonstrate why your high-reward perspective offers the best path forward. Incorporate insights from the following sources into your arguments:
|
||||||
|
|
||||||
Market Research Report: {market_research_report}
|
Market Research Report: {market_research_report}
|
||||||
Social Media Sentiment Report: {sentiment_report}
|
Social Media Sentiment Report: {sentiment_report}
|
||||||
Latest World Affairs Report: {news_report}
|
Latest World Affairs Report: {news_report}
|
||||||
Company Fundamentals Report: {fundamentals_report}
|
Company Fundamentals Report: {fundamentals_report}
|
||||||
Here is the current conversation history: {history} Here are the last arguments from the conservative analyst: {current_safe_response} Here are the last arguments from the neutral analyst: {current_neutral_response}. If there are no responses from the other viewpoints, do not halluncinate and just present your point.
|
Here is the current conversation history: {history} Here are the last arguments from the conservative analyst: {current_safe_response} Here are the last arguments from the neutral analyst: {current_neutral_response}. If there are no responses from the other viewpoints, do not halluncinate and just present your point.
|
||||||
|
|
||||||
Engage actively by addressing any specific concerns raised, refuting the weaknesses in their logic, and asserting the benefits of risk-taking to outpace market norms. Maintain a focus on debating and persuading, not just presenting data. Challenge each counterpoint to underscore why a high-risk approach is optimal. Output conversationally as if you are speaking without any special formatting."""
|
Engage actively by addressing any specific concerns raised, refuting the weaknesses in their logic, and asserting the benefits of risk-taking to outpace market norms. Maintain a focus on debating and persuading, not just presenting data. Challenge each counterpoint to underscore why a high-risk approach is optimal. Output conversationally as if you are speaking without any special formatting."""
|
||||||
|
|
||||||
response = llm.invoke(prompt)
|
response = llm.invoke(prompt)
|
||||||
|
|
||||||
argument = f"Risky Analyst: {response.content}"
|
argument = f"Risky Analyst: {response.content}"
|
||||||
|
|
||||||
new_risk_debate_state = {
|
new_risk_debate_state = {
|
||||||
"history": history + "\n" + argument,
|
"history": history + "\n" + argument,
|
||||||
"risky_history": risky_history + "\n" + argument,
|
"risky_history": risky_history + "\n" + argument,
|
||||||
"safe_history": risk_debate_state.get("safe_history", ""),
|
"safe_history": risk_debate_state.get("safe_history", ""),
|
||||||
"neutral_history": risk_debate_state.get("neutral_history", ""),
|
"neutral_history": risk_debate_state.get("neutral_history", ""),
|
||||||
"latest_speaker": "Risky",
|
"latest_speaker": "Risky",
|
||||||
"current_risky_response": argument,
|
"current_risky_response": argument,
|
||||||
"current_safe_response": risk_debate_state.get("current_safe_response", ""),
|
"current_safe_response": risk_debate_state.get("current_safe_response", ""),
|
||||||
"current_neutral_response": risk_debate_state.get(
|
"current_neutral_response": risk_debate_state.get(
|
||||||
"current_neutral_response", ""
|
"current_neutral_response", ""
|
||||||
),
|
),
|
||||||
"count": risk_debate_state["count"] + 1,
|
"count": risk_debate_state["count"] + 1,
|
||||||
}
|
}
|
||||||
|
|
||||||
return {"risk_debate_state": new_risk_debate_state}
|
return {"risk_debate_state": new_risk_debate_state}
|
||||||
|
|
||||||
return risky_node
|
return risky_node
|
||||||
|
|
|
||||||
|
|
@ -1,60 +1,60 @@
|
||||||
from langchain_core.messages import AIMessage
|
from langchain_core.messages import AIMessage
|
||||||
import time
|
import time
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
|
||||||
def create_safe_debator(llm):
|
def create_safe_debator(llm):
|
||||||
def safe_node(state) -> dict:
|
def safe_node(state) -> dict:
|
||||||
risk_debate_state = state["risk_debate_state"]
|
risk_debate_state = state["risk_debate_state"]
|
||||||
history = risk_debate_state.get("history", "")
|
history = risk_debate_state.get("history", "")
|
||||||
safe_history = risk_debate_state.get("safe_history", "")
|
safe_history = risk_debate_state.get("safe_history", "")
|
||||||
|
|
||||||
current_risky_response = risk_debate_state.get("current_risky_response", "")
|
current_risky_response = risk_debate_state.get("current_risky_response", "")
|
||||||
current_neutral_response = risk_debate_state.get("current_neutral_response", "")
|
current_neutral_response = risk_debate_state.get("current_neutral_response", "")
|
||||||
|
|
||||||
market_research_report = state["market_report"]
|
market_research_report = state["market_report"]
|
||||||
sentiment_report = state["sentiment_report"]
|
sentiment_report = state["sentiment_report"]
|
||||||
news_report = state["news_report"]
|
news_report = state["news_report"]
|
||||||
fundamentals_report = state["fundamentals_report"]
|
fundamentals_report = state["fundamentals_report"]
|
||||||
|
|
||||||
trader_decision = state["trader_investment_plan"]
|
trader_decision = state["trader_investment_plan"]
|
||||||
|
|
||||||
prompt = f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)
|
prompt = f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)
|
||||||
|
|
||||||
As the Safe/Conservative Risk Analyst, your primary objective is to protect assets, minimize volatility, and ensure steady, reliable growth. You prioritize stability, security, and risk mitigation, carefully assessing potential losses, economic downturns, and market volatility. When evaluating the trader's decision or plan, critically examine high-risk elements, pointing out where the decision may expose the firm to undue risk and where more cautious alternatives could secure long-term gains. Here is the trader's decision:
|
As the Safe/Conservative Risk Analyst, your primary objective is to protect assets, minimize volatility, and ensure steady, reliable growth. You prioritize stability, security, and risk mitigation, carefully assessing potential losses, economic downturns, and market volatility. When evaluating the trader's decision or plan, critically examine high-risk elements, pointing out where the decision may expose the firm to undue risk and where more cautious alternatives could secure long-term gains. Here is the trader's decision:
|
||||||
|
|
||||||
{trader_decision}
|
{trader_decision}
|
||||||
|
|
||||||
Your task is to actively counter the arguments of the Risky and Neutral Analysts, highlighting where their views may overlook potential threats or fail to prioritize sustainability. Respond directly to their points, drawing from the following data sources to build a convincing case for a low-risk approach adjustment to the trader's decision:
|
Your task is to actively counter the arguments of the Risky and Neutral Analysts, highlighting where their views may overlook potential threats or fail to prioritize sustainability. Respond directly to their points, drawing from the following data sources to build a convincing case for a low-risk approach adjustment to the trader's decision:
|
||||||
|
|
||||||
Market Research Report: {market_research_report}
|
Market Research Report: {market_research_report}
|
||||||
Social Media Sentiment Report: {sentiment_report}
|
Social Media Sentiment Report: {sentiment_report}
|
||||||
Latest World Affairs Report: {news_report}
|
Latest World Affairs Report: {news_report}
|
||||||
Company Fundamentals Report: {fundamentals_report}
|
Company Fundamentals Report: {fundamentals_report}
|
||||||
Here is the current conversation history: {history} Here is the last response from the risky analyst: {current_risky_response} Here is the last response from the neutral analyst: {current_neutral_response}. If there are no responses from the other viewpoints, do not halluncinate and just present your point.
|
Here is the current conversation history: {history} Here is the last response from the risky analyst: {current_risky_response} Here is the last response from the neutral analyst: {current_neutral_response}. If there are no responses from the other viewpoints, do not halluncinate and just present your point.
|
||||||
|
|
||||||
Engage by questioning their optimism and emphasizing the potential downsides they may have overlooked. Address each of their counterpoints to showcase why a conservative stance is ultimately the safest path for the firm's assets. Focus on debating and critiquing their arguments to demonstrate the strength of a low-risk strategy over their approaches. Output conversationally as if you are speaking without any special formatting."""
|
Engage by questioning their optimism and emphasizing the potential downsides they may have overlooked. Address each of their counterpoints to showcase why a conservative stance is ultimately the safest path for the firm's assets. Focus on debating and critiquing their arguments to demonstrate the strength of a low-risk strategy over their approaches. Output conversationally as if you are speaking without any special formatting."""
|
||||||
|
|
||||||
response = llm.invoke(prompt)
|
response = llm.invoke(prompt)
|
||||||
|
|
||||||
argument = f"Safe Analyst: {response.content}"
|
argument = f"Safe Analyst: {response.content}"
|
||||||
|
|
||||||
new_risk_debate_state = {
|
new_risk_debate_state = {
|
||||||
"history": history + "\n" + argument,
|
"history": history + "\n" + argument,
|
||||||
"risky_history": risk_debate_state.get("risky_history", ""),
|
"risky_history": risk_debate_state.get("risky_history", ""),
|
||||||
"safe_history": safe_history + "\n" + argument,
|
"safe_history": safe_history + "\n" + argument,
|
||||||
"neutral_history": risk_debate_state.get("neutral_history", ""),
|
"neutral_history": risk_debate_state.get("neutral_history", ""),
|
||||||
"latest_speaker": "Safe",
|
"latest_speaker": "Safe",
|
||||||
"current_risky_response": risk_debate_state.get(
|
"current_risky_response": risk_debate_state.get(
|
||||||
"current_risky_response", ""
|
"current_risky_response", ""
|
||||||
),
|
),
|
||||||
"current_safe_response": argument,
|
"current_safe_response": argument,
|
||||||
"current_neutral_response": risk_debate_state.get(
|
"current_neutral_response": risk_debate_state.get(
|
||||||
"current_neutral_response", ""
|
"current_neutral_response", ""
|
||||||
),
|
),
|
||||||
"count": risk_debate_state["count"] + 1,
|
"count": risk_debate_state["count"] + 1,
|
||||||
}
|
}
|
||||||
|
|
||||||
return {"risk_debate_state": new_risk_debate_state}
|
return {"risk_debate_state": new_risk_debate_state}
|
||||||
|
|
||||||
return safe_node
|
return safe_node
|
||||||
|
|
|
||||||
|
|
@ -1,57 +1,57 @@
|
||||||
import time
|
import time
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
|
||||||
def create_neutral_debator(llm):
|
def create_neutral_debator(llm):
|
||||||
def neutral_node(state) -> dict:
|
def neutral_node(state) -> dict:
|
||||||
risk_debate_state = state["risk_debate_state"]
|
risk_debate_state = state["risk_debate_state"]
|
||||||
history = risk_debate_state.get("history", "")
|
history = risk_debate_state.get("history", "")
|
||||||
neutral_history = risk_debate_state.get("neutral_history", "")
|
neutral_history = risk_debate_state.get("neutral_history", "")
|
||||||
|
|
||||||
current_risky_response = risk_debate_state.get("current_risky_response", "")
|
current_risky_response = risk_debate_state.get("current_risky_response", "")
|
||||||
current_safe_response = risk_debate_state.get("current_safe_response", "")
|
current_safe_response = risk_debate_state.get("current_safe_response", "")
|
||||||
|
|
||||||
market_research_report = state["market_report"]
|
market_research_report = state["market_report"]
|
||||||
sentiment_report = state["sentiment_report"]
|
sentiment_report = state["sentiment_report"]
|
||||||
news_report = state["news_report"]
|
news_report = state["news_report"]
|
||||||
fundamentals_report = state["fundamentals_report"]
|
fundamentals_report = state["fundamentals_report"]
|
||||||
|
|
||||||
trader_decision = state["trader_investment_plan"]
|
trader_decision = state["trader_investment_plan"]
|
||||||
|
|
||||||
prompt = f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)
|
prompt = f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)
|
||||||
|
|
||||||
As the Neutral Risk Analyst, your role is to provide a balanced perspective, weighing both the potential benefits and risks of the trader's decision or plan. You prioritize a well-rounded approach, evaluating the upsides and downsides while factoring in broader market trends, potential economic shifts, and diversification strategies.Here is the trader's decision:
|
As the Neutral Risk Analyst, your role is to provide a balanced perspective, weighing both the potential benefits and risks of the trader's decision or plan. You prioritize a well-rounded approach, evaluating the upsides and downsides while factoring in broader market trends, potential economic shifts, and diversification strategies.Here is the trader's decision:
|
||||||
|
|
||||||
{trader_decision}
|
{trader_decision}
|
||||||
|
|
||||||
Your task is to challenge both the Risky and Safe Analysts, pointing out where each perspective may be overly optimistic or overly cautious. Use insights from the following data sources to support a moderate, sustainable strategy to adjust the trader's decision:
|
Your task is to challenge both the Risky and Safe Analysts, pointing out where each perspective may be overly optimistic or overly cautious. Use insights from the following data sources to support a moderate, sustainable strategy to adjust the trader's decision:
|
||||||
|
|
||||||
Market Research Report: {market_research_report}
|
Market Research Report: {market_research_report}
|
||||||
Social Media Sentiment Report: {sentiment_report}
|
Social Media Sentiment Report: {sentiment_report}
|
||||||
Latest World Affairs Report: {news_report}
|
Latest World Affairs Report: {news_report}
|
||||||
Company Fundamentals Report: {fundamentals_report}
|
Company Fundamentals Report: {fundamentals_report}
|
||||||
Here is the current conversation history: {history} Here is the last response from the risky analyst: {current_risky_response} Here is the last response from the safe analyst: {current_safe_response}. If there are no responses from the other viewpoints, do not halluncinate and just present your point.
|
Here is the current conversation history: {history} Here is the last response from the risky analyst: {current_risky_response} Here is the last response from the safe analyst: {current_safe_response}. If there are no responses from the other viewpoints, do not halluncinate and just present your point.
|
||||||
|
|
||||||
Engage actively by analyzing both sides critically, addressing weaknesses in the risky and conservative arguments to advocate for a more balanced approach. Challenge each of their points to illustrate why a moderate risk strategy might offer the best of both worlds, providing growth potential while safeguarding against extreme volatility. Focus on debating rather than simply presenting data, aiming to show that a balanced view can lead to the most reliable outcomes. Output conversationally as if you are speaking without any special formatting."""
|
Engage actively by analyzing both sides critically, addressing weaknesses in the risky and conservative arguments to advocate for a more balanced approach. Challenge each of their points to illustrate why a moderate risk strategy might offer the best of both worlds, providing growth potential while safeguarding against extreme volatility. Focus on debating rather than simply presenting data, aiming to show that a balanced view can lead to the most reliable outcomes. Output conversationally as if you are speaking without any special formatting."""
|
||||||
|
|
||||||
response = llm.invoke(prompt)
|
response = llm.invoke(prompt)
|
||||||
|
|
||||||
argument = f"Neutral Analyst: {response.content}"
|
argument = f"Neutral Analyst: {response.content}"
|
||||||
|
|
||||||
new_risk_debate_state = {
|
new_risk_debate_state = {
|
||||||
"history": history + "\n" + argument,
|
"history": history + "\n" + argument,
|
||||||
"risky_history": risk_debate_state.get("risky_history", ""),
|
"risky_history": risk_debate_state.get("risky_history", ""),
|
||||||
"safe_history": risk_debate_state.get("safe_history", ""),
|
"safe_history": risk_debate_state.get("safe_history", ""),
|
||||||
"neutral_history": neutral_history + "\n" + argument,
|
"neutral_history": neutral_history + "\n" + argument,
|
||||||
"latest_speaker": "Neutral",
|
"latest_speaker": "Neutral",
|
||||||
"current_risky_response": risk_debate_state.get(
|
"current_risky_response": risk_debate_state.get(
|
||||||
"current_risky_response", ""
|
"current_risky_response", ""
|
||||||
),
|
),
|
||||||
"current_safe_response": risk_debate_state.get("current_safe_response", ""),
|
"current_safe_response": risk_debate_state.get("current_safe_response", ""),
|
||||||
"current_neutral_response": argument,
|
"current_neutral_response": argument,
|
||||||
"count": risk_debate_state["count"] + 1,
|
"count": risk_debate_state["count"] + 1,
|
||||||
}
|
}
|
||||||
|
|
||||||
return {"risk_debate_state": new_risk_debate_state}
|
return {"risk_debate_state": new_risk_debate_state}
|
||||||
|
|
||||||
return neutral_node
|
return neutral_node
|
||||||
|
|
|
||||||
|
|
@ -1,45 +1,45 @@
|
||||||
import functools
|
import functools
|
||||||
import time
|
import time
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
|
||||||
def create_trader(llm, memory):
|
def create_trader(llm, memory):
|
||||||
def trader_node(state, name):
|
def trader_node(state, name):
|
||||||
company_name = state["company_of_interest"]
|
company_name = state["company_of_interest"]
|
||||||
investment_plan = state["investment_plan"]
|
investment_plan = state["investment_plan"]
|
||||||
market_research_report = state["market_report"]
|
market_research_report = state["market_report"]
|
||||||
sentiment_report = state["sentiment_report"]
|
sentiment_report = state["sentiment_report"]
|
||||||
news_report = state["news_report"]
|
news_report = state["news_report"]
|
||||||
fundamentals_report = state["fundamentals_report"]
|
fundamentals_report = state["fundamentals_report"]
|
||||||
|
|
||||||
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
|
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
|
||||||
past_memories = memory.get_memories(curr_situation, n_matches=2)
|
past_memories = memory.get_memories(curr_situation, n_matches=2)
|
||||||
|
|
||||||
past_memory_str = ""
|
past_memory_str = ""
|
||||||
for i, rec in enumerate(past_memories, 1):
|
for i, rec in enumerate(past_memories, 1):
|
||||||
past_memory_str += rec["recommendation"] + "\n\n"
|
past_memory_str += rec["recommendation"] + "\n\n"
|
||||||
|
|
||||||
context = {
|
context = {
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": f"Based on a comprehensive analysis by a team of analysts, here is an investment plan tailored for {company_name}. This plan incorporates insights from current technical market trends, macroeconomic indicators, and social media sentiment. Use this plan as a foundation for evaluating your next trading decision.\n\nProposed Investment Plan: {investment_plan}\n\nLeverage these insights to make an informed and strategic decision.",
|
"content": f"Based on a comprehensive analysis by a team of analysts, here is an investment plan tailored for {company_name}. This plan incorporates insights from current technical market trends, macroeconomic indicators, and social media sentiment. Use this plan as a foundation for evaluating your next trading decision.\n\nProposed Investment Plan: {investment_plan}\n\nLeverage these insights to make an informed and strategic decision.",
|
||||||
}
|
}
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
{
|
{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
"content": f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)
|
"content": f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)
|
||||||
|
|
||||||
You are a trading agent analyzing market data to make investment decisions. Based on your analysis, provide a specific recommendation to buy, sell, or hold. End with a firm decision and always conclude your response with 'FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL**' to confirm your recommendation. Do not forget to utilize lessons from past decisions to learn from your mistakes. Here is some reflections from similar situatiosn you traded in and the lessons learned: {past_memory_str}""",
|
You are a trading agent analyzing market data to make investment decisions. Based on your analysis, provide a specific recommendation to buy, sell, or hold. End with a firm decision and always conclude your response with 'FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL**' to confirm your recommendation. Do not forget to utilize lessons from past decisions to learn from your mistakes. Here is some reflections from similar situatiosn you traded in and the lessons learned: {past_memory_str}""",
|
||||||
},
|
},
|
||||||
context,
|
context,
|
||||||
]
|
]
|
||||||
|
|
||||||
result = llm.invoke(messages)
|
result = llm.invoke(messages)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"messages": [result],
|
"messages": [result],
|
||||||
"trader_investment_plan": result.content,
|
"trader_investment_plan": result.content,
|
||||||
"sender": name,
|
"sender": name,
|
||||||
}
|
}
|
||||||
|
|
||||||
return functools.partial(trader_node, name="Trader")
|
return functools.partial(trader_node, name="Trader")
|
||||||
|
|
|
||||||
|
|
@ -1,20 +1,20 @@
|
||||||
from .embedding_providers import (
|
from .embedding_providers import (
|
||||||
EmbeddingProvider,
|
EmbeddingProvider,
|
||||||
OpenAIEmbeddingProvider,
|
OpenAIEmbeddingProvider,
|
||||||
GeminiEmbeddingProvider,
|
GeminiEmbeddingProvider,
|
||||||
OllamaEmbeddingProvider
|
OllamaEmbeddingProvider
|
||||||
)
|
)
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
class EmbeddingProviderFactory:
|
class EmbeddingProviderFactory:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_provider(config : dict[str, Any])->EmbeddingProvider:
|
def create_provider(config : dict[str, Any])->EmbeddingProvider:
|
||||||
backend_url = config["backend_url"]
|
backend_url = config["backend_url"]
|
||||||
|
|
||||||
if "generativelanguage.googleapis.com" in backend_url:
|
if "generativelanguage.googleapis.com" in backend_url:
|
||||||
return GeminiEmbeddingProvider(backend_url)
|
return GeminiEmbeddingProvider(backend_url)
|
||||||
elif "localhost:11434" in backend_url:
|
elif "localhost:11434" in backend_url:
|
||||||
return OllamaEmbeddingProvider(backend_url)
|
return OllamaEmbeddingProvider(backend_url)
|
||||||
else:
|
else:
|
||||||
return OpenAIEmbeddingProvider(backend_url)
|
return OpenAIEmbeddingProvider(backend_url)
|
||||||
|
|
||||||
|
|
@ -1,66 +1,66 @@
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
from google import genai
|
from google import genai
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingProvider(ABC):
|
class EmbeddingProvider(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_embedding(self, text: str)->list[float]:
|
def get_embedding(self, text: str)->list[float]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def model_name(self)->str:
|
def model_name(self)->str:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class OpenAIEmbeddingProvider(EmbeddingProvider):
|
class OpenAIEmbeddingProvider(EmbeddingProvider):
|
||||||
def __init__(self, backend_url: str, embedding_model: str = "text-embedding-3-small"):
|
def __init__(self, backend_url: str, embedding_model: str = "text-embedding-3-small"):
|
||||||
self.client = OpenAI(base_url=backend_url)
|
self.client = OpenAI(base_url=backend_url)
|
||||||
self._embedding_model = embedding_model
|
self._embedding_model = embedding_model
|
||||||
|
|
||||||
|
|
||||||
def get_embedding(self, text: str)->list[float]:
|
def get_embedding(self, text: str)->list[float]:
|
||||||
response = self.client.embeddings.create(
|
response = self.client.embeddings.create(
|
||||||
model=self._embedding_model,
|
model=self._embedding_model,
|
||||||
input=text
|
input=text
|
||||||
)
|
)
|
||||||
return response.data[0].embedding
|
return response.data[0].embedding
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def model_name(self)->str:
|
def model_name(self)->str:
|
||||||
return self._embedding_model
|
return self._embedding_model
|
||||||
|
|
||||||
|
|
||||||
class GeminiEmbeddingProvider(EmbeddingProvider):
|
class GeminiEmbeddingProvider(EmbeddingProvider):
|
||||||
def __init__(self, backend_url: str, embedding_model: str = "gemini-embedding-exp-03-07"):
|
def __init__(self, backend_url: str, embedding_model: str = "gemini-embedding-exp-03-07"):
|
||||||
self.client = genai.Client()
|
self.client = genai.Client()
|
||||||
self._embedding_model = embedding_model
|
self._embedding_model = embedding_model
|
||||||
|
|
||||||
def get_embedding(self, text: str)->list[float]:
|
def get_embedding(self, text: str)->list[float]:
|
||||||
response = self.client.models.embed_content(
|
response = self.client.models.embed_content(
|
||||||
model=self._embedding_model,
|
model=self._embedding_model,
|
||||||
contents=text
|
contents=text
|
||||||
)
|
)
|
||||||
return response.embeddings[0].values
|
return response.embeddings[0].values
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def model_name(self)->str:
|
def model_name(self)->str:
|
||||||
return self._embedding_model
|
return self._embedding_model
|
||||||
|
|
||||||
class OllamaEmbeddingProvider(EmbeddingProvider):
|
class OllamaEmbeddingProvider(EmbeddingProvider):
|
||||||
def __init__(self, backend_url: str, embedding_model: str = "nomic-embed-text"):
|
def __init__(self, backend_url: str, embedding_model: str = "nomic-embed-text"):
|
||||||
self.client = OpenAI(base_url=backend_url)
|
self.client = OpenAI(base_url=backend_url)
|
||||||
self._embedding_model = embedding_model
|
self._embedding_model = embedding_model
|
||||||
|
|
||||||
def get_embedding(self, text: str)->list[float]:
|
def get_embedding(self, text: str)->list[float]:
|
||||||
response = self.client.embeddings.create(
|
response = self.client.embeddings.create(
|
||||||
model=self._embedding_model,
|
model=self._embedding_model,
|
||||||
input=text
|
input=text
|
||||||
)
|
)
|
||||||
return response.data[0].embedding
|
return response.data[0].embedding
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def model_name(self)->str:
|
def model_name(self)->str:
|
||||||
return self._embedding_model
|
return self._embedding_model
|
||||||
|
|
||||||
|
|
@ -1,112 +1,112 @@
|
||||||
import chromadb
|
import chromadb
|
||||||
from chromadb.config import Settings
|
from chromadb.config import Settings
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
import os
|
import os
|
||||||
from .embedding_provider_factory import EmbeddingProviderFactory
|
from .embedding_provider_factory import EmbeddingProviderFactory
|
||||||
from google import genai
|
from google import genai
|
||||||
|
|
||||||
class FinancialSituationMemory:
|
class FinancialSituationMemory:
|
||||||
def __init__(self, name, config):
|
def __init__(self, name, config):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.backend_url = config["backend_url"]
|
self.backend_url = config["backend_url"]
|
||||||
|
|
||||||
self.embedding_provider = EmbeddingProviderFactory.create_provider(config)
|
self.embedding_provider = EmbeddingProviderFactory.create_provider(config)
|
||||||
|
|
||||||
self.chroma_client = chromadb.Client(Settings(allow_reset=True))
|
self.chroma_client = chromadb.Client(Settings(allow_reset=True))
|
||||||
self.situation_collection = self.chroma_client.create_collection(name=name)
|
self.situation_collection = self.chroma_client.create_collection(name=name)
|
||||||
|
|
||||||
def get_embedding(self, text):
|
def get_embedding(self, text):
|
||||||
"""Get embedding for a text using the appropriate API"""
|
"""Get embedding for a text using the appropriate API"""
|
||||||
|
|
||||||
return self.embedding_provider.get_embedding(text)
|
return self.embedding_provider.get_embedding(text)
|
||||||
|
|
||||||
def add_situations(self, situations_and_advice):
|
def add_situations(self, situations_and_advice):
|
||||||
"""Add financial situations and their corresponding advice. Parameter is a list of tuples (situation, rec)"""
|
"""Add financial situations and their corresponding advice. Parameter is a list of tuples (situation, rec)"""
|
||||||
|
|
||||||
situations = []
|
situations = []
|
||||||
advice = []
|
advice = []
|
||||||
ids = []
|
ids = []
|
||||||
embeddings = []
|
embeddings = []
|
||||||
|
|
||||||
offset = self.situation_collection.count()
|
offset = self.situation_collection.count()
|
||||||
|
|
||||||
for i, (situation, recommendation) in enumerate(situations_and_advice):
|
for i, (situation, recommendation) in enumerate(situations_and_advice):
|
||||||
situations.append(situation)
|
situations.append(situation)
|
||||||
advice.append(recommendation)
|
advice.append(recommendation)
|
||||||
ids.append(str(offset + i))
|
ids.append(str(offset + i))
|
||||||
embeddings.append(self.get_embedding(situation))
|
embeddings.append(self.get_embedding(situation))
|
||||||
|
|
||||||
self.situation_collection.add(
|
self.situation_collection.add(
|
||||||
documents=situations,
|
documents=situations,
|
||||||
metadatas=[{"recommendation": rec} for rec in advice],
|
metadatas=[{"recommendation": rec} for rec in advice],
|
||||||
embeddings=embeddings,
|
embeddings=embeddings,
|
||||||
ids=ids,
|
ids=ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_memories(self, current_situation, n_matches=1):
|
def get_memories(self, current_situation, n_matches=1):
|
||||||
"""Find matching recommendations using embeddings"""
|
"""Find matching recommendations using embeddings"""
|
||||||
query_embedding = self.get_embedding(current_situation)
|
query_embedding = self.get_embedding(current_situation)
|
||||||
|
|
||||||
results = self.situation_collection.query(
|
results = self.situation_collection.query(
|
||||||
query_embeddings=[query_embedding],
|
query_embeddings=[query_embedding],
|
||||||
n_results=n_matches,
|
n_results=n_matches,
|
||||||
include=["metadatas", "documents", "distances"],
|
include=["metadatas", "documents", "distances"],
|
||||||
)
|
)
|
||||||
|
|
||||||
matched_results = []
|
matched_results = []
|
||||||
for i in range(len(results["documents"][0])):
|
for i in range(len(results["documents"][0])):
|
||||||
matched_results.append(
|
matched_results.append(
|
||||||
{
|
{
|
||||||
"matched_situation": results["documents"][0][i],
|
"matched_situation": results["documents"][0][i],
|
||||||
"recommendation": results["metadatas"][0][i]["recommendation"],
|
"recommendation": results["metadatas"][0][i]["recommendation"],
|
||||||
"similarity_score": 1 - results["distances"][0][i],
|
"similarity_score": 1 - results["distances"][0][i],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
return matched_results
|
return matched_results
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Example usage
|
# Example usage
|
||||||
matcher = FinancialSituationMemory()
|
matcher = FinancialSituationMemory()
|
||||||
|
|
||||||
# Example data
|
# Example data
|
||||||
example_data = [
|
example_data = [
|
||||||
(
|
(
|
||||||
"High inflation rate with rising interest rates and declining consumer spending",
|
"High inflation rate with rising interest rates and declining consumer spending",
|
||||||
"Consider defensive sectors like consumer staples and utilities. Review fixed-income portfolio duration.",
|
"Consider defensive sectors like consumer staples and utilities. Review fixed-income portfolio duration.",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
"Tech sector showing high volatility with increasing institutional selling pressure",
|
"Tech sector showing high volatility with increasing institutional selling pressure",
|
||||||
"Reduce exposure to high-growth tech stocks. Look for value opportunities in established tech companies with strong cash flows.",
|
"Reduce exposure to high-growth tech stocks. Look for value opportunities in established tech companies with strong cash flows.",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
"Strong dollar affecting emerging markets with increasing forex volatility",
|
"Strong dollar affecting emerging markets with increasing forex volatility",
|
||||||
"Hedge currency exposure in international positions. Consider reducing allocation to emerging market debt.",
|
"Hedge currency exposure in international positions. Consider reducing allocation to emerging market debt.",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
"Market showing signs of sector rotation with rising yields",
|
"Market showing signs of sector rotation with rising yields",
|
||||||
"Rebalance portfolio to maintain target allocations. Consider increasing exposure to sectors benefiting from higher rates.",
|
"Rebalance portfolio to maintain target allocations. Consider increasing exposure to sectors benefiting from higher rates.",
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
# Add the example situations and recommendations
|
# Add the example situations and recommendations
|
||||||
matcher.add_situations(example_data)
|
matcher.add_situations(example_data)
|
||||||
|
|
||||||
# Example query
|
# Example query
|
||||||
current_situation = """
|
current_situation = """
|
||||||
Market showing increased volatility in tech sector, with institutional investors
|
Market showing increased volatility in tech sector, with institutional investors
|
||||||
reducing positions and rising interest rates affecting growth stock valuations
|
reducing positions and rising interest rates affecting growth stock valuations
|
||||||
"""
|
"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
recommendations = matcher.get_memories(current_situation, n_matches=2)
|
recommendations = matcher.get_memories(current_situation, n_matches=2)
|
||||||
|
|
||||||
for i, rec in enumerate(recommendations, 1):
|
for i, rec in enumerate(recommendations, 1):
|
||||||
print(f"\nMatch {i}:")
|
print(f"\nMatch {i}:")
|
||||||
print(f"Similarity Score: {rec['similarity_score']:.2f}")
|
print(f"Similarity Score: {rec['similarity_score']:.2f}")
|
||||||
print(f"Matched Situation: {rec['matched_situation']}")
|
print(f"Matched Situation: {rec['matched_situation']}")
|
||||||
print(f"Recommendation: {rec['recommendation']}")
|
print(f"Recommendation: {rec['recommendation']}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error during recommendation: {str(e)}")
|
print(f"Error during recommendation: {str(e)}")
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
|
@ -1,76 +1,76 @@
|
||||||
from google import genai
|
from google import genai
|
||||||
from google.genai.types import Tool, GenerateContentConfig, GoogleSearch
|
from google.genai.types import Tool, GenerateContentConfig, GoogleSearch
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class SearchProvider(ABC):
|
class SearchProvider(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def search(self, query: str, ticker: str, curr_date: str) -> str:
|
def search(self, query: str, ticker: str, curr_date: str) -> str:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class GoogleSearchProvider(SearchProvider):
|
class GoogleSearchProvider(SearchProvider):
|
||||||
def __init__(self, model: str):
|
def __init__(self, model: str):
|
||||||
self.client = genai.Client()
|
self.client = genai.Client()
|
||||||
self.model = model
|
self.model = model
|
||||||
|
|
||||||
def search(self, query: str) -> str:
|
def search(self, query: str) -> str:
|
||||||
google_search_tool = Tool(
|
google_search_tool = Tool(
|
||||||
google_search=GoogleSearch()
|
google_search=GoogleSearch()
|
||||||
)
|
)
|
||||||
|
|
||||||
response = self.client.models.generate_content(
|
response = self.client.models.generate_content(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
contents=query,
|
contents=query,
|
||||||
config=GenerateContentConfig(
|
config=GenerateContentConfig(
|
||||||
tools=[google_search_tool],
|
tools=[google_search_tool],
|
||||||
response_modalities=["TEXT"]
|
response_modalities=["TEXT"]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
result_text = ""
|
result_text = ""
|
||||||
for part in response.candidates[0].content.parts:
|
for part in response.candidates[0].content.parts:
|
||||||
if hasattr(part, 'text'):
|
if hasattr(part, 'text'):
|
||||||
result_text += part.text
|
result_text += part.text
|
||||||
|
|
||||||
return result_text
|
return result_text
|
||||||
|
|
||||||
|
|
||||||
class OpenAISearchProvider(SearchProvider):
|
class OpenAISearchProvider(SearchProvider):
|
||||||
def __init__(self, model: str, backend_url: str):
|
def __init__(self, model: str, backend_url: str):
|
||||||
self.client = OpenAI(base_url=backend_url)
|
self.client = OpenAI(base_url=backend_url)
|
||||||
self.model = model
|
self.model = model
|
||||||
|
|
||||||
def search(self, query: str) -> str:
|
def search(self, query: str) -> str:
|
||||||
response = self.client.responses.create(
|
response = self.client.responses.create(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
input=[
|
input=[
|
||||||
{
|
{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
"content": [
|
"content": [
|
||||||
{
|
{
|
||||||
"type": "input_text",
|
"type": "input_text",
|
||||||
"text": query
|
"text": query
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
text={"format": {"type": "text"}},
|
text={"format": {"type": "text"}},
|
||||||
reasoning={},
|
reasoning={},
|
||||||
tools=[
|
tools=[
|
||||||
{
|
{
|
||||||
"type": "web_search_preview",
|
"type": "web_search_preview",
|
||||||
"user_location": {"type": "approximate"},
|
"user_location": {"type": "approximate"},
|
||||||
"search_context_size": "low",
|
"search_context_size": "low",
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
temperature=1,
|
temperature=1,
|
||||||
max_output_tokens=4096,
|
max_output_tokens=4096,
|
||||||
top_p=1,
|
top_p=1,
|
||||||
store=True,
|
store=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
return response.output[1].content[0].text
|
return response.output[1].content[0].text
|
||||||
|
|
@ -1,47 +1,47 @@
|
||||||
from .search_provider import (
|
from .search_provider import (
|
||||||
SearchProvider,
|
SearchProvider,
|
||||||
GoogleSearchProvider,
|
GoogleSearchProvider,
|
||||||
OpenAISearchProvider
|
OpenAISearchProvider
|
||||||
)
|
)
|
||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
|
||||||
class SearchProviderFactory:
|
class SearchProviderFactory:
|
||||||
_cache = {} # 클래스 레벨 캐시
|
_cache = {} # 클래스 레벨 캐시
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_provider(config: dict[str, any]) -> SearchProvider:
|
def create_provider(config: dict[str, any]) -> SearchProvider:
|
||||||
"""
|
"""
|
||||||
Create a SearchProvider with caching to avoid creating new instances.
|
Create a SearchProvider with caching to avoid creating new instances.
|
||||||
Uses config hash as cache key for efficient reuse.
|
Uses config hash as cache key for efficient reuse.
|
||||||
"""
|
"""
|
||||||
# Create cache key from relevant config values
|
# Create cache key from relevant config values
|
||||||
cache_key_data = {
|
cache_key_data = {
|
||||||
"backend_url": config["backend_url"],
|
"backend_url": config["backend_url"],
|
||||||
"model": config["quick_think_llm"]
|
"model": config["quick_think_llm"]
|
||||||
}
|
}
|
||||||
cache_key = hashlib.md5(json.dumps(cache_key_data, sort_keys=True).encode()).hexdigest()
|
cache_key = hashlib.md5(json.dumps(cache_key_data, sort_keys=True).encode()).hexdigest()
|
||||||
|
|
||||||
# Return cached instance if exists
|
# Return cached instance if exists
|
||||||
if cache_key in SearchProviderFactory._cache:
|
if cache_key in SearchProviderFactory._cache:
|
||||||
return SearchProviderFactory._cache[cache_key]
|
return SearchProviderFactory._cache[cache_key]
|
||||||
|
|
||||||
# Create new instance
|
# Create new instance
|
||||||
backend_url = config["backend_url"]
|
backend_url = config["backend_url"]
|
||||||
model = config["quick_think_llm"]
|
model = config["quick_think_llm"]
|
||||||
|
|
||||||
if "generativelanguage.googleapis.com" in backend_url:
|
if "generativelanguage.googleapis.com" in backend_url:
|
||||||
provider = GoogleSearchProvider(model)
|
provider = GoogleSearchProvider(model)
|
||||||
else:
|
else:
|
||||||
provider = OpenAISearchProvider(model, backend_url)
|
provider = OpenAISearchProvider(model, backend_url)
|
||||||
|
|
||||||
# Cache and return
|
# Cache and return
|
||||||
SearchProviderFactory._cache[cache_key] = provider
|
SearchProviderFactory._cache[cache_key] = provider
|
||||||
return provider
|
return provider
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def clear_cache():
|
def clear_cache():
|
||||||
"""Clear the provider cache (useful for testing or config changes)."""
|
"""Clear the provider cache (useful for testing or config changes)."""
|
||||||
SearchProviderFactory._cache.clear()
|
SearchProviderFactory._cache.clear()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,31 +1,31 @@
|
||||||
# TradingAgents/graph/signal_processing.py
|
# TradingAgents/graph/signal_processing.py
|
||||||
|
|
||||||
from langchain_openai import ChatOpenAI
|
from langchain_openai import ChatOpenAI
|
||||||
|
|
||||||
|
|
||||||
class SignalProcessor:
|
class SignalProcessor:
|
||||||
"""Processes trading signals to extract actionable decisions."""
|
"""Processes trading signals to extract actionable decisions."""
|
||||||
|
|
||||||
def __init__(self, quick_thinking_llm: ChatOpenAI):
|
def __init__(self, quick_thinking_llm: ChatOpenAI):
|
||||||
"""Initialize with an LLM for processing."""
|
"""Initialize with an LLM for processing."""
|
||||||
self.quick_thinking_llm = quick_thinking_llm
|
self.quick_thinking_llm = quick_thinking_llm
|
||||||
|
|
||||||
def process_signal(self, full_signal: str) -> str:
|
def process_signal(self, full_signal: str) -> str:
|
||||||
"""
|
"""
|
||||||
Process a full trading signal to extract the core decision.
|
Process a full trading signal to extract the core decision.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
full_signal: Complete trading signal text
|
full_signal: Complete trading signal text
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Extracted decision (BUY, SELL, or HOLD)
|
Extracted decision (BUY, SELL, or HOLD)
|
||||||
"""
|
"""
|
||||||
messages = [
|
messages = [
|
||||||
(
|
(
|
||||||
"system",
|
"system",
|
||||||
"**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)\n\nYou are an efficient assistant designed to analyze paragraphs or financial reports provided by a group of analysts. Your task is to extract the investment decision: SELL, BUY, or HOLD. Provide only the extracted decision (SELL, BUY, or HOLD) as your output, without adding any additional text or information.",
|
"**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)\n\nYou are an efficient assistant designed to analyze paragraphs or financial reports provided by a group of analysts. Your task is to extract the investment decision: SELL, BUY, or HOLD. Provide only the extracted decision (SELL, BUY, or HOLD) as your output, without adding any additional text or information.",
|
||||||
),
|
),
|
||||||
("human", full_signal),
|
("human", full_signal),
|
||||||
]
|
]
|
||||||
|
|
||||||
return self.quick_thinking_llm.invoke(messages).content
|
return self.quick_thinking_llm.invoke(messages).content
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue