dev/web
This commit is contained in:
parent
fbd96e9c18
commit
ab1b0120c2
|
|
@ -1,10 +1,11 @@
|
|||
env/
|
||||
__pycache__/
|
||||
.DS_Store
|
||||
*.csv
|
||||
src/
|
||||
eval_results/
|
||||
eval_data/
|
||||
*.egg-info/
|
||||
results/
|
||||
.env
|
||||
env/
|
||||
__pycache__/
|
||||
.DS_Store
|
||||
*.csv
|
||||
src/
|
||||
eval_results/
|
||||
eval_data/
|
||||
*.egg-info/
|
||||
results/
|
||||
.env
|
||||
tradingagents/dataflows/data_cache/
|
||||
|
|
@ -1 +1 @@
|
|||
3.10
|
||||
3.10
|
||||
|
|
|
|||
|
|
@ -1,15 +1,15 @@
|
|||
{
|
||||
// Use IntelliSense to learn about possible attributes.
|
||||
// Hover to view descriptions of existing attributes.
|
||||
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Python Debugger: main.py",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/main.py",
|
||||
"console": "integratedTerminal"
|
||||
}
|
||||
]
|
||||
{
|
||||
// Use IntelliSense to learn about possible attributes.
|
||||
// Hover to view descriptions of existing attributes.
|
||||
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Python Debugger: main.py",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/main.py",
|
||||
"console": "integratedTerminal"
|
||||
}
|
||||
]
|
||||
}
|
||||
430
README.md
430
README.md
|
|
@ -1,215 +1,215 @@
|
|||
<p align="center">
|
||||
<img src="assets/TauricResearch.png" style="width: 60%; height: auto;">
|
||||
</p>
|
||||
|
||||
<div align="center" style="line-height: 1;">
|
||||
<a href="https://arxiv.org/abs/2412.20138" target="_blank"><img alt="arXiv" src="https://img.shields.io/badge/arXiv-2412.20138-B31B1B?logo=arxiv"/></a>
|
||||
<a href="https://discord.com/invite/hk9PGKShPK" target="_blank"><img alt="Discord" src="https://img.shields.io/badge/Discord-TradingResearch-7289da?logo=discord&logoColor=white&color=7289da"/></a>
|
||||
<a href="./assets/wechat.png" target="_blank"><img alt="WeChat" src="https://img.shields.io/badge/WeChat-TauricResearch-brightgreen?logo=wechat&logoColor=white"/></a>
|
||||
<a href="https://x.com/TauricResearch" target="_blank"><img alt="X Follow" src="https://img.shields.io/badge/X-TauricResearch-white?logo=x&logoColor=white"/></a>
|
||||
<br>
|
||||
<a href="https://github.com/TauricResearch/" target="_blank"><img alt="Community" src="https://img.shields.io/badge/Join_GitHub_Community-TauricResearch-14C290?logo=discourse"/></a>
|
||||
</div>
|
||||
|
||||
<div align="center">
|
||||
<!-- Keep these links. Translations will automatically update with the README. -->
|
||||
<a href="https://www.readme-i18n.com/TauricResearch/TradingAgents?lang=de">Deutsch</a> |
|
||||
<a href="https://www.readme-i18n.com/TauricResearch/TradingAgents?lang=es">Español</a> |
|
||||
<a href="https://www.readme-i18n.com/TauricResearch/TradingAgents?lang=fr">français</a> |
|
||||
<a href="https://www.readme-i18n.com/TauricResearch/TradingAgents?lang=ja">日本語</a> |
|
||||
<a href="https://www.readme-i18n.com/TauricResearch/TradingAgents?lang=ko">한국어</a> |
|
||||
<a href="https://www.readme-i18n.com/TauricResearch/TradingAgents?lang=pt">Português</a> |
|
||||
<a href="https://www.readme-i18n.com/TauricResearch/TradingAgents?lang=ru">Русский</a> |
|
||||
<a href="https://www.readme-i18n.com/TauricResearch/TradingAgents?lang=zh">中文</a>
|
||||
</div>
|
||||
|
||||
---
|
||||
|
||||
# TradingAgents: Multi-Agents LLM Financial Trading Framework
|
||||
|
||||
> 🎉 **TradingAgents** officially released! We have received numerous inquiries about the work, and we would like to express our thanks for the enthusiasm in our community.
|
||||
>
|
||||
> So we decided to fully open-source the framework. Looking forward to building impactful projects with you!
|
||||
|
||||
<div align="center">
|
||||
<a href="https://www.star-history.com/#TauricResearch/TradingAgents&Date">
|
||||
<picture>
|
||||
<source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=TauricResearch/TradingAgents&type=Date&theme=dark" />
|
||||
<source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=TauricResearch/TradingAgents&type=Date" />
|
||||
<img alt="TradingAgents Star History" src="https://api.star-history.com/svg?repos=TauricResearch/TradingAgents&type=Date" style="width: 80%; height: auto;" />
|
||||
</picture>
|
||||
</a>
|
||||
</div>
|
||||
|
||||
<div align="center">
|
||||
|
||||
🚀 [TradingAgents](#tradingagents-framework) | ⚡ [Installation & CLI](#installation-and-cli) | 🎬 [Demo](https://www.youtube.com/watch?v=90gr5lwjIho) | 📦 [Package Usage](#tradingagents-package) | 🤝 [Contributing](#contributing) | 📄 [Citation](#citation)
|
||||
|
||||
</div>
|
||||
|
||||
## TradingAgents Framework
|
||||
|
||||
TradingAgents is a multi-agent trading framework that mirrors the dynamics of real-world trading firms. By deploying specialized LLM-powered agents: from fundamental analysts, sentiment experts, and technical analysts, to trader, risk management team, the platform collaboratively evaluates market conditions and informs trading decisions. Moreover, these agents engage in dynamic discussions to pinpoint the optimal strategy.
|
||||
|
||||
<p align="center">
|
||||
<img src="assets/schema.png" style="width: 100%; height: auto;">
|
||||
</p>
|
||||
|
||||
> TradingAgents framework is designed for research purposes. Trading performance may vary based on many factors, including the chosen backbone language models, model temperature, trading periods, the quality of data, and other non-deterministic factors. [It is not intended as financial, investment, or trading advice.](https://tauric.ai/disclaimer/)
|
||||
|
||||
Our framework decomposes complex trading tasks into specialized roles. This ensures the system achieves a robust, scalable approach to market analysis and decision-making.
|
||||
|
||||
### Analyst Team
|
||||
- Fundamentals Analyst: Evaluates company financials and performance metrics, identifying intrinsic values and potential red flags.
|
||||
- Sentiment Analyst: Analyzes social media and public sentiment using sentiment scoring algorithms to gauge short-term market mood.
|
||||
- News Analyst: Monitors global news and macroeconomic indicators, interpreting the impact of events on market conditions.
|
||||
- Technical Analyst: Utilizes technical indicators (like MACD and RSI) to detect trading patterns and forecast price movements.
|
||||
|
||||
<p align="center">
|
||||
<img src="assets/analyst.png" width="100%" style="display: inline-block; margin: 0 2%;">
|
||||
</p>
|
||||
|
||||
### Researcher Team
|
||||
- Comprises both bullish and bearish researchers who critically assess the insights provided by the Analyst Team. Through structured debates, they balance potential gains against inherent risks.
|
||||
|
||||
<p align="center">
|
||||
<img src="assets/researcher.png" width="70%" style="display: inline-block; margin: 0 2%;">
|
||||
</p>
|
||||
|
||||
### Trader Agent
|
||||
- Composes reports from the analysts and researchers to make informed trading decisions. It determines the timing and magnitude of trades based on comprehensive market insights.
|
||||
|
||||
<p align="center">
|
||||
<img src="assets/trader.png" width="70%" style="display: inline-block; margin: 0 2%;">
|
||||
</p>
|
||||
|
||||
### Risk Management and Portfolio Manager
|
||||
- Continuously evaluates portfolio risk by assessing market volatility, liquidity, and other risk factors. The risk management team evaluates and adjusts trading strategies, providing assessment reports to the Portfolio Manager for final decision.
|
||||
- The Portfolio Manager approves/rejects the transaction proposal. If approved, the order will be sent to the simulated exchange and executed.
|
||||
|
||||
<p align="center">
|
||||
<img src="assets/risk.png" width="70%" style="display: inline-block; margin: 0 2%;">
|
||||
</p>
|
||||
|
||||
## Installation and CLI
|
||||
|
||||
### Installation
|
||||
|
||||
Clone TradingAgents:
|
||||
```bash
|
||||
git clone https://github.com/TauricResearch/TradingAgents.git
|
||||
cd TradingAgents
|
||||
```
|
||||
|
||||
Create a virtual environment in any of your favorite environment managers:
|
||||
```bash
|
||||
conda create -n tradingagents python=3.13
|
||||
conda activate tradingagents
|
||||
```
|
||||
|
||||
Install dependencies:
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### Required APIs
|
||||
|
||||
You will also need the FinnHub API for financial data. All of our code is implemented with the free tier.
|
||||
```bash
|
||||
export FINNHUB_API_KEY=$YOUR_FINNHUB_API_KEY
|
||||
```
|
||||
|
||||
You will need the OpenAI API or GEMINI API for all the agents.
|
||||
```bash
|
||||
export OPENAI_API_KEY=$YOUR_OPENAI_API_KEY
|
||||
export GEMINI_API_KEY=$YOUR_GEMINI_API_KEY
|
||||
export GOOGLE_API_KEY=$YOUR_GEMINI_API_KEY
|
||||
```
|
||||
|
||||
### CLI Usage
|
||||
|
||||
You can also try out the CLI directly by running:
|
||||
```bash
|
||||
python -m cli.main
|
||||
```
|
||||
You will see a screen where you can select your desired tickers, date, LLMs, research depth, etc.
|
||||
|
||||
<p align="center">
|
||||
<img src="assets/cli/cli_init.png" width="100%" style="display: inline-block; margin: 0 2%;">
|
||||
</p>
|
||||
|
||||
An interface will appear showing results as they load, letting you track the agent's progress as it runs.
|
||||
|
||||
<p align="center">
|
||||
<img src="assets/cli/cli_news.png" width="100%" style="display: inline-block; margin: 0 2%;">
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<img src="assets/cli/cli_transaction.png" width="100%" style="display: inline-block; margin: 0 2%;">
|
||||
</p>
|
||||
|
||||
## TradingAgents Package
|
||||
|
||||
### Implementation Details
|
||||
|
||||
We built TradingAgents with LangGraph to ensure flexibility and modularity. We utilize `o1-preview` and `gpt-4o` as our deep thinking and fast thinking LLMs for our experiments. However, for testing purposes, we recommend you use `o4-mini` and `gpt-4.1-mini` to save on costs as our framework makes **lots of** API calls.
|
||||
|
||||
### Python Usage
|
||||
|
||||
To use TradingAgents inside your code, you can import the `tradingagents` module and initialize a `TradingAgentsGraph()` object. The `.propagate()` function will return a decision. You can run `main.py`, here's also a quick example:
|
||||
|
||||
```python
|
||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||
from tradingagents.default_config import DEFAULT_CONFIG
|
||||
|
||||
ta = TradingAgentsGraph(debug=True, config=DEFAULT_CONFIG.copy())
|
||||
|
||||
# forward propagate
|
||||
_, decision = ta.propagate("NVDA", "2024-05-10")
|
||||
print(decision)
|
||||
```
|
||||
|
||||
You can also adjust the default configuration to set your own choice of LLMs, debate rounds, etc.
|
||||
|
||||
```python
|
||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||
from tradingagents.default_config import DEFAULT_CONFIG
|
||||
|
||||
# Create a custom config
|
||||
config = DEFAULT_CONFIG.copy()
|
||||
config["deep_think_llm"] = "gpt-4.1-nano" # Use a different model
|
||||
config["quick_think_llm"] = "gpt-4.1-nano" # Use a different model
|
||||
config["max_debate_rounds"] = 1 # Increase debate rounds
|
||||
config["online_tools"] = True # Use online tools or cached data
|
||||
|
||||
# Initialize with custom config
|
||||
ta = TradingAgentsGraph(debug=True, config=config)
|
||||
|
||||
# forward propagate
|
||||
_, decision = ta.propagate("NVDA", "2024-05-10")
|
||||
print(decision)
|
||||
```
|
||||
|
||||
> For `online_tools`, we recommend enabling them for experimentation, as they provide access to real-time data. The agents' offline tools rely on cached data from our **Tauric TradingDB**, a curated dataset we use for backtesting. We're currently in the process of refining this dataset, and we plan to release it soon alongside our upcoming projects. Stay tuned!
|
||||
|
||||
You can view the full list of configurations in `tradingagents/default_config.py`.
|
||||
|
||||
## Contributing
|
||||
|
||||
We welcome contributions from the community! Whether it's fixing a bug, improving documentation, or suggesting a new feature, your input helps make this project better. If you are interested in this line of research, please consider joining our open-source financial AI research community [Tauric Research](https://tauric.ai/).
|
||||
|
||||
## Citation
|
||||
|
||||
Please reference our work if you find *TradingAgents* provides you with some help :)
|
||||
|
||||
```
|
||||
@misc{xiao2025tradingagentsmultiagentsllmfinancial,
|
||||
title={TradingAgents: Multi-Agents LLM Financial Trading Framework},
|
||||
author={Yijia Xiao and Edward Sun and Di Luo and Wei Wang},
|
||||
year={2025},
|
||||
eprint={2412.20138},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={q-fin.TR},
|
||||
url={https://arxiv.org/abs/2412.20138},
|
||||
}
|
||||
```
|
||||
<p align="center">
|
||||
<img src="assets/TauricResearch.png" style="width: 60%; height: auto;">
|
||||
</p>
|
||||
|
||||
<div align="center" style="line-height: 1;">
|
||||
<a href="https://arxiv.org/abs/2412.20138" target="_blank"><img alt="arXiv" src="https://img.shields.io/badge/arXiv-2412.20138-B31B1B?logo=arxiv"/></a>
|
||||
<a href="https://discord.com/invite/hk9PGKShPK" target="_blank"><img alt="Discord" src="https://img.shields.io/badge/Discord-TradingResearch-7289da?logo=discord&logoColor=white&color=7289da"/></a>
|
||||
<a href="./assets/wechat.png" target="_blank"><img alt="WeChat" src="https://img.shields.io/badge/WeChat-TauricResearch-brightgreen?logo=wechat&logoColor=white"/></a>
|
||||
<a href="https://x.com/TauricResearch" target="_blank"><img alt="X Follow" src="https://img.shields.io/badge/X-TauricResearch-white?logo=x&logoColor=white"/></a>
|
||||
<br>
|
||||
<a href="https://github.com/TauricResearch/" target="_blank"><img alt="Community" src="https://img.shields.io/badge/Join_GitHub_Community-TauricResearch-14C290?logo=discourse"/></a>
|
||||
</div>
|
||||
|
||||
<div align="center">
|
||||
<!-- Keep these links. Translations will automatically update with the README. -->
|
||||
<a href="https://www.readme-i18n.com/TauricResearch/TradingAgents?lang=de">Deutsch</a> |
|
||||
<a href="https://www.readme-i18n.com/TauricResearch/TradingAgents?lang=es">Español</a> |
|
||||
<a href="https://www.readme-i18n.com/TauricResearch/TradingAgents?lang=fr">français</a> |
|
||||
<a href="https://www.readme-i18n.com/TauricResearch/TradingAgents?lang=ja">日本語</a> |
|
||||
<a href="https://www.readme-i18n.com/TauricResearch/TradingAgents?lang=ko">한국어</a> |
|
||||
<a href="https://www.readme-i18n.com/TauricResearch/TradingAgents?lang=pt">Português</a> |
|
||||
<a href="https://www.readme-i18n.com/TauricResearch/TradingAgents?lang=ru">Русский</a> |
|
||||
<a href="https://www.readme-i18n.com/TauricResearch/TradingAgents?lang=zh">中文</a>
|
||||
</div>
|
||||
|
||||
---
|
||||
|
||||
# TradingAgents: Multi-Agents LLM Financial Trading Framework
|
||||
|
||||
> 🎉 **TradingAgents** officially released! We have received numerous inquiries about the work, and we would like to express our thanks for the enthusiasm in our community.
|
||||
>
|
||||
> So we decided to fully open-source the framework. Looking forward to building impactful projects with you!
|
||||
|
||||
<div align="center">
|
||||
<a href="https://www.star-history.com/#TauricResearch/TradingAgents&Date">
|
||||
<picture>
|
||||
<source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=TauricResearch/TradingAgents&type=Date&theme=dark" />
|
||||
<source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=TauricResearch/TradingAgents&type=Date" />
|
||||
<img alt="TradingAgents Star History" src="https://api.star-history.com/svg?repos=TauricResearch/TradingAgents&type=Date" style="width: 80%; height: auto;" />
|
||||
</picture>
|
||||
</a>
|
||||
</div>
|
||||
|
||||
<div align="center">
|
||||
|
||||
🚀 [TradingAgents](#tradingagents-framework) | ⚡ [Installation & CLI](#installation-and-cli) | 🎬 [Demo](https://www.youtube.com/watch?v=90gr5lwjIho) | 📦 [Package Usage](#tradingagents-package) | 🤝 [Contributing](#contributing) | 📄 [Citation](#citation)
|
||||
|
||||
</div>
|
||||
|
||||
## TradingAgents Framework
|
||||
|
||||
TradingAgents is a multi-agent trading framework that mirrors the dynamics of real-world trading firms. By deploying specialized LLM-powered agents: from fundamental analysts, sentiment experts, and technical analysts, to trader, risk management team, the platform collaboratively evaluates market conditions and informs trading decisions. Moreover, these agents engage in dynamic discussions to pinpoint the optimal strategy.
|
||||
|
||||
<p align="center">
|
||||
<img src="assets/schema.png" style="width: 100%; height: auto;">
|
||||
</p>
|
||||
|
||||
> TradingAgents framework is designed for research purposes. Trading performance may vary based on many factors, including the chosen backbone language models, model temperature, trading periods, the quality of data, and other non-deterministic factors. [It is not intended as financial, investment, or trading advice.](https://tauric.ai/disclaimer/)
|
||||
|
||||
Our framework decomposes complex trading tasks into specialized roles. This ensures the system achieves a robust, scalable approach to market analysis and decision-making.
|
||||
|
||||
### Analyst Team
|
||||
- Fundamentals Analyst: Evaluates company financials and performance metrics, identifying intrinsic values and potential red flags.
|
||||
- Sentiment Analyst: Analyzes social media and public sentiment using sentiment scoring algorithms to gauge short-term market mood.
|
||||
- News Analyst: Monitors global news and macroeconomic indicators, interpreting the impact of events on market conditions.
|
||||
- Technical Analyst: Utilizes technical indicators (like MACD and RSI) to detect trading patterns and forecast price movements.
|
||||
|
||||
<p align="center">
|
||||
<img src="assets/analyst.png" width="100%" style="display: inline-block; margin: 0 2%;">
|
||||
</p>
|
||||
|
||||
### Researcher Team
|
||||
- Comprises both bullish and bearish researchers who critically assess the insights provided by the Analyst Team. Through structured debates, they balance potential gains against inherent risks.
|
||||
|
||||
<p align="center">
|
||||
<img src="assets/researcher.png" width="70%" style="display: inline-block; margin: 0 2%;">
|
||||
</p>
|
||||
|
||||
### Trader Agent
|
||||
- Composes reports from the analysts and researchers to make informed trading decisions. It determines the timing and magnitude of trades based on comprehensive market insights.
|
||||
|
||||
<p align="center">
|
||||
<img src="assets/trader.png" width="70%" style="display: inline-block; margin: 0 2%;">
|
||||
</p>
|
||||
|
||||
### Risk Management and Portfolio Manager
|
||||
- Continuously evaluates portfolio risk by assessing market volatility, liquidity, and other risk factors. The risk management team evaluates and adjusts trading strategies, providing assessment reports to the Portfolio Manager for final decision.
|
||||
- The Portfolio Manager approves/rejects the transaction proposal. If approved, the order will be sent to the simulated exchange and executed.
|
||||
|
||||
<p align="center">
|
||||
<img src="assets/risk.png" width="70%" style="display: inline-block; margin: 0 2%;">
|
||||
</p>
|
||||
|
||||
## Installation and CLI
|
||||
|
||||
### Installation
|
||||
|
||||
Clone TradingAgents:
|
||||
```bash
|
||||
git clone https://github.com/TauricResearch/TradingAgents.git
|
||||
cd TradingAgents
|
||||
```
|
||||
|
||||
Create a virtual environment in any of your favorite environment managers:
|
||||
```bash
|
||||
conda create -n tradingagents python=3.13
|
||||
conda activate tradingagents
|
||||
```
|
||||
|
||||
Install dependencies:
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### Required APIs
|
||||
|
||||
You will also need the FinnHub API for financial data. All of our code is implemented with the free tier.
|
||||
```bash
|
||||
export FINNHUB_API_KEY=$YOUR_FINNHUB_API_KEY
|
||||
```
|
||||
|
||||
You will need the OpenAI API or GEMINI API for all the agents.
|
||||
```bash
|
||||
export OPENAI_API_KEY=$YOUR_OPENAI_API_KEY
|
||||
export GEMINI_API_KEY=$YOUR_GEMINI_API_KEY
|
||||
export GOOGLE_API_KEY=$YOUR_GEMINI_API_KEY
|
||||
```
|
||||
|
||||
### CLI Usage
|
||||
|
||||
You can also try out the CLI directly by running:
|
||||
```bash
|
||||
python -m cli.main
|
||||
```
|
||||
You will see a screen where you can select your desired tickers, date, LLMs, research depth, etc.
|
||||
|
||||
<p align="center">
|
||||
<img src="assets/cli/cli_init.png" width="100%" style="display: inline-block; margin: 0 2%;">
|
||||
</p>
|
||||
|
||||
An interface will appear showing results as they load, letting you track the agent's progress as it runs.
|
||||
|
||||
<p align="center">
|
||||
<img src="assets/cli/cli_news.png" width="100%" style="display: inline-block; margin: 0 2%;">
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<img src="assets/cli/cli_transaction.png" width="100%" style="display: inline-block; margin: 0 2%;">
|
||||
</p>
|
||||
|
||||
## TradingAgents Package
|
||||
|
||||
### Implementation Details
|
||||
|
||||
We built TradingAgents with LangGraph to ensure flexibility and modularity. We utilize `o1-preview` and `gpt-4o` as our deep thinking and fast thinking LLMs for our experiments. However, for testing purposes, we recommend you use `o4-mini` and `gpt-4.1-mini` to save on costs as our framework makes **lots of** API calls.
|
||||
|
||||
### Python Usage
|
||||
|
||||
To use TradingAgents inside your code, you can import the `tradingagents` module and initialize a `TradingAgentsGraph()` object. The `.propagate()` function will return a decision. You can run `main.py`, here's also a quick example:
|
||||
|
||||
```python
|
||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||
from tradingagents.default_config import DEFAULT_CONFIG
|
||||
|
||||
ta = TradingAgentsGraph(debug=True, config=DEFAULT_CONFIG.copy())
|
||||
|
||||
# forward propagate
|
||||
_, decision = ta.propagate("NVDA", "2024-05-10")
|
||||
print(decision)
|
||||
```
|
||||
|
||||
You can also adjust the default configuration to set your own choice of LLMs, debate rounds, etc.
|
||||
|
||||
```python
|
||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||
from tradingagents.default_config import DEFAULT_CONFIG
|
||||
|
||||
# Create a custom config
|
||||
config = DEFAULT_CONFIG.copy()
|
||||
config["deep_think_llm"] = "gpt-4.1-nano" # Use a different model
|
||||
config["quick_think_llm"] = "gpt-4.1-nano" # Use a different model
|
||||
config["max_debate_rounds"] = 1 # Increase debate rounds
|
||||
config["online_tools"] = True # Use online tools or cached data
|
||||
|
||||
# Initialize with custom config
|
||||
ta = TradingAgentsGraph(debug=True, config=config)
|
||||
|
||||
# forward propagate
|
||||
_, decision = ta.propagate("NVDA", "2024-05-10")
|
||||
print(decision)
|
||||
```
|
||||
|
||||
> For `online_tools`, we recommend enabling them for experimentation, as they provide access to real-time data. The agents' offline tools rely on cached data from our **Tauric TradingDB**, a curated dataset we use for backtesting. We're currently in the process of refining this dataset, and we plan to release it soon alongside our upcoming projects. Stay tuned!
|
||||
|
||||
You can view the full list of configurations in `tradingagents/default_config.py`.
|
||||
|
||||
## Contributing
|
||||
|
||||
We welcome contributions from the community! Whether it's fixing a bug, improving documentation, or suggesting a new feature, your input helps make this project better. If you are interested in this line of research, please consider joining our open-source financial AI research community [Tauric Research](https://tauric.ai/).
|
||||
|
||||
## Citation
|
||||
|
||||
Please reference our work if you find *TradingAgents* provides you with some help :)
|
||||
|
||||
```
|
||||
@misc{xiao2025tradingagentsmultiagentsllmfinancial,
|
||||
title={TradingAgents: Multi-Agents LLM Financial Trading Framework},
|
||||
author={Yijia Xiao and Edward Sun and Di Luo and Wei Wang},
|
||||
year={2025},
|
||||
eprint={2412.20138},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={q-fin.TR},
|
||||
url={https://arxiv.org/abs/2412.20138},
|
||||
}
|
||||
```
|
||||
|
|
|
|||
134
app/api/deps.py
134
app/api/deps.py
|
|
@ -1,67 +1,67 @@
|
|||
from typing import Generator, Optional
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from jose import jwt, JWTError
|
||||
from pydantic import BaseModel
|
||||
from sqlmodel import Session
|
||||
|
||||
from app.core.config import settings
|
||||
from app.infrastructure.database import get_db
|
||||
from app.domain.models import User
|
||||
from app.infrastructure.repositories.user import UserRepository
|
||||
from app.core.services.trading_analysis import TradingAnalysisService
|
||||
|
||||
reusable_oauth2 = OAuth2PasswordBearer(tokenUrl=f"{settings.API_V1_STR}/login/access-token")
|
||||
|
||||
class TokenData(BaseModel):
|
||||
username: Optional[str] = None
|
||||
|
||||
def get_user_repository(db: Session = Depends(get_db)) -> UserRepository:
|
||||
return UserRepository(db)
|
||||
|
||||
def get_user_from_token(token: str, db: Session) -> Optional[User]:
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
|
||||
)
|
||||
token_data = TokenData(username=payload.get("sub"))
|
||||
except JWTError:
|
||||
return None
|
||||
|
||||
user_repo = UserRepository(db)
|
||||
user = user_repo.get_by_email(email=token_data.username)
|
||||
return user
|
||||
|
||||
def get_current_user(
|
||||
db: Session = Depends(get_db), token: str = Depends(reusable_oauth2)
|
||||
) -> User:
|
||||
user = get_user_from_token(token=token, db=db)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
return user
|
||||
|
||||
def get_current_active_user(
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> User:
|
||||
if not current_user.is_active:
|
||||
raise HTTPException(status_code=400, detail="Inactive user")
|
||||
return current_user
|
||||
|
||||
def get_current_active_superuser(
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
) -> User:
|
||||
if not current_user.is_superuser:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="The user doesn't have enough privileges"
|
||||
)
|
||||
return current_user
|
||||
|
||||
def get_analysis_service(
|
||||
db: Session = Depends(get_db),
|
||||
user: User = Depends(get_current_active_user)
|
||||
) -> TradingAnalysisService:
|
||||
return TradingAnalysisService(user=user, db=db)
|
||||
from typing import Generator, Optional
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from jose import jwt, JWTError
|
||||
from pydantic import BaseModel
|
||||
from sqlmodel import Session
|
||||
|
||||
from app.core.config import settings
|
||||
from app.infrastructure.database import get_db
|
||||
from app.domain.models import User
|
||||
from app.infrastructure.repositories.user import UserRepository
|
||||
from app.core.services.trading_analysis import TradingAnalysisService
|
||||
|
||||
reusable_oauth2 = OAuth2PasswordBearer(tokenUrl=f"{settings.API_V1_STR}/login/access-token")
|
||||
|
||||
class TokenData(BaseModel):
|
||||
username: Optional[str] = None
|
||||
|
||||
def get_user_repository(db: Session = Depends(get_db)) -> UserRepository:
|
||||
return UserRepository(db)
|
||||
|
||||
def get_user_from_token(token: str, db: Session) -> Optional[User]:
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
|
||||
)
|
||||
token_data = TokenData(username=payload.get("sub"))
|
||||
except JWTError:
|
||||
return None
|
||||
|
||||
user_repo = UserRepository(db)
|
||||
user = user_repo.get_by_email(email=token_data.username)
|
||||
return user
|
||||
|
||||
def get_current_user(
|
||||
db: Session = Depends(get_db), token: str = Depends(reusable_oauth2)
|
||||
) -> User:
|
||||
user = get_user_from_token(token=token, db=db)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
return user
|
||||
|
||||
def get_current_active_user(
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> User:
|
||||
if not current_user.is_active:
|
||||
raise HTTPException(status_code=400, detail="Inactive user")
|
||||
return current_user
|
||||
|
||||
def get_current_active_superuser(
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
) -> User:
|
||||
if not current_user.is_superuser:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="The user doesn't have enough privileges"
|
||||
)
|
||||
return current_user
|
||||
|
||||
def get_analysis_service(
|
||||
db: Session = Depends(get_db),
|
||||
user: User = Depends(get_current_active_user)
|
||||
) -> TradingAnalysisService:
|
||||
return TradingAnalysisService(user=user, db=db)
|
||||
|
|
|
|||
|
|
@ -1,94 +1,94 @@
|
|||
from typing import Any, List
|
||||
from fastapi import APIRouter, Depends, HTTPException, WebSocket, WebSocketDisconnect, BackgroundTasks
|
||||
from app.api import deps
|
||||
from app.core.schemas.analysis import AnalysisSession, AnalysisSessionCreate
|
||||
from app.domain.models import User as UserModel
|
||||
from app.core.services.trading_analysis import TradingAnalysisService
|
||||
from app.core.websocket_manager import WebSocketManager
|
||||
from sqlmodel import Session
|
||||
from cli.utils import SHALLOW_AGENT_OPTIONS, DEEP_AGENT_OPTIONS, BASE_URLS
|
||||
|
||||
router = APIRouter()
|
||||
manager = WebSocketManager()
|
||||
|
||||
@router.post("/start", response_model=AnalysisSession)
|
||||
def start_analysis(
|
||||
*,
|
||||
analysis_in: AnalysisSessionCreate,
|
||||
background_tasks: BackgroundTasks,
|
||||
service: TradingAnalysisService = Depends(deps.get_analysis_service),
|
||||
) -> Any:
|
||||
"""
|
||||
Start a new analysis session.
|
||||
"""
|
||||
session = service.create_session(analysis_in=analysis_in)
|
||||
background_tasks.add_task(service.run_analysis, session_id=session.id)
|
||||
return session
|
||||
|
||||
@router.get("/history", response_model=List[AnalysisSession])
|
||||
def get_analysis_history(
|
||||
service: TradingAnalysisService = Depends(deps.get_analysis_service),
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
) -> Any:
|
||||
"""
|
||||
Get analysis history for the current user.
|
||||
"""
|
||||
return service.get_user_sessions(skip=skip, limit=limit)
|
||||
|
||||
@router.get("/options")
|
||||
def get_analysis_options():
|
||||
"""
|
||||
Get available options for analysis.
|
||||
"""
|
||||
return {
|
||||
'analysts': [
|
||||
{'value': 'market', 'label': 'Market Analyst'},
|
||||
{'value': 'social', 'label': 'Social Analyst'},
|
||||
{'value': 'news', 'label': 'News Analyst'},
|
||||
{'value': 'fundamentals', 'label': 'Fundamentals Analyst'},
|
||||
],
|
||||
'research_depths': [
|
||||
{'value': 1, 'label': 'Shallow'},
|
||||
{'value': 3, 'label': 'Medium'},
|
||||
{'value': 5, 'label': 'Deep'},
|
||||
],
|
||||
'llm_providers': [{'name': p[0], 'url': p[1]} for p in BASE_URLS],
|
||||
'shallow_thinkers': SHALLOW_AGENT_OPTIONS,
|
||||
'deep_thinkers': DEEP_AGENT_OPTIONS,
|
||||
}
|
||||
|
||||
@router.get("/{session_id}", response_model=AnalysisSession)
|
||||
def get_analysis_session(
|
||||
session_id: int,
|
||||
service: TradingAnalysisService = Depends(deps.get_analysis_service),
|
||||
) -> Any:
|
||||
"""
|
||||
Get a specific analysis session by ID.
|
||||
"""
|
||||
session = service.get_session(session_id=session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="Analysis session not found")
|
||||
return session
|
||||
|
||||
@router.websocket("/ws")
|
||||
async def websocket_endpoint(
|
||||
websocket: WebSocket,
|
||||
token: str,
|
||||
db: Session = Depends(deps.get_db)
|
||||
):
|
||||
"""
|
||||
WebSocket endpoint for real-time analysis updates.
|
||||
"""
|
||||
user = deps.get_user_from_token(token=token, db=db)
|
||||
if not user or not user.is_active:
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
|
||||
await manager.connect(user.id, websocket)
|
||||
try:
|
||||
while True:
|
||||
# Keep the connection alive
|
||||
await websocket.receive_text()
|
||||
except WebSocketDisconnect:
|
||||
from typing import Any, List
|
||||
from fastapi import APIRouter, Depends, HTTPException, WebSocket, WebSocketDisconnect, BackgroundTasks
|
||||
from app.api import deps
|
||||
from app.core.schemas.analysis import AnalysisSession, AnalysisSessionCreate
|
||||
from app.domain.models import User as UserModel
|
||||
from app.core.services.trading_analysis import TradingAnalysisService
|
||||
from app.core.websocket_manager import WebSocketManager
|
||||
from sqlmodel import Session
|
||||
from cli.utils import SHALLOW_AGENT_OPTIONS, DEEP_AGENT_OPTIONS, BASE_URLS
|
||||
|
||||
router = APIRouter()
|
||||
manager = WebSocketManager()
|
||||
|
||||
@router.post("/start", response_model=AnalysisSession)
|
||||
def start_analysis(
|
||||
*,
|
||||
analysis_in: AnalysisSessionCreate,
|
||||
background_tasks: BackgroundTasks,
|
||||
service: TradingAnalysisService = Depends(deps.get_analysis_service),
|
||||
) -> Any:
|
||||
"""
|
||||
Start a new analysis session.
|
||||
"""
|
||||
session = service.create_session(analysis_in=analysis_in)
|
||||
background_tasks.add_task(service.run_analysis, session_id=session.id)
|
||||
return session
|
||||
|
||||
@router.get("/history", response_model=List[AnalysisSession])
|
||||
def get_analysis_history(
|
||||
service: TradingAnalysisService = Depends(deps.get_analysis_service),
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
) -> Any:
|
||||
"""
|
||||
Get analysis history for the current user.
|
||||
"""
|
||||
return service.get_user_sessions(skip=skip, limit=limit)
|
||||
|
||||
@router.get("/options")
|
||||
def get_analysis_options():
|
||||
"""
|
||||
Get available options for analysis.
|
||||
"""
|
||||
return {
|
||||
'analysts': [
|
||||
{'value': 'market', 'label': 'Market Analyst'},
|
||||
{'value': 'social', 'label': 'Social Analyst'},
|
||||
{'value': 'news', 'label': 'News Analyst'},
|
||||
{'value': 'fundamentals', 'label': 'Fundamentals Analyst'},
|
||||
],
|
||||
'research_depths': [
|
||||
{'value': 1, 'label': 'Shallow'},
|
||||
{'value': 3, 'label': 'Medium'},
|
||||
{'value': 5, 'label': 'Deep'},
|
||||
],
|
||||
'llm_providers': [{'name': p[0], 'url': p[1]} for p in BASE_URLS],
|
||||
'shallow_thinkers': SHALLOW_AGENT_OPTIONS,
|
||||
'deep_thinkers': DEEP_AGENT_OPTIONS,
|
||||
}
|
||||
|
||||
@router.get("/{session_id}", response_model=AnalysisSession)
|
||||
def get_analysis_session(
|
||||
session_id: int,
|
||||
service: TradingAnalysisService = Depends(deps.get_analysis_service),
|
||||
) -> Any:
|
||||
"""
|
||||
Get a specific analysis session by ID.
|
||||
"""
|
||||
session = service.get_session(session_id=session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="Analysis session not found")
|
||||
return session
|
||||
|
||||
@router.websocket("/ws")
|
||||
async def websocket_endpoint(
|
||||
websocket: WebSocket,
|
||||
token: str,
|
||||
db: Session = Depends(deps.get_db)
|
||||
):
|
||||
"""
|
||||
WebSocket endpoint for real-time analysis updates.
|
||||
"""
|
||||
user = deps.get_user_from_token(token=token, db=db)
|
||||
if not user or not user.is_active:
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
|
||||
await manager.connect(user.id, websocket)
|
||||
try:
|
||||
while True:
|
||||
# Keep the connection alive
|
||||
await websocket.receive_text()
|
||||
except WebSocketDisconnect:
|
||||
manager.disconnect(user.id, websocket)
|
||||
|
|
@ -1,35 +1,35 @@
|
|||
from datetime import timedelta
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from sqlmodel import Session
|
||||
|
||||
from app.api import deps
|
||||
from app.core.config import settings
|
||||
from app.core.schemas.token import Token
|
||||
from app.core import security
|
||||
from app.infrastructure.repositories.user import UserRepository
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("/login/access-token", response_model=Token)
|
||||
def login_access_token(
|
||||
db: Session = Depends(deps.get_db), form_data: OAuth2PasswordRequestForm = Depends()
|
||||
):
|
||||
"""
|
||||
OAuth2 compatible token login, get an access token for future requests
|
||||
"""
|
||||
user_repo = UserRepository(db)
|
||||
user = user_repo.get_by_email(email=form_data.username)
|
||||
|
||||
if not user or not security.verify_password(form_data.password, user.hashed_password):
|
||||
raise HTTPException(status_code=400, detail="Incorrect email or password")
|
||||
elif not user.is_active:
|
||||
raise HTTPException(status_code=400, detail="Inactive user")
|
||||
|
||||
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
return {
|
||||
"access_token": security.create_access_token(
|
||||
user.email, expires_delta=access_token_expires
|
||||
),
|
||||
"token_type": "bearer",
|
||||
}
|
||||
from datetime import timedelta
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from sqlmodel import Session
|
||||
|
||||
from app.api import deps
|
||||
from app.core.config import settings
|
||||
from app.core.schemas.token import Token
|
||||
from app.core import security
|
||||
from app.infrastructure.repositories.user import UserRepository
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("/login/access-token", response_model=Token)
|
||||
def login_access_token(
|
||||
db: Session = Depends(deps.get_db), form_data: OAuth2PasswordRequestForm = Depends()
|
||||
):
|
||||
"""
|
||||
OAuth2 compatible token login, get an access token for future requests
|
||||
"""
|
||||
user_repo = UserRepository(db)
|
||||
user = user_repo.get_by_email(email=form_data.username)
|
||||
|
||||
if not user or not security.verify_password(form_data.password, user.hashed_password):
|
||||
raise HTTPException(status_code=400, detail="Incorrect email or password")
|
||||
elif not user.is_active:
|
||||
raise HTTPException(status_code=400, detail="Inactive user")
|
||||
|
||||
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
return {
|
||||
"access_token": security.create_access_token(
|
||||
user.email, expires_delta=access_token_expires
|
||||
),
|
||||
"token_type": "bearer",
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,89 +1,89 @@
|
|||
from typing import Any, List
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from app.api import deps
|
||||
from app.core.schemas.user import User, UserCreate, UserUpdate
|
||||
from app.domain.models import User as UserModel
|
||||
from app.domain.repositories import IUserRepository
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/", response_model=List[User])
|
||||
def read_users(
|
||||
repo: IUserRepository = Depends(deps.get_user_repository),
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
current_user: UserModel = Depends(deps.get_current_active_superuser),
|
||||
) -> Any:
|
||||
"""
|
||||
Retrieve users.
|
||||
"""
|
||||
users = repo.get_multi(skip=skip, limit=limit)
|
||||
return users
|
||||
|
||||
@router.post("/", response_model=User)
|
||||
def create_user(
|
||||
*,
|
||||
repo: IUserRepository = Depends(deps.get_user_repository),
|
||||
user_in: UserCreate,
|
||||
current_user: UserModel = Depends(deps.get_current_active_superuser),
|
||||
) -> Any:
|
||||
"""
|
||||
Create new user.
|
||||
"""
|
||||
user = repo.get_by_email(email=user_in.email)
|
||||
if user:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="The user with this username already exists in the system.",
|
||||
)
|
||||
user = repo.create(obj_in=user_in)
|
||||
return user
|
||||
|
||||
@router.get("/me", response_model=User)
|
||||
def read_user_me(
|
||||
current_user: UserModel = Depends(deps.get_current_active_user),
|
||||
) -> Any:
|
||||
"""
|
||||
Get current user.
|
||||
"""
|
||||
return current_user
|
||||
|
||||
@router.get("/{user_id}", response_model=User)
|
||||
def read_user_by_id(
|
||||
user_id: int,
|
||||
repo: IUserRepository = Depends(deps.get_user_repository),
|
||||
current_user: UserModel = Depends(deps.get_current_active_user),
|
||||
) -> Any:
|
||||
"""
|
||||
Get a specific user by id.
|
||||
"""
|
||||
user = repo.get(id=user_id)
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
if user == current_user:
|
||||
return user
|
||||
if not repo.is_superuser(user=current_user):
|
||||
raise HTTPException(
|
||||
status_code=403, detail="The user doesn't have enough privileges"
|
||||
)
|
||||
return user
|
||||
|
||||
@router.put("/{user_id}", response_model=User)
|
||||
def update_user(
|
||||
*,
|
||||
repo: IUserRepository = Depends(deps.get_user_repository),
|
||||
user_id: int,
|
||||
user_in: UserUpdate,
|
||||
current_user: UserModel = Depends(deps.get_current_active_superuser),
|
||||
) -> Any:
|
||||
"""
|
||||
Update a user.
|
||||
"""
|
||||
user = repo.get(id=user_id)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="The user with this username does not exist in the system",
|
||||
)
|
||||
user = repo.update(db_obj=user, obj_in=user_in)
|
||||
return user
|
||||
from typing import Any, List
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from app.api import deps
|
||||
from app.core.schemas.user import User, UserCreate, UserUpdate
|
||||
from app.domain.models import User as UserModel
|
||||
from app.domain.repositories import IUserRepository
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/", response_model=List[User])
|
||||
def read_users(
|
||||
repo: IUserRepository = Depends(deps.get_user_repository),
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
current_user: UserModel = Depends(deps.get_current_active_superuser),
|
||||
) -> Any:
|
||||
"""
|
||||
Retrieve users.
|
||||
"""
|
||||
users = repo.get_multi(skip=skip, limit=limit)
|
||||
return users
|
||||
|
||||
@router.post("/", response_model=User)
|
||||
def create_user(
|
||||
*,
|
||||
repo: IUserRepository = Depends(deps.get_user_repository),
|
||||
user_in: UserCreate,
|
||||
current_user: UserModel = Depends(deps.get_current_active_superuser),
|
||||
) -> Any:
|
||||
"""
|
||||
Create new user.
|
||||
"""
|
||||
user = repo.get_by_email(email=user_in.email)
|
||||
if user:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="The user with this username already exists in the system.",
|
||||
)
|
||||
user = repo.create(obj_in=user_in)
|
||||
return user
|
||||
|
||||
@router.get("/me", response_model=User)
|
||||
def read_user_me(
|
||||
current_user: UserModel = Depends(deps.get_current_active_user),
|
||||
) -> Any:
|
||||
"""
|
||||
Get current user.
|
||||
"""
|
||||
return current_user
|
||||
|
||||
@router.get("/{user_id}", response_model=User)
|
||||
def read_user_by_id(
|
||||
user_id: int,
|
||||
repo: IUserRepository = Depends(deps.get_user_repository),
|
||||
current_user: UserModel = Depends(deps.get_current_active_user),
|
||||
) -> Any:
|
||||
"""
|
||||
Get a specific user by id.
|
||||
"""
|
||||
user = repo.get(id=user_id)
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
if user == current_user:
|
||||
return user
|
||||
if not repo.is_superuser(user=current_user):
|
||||
raise HTTPException(
|
||||
status_code=403, detail="The user doesn't have enough privileges"
|
||||
)
|
||||
return user
|
||||
|
||||
@router.put("/{user_id}", response_model=User)
|
||||
def update_user(
|
||||
*,
|
||||
repo: IUserRepository = Depends(deps.get_user_repository),
|
||||
user_id: int,
|
||||
user_in: UserUpdate,
|
||||
current_user: UserModel = Depends(deps.get_current_active_superuser),
|
||||
) -> Any:
|
||||
"""
|
||||
Update a user.
|
||||
"""
|
||||
user = repo.get(id=user_id)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="The user with this username does not exist in the system",
|
||||
)
|
||||
user = repo.update(db_obj=user, obj_in=user_in)
|
||||
return user
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from fastapi import APIRouter
|
||||
from app.api.endpoints import login, users, analysis
|
||||
|
||||
api_router = APIRouter()
|
||||
api_router.include_router(login.router, tags=["login"])
|
||||
api_router.include_router(users.router, prefix="/users", tags=["users"])
|
||||
api_router.include_router(analysis.router, prefix="/analysis", tags=["analysis"])
|
||||
from fastapi import APIRouter
|
||||
from app.api.endpoints import login, users, analysis
|
||||
|
||||
api_router = APIRouter()
|
||||
api_router.include_router(login.router, tags=["login"])
|
||||
api_router.include_router(users.router, prefix="/users", tags=["users"])
|
||||
api_router.include_router(analysis.router, prefix="/analysis", tags=["analysis"])
|
||||
|
|
|
|||
|
|
@ -1,26 +1,26 @@
|
|||
import os
|
||||
from pydantic import BaseSettings
|
||||
from typing import List, Optional
|
||||
|
||||
class Settings(BaseSettings):
|
||||
PROJECT_NAME: str = "TradingAgents Backend"
|
||||
API_V1_STR: str = "/api/v1"
|
||||
|
||||
# Security
|
||||
SECRET_KEY: str = os.getenv("SECRET_KEY", "a_very_secret_key")
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = 60 * 24 * 8 # 8 days
|
||||
ALGORITHM: str = "HS256"
|
||||
|
||||
# Database
|
||||
DATABASE_URL: str = os.getenv("DATABASE_URL", "sqlite:///./tradingagents.db")
|
||||
|
||||
# OpenAI
|
||||
OPENAI_API_KEY: str = os.getenv("OPENAI_API_KEY", "")
|
||||
|
||||
# CORS
|
||||
CORS_ALLOWED_ORIGINS: List[str] = os.getenv("CORS_ALLOWED_ORIGINS", "http://localhost:3000,http://127.0.0.1:3000").split(',')
|
||||
|
||||
class Config:
|
||||
case_sensitive = True
|
||||
|
||||
import os
|
||||
from pydantic import BaseSettings
|
||||
from typing import List, Optional
|
||||
|
||||
class Settings(BaseSettings):
|
||||
PROJECT_NAME: str = "TradingAgents Backend"
|
||||
API_V1_STR: str = "/api/v1"
|
||||
|
||||
# Security
|
||||
SECRET_KEY: str = os.getenv("SECRET_KEY", "a_very_secret_key")
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = 60 * 24 * 8 # 8 days
|
||||
ALGORITHM: str = "HS256"
|
||||
|
||||
# Database
|
||||
DATABASE_URL: str = os.getenv("DATABASE_URL", "sqlite:///./tradingagents.db")
|
||||
|
||||
# OpenAI
|
||||
OPENAI_API_KEY: str = os.getenv("OPENAI_API_KEY", "")
|
||||
|
||||
# CORS
|
||||
CORS_ALLOWED_ORIGINS: List[str] = os.getenv("CORS_ALLOWED_ORIGINS", "http://localhost:3000,http://127.0.0.1:3000").split(',')
|
||||
|
||||
class Config:
|
||||
case_sensitive = True
|
||||
|
||||
settings = Settings()
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
from .user import User, UserCreate, UserUpdate
|
||||
from .token import Token, TokenPayload
|
||||
from .profile import Profile, ProfileCreate, ProfileUpdate
|
||||
from .analysis import AnalysisSession, AnalysisSessionCreate, AnalysisSessionUpdate
|
||||
from .user import User, UserCreate, UserUpdate
|
||||
from .token import Token, TokenPayload
|
||||
from .profile import Profile, ProfileCreate, ProfileUpdate
|
||||
from .analysis import AnalysisSession, AnalysisSessionCreate, AnalysisSessionUpdate
|
||||
|
|
|
|||
|
|
@ -1,38 +1,38 @@
|
|||
from pydantic import BaseModel
|
||||
from typing import List, Optional
|
||||
from datetime import date, datetime
|
||||
from app.domain.models import AnalysisStatus
|
||||
|
||||
class AnalysisSessionBase(BaseModel):
|
||||
ticker: str
|
||||
analysts_selected: List[str]
|
||||
research_depth: int
|
||||
llm_provider: str
|
||||
backend_url: str
|
||||
shallow_thinker: str
|
||||
deep_thinker: str
|
||||
|
||||
class AnalysisSessionCreate(AnalysisSessionBase):
|
||||
pass
|
||||
|
||||
class AnalysisSessionUpdate(BaseModel):
|
||||
status: Optional[AnalysisStatus] = None
|
||||
final_report: Optional[str] = None
|
||||
error_message: Optional[str] = None
|
||||
|
||||
class AnalysisSessionInDBBase(AnalysisSessionBase):
|
||||
id: int
|
||||
user_id: int
|
||||
analysis_date: date
|
||||
status: AnalysisStatus
|
||||
final_report: Optional[str] = None
|
||||
error_message: Optional[str] = None
|
||||
created_at: datetime
|
||||
started_at: Optional[datetime] = None
|
||||
completed_at: Optional[datetime] = None
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
|
||||
class AnalysisSession(AnalysisSessionInDBBase):
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional
|
||||
from datetime import date, datetime
|
||||
from app.domain.models import AnalysisStatus
|
||||
|
||||
class AnalysisSessionBase(BaseModel):
|
||||
ticker: str
|
||||
analysts_selected: List[str]
|
||||
research_depth: int
|
||||
llm_provider: str
|
||||
backend_url: str
|
||||
shallow_thinker: str
|
||||
deep_thinker: str
|
||||
|
||||
class AnalysisSessionCreate(AnalysisSessionBase):
|
||||
pass
|
||||
|
||||
class AnalysisSessionUpdate(BaseModel):
|
||||
status: Optional[AnalysisStatus] = None
|
||||
final_report: Optional[str] = None
|
||||
error_message: Optional[str] = None
|
||||
|
||||
class AnalysisSessionInDBBase(AnalysisSessionBase):
|
||||
id: int
|
||||
user_id: int
|
||||
analysis_date: date
|
||||
status: AnalysisStatus
|
||||
final_report: Optional[str] = None
|
||||
error_message: Optional[str] = None
|
||||
created_at: datetime
|
||||
started_at: Optional[datetime] = None
|
||||
completed_at: Optional[datetime] = None
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
|
||||
class AnalysisSession(AnalysisSessionInDBBase):
|
||||
pass
|
||||
|
|
@ -1,20 +1,20 @@
|
|||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
|
||||
class ProfileBase(BaseModel):
|
||||
default_ticker: str = "SPY"
|
||||
preferred_research_depth: int = 3
|
||||
preferred_shallow_thinker: str = "gpt-4o-mini"
|
||||
preferred_deep_thinker: str = "gpt-4o"
|
||||
|
||||
class ProfileCreate(ProfileBase):
|
||||
pass
|
||||
|
||||
class ProfileUpdate(ProfileBase):
|
||||
openai_api_key: Optional[str] = None
|
||||
|
||||
class Profile(ProfileBase):
|
||||
has_openai_api_key: bool
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
|
||||
class ProfileBase(BaseModel):
|
||||
default_ticker: str = "SPY"
|
||||
preferred_research_depth: int = 3
|
||||
preferred_shallow_thinker: str = "gpt-4o-mini"
|
||||
preferred_deep_thinker: str = "gpt-4o"
|
||||
|
||||
class ProfileCreate(ProfileBase):
|
||||
pass
|
||||
|
||||
class ProfileUpdate(ProfileBase):
|
||||
openai_api_key: Optional[str] = None
|
||||
|
||||
class Profile(ProfileBase):
|
||||
has_openai_api_key: bool
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
|
||||
class Token(BaseModel):
|
||||
access_token: str
|
||||
token_type: str
|
||||
|
||||
class TokenPayload(BaseModel):
|
||||
sub: Optional[int] = None
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
|
||||
class Token(BaseModel):
|
||||
access_token: str
|
||||
token_type: str
|
||||
|
||||
class TokenPayload(BaseModel):
|
||||
sub: Optional[int] = None
|
||||
|
|
|
|||
|
|
@ -1,28 +1,28 @@
|
|||
from pydantic import BaseModel, EmailStr
|
||||
from typing import Optional
|
||||
|
||||
class UserBase(BaseModel):
|
||||
email: EmailStr
|
||||
username: str
|
||||
first_name: Optional[str] = None
|
||||
last_name: Optional[str] = None
|
||||
|
||||
class UserCreate(UserBase):
|
||||
password: str
|
||||
|
||||
class UserUpdate(UserBase):
|
||||
pass
|
||||
|
||||
class UserInDBBase(UserBase):
|
||||
id: int
|
||||
is_active: bool
|
||||
is_superuser: bool
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
|
||||
class User(UserInDBBase):
|
||||
pass
|
||||
|
||||
class UserInDB(UserInDBBase):
|
||||
hashed_password: str
|
||||
from pydantic import BaseModel, EmailStr
|
||||
from typing import Optional
|
||||
|
||||
class UserBase(BaseModel):
|
||||
email: EmailStr
|
||||
username: str
|
||||
first_name: Optional[str] = None
|
||||
last_name: Optional[str] = None
|
||||
|
||||
class UserCreate(UserBase):
|
||||
password: str
|
||||
|
||||
class UserUpdate(UserBase):
|
||||
pass
|
||||
|
||||
class UserInDBBase(UserBase):
|
||||
id: int
|
||||
is_active: bool
|
||||
is_superuser: bool
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
|
||||
class User(UserInDBBase):
|
||||
pass
|
||||
|
||||
class UserInDB(UserInDBBase):
|
||||
hashed_password: str
|
||||
|
|
|
|||
|
|
@ -1,23 +1,23 @@
|
|||
from datetime import datetime, timedelta
|
||||
from typing import Any, Union, Optional
|
||||
from jose import jwt
|
||||
from passlib.context import CryptContext
|
||||
from app.core.config import settings
|
||||
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
def create_access_token(subject: Union[str, Any], expires_delta: Optional[timedelta] = None) -> str:
|
||||
if expires_delta:
|
||||
expire = datetime.utcnow() + expires_delta
|
||||
else:
|
||||
expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
|
||||
to_encode = {"exp": expire, "sub": str(subject)}
|
||||
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
||||
return encoded_jwt
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
|
||||
def get_password_hash(password: str) -> str:
|
||||
return pwd_context.hash(password)
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Union, Optional
|
||||
from jose import jwt
|
||||
from passlib.context import CryptContext
|
||||
from app.core.config import settings
|
||||
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
def create_access_token(subject: Union[str, Any], expires_delta: Optional[timedelta] = None) -> str:
|
||||
if expires_delta:
|
||||
expire = datetime.utcnow() + expires_delta
|
||||
else:
|
||||
expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
|
||||
to_encode = {"exp": expire, "sub": str(subject)}
|
||||
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
||||
return encoded_jwt
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
|
||||
def get_password_hash(password: str) -> str:
|
||||
return pwd_context.hash(password)
|
||||
|
|
|
|||
|
|
@ -1,128 +1,128 @@
|
|||
import asyncio
|
||||
import datetime
|
||||
import json
|
||||
from typing import Dict, List, Optional
|
||||
from sqlmodel import Session, select
|
||||
from app.domain.models import User, AnalysisSession, AnalysisStatus
|
||||
from app.core.schemas.analysis import AnalysisSessionCreate
|
||||
from app.core.config import settings
|
||||
from cli.models import AnalystType
|
||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||
from tradingagents.default_config import DEFAULT_CONFIG
|
||||
from app.api.deps import get_db
|
||||
from app.core.websocket_manager import WebSocketManager
|
||||
|
||||
class TradingAnalysisService:
|
||||
def __init__(self, user: User, db: Session):
|
||||
self.user = user
|
||||
self.db = db
|
||||
self.websocket_manager = WebSocketManager()
|
||||
|
||||
async def run_analysis(self, session_id: int):
|
||||
"""분석 실행"""
|
||||
session = self.get_session(session_id=session_id)
|
||||
if not session:
|
||||
return
|
||||
|
||||
try:
|
||||
session.status = AnalysisStatus.RUNNING
|
||||
session.started_at = datetime.datetime.utcnow()
|
||||
self.db.add(session)
|
||||
self.db.commit()
|
||||
self.db.refresh(session)
|
||||
|
||||
await self.websocket_manager.send_to_user(
|
||||
self.user.id,
|
||||
{
|
||||
'type': 'analysis_started',
|
||||
'session_id': session.id,
|
||||
'message': '분석을 시작합니다...'
|
||||
}
|
||||
)
|
||||
|
||||
# Prepare config for TradingAgentsGraph
|
||||
config = DEFAULT_CONFIG.copy()
|
||||
config.update({
|
||||
'openai_api_key': settings.OPENAI_API_KEY,
|
||||
'llm_provider': session.llm_provider,
|
||||
'backend_url': session.backend_url,
|
||||
'shallow_thinking_model': session.shallow_thinker,
|
||||
'deep_thinking_model': session.deep_thinker,
|
||||
})
|
||||
|
||||
# Progress callback for websocket
|
||||
async def progress_callback(message_type: str, content: str, agent: str = None, step: int = 0, total: int = 0):
|
||||
progress_percent = int((step / total) * 99) if total > 0 else 0
|
||||
await self.websocket_manager.send_to_user(self.user.id, {
|
||||
'type': 'analysis_progress',
|
||||
'session_id': session.id,
|
||||
'message_type': message_type,
|
||||
'content': content,
|
||||
'agent': agent,
|
||||
'progress': progress_percent,
|
||||
})
|
||||
|
||||
trading_graph = TradingAgentsGraph(
|
||||
config=config,
|
||||
selected_analysts=session.analysts_selected,
|
||||
)
|
||||
|
||||
input_data = {
|
||||
'company_of_interest': session.ticker,
|
||||
'trade_date': session.analysis_date.strftime('%Y-%m-%d'),
|
||||
}
|
||||
|
||||
final_state, result = await asyncio.to_thread(
|
||||
trading_graph.propagate,
|
||||
input_data['company_of_interest'],
|
||||
input_data['trade_date']
|
||||
)
|
||||
|
||||
session.status = AnalysisStatus.COMPLETED
|
||||
session.completed_at = datetime.datetime.utcnow()
|
||||
session.final_report = json.dumps(final_state) # Store full state as JSON
|
||||
self.db.add(session)
|
||||
self.db.commit()
|
||||
|
||||
await self.websocket_manager.send_to_user(
|
||||
self.user.id,
|
||||
{
|
||||
'type': 'analysis_completed',
|
||||
'session_id': session.id,
|
||||
'message': '분석이 완료되었습니다.',
|
||||
'result': result
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
session.status = AnalysisStatus.FAILED
|
||||
session.error_message = str(e)
|
||||
self.db.add(session)
|
||||
self.db.commit()
|
||||
await self.websocket_manager.send_to_user(
|
||||
self.user.id,
|
||||
{
|
||||
'type': 'analysis_failed',
|
||||
'session_id': session.id,
|
||||
'message': f'분석 중 오류가 발생했습니다: {str(e)}'
|
||||
}
|
||||
)
|
||||
|
||||
def create_session(self, *, analysis_in: AnalysisSessionCreate) -> AnalysisSession:
|
||||
session = AnalysisSession(
|
||||
**analysis_in.dict(),
|
||||
user_id=self.user.id,
|
||||
analysis_date=datetime.date.today()
|
||||
)
|
||||
self.db.add(session)
|
||||
self.db.commit()
|
||||
self.db.refresh(session)
|
||||
return session
|
||||
|
||||
def get_session(self, *, session_id: int) -> Optional[AnalysisSession]:
|
||||
statement = select(AnalysisSession).where(AnalysisSession.id == session_id, AnalysisSession.user_id == self.user.id)
|
||||
return self.db.exec(statement).first()
|
||||
|
||||
def get_user_sessions(self, *, skip: int = 0, limit: int = 100) -> List[AnalysisSession]:
|
||||
statement = select(AnalysisSession).where(AnalysisSession.user_id == self.user.id).order_by(AnalysisSession.created_at.desc()).offset(skip).limit(limit)
|
||||
return self.db.exec(statement).all()
|
||||
import asyncio
|
||||
import datetime
|
||||
import json
|
||||
from typing import Dict, List, Optional
|
||||
from sqlmodel import Session, select
|
||||
from app.domain.models import User, AnalysisSession, AnalysisStatus
|
||||
from app.core.schemas.analysis import AnalysisSessionCreate
|
||||
from app.core.config import settings
|
||||
from cli.models import AnalystType
|
||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||
from tradingagents.default_config import DEFAULT_CONFIG
|
||||
from app.api.deps import get_db
|
||||
from app.core.websocket_manager import WebSocketManager
|
||||
|
||||
class TradingAnalysisService:
|
||||
def __init__(self, user: User, db: Session):
|
||||
self.user = user
|
||||
self.db = db
|
||||
self.websocket_manager = WebSocketManager()
|
||||
|
||||
async def run_analysis(self, session_id: int):
|
||||
"""분석 실행"""
|
||||
session = self.get_session(session_id=session_id)
|
||||
if not session:
|
||||
return
|
||||
|
||||
try:
|
||||
session.status = AnalysisStatus.RUNNING
|
||||
session.started_at = datetime.datetime.utcnow()
|
||||
self.db.add(session)
|
||||
self.db.commit()
|
||||
self.db.refresh(session)
|
||||
|
||||
await self.websocket_manager.send_to_user(
|
||||
self.user.id,
|
||||
{
|
||||
'type': 'analysis_started',
|
||||
'session_id': session.id,
|
||||
'message': '분석을 시작합니다...'
|
||||
}
|
||||
)
|
||||
|
||||
# Prepare config for TradingAgentsGraph
|
||||
config = DEFAULT_CONFIG.copy()
|
||||
config.update({
|
||||
'openai_api_key': settings.OPENAI_API_KEY,
|
||||
'llm_provider': session.llm_provider,
|
||||
'backend_url': session.backend_url,
|
||||
'shallow_thinking_model': session.shallow_thinker,
|
||||
'deep_thinking_model': session.deep_thinker,
|
||||
})
|
||||
|
||||
# Progress callback for websocket
|
||||
async def progress_callback(message_type: str, content: str, agent: str = None, step: int = 0, total: int = 0):
|
||||
progress_percent = int((step / total) * 99) if total > 0 else 0
|
||||
await self.websocket_manager.send_to_user(self.user.id, {
|
||||
'type': 'analysis_progress',
|
||||
'session_id': session.id,
|
||||
'message_type': message_type,
|
||||
'content': content,
|
||||
'agent': agent,
|
||||
'progress': progress_percent,
|
||||
})
|
||||
|
||||
trading_graph = TradingAgentsGraph(
|
||||
config=config,
|
||||
selected_analysts=session.analysts_selected,
|
||||
)
|
||||
|
||||
input_data = {
|
||||
'company_of_interest': session.ticker,
|
||||
'trade_date': session.analysis_date.strftime('%Y-%m-%d'),
|
||||
}
|
||||
|
||||
final_state, result = await asyncio.to_thread(
|
||||
trading_graph.propagate,
|
||||
input_data['company_of_interest'],
|
||||
input_data['trade_date']
|
||||
)
|
||||
|
||||
session.status = AnalysisStatus.COMPLETED
|
||||
session.completed_at = datetime.datetime.utcnow()
|
||||
session.final_report = json.dumps(final_state) # Store full state as JSON
|
||||
self.db.add(session)
|
||||
self.db.commit()
|
||||
|
||||
await self.websocket_manager.send_to_user(
|
||||
self.user.id,
|
||||
{
|
||||
'type': 'analysis_completed',
|
||||
'session_id': session.id,
|
||||
'message': '분석이 완료되었습니다.',
|
||||
'result': result
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
session.status = AnalysisStatus.FAILED
|
||||
session.error_message = str(e)
|
||||
self.db.add(session)
|
||||
self.db.commit()
|
||||
await self.websocket_manager.send_to_user(
|
||||
self.user.id,
|
||||
{
|
||||
'type': 'analysis_failed',
|
||||
'session_id': session.id,
|
||||
'message': f'분석 중 오류가 발생했습니다: {str(e)}'
|
||||
}
|
||||
)
|
||||
|
||||
def create_session(self, *, analysis_in: AnalysisSessionCreate) -> AnalysisSession:
|
||||
session = AnalysisSession(
|
||||
**analysis_in.dict(),
|
||||
user_id=self.user.id,
|
||||
analysis_date=datetime.date.today()
|
||||
)
|
||||
self.db.add(session)
|
||||
self.db.commit()
|
||||
self.db.refresh(session)
|
||||
return session
|
||||
|
||||
def get_session(self, *, session_id: int) -> Optional[AnalysisSession]:
|
||||
statement = select(AnalysisSession).where(AnalysisSession.id == session_id, AnalysisSession.user_id == self.user.id)
|
||||
return self.db.exec(statement).first()
|
||||
|
||||
def get_user_sessions(self, *, skip: int = 0, limit: int = 100) -> List[AnalysisSession]:
|
||||
statement = select(AnalysisSession).where(AnalysisSession.user_id == self.user.id).order_by(AnalysisSession.created_at.desc()).offset(skip).limit(limit)
|
||||
return self.db.exec(statement).all()
|
||||
|
|
|
|||
|
|
@ -1,23 +1,23 @@
|
|||
from typing import Dict, List
|
||||
from fastapi import WebSocket
|
||||
|
||||
class WebSocketManager:
|
||||
def __init__(self):
|
||||
self.active_connections: Dict[int, List[WebSocket]] = {}
|
||||
|
||||
async def connect(self, user_id: int, websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
if user_id not in self.active_connections:
|
||||
self.active_connections[user_id] = []
|
||||
self.active_connections[user_id].append(websocket)
|
||||
|
||||
def disconnect(self, user_id: int, websocket: WebSocket):
|
||||
if user_id in self.active_connections:
|
||||
self.active_connections[user_id].remove(websocket)
|
||||
if not self.active_connections[user_id]:
|
||||
del self.active_connections[user_id]
|
||||
|
||||
async def send_to_user(self, user_id: int, message: dict):
|
||||
if user_id in self.active_connections:
|
||||
for connection in self.active_connections[user_id]:
|
||||
await connection.send_json(message)
|
||||
from typing import Dict, List
|
||||
from fastapi import WebSocket
|
||||
|
||||
class WebSocketManager:
|
||||
def __init__(self):
|
||||
self.active_connections: Dict[int, List[WebSocket]] = {}
|
||||
|
||||
async def connect(self, user_id: int, websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
if user_id not in self.active_connections:
|
||||
self.active_connections[user_id] = []
|
||||
self.active_connections[user_id].append(websocket)
|
||||
|
||||
def disconnect(self, user_id: int, websocket: WebSocket):
|
||||
if user_id in self.active_connections:
|
||||
self.active_connections[user_id].remove(websocket)
|
||||
if not self.active_connections[user_id]:
|
||||
del self.active_connections[user_id]
|
||||
|
||||
async def send_to_user(self, user_id: int, message: dict):
|
||||
if user_id in self.active_connections:
|
||||
for connection in self.active_connections[user_id]:
|
||||
await connection.send_json(message)
|
||||
|
|
|
|||
|
|
@ -1,56 +1,56 @@
|
|||
from datetime import date, datetime
|
||||
from typing import List, Optional
|
||||
from sqlmodel import Field, SQLModel, JSON, Column
|
||||
import enum
|
||||
|
||||
|
||||
class User(SQLModel, table=True):
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
email: str = Field(unique=True, index=True)
|
||||
username: str = Field(unique=True, index=True)
|
||||
hashed_password: str
|
||||
first_name: Optional[str] = None
|
||||
last_name: Optional[str] = None
|
||||
is_active: bool = Field(default=True)
|
||||
is_superuser: bool = Field(default=False)
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow, sa_column_kwargs={"onupdate": datetime.utcnow})
|
||||
|
||||
|
||||
class UserProfile(SQLModel, table=True):
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
user_id: int = Field(foreign_key="user.id", unique=True)
|
||||
encrypted_openai_api_key: Optional[str] = None
|
||||
default_ticker: str = Field(default="SPY")
|
||||
preferred_research_depth: int = Field(default=3)
|
||||
preferred_shallow_thinker: str = Field(default="gpt-4o-mini")
|
||||
preferred_deep_thinker: str = Field(default="gpt-4o")
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow, sa_column_kwargs={"onupdate": datetime.utcnow})
|
||||
|
||||
|
||||
class AnalysisStatus(str, enum.Enum):
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
class AnalysisSession(SQLModel, table=True):
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
user_id: int = Field(foreign_key="user.id")
|
||||
ticker: str
|
||||
analysis_date: date
|
||||
analysts_selected: List[str] = Field(sa_column=Column(JSON))
|
||||
research_depth: int
|
||||
llm_provider: str
|
||||
backend_url: str
|
||||
shallow_thinker: str
|
||||
deep_thinker: str
|
||||
status: AnalysisStatus = Field(default=AnalysisStatus.PENDING)
|
||||
final_report: Optional[str] = None
|
||||
error_message: Optional[str] = None
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
started_at: Optional[datetime] = None
|
||||
from datetime import date, datetime
|
||||
from typing import List, Optional
|
||||
from sqlmodel import Field, SQLModel, JSON, Column
|
||||
import enum
|
||||
|
||||
|
||||
class User(SQLModel, table=True):
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
email: str = Field(unique=True, index=True)
|
||||
username: str = Field(unique=True, index=True)
|
||||
hashed_password: str
|
||||
first_name: Optional[str] = None
|
||||
last_name: Optional[str] = None
|
||||
is_active: bool = Field(default=True)
|
||||
is_superuser: bool = Field(default=False)
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow, sa_column_kwargs={"onupdate": datetime.utcnow})
|
||||
|
||||
|
||||
class UserProfile(SQLModel, table=True):
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
user_id: int = Field(foreign_key="user.id", unique=True)
|
||||
encrypted_openai_api_key: Optional[str] = None
|
||||
default_ticker: str = Field(default="SPY")
|
||||
preferred_research_depth: int = Field(default=3)
|
||||
preferred_shallow_thinker: str = Field(default="gpt-4o-mini")
|
||||
preferred_deep_thinker: str = Field(default="gpt-4o")
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow, sa_column_kwargs={"onupdate": datetime.utcnow})
|
||||
|
||||
|
||||
class AnalysisStatus(str, enum.Enum):
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
class AnalysisSession(SQLModel, table=True):
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
user_id: int = Field(foreign_key="user.id")
|
||||
ticker: str
|
||||
analysis_date: date
|
||||
analysts_selected: List[str] = Field(sa_column=Column(JSON))
|
||||
research_depth: int
|
||||
llm_provider: str
|
||||
backend_url: str
|
||||
shallow_thinker: str
|
||||
deep_thinker: str
|
||||
status: AnalysisStatus = Field(default=AnalysisStatus.PENDING)
|
||||
final_report: Optional[str] = None
|
||||
error_message: Optional[str] = None
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
started_at: Optional[datetime] = None
|
||||
completed_at: Optional[datetime] = None
|
||||
|
|
@ -1,48 +1,48 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Generic, TypeVar, Optional, List
|
||||
from sqlmodel import SQLModel
|
||||
from app.core.schemas.user import UserCreate, UserUpdate
|
||||
from app.domain.models import User
|
||||
|
||||
ModelType = TypeVar("ModelType", bound=SQLModel)
|
||||
CreateSchemaType = TypeVar("CreateSchemaType", bound=SQLModel)
|
||||
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=SQLModel)
|
||||
|
||||
class IRepository(Generic[ModelType], ABC):
|
||||
@abstractmethod
|
||||
def get(self, id: int) -> Optional[ModelType]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_multi(self, *, skip: int = 0, limit: int = 100) -> List[ModelType]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def create(self, *, obj_in: CreateSchemaType) -> ModelType:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update(self, *, db_obj: ModelType, obj_in: UpdateSchemaType) -> ModelType:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def remove(self, *, id: int) -> ModelType:
|
||||
pass
|
||||
|
||||
|
||||
class IUserRepository(IRepository[User], ABC):
|
||||
@abstractmethod
|
||||
def get_by_email(self, *, email: str) -> Optional[User]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def create(self, *, obj_in: UserCreate) -> User:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update(self, *, db_obj: User, obj_in: UserUpdate) -> User:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def is_superuser(self, *, user: User) -> bool:
|
||||
pass
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Generic, TypeVar, Optional, List
|
||||
from sqlmodel import SQLModel
|
||||
from app.core.schemas.user import UserCreate, UserUpdate
|
||||
from app.domain.models import User
|
||||
|
||||
ModelType = TypeVar("ModelType", bound=SQLModel)
|
||||
CreateSchemaType = TypeVar("CreateSchemaType", bound=SQLModel)
|
||||
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=SQLModel)
|
||||
|
||||
class IRepository(Generic[ModelType], ABC):
|
||||
@abstractmethod
|
||||
def get(self, id: int) -> Optional[ModelType]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_multi(self, *, skip: int = 0, limit: int = 100) -> List[ModelType]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def create(self, *, obj_in: CreateSchemaType) -> ModelType:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update(self, *, db_obj: ModelType, obj_in: UpdateSchemaType) -> ModelType:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def remove(self, *, id: int) -> ModelType:
|
||||
pass
|
||||
|
||||
|
||||
class IUserRepository(IRepository[User], ABC):
|
||||
@abstractmethod
|
||||
def get_by_email(self, *, email: str) -> Optional[User]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def create(self, *, obj_in: UserCreate) -> User:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update(self, *, db_obj: User, obj_in: UserUpdate) -> User:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def is_superuser(self, *, user: User) -> bool:
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
from sqlmodel import create_engine, Session
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
engine = create_engine(settings.DATABASE_URL, echo=True, connect_args={"check_same_thread": False})
|
||||
|
||||
def get_db():
|
||||
with Session(engine) as session:
|
||||
yield session
|
||||
from sqlmodel import create_engine, Session
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
engine = create_engine(settings.DATABASE_URL, echo=True, connect_args={"check_same_thread": False})
|
||||
|
||||
def get_db():
|
||||
with Session(engine) as session:
|
||||
yield session
|
||||
|
|
|
|||
|
|
@ -1,53 +1,53 @@
|
|||
from typing import Optional
|
||||
from sqlmodel import Session, select
|
||||
from app.domain.models import User
|
||||
from app.core.schemas.user import UserCreate, UserUpdate
|
||||
from app.domain.repositories import IUserRepository
|
||||
from app.core.security import get_password_hash
|
||||
|
||||
class UserRepository(IUserRepository):
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
def get(self, id: int) -> Optional[User]:
|
||||
return self.db.get(User, id)
|
||||
|
||||
def get_by_email(self, *, email: str) -> Optional[User]:
|
||||
statement = select(User).where(User.email == email)
|
||||
return self.db.exec(statement).first()
|
||||
|
||||
def get_multi(self, *, skip: int = 0, limit: int = 100) -> list[User]:
|
||||
statement = select(User).offset(skip).limit(limit)
|
||||
return self.db.exec(statement).all()
|
||||
|
||||
def create(self, *, obj_in: UserCreate) -> User:
|
||||
db_obj = User(
|
||||
email=obj_in.email,
|
||||
username=obj_in.username,
|
||||
hashed_password=get_password_hash(obj_in.password),
|
||||
first_name=obj_in.first_name,
|
||||
last_name=obj_in.last_name,
|
||||
)
|
||||
self.db.add(db_obj)
|
||||
self.db.commit()
|
||||
self.db.refresh(db_obj)
|
||||
return db_obj
|
||||
|
||||
def update(self, *, db_obj: User, obj_in: UserUpdate) -> User:
|
||||
update_data = obj_in.dict(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(db_obj, field, value)
|
||||
|
||||
self.db.add(db_obj)
|
||||
self.db.commit()
|
||||
self.db.refresh(db_obj)
|
||||
return db_obj
|
||||
|
||||
def remove(self, *, id: int) -> User:
|
||||
db_obj = self.db.get(User, id)
|
||||
self.db.delete(db_obj)
|
||||
self.db.commit()
|
||||
return db_obj
|
||||
|
||||
def is_superuser(self, *, user: User) -> bool:
|
||||
return user.is_superuser
|
||||
from typing import Optional
|
||||
from sqlmodel import Session, select
|
||||
from app.domain.models import User
|
||||
from app.core.schemas.user import UserCreate, UserUpdate
|
||||
from app.domain.repositories import IUserRepository
|
||||
from app.core.security import get_password_hash
|
||||
|
||||
class UserRepository(IUserRepository):
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
def get(self, id: int) -> Optional[User]:
|
||||
return self.db.get(User, id)
|
||||
|
||||
def get_by_email(self, *, email: str) -> Optional[User]:
|
||||
statement = select(User).where(User.email == email)
|
||||
return self.db.exec(statement).first()
|
||||
|
||||
def get_multi(self, *, skip: int = 0, limit: int = 100) -> list[User]:
|
||||
statement = select(User).offset(skip).limit(limit)
|
||||
return self.db.exec(statement).all()
|
||||
|
||||
def create(self, *, obj_in: UserCreate) -> User:
|
||||
db_obj = User(
|
||||
email=obj_in.email,
|
||||
username=obj_in.username,
|
||||
hashed_password=get_password_hash(obj_in.password),
|
||||
first_name=obj_in.first_name,
|
||||
last_name=obj_in.last_name,
|
||||
)
|
||||
self.db.add(db_obj)
|
||||
self.db.commit()
|
||||
self.db.refresh(db_obj)
|
||||
return db_obj
|
||||
|
||||
def update(self, *, db_obj: User, obj_in: UserUpdate) -> User:
|
||||
update_data = obj_in.dict(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(db_obj, field, value)
|
||||
|
||||
self.db.add(db_obj)
|
||||
self.db.commit()
|
||||
self.db.refresh(db_obj)
|
||||
return db_obj
|
||||
|
||||
def remove(self, *, id: int) -> User:
|
||||
db_obj = self.db.get(User, id)
|
||||
self.db.delete(db_obj)
|
||||
self.db.commit()
|
||||
return db_obj
|
||||
|
||||
def is_superuser(self, *, user: User) -> bool:
|
||||
return user.is_superuser
|
||||
|
|
|
|||
70
app/main.py
70
app/main.py
|
|
@ -1,36 +1,36 @@
|
|||
import sys
|
||||
import os
|
||||
from fastapi import FastAPI
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
|
||||
# Add project root to path to allow importing tradingagents
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")))
|
||||
|
||||
from app.api.router import api_router
|
||||
from app.core.config import settings
|
||||
from app.infrastructure.database import engine
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
def create_tables():
|
||||
SQLModel.metadata.create_all(engine)
|
||||
|
||||
app = FastAPI(
|
||||
title=settings.PROJECT_NAME,
|
||||
openapi_url=f"{settings.API_V1_STR}/openapi.json"
|
||||
)
|
||||
|
||||
@app.on_event("startup")
|
||||
def on_startup():
|
||||
create_tables()
|
||||
|
||||
# Set all CORS enabled origins
|
||||
if settings.CORS_ALLOWED_ORIGINS:
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=[str(origin) for origin in settings.CORS_ALLOWED_ORIGINS],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
import sys
|
||||
import os
|
||||
from fastapi import FastAPI
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
|
||||
# Add project root to path to allow importing tradingagents
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")))
|
||||
|
||||
from app.api.router import api_router
|
||||
from app.core.config import settings
|
||||
from app.infrastructure.database import engine
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
def create_tables():
|
||||
SQLModel.metadata.create_all(engine)
|
||||
|
||||
app = FastAPI(
|
||||
title=settings.PROJECT_NAME,
|
||||
openapi_url=f"{settings.API_V1_STR}/openapi.json"
|
||||
)
|
||||
|
||||
@app.on_event("startup")
|
||||
def on_startup():
|
||||
create_tables()
|
||||
|
||||
# Set all CORS enabled origins
|
||||
if settings.CORS_ALLOWED_ORIGINS:
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=[str(origin) for origin in settings.CORS_ALLOWED_ORIGINS],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
app.include_router(api_router, prefix=settings.API_V1_STR)
|
||||
|
|
@ -1,2 +1,2 @@
|
|||
.env
|
||||
.env
|
||||
wallet/
|
||||
|
|
@ -1,247 +1,299 @@
|
|||
from sqlmodel import Session
|
||||
from analysis.domain.repository.analysis_repo import IAnalysisRepository
|
||||
from ulid import ULID
|
||||
from analysis.domain.analysis import Analysis as AnalysisVO
|
||||
from analysis.interface.dto import TradingAnalysisRequest, AnalysisProgressUpdate
|
||||
from fastapi import HTTPException, status, BackgroundTasks
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
|
||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||
from tradingagents.default_config import DEFAULT_CONFIG
|
||||
|
||||
|
||||
|
||||
class AnalysisService:
|
||||
def __init__(
|
||||
self,
|
||||
analysis_repo: IAnalysisRepository,
|
||||
session: Session,
|
||||
ulid: ULID
|
||||
):
|
||||
self.analysis_repo = analysis_repo
|
||||
self.session = session
|
||||
self.ulid = ulid
|
||||
|
||||
def get_analysis_list(
|
||||
self,
|
||||
member_id: str
|
||||
) -> list[AnalysisVO]:
|
||||
analyses = self.analysis_repo.find_by_member_id(member_id)
|
||||
if not analyses:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Analysis not found")
|
||||
return analyses
|
||||
|
||||
def get_analysis_by_id(
|
||||
self,
|
||||
analysis_id: str,
|
||||
member_id: str
|
||||
) -> AnalysisVO:
|
||||
analysis = self.analysis_repo.find_by_id(analysis_id)
|
||||
if not analysis:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Analysis not found")
|
||||
|
||||
if analysis.member_id != member_id:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Access denied")
|
||||
|
||||
return analysis
|
||||
|
||||
def create_analysis(
|
||||
self,
|
||||
member_id: str,
|
||||
request: TradingAnalysisRequest,
|
||||
background_tasks: BackgroundTasks
|
||||
) -> AnalysisVO:
|
||||
# 분석 요청 생성
|
||||
analysis_id = self.ulid.generate()
|
||||
now = datetime.now()
|
||||
|
||||
analysis_vo = AnalysisVO(
|
||||
id=analysis_id,
|
||||
member_id=member_id,
|
||||
ticker=request.ticker,
|
||||
analysis_date=request.analysis_date,
|
||||
analysts_selected=[analyst.value for analyst in request.analysts],
|
||||
research_depth=request.research_depth,
|
||||
llm_provider=request.llm_provider,
|
||||
backend_url=request.backend_url,
|
||||
shallow_thinker=request.shallow_thinker,
|
||||
deep_thinker=request.deep_thinker,
|
||||
status="pending",
|
||||
created_at=now,
|
||||
updated_at=now
|
||||
)
|
||||
|
||||
saved_analysis = self.analysis_repo.save(analysis_vo)
|
||||
self.session.commit()
|
||||
|
||||
# 백그라운드에서 분석 실행
|
||||
background_tasks.add_task(self._run_analysis, saved_analysis.id)
|
||||
|
||||
return saved_analysis
|
||||
|
||||
async def _run_analysis(self, analysis_id: str):
|
||||
"""백그라운드에서 실제 분석을 실행하는 메서드"""
|
||||
try:
|
||||
# 분석 상태를 RUNNING으로 변경
|
||||
analysis = self.analysis_repo.find_by_id(analysis_id)
|
||||
if analysis:
|
||||
analysis.status = "running"
|
||||
analysis.updated_at = datetime.now()
|
||||
self.analysis_repo.update(analysis)
|
||||
self.session.commit()
|
||||
|
||||
# 분석 정보 조회
|
||||
analysis = self.analysis_repo.find_by_id(analysis_id)
|
||||
if not analysis:
|
||||
return
|
||||
|
||||
# TradingAgentsGraph 설정 및 실행
|
||||
config = self._create_config(analysis)
|
||||
|
||||
# 분석 실행 (실제 구현)
|
||||
await self._execute_trading_analysis(analysis_id, analysis, config)
|
||||
|
||||
# 분석 완료 상태로 변경
|
||||
analysis = self.analysis_repo.find_by_id(analysis_id)
|
||||
if analysis:
|
||||
analysis.status = "completed"
|
||||
analysis.completed_at = datetime.now()
|
||||
analysis.updated_at = datetime.now()
|
||||
self.analysis_repo.update(analysis)
|
||||
self.session.commit()
|
||||
|
||||
except Exception as e:
|
||||
# 에러 발생 시 실패 상태로 변경
|
||||
analysis = self.analysis_repo.find_by_id(analysis_id)
|
||||
if analysis:
|
||||
analysis.status = "failed"
|
||||
analysis.error_message = str(e)
|
||||
analysis.completed_at = datetime.now()
|
||||
analysis.updated_at = datetime.now()
|
||||
self.analysis_repo.update(analysis)
|
||||
self.session.commit()
|
||||
|
||||
def _create_config(self, analysis: AnalysisVO) -> dict:
|
||||
"""분석 설정을 생성하는 메서드"""
|
||||
config = DEFAULT_CONFIG.copy() if DEFAULT_CONFIG else {}
|
||||
config.update({
|
||||
"max_debate_rounds": analysis.research_depth,
|
||||
"max_risk_discuss_rounds": analysis.research_depth,
|
||||
"quick_think_llm": analysis.shallow_thinker,
|
||||
"deep_think_llm": analysis.deep_thinker,
|
||||
"backend_url": analysis.backend_url,
|
||||
"llm_provider": analysis.llm_provider.lower(),
|
||||
})
|
||||
return config
|
||||
|
||||
async def _execute_trading_analysis(self, analysis_id: str, analysis: AnalysisVO, config: dict):
|
||||
"""실제 TradingAgentsGraph를 실행하는 메서드"""
|
||||
try:
|
||||
# TradingAgentsGraph 초기화
|
||||
graph = TradingAgentsGraph(
|
||||
analysis.analysts_selected,
|
||||
config=config,
|
||||
debug=True
|
||||
)
|
||||
|
||||
# 초기 상태 생성
|
||||
init_agent_state = graph.propagator.create_initial_state(
|
||||
analysis.ticker,
|
||||
analysis.analysis_date
|
||||
)
|
||||
args = graph.propagator.get_graph_args()
|
||||
|
||||
# 분석 실행 및 결과 처리
|
||||
trace = []
|
||||
async for chunk in graph.graph.astream(init_agent_state, **args):
|
||||
trace.append(chunk)
|
||||
|
||||
# 실시간으로 분석 결과 업데이트
|
||||
await self._process_analysis_chunk(analysis_id, chunk)
|
||||
|
||||
# 최종 결과 처리
|
||||
if trace:
|
||||
final_state = trace[-1]
|
||||
final_decision = graph.process_signal(final_state.get("final_trade_decision", ""))
|
||||
|
||||
# 최종 보고서 생성
|
||||
final_report = self._generate_final_report(final_state)
|
||||
|
||||
# 최종 결과 저장
|
||||
self.analysis_repo.update(analysis_id, {
|
||||
"final_trade_decision": final_decision,
|
||||
"final_report": final_report
|
||||
})
|
||||
self.session.commit()
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"Analysis execution failed: {str(e)}")
|
||||
|
||||
async def _process_analysis_chunk(self, analysis_id: str, chunk: dict):
|
||||
"""분석 중간 결과를 처리하고 저장하는 메서드"""
|
||||
updates = {}
|
||||
|
||||
# 개별 분석가 보고서 업데이트
|
||||
if "market_report" in chunk and chunk["market_report"]:
|
||||
updates["market_report"] = chunk["market_report"]
|
||||
|
||||
if "sentiment_report" in chunk and chunk["sentiment_report"]:
|
||||
updates["sentiment_report"] = chunk["sentiment_report"]
|
||||
|
||||
if "news_report" in chunk and chunk["news_report"]:
|
||||
updates["news_report"] = chunk["news_report"]
|
||||
|
||||
if "fundamentals_report" in chunk and chunk["fundamentals_report"]:
|
||||
updates["fundamentals_report"] = chunk["fundamentals_report"]
|
||||
|
||||
# 팀별 의사결정 과정 업데이트
|
||||
if "investment_debate_state" in chunk and chunk["investment_debate_state"]:
|
||||
updates["investment_debate_state"] = chunk["investment_debate_state"]
|
||||
|
||||
if "trader_investment_plan" in chunk and chunk["trader_investment_plan"]:
|
||||
updates["trader_investment_plan"] = chunk["trader_investment_plan"]
|
||||
|
||||
if "risk_debate_state" in chunk and chunk["risk_debate_state"]:
|
||||
updates["risk_debate_state"] = chunk["risk_debate_state"]
|
||||
|
||||
# 업데이트가 있는 경우 저장
|
||||
if updates:
|
||||
self.analysis_repo.update(analysis_id, updates)
|
||||
self.session.commit()
|
||||
|
||||
def _generate_final_report(self, final_state: dict) -> str:
|
||||
"""최종 통합 보고서를 생성하는 메서드"""
|
||||
report_parts = []
|
||||
|
||||
# Analyst Team Reports
|
||||
if any(final_state.get(section) for section in ["market_report", "sentiment_report", "news_report", "fundamentals_report"]):
|
||||
report_parts.append("## Analyst Team Reports")
|
||||
|
||||
if final_state.get("market_report"):
|
||||
report_parts.append(f"### Market Analysis\n{final_state['market_report']}")
|
||||
if final_state.get("sentiment_report"):
|
||||
report_parts.append(f"### Social Sentiment\n{final_state['sentiment_report']}")
|
||||
if final_state.get("news_report"):
|
||||
report_parts.append(f"### News Analysis\n{final_state['news_report']}")
|
||||
if final_state.get("fundamentals_report"):
|
||||
report_parts.append(f"### Fundamentals Analysis\n{final_state['fundamentals_report']}")
|
||||
|
||||
# Research Team Reports
|
||||
if final_state.get("investment_debate_state"):
|
||||
report_parts.append("## Research Team Decision")
|
||||
debate_state = final_state["investment_debate_state"]
|
||||
if debate_state.get("judge_decision"):
|
||||
report_parts.append(f"{debate_state['judge_decision']}")
|
||||
|
||||
# Trading Team Reports
|
||||
if final_state.get("trader_investment_plan"):
|
||||
report_parts.append("## Trading Team Plan")
|
||||
report_parts.append(f"{final_state['trader_investment_plan']}")
|
||||
|
||||
# Portfolio Management Decision
|
||||
if final_state.get("risk_debate_state") and final_state["risk_debate_state"].get("judge_decision"):
|
||||
report_parts.append("## Portfolio Management Decision")
|
||||
report_parts.append(f"{final_state['risk_debate_state']['judge_decision']}")
|
||||
|
||||
return "\n\n".join(report_parts) if report_parts else "No analysis results available."
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '../../..'))
|
||||
|
||||
import logging
|
||||
from sqlmodel import Session
|
||||
from analysis.domain.repository.analysis_repo import IAnalysisRepository
|
||||
from ulid import ULID
|
||||
from analysis.domain.analysis import Analysis as AnalysisVO
|
||||
from analysis.interface.dto import TradingAnalysisRequest, AnalysisProgressUpdate
|
||||
from fastapi import HTTPException, status, BackgroundTasks
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
|
||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||
from tradingagents.default_config import DEFAULT_CONFIG
|
||||
from analysis.application.websocket_manager import WebSocketManager
|
||||
from analysis.infra.db_models.analysis import AnalysisStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class AnalysisService:
|
||||
def __init__(
|
||||
self,
|
||||
analysis_repo: IAnalysisRepository,
|
||||
session: Session,
|
||||
ulid: ULID,
|
||||
websocket_manager: WebSocketManager
|
||||
):
|
||||
self.analysis_repo = analysis_repo
|
||||
self.session = session
|
||||
self.ulid = ulid
|
||||
self.websocket_manager = websocket_manager
|
||||
|
||||
def get_analysis_list(
|
||||
self,
|
||||
member_id: str
|
||||
) -> list[AnalysisVO]:
|
||||
analyses = self.analysis_repo.find_by_member_id(member_id)
|
||||
if not analyses:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Analysis not found")
|
||||
return analyses
|
||||
|
||||
def get_analysis_by_id(
|
||||
self,
|
||||
analysis_id: str,
|
||||
member_id: str
|
||||
) -> AnalysisVO:
|
||||
analysis = self.analysis_repo.find_by_id(analysis_id)
|
||||
if not analysis:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Analysis not found")
|
||||
|
||||
if analysis.member_id != member_id:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Access denied")
|
||||
|
||||
return analysis
|
||||
|
||||
def get_analysis_sessions_by_member(
|
||||
self,
|
||||
member_id: str
|
||||
) -> list[AnalysisVO]:
|
||||
analyses = self.analysis_repo.find_by_member_id(member_id)
|
||||
if not analyses:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Analysis not found")
|
||||
return analyses
|
||||
|
||||
def create_analysis(
|
||||
self,
|
||||
member_id: str,
|
||||
request: TradingAnalysisRequest,
|
||||
background_tasks: BackgroundTasks
|
||||
) -> AnalysisVO:
|
||||
# 분석 요청 생성
|
||||
analysis_id = self.ulid.generate()
|
||||
now = datetime.now()
|
||||
|
||||
analysis_vo = AnalysisVO(
|
||||
id=analysis_id,
|
||||
member_id=member_id,
|
||||
ticker=request.ticker,
|
||||
analysis_date=request.analysis_date,
|
||||
analysts_selected=[analyst.value for analyst in request.analysts],
|
||||
research_depth=request.research_depth,
|
||||
llm_provider=request.llm_provider,
|
||||
backend_url=request.backend_url,
|
||||
shallow_thinker=request.shallow_thinker,
|
||||
deep_thinker=request.deep_thinker,
|
||||
status=AnalysisStatus.PENDING,
|
||||
created_at=now,
|
||||
updated_at=now
|
||||
)
|
||||
|
||||
saved_analysis = self.analysis_repo.save(analysis_vo)
|
||||
if not saved_analysis:
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to save analysis")
|
||||
|
||||
self.session.commit()
|
||||
|
||||
# Register analysis with websocket manager
|
||||
self.websocket_manager.register_analysis(saved_analysis.id, member_id)
|
||||
|
||||
# 백그라운드에서 분석 실행
|
||||
background_tasks.add_task(self._run_analysis, saved_analysis.id)
|
||||
|
||||
return saved_analysis
|
||||
|
||||
async def _run_analysis(self, analysis_id: str):
|
||||
"""백그라운드에서 실제 분석을 실행하는 메서드"""
|
||||
try:
|
||||
analysis = AnalysisVO(
|
||||
id=analysis_id,
|
||||
status=AnalysisStatus.RUNNING,
|
||||
updated_at=datetime.now()
|
||||
)
|
||||
|
||||
analysis = self.analysis_repo.update(analysis)
|
||||
if not analysis:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Analysis not found")
|
||||
|
||||
await self.websocket_manager.send_analysis_update(
|
||||
analysis_id=analysis_id,
|
||||
update_type="status_changed",
|
||||
data={"status": "running", "message": "Analysis started"}
|
||||
)
|
||||
|
||||
|
||||
|
||||
# TradingAgentsGraph 설정 및 실행
|
||||
if analysis:
|
||||
config = self._create_config(analysis)
|
||||
|
||||
# 분석 실행 (실제 구현)
|
||||
await self._execute_trading_analysis(analysis_id, analysis, config)
|
||||
|
||||
# 완료 상태로 업데이트
|
||||
completed_analysis = AnalysisVO(
|
||||
id=analysis_id,
|
||||
status=AnalysisStatus.COMPLETED,
|
||||
completed_at=datetime.now(),
|
||||
updated_at=datetime.now()
|
||||
)
|
||||
self.analysis_repo.update(completed_analysis)
|
||||
self.session.commit()
|
||||
|
||||
|
||||
except Exception as e:
|
||||
now = datetime.now()
|
||||
updates = AnalysisVO(
|
||||
status=AnalysisStatus.FAILED,
|
||||
error_message=str(e),
|
||||
completed_at = now,
|
||||
updated_at = now
|
||||
)
|
||||
|
||||
self.analysis_repo.update(updates)
|
||||
self.session.commit()
|
||||
|
||||
|
||||
def _create_config(self, analysis: AnalysisVO) -> dict:
|
||||
"""분석 설정을 생성하는 메서드"""
|
||||
config = {}
|
||||
config.update({
|
||||
"max_debate_rounds": analysis.research_depth,
|
||||
"max_risk_discuss_rounds": analysis.research_depth,
|
||||
"quick_think_llm": analysis.shallow_thinker,
|
||||
"deep_think_llm": analysis.deep_thinker,
|
||||
"backend_url": analysis.backend_url,
|
||||
"llm_provider": analysis.llm_provider.lower(),
|
||||
})
|
||||
return config
|
||||
|
||||
async def _execute_trading_analysis(self, analysis_id: str, analysis: AnalysisVO, config: dict):
|
||||
"""실제 TradingAgentsGraph를 실행하는 메서드"""
|
||||
try:
|
||||
logger.info(f"Starting trading analysis for {analysis_id} with ticker {analysis.ticker}")
|
||||
logger.info(f"Analysts selected: {analysis.analysts_selected}")
|
||||
logger.info(f"Config: {config}")
|
||||
|
||||
# TradingAgentsGraph 초기화
|
||||
graph = TradingAgentsGraph(
|
||||
analysis.analysts_selected,
|
||||
config=config,
|
||||
debug=True
|
||||
)
|
||||
logger.info("TradingAgentsGraph initialized successfully")
|
||||
|
||||
# 초기 상태 생성
|
||||
init_agent_state = graph.propagator.create_initial_state(
|
||||
analysis.ticker,
|
||||
analysis.analysis_date
|
||||
)
|
||||
args = graph.propagator.get_graph_args()
|
||||
|
||||
# 분석 실행 및 결과 처리
|
||||
logger.info("Starting graph execution...")
|
||||
trace = []
|
||||
chunk_count = 0
|
||||
async for chunk in graph.graph.astream(init_agent_state, **args):
|
||||
chunk_count += 1
|
||||
logger.info(f"Processing chunk {chunk_count}: {list(chunk.keys()) if chunk else 'Empty chunk'}")
|
||||
trace.append(chunk)
|
||||
|
||||
# 실시간으로 분석 결과 업데이트
|
||||
await self._process_analysis_chunk(analysis_id, chunk)
|
||||
|
||||
# 최종 결과 처리
|
||||
if trace:
|
||||
final_state = trace[-1]
|
||||
final_decision = graph.process_signal(final_state.get("final_trade_decision", ""))
|
||||
|
||||
# 최종 보고서 생성
|
||||
final_report = self._generate_final_report(final_state)
|
||||
analysis.final_trade_decision = final_decision
|
||||
analysis.final_report = final_report
|
||||
|
||||
# 최종 결과 저장
|
||||
updates = AnalysisVO(
|
||||
id=analysis_id,
|
||||
final_trade_decision=final_decision,
|
||||
final_report=final_report
|
||||
)
|
||||
self.analysis_repo.update(updates)
|
||||
|
||||
self.session.commit()
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"Analysis execution failed: {str(e)}")
|
||||
|
||||
async def _process_analysis_chunk(self, analysis_id: str, chunk: dict):
|
||||
"""분석 중간 결과를 처리하고 저장하는 메서드"""
|
||||
updates = {}
|
||||
|
||||
# 개별 분석가 보고서 업데이트
|
||||
if "market_report" in chunk and chunk["market_report"]:
|
||||
updates["market_report"] = chunk["market_report"]
|
||||
|
||||
if "sentiment_report" in chunk and chunk["sentiment_report"]:
|
||||
updates["sentiment_report"] = chunk["sentiment_report"]
|
||||
|
||||
if "news_report" in chunk and chunk["news_report"]:
|
||||
updates["news_report"] = chunk["news_report"]
|
||||
|
||||
if "fundamentals_report" in chunk and chunk["fundamentals_report"]:
|
||||
updates["fundamentals_report"] = chunk["fundamentals_report"]
|
||||
|
||||
# 팀별 의사결정 과정 업데이트
|
||||
if "investment_debate_state" in chunk and chunk["investment_debate_state"]:
|
||||
updates["investment_debate_state"] = chunk["investment_debate_state"]
|
||||
|
||||
if "trader_investment_plan" in chunk and chunk["trader_investment_plan"]:
|
||||
updates["trader_investment_plan"] = chunk["trader_investment_plan"]
|
||||
|
||||
if "risk_debate_state" in chunk and chunk["risk_debate_state"]:
|
||||
updates["risk_debate_state"] = chunk["risk_debate_state"]
|
||||
|
||||
# 업데이트가 있는 경우 저장
|
||||
if updates:
|
||||
# analysis_id를 포함한 AnalysisVO 객체 생성
|
||||
updates["id"] = analysis_id
|
||||
updates_vo = AnalysisVO(**updates)
|
||||
self.analysis_repo.update(updates_vo)
|
||||
self.session.commit()
|
||||
|
||||
def _generate_final_report(self, final_state: dict) -> str:
|
||||
"""최종 통합 보고서를 생성하는 메서드"""
|
||||
report_parts = []
|
||||
|
||||
# Analyst Team Reports
|
||||
if any(final_state.get(section) for section in ["market_report", "sentiment_report", "news_report", "fundamentals_report"]):
|
||||
report_parts.append("## Analyst Team Reports")
|
||||
|
||||
if final_state.get("market_report"):
|
||||
report_parts.append(f"### Market Analysis\n{final_state['market_report']}")
|
||||
if final_state.get("sentiment_report"):
|
||||
report_parts.append(f"### Social Sentiment\n{final_state['sentiment_report']}")
|
||||
if final_state.get("news_report"):
|
||||
report_parts.append(f"### News Analysis\n{final_state['news_report']}")
|
||||
if final_state.get("fundamentals_report"):
|
||||
report_parts.append(f"### Fundamentals Analysis\n{final_state['fundamentals_report']}")
|
||||
|
||||
# Research Team Reports
|
||||
if final_state.get("investment_debate_state"):
|
||||
report_parts.append("## Research Team Decision")
|
||||
debate_state = final_state["investment_debate_state"]
|
||||
if debate_state.get("judge_decision"):
|
||||
report_parts.append(f"{debate_state['judge_decision']}")
|
||||
|
||||
# Trading Team Reports
|
||||
if final_state.get("trader_investment_plan"):
|
||||
report_parts.append("## Trading Team Plan")
|
||||
report_parts.append(f"{final_state['trader_investment_plan']}")
|
||||
|
||||
# Portfolio Management Decision
|
||||
if final_state.get("risk_debate_state") and final_state["risk_debate_state"].get("judge_decision"):
|
||||
report_parts.append("## Portfolio Management Decision")
|
||||
report_parts.append(f"{final_state['risk_debate_state']['judge_decision']}")
|
||||
|
||||
return "\n\n".join(report_parts) if report_parts else "No analysis results available."
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,65 @@
|
|||
from typing import Dict, Set
|
||||
from fastapi import WebSocket
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class WebSocketManager:
|
||||
def __init__(self):
|
||||
# Store active connections by member_id
|
||||
self.active_connections: Dict[str, Set[WebSocket]] = {}
|
||||
# Store analysis_id to member_id mapping
|
||||
self.analysis_member_map: Dict[str, str] = {}
|
||||
|
||||
async def connect(self, websocket: WebSocket, member_id: str):
|
||||
await websocket.accept()
|
||||
if member_id not in self.active_connections:
|
||||
self.active_connections[member_id] = set()
|
||||
self.active_connections[member_id].add(websocket)
|
||||
|
||||
def disconnect(self, websocket: WebSocket, member_id: str):
|
||||
if member_id in self.active_connections:
|
||||
self.active_connections[member_id].discard(websocket)
|
||||
if not self.active_connections[member_id]:
|
||||
del self.active_connections[member_id]
|
||||
|
||||
def register_analysis(self, analysis_id: str, member_id: str):
|
||||
"""Register which member owns which analysis"""
|
||||
self.analysis_member_map[analysis_id] = member_id
|
||||
|
||||
async def send_analysis_update(self, analysis_id: str, update_type: str, data: dict):
|
||||
"""Send analysis update to the member who owns the analysis"""
|
||||
member_id = self.analysis_member_map.get(analysis_id)
|
||||
if not member_id:
|
||||
return
|
||||
|
||||
message = {
|
||||
"type": "analysis_update",
|
||||
"analysis_id": analysis_id,
|
||||
"update_type": update_type,
|
||||
"data": data,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
await self.send_to_member(member_id, message)
|
||||
|
||||
|
||||
|
||||
async def send_to_member(self, member_id: str, message: dict|str):
|
||||
"""Send message to all connections of a specific member"""
|
||||
if member_id not in self.active_connections:
|
||||
return
|
||||
|
||||
dead_connections = set()
|
||||
for connection in self.active_connections[member_id]:
|
||||
try:
|
||||
if isinstance(message, dict):
|
||||
await connection.send_json(message)
|
||||
else:
|
||||
await connection.send_text(message)
|
||||
except Exception:
|
||||
dead_connections.add(connection)
|
||||
|
||||
# Clean up dead connections
|
||||
for connection in dead_connections:
|
||||
self.disconnect(connection, member_id)
|
||||
|
|
@ -1,19 +1,20 @@
|
|||
from pydantic import BaseModel
|
||||
from datetime import datetime
|
||||
from typing import List, Dict
|
||||
from pydantic import BaseModel, field_validator
|
||||
from datetime import datetime, date
|
||||
from typing import List, Dict, Union
|
||||
from analysis.infra.db_models.analysis import AnalysisStatus
|
||||
|
||||
class Analysis(BaseModel):
|
||||
id: str | None = None
|
||||
member_id: str
|
||||
ticker: str
|
||||
analysis_date: str
|
||||
analysts_selected: List[str] = []
|
||||
id: str
|
||||
member_id: str | None = None
|
||||
ticker: str | None = None
|
||||
analysis_date: date | None = None
|
||||
analysts_selected: list[str] = []
|
||||
research_depth: int = 3
|
||||
llm_provider: str = "openai"
|
||||
backend_url: str = "https://api.openai.com/v1"
|
||||
shallow_thinker: str = "gpt-4o-mini"
|
||||
deep_thinker: str = "gpt-4o"
|
||||
status: str
|
||||
shallow_thinker: str = "gpt-4o"
|
||||
deep_thinker: str = "o3"
|
||||
status: AnalysisStatus = AnalysisStatus.PENDING
|
||||
|
||||
# 개별 분석가 리포트들
|
||||
market_report: str | None = None
|
||||
|
|
@ -33,5 +34,5 @@ class Analysis(BaseModel):
|
|||
# 실행 결과 정보
|
||||
error_message: str | None = None
|
||||
completed_at: datetime | None = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
created_at: datetime | None = None
|
||||
updated_at: datetime | None = None
|
||||
|
|
@ -1,20 +1,20 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from analysis.domain.analysis import Analysis as AnalysisVO
|
||||
from analysis.interface.dto import TradingAnalysisRequest
|
||||
|
||||
class IAnalysisRepository(ABC):
|
||||
@abstractmethod
|
||||
def find_by_member_id(self, member_id: str) -> list[AnalysisVO] | None:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def find_by_id(self, analysis_id: str) -> AnalysisVO | None:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def update(self, analysis: AnalysisVO) -> AnalysisVO | None:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def save(self, analysis: AnalysisVO) -> AnalysisVO:
|
||||
raise NotImplementedError()
|
||||
from abc import ABC, abstractmethod
|
||||
from analysis.domain.analysis import Analysis as AnalysisVO
|
||||
from analysis.interface.dto import TradingAnalysisRequest
|
||||
|
||||
class IAnalysisRepository(ABC):
|
||||
@abstractmethod
|
||||
def find_by_member_id(self, member_id: str) -> list[AnalysisVO] | None:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def find_by_id(self, analysis_id: str) -> AnalysisVO | None:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def update(self, analysis: AnalysisVO) -> AnalysisVO | None:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def save(self, analysis: AnalysisVO) -> AnalysisVO:
|
||||
raise NotImplementedError()
|
||||
|
|
|
|||
|
|
@ -1,57 +1,57 @@
|
|||
from datetime import datetime,date
|
||||
from typing import TYPE_CHECKING
|
||||
from sqlmodel import SQLModel, Field, JSON, Relationship
|
||||
import enum
|
||||
from sqlalchemy import Column
|
||||
|
||||
# TYPE_CHECKING을 사용해서 circular import 방지
|
||||
if TYPE_CHECKING:
|
||||
from member.infra.db_models.member import Member
|
||||
|
||||
class AnalysisStatus(str, enum.Enum):
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
class Analysis(SQLModel, table=True):
|
||||
__tablename__ = "analyses"
|
||||
id: str = Field(default=None, max_length=36, primary_key=True)
|
||||
|
||||
# 기본 분석 설정 정보
|
||||
ticker: str
|
||||
analysis_date: date
|
||||
analysts_selected: list[str] = Field(sa_column=Column(JSON))
|
||||
research_depth: int
|
||||
llm_provider: str
|
||||
backend_url: str
|
||||
shallow_thinker: str
|
||||
deep_thinker: str
|
||||
status: AnalysisStatus = Field(default=AnalysisStatus.PENDING)
|
||||
|
||||
# 개별 분석가 리포트들
|
||||
market_report: str | None = Field(default=None, description="Market Analyst 리포트")
|
||||
sentiment_report: str | None = Field(default=None, description="Social Analyst 리포트")
|
||||
news_report: str | None = Field(default=None, description="News Analyst 리포트")
|
||||
fundamentals_report: str | None = Field(default=None, description="Fundamentals Analyst 리포트")
|
||||
|
||||
# 팀별 의사결정 과정
|
||||
investment_debate_state: dict | None = Field(default=None, sa_column=Column(JSON), description="Research Team 토론 과정")
|
||||
trader_investment_plan: str | None = Field(default=None, description="Trading Team 계획")
|
||||
risk_debate_state: dict | None = Field(default=None, sa_column=Column(JSON), description="Risk Management Team 토론 과정")
|
||||
|
||||
# 최종 결과물
|
||||
final_trade_decision: str | None = Field(default=None, description="최종 거래 결정")
|
||||
final_report: str | None = Field(default=None, description="전체 통합 리포트")
|
||||
|
||||
# 실행 결과 정보
|
||||
error_message: str | None = None
|
||||
completed_at: datetime | None = None
|
||||
created_at : datetime = Field(nullable=False)
|
||||
updated_at : datetime = Field(nullable=False)
|
||||
|
||||
# Foreign Key와 Relationship 설정
|
||||
member_id: str = Field(foreign_key="members.id")
|
||||
from datetime import datetime,date
|
||||
from typing import TYPE_CHECKING
|
||||
from sqlmodel import SQLModel, Field, JSON, Relationship
|
||||
import enum
|
||||
from sqlalchemy import Column, Text
|
||||
|
||||
# TYPE_CHECKING을 사용해서 circular import 방지
|
||||
if TYPE_CHECKING:
|
||||
from member.infra.db_models.member import Member
|
||||
|
||||
class AnalysisStatus(str, enum.Enum):
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
class Analysis(SQLModel, table=True):
|
||||
__tablename__ = "analyses"
|
||||
id: str = Field(default=None, max_length=36, primary_key=True)
|
||||
|
||||
# 기본 분석 설정 정보
|
||||
ticker: str
|
||||
analysis_date: date
|
||||
analysts_selected: list[str] = Field(sa_column=Column(JSON))
|
||||
research_depth: int
|
||||
llm_provider: str
|
||||
backend_url: str
|
||||
shallow_thinker: str
|
||||
deep_thinker: str
|
||||
status: AnalysisStatus = Field(default=AnalysisStatus.PENDING)
|
||||
|
||||
# 개별 분석가 리포트들
|
||||
market_report: str | None = Field(default=None, sa_column=Column(Text), description="Market Analyst 리포트")
|
||||
sentiment_report: str | None = Field(default=None, sa_column=Column(Text), description="Social Analyst 리포트")
|
||||
news_report: str | None = Field(default=None, sa_column=Column(Text), description="News Analyst 리포트")
|
||||
fundamentals_report: str | None = Field(default=None, sa_column=Column(Text), description="Fundamentals Analyst 리포트")
|
||||
|
||||
# 팀별 의사결정 과정
|
||||
investment_debate_state: dict | None = Field(default=None, sa_column=Column(JSON), description="Research Team 토론 과정")
|
||||
trader_investment_plan: str | None = Field(default=None, sa_column=Column(Text), description="Trading Team 계획")
|
||||
risk_debate_state: dict | None = Field(default=None, sa_column=Column(JSON), description="Risk Management Team 토론 과정")
|
||||
|
||||
# 최종 결과물
|
||||
final_trade_decision: str | None = Field(default=None, sa_column=Column(Text), description="최종 거래 결정")
|
||||
final_report: str | None = Field(default=None, sa_column=Column(Text), description="전체 통합 리포트")
|
||||
|
||||
# 실행 결과 정보
|
||||
error_message: str | None = Field(default=None, sa_column=Column(Text))
|
||||
completed_at: datetime | None = None
|
||||
created_at : datetime = Field(nullable=False)
|
||||
updated_at : datetime = Field(nullable=False)
|
||||
|
||||
# Foreign Key와 Relationship 설정
|
||||
member_id: str = Field(foreign_key="members.id")
|
||||
member: "Member" = Relationship(back_populates="analyses")
|
||||
|
|
@ -1,80 +1,55 @@
|
|||
from analysis.domain.repository.analysis_repo import IAnalysisRepository
|
||||
from sqlmodel import Session, select
|
||||
from analysis.domain.analysis import Analysis as AnalysisVO
|
||||
from analysis.infra.db_models.analysis import Analysis, AnalysisStatus
|
||||
from analysis.interface.dto import TradingAnalysisRequest
|
||||
from utils.db_utils import row_to_dict
|
||||
from sqlalchemy.orm import selectinload
|
||||
from datetime import datetime, date
|
||||
|
||||
class AnalysisRepository(IAnalysisRepository):
|
||||
def __init__(self, session: Session):
|
||||
self.session = session
|
||||
|
||||
def find_by_member_id(self, member_id: str) -> list[AnalysisVO] | None:
|
||||
query = select(Analysis).where(Analysis.member_id == member_id)
|
||||
analyses = self.session.exec(query).all()
|
||||
|
||||
if not analyses:
|
||||
return None
|
||||
|
||||
return [AnalysisVO(**row_to_dict(analysis)) for analysis in analyses]
|
||||
|
||||
def find_by_id(self, analysis_id: str) -> AnalysisVO | None:
|
||||
analysis = self.session.get(Analysis, analysis_id)
|
||||
if not analysis:
|
||||
return None
|
||||
return AnalysisVO(**row_to_dict(analysis))
|
||||
|
||||
def save(self, analysis: AnalysisVO) -> AnalysisVO:
|
||||
new_analysis = Analysis(
|
||||
id=analysis.id,
|
||||
member_id=analysis.member_id,
|
||||
ticker=analysis.ticker,
|
||||
analysis_date=date.fromisoformat(analysis.analysis_date),
|
||||
analysts_selected=analysis.analysts_selected,
|
||||
research_depth=analysis.research_depth,
|
||||
llm_provider=analysis.llm_provider,
|
||||
backend_url=analysis.backend_url,
|
||||
shallow_thinker=analysis.shallow_thinker,
|
||||
deep_thinker=analysis.deep_thinker,
|
||||
status=analysis.status,
|
||||
market_report=analysis.market_report,
|
||||
sentiment_report=analysis.sentiment_report,
|
||||
news_report=analysis.news_report,
|
||||
fundamentals_report=analysis.fundamentals_report,
|
||||
investment_debate_state=analysis.investment_debate_state,
|
||||
trader_investment_plan=analysis.trader_investment_plan,
|
||||
risk_debate_state=analysis.risk_debate_state,
|
||||
final_trade_decision=analysis.final_trade_decision,
|
||||
final_report=analysis.final_report,
|
||||
error_message=analysis.error_message,
|
||||
completed_at=analysis.completed_at,
|
||||
created_at=analysis.created_at,
|
||||
updated_at=analysis.updated_at
|
||||
)
|
||||
|
||||
self.session.add(new_analysis)
|
||||
self.session.flush()
|
||||
self.session.refresh(new_analysis)
|
||||
|
||||
analysis.id = new_analysis.id
|
||||
return analysis
|
||||
|
||||
def update(self, analysis_vo: AnalysisVO) -> AnalysisVO | None:
|
||||
analysis = self.session.get(Analysis, analysis_vo.id)
|
||||
if not analysis:
|
||||
return None
|
||||
|
||||
# AnalysisVO의 데이터를 SQLModel 객체에 업데이트
|
||||
vo_data = analysis_vo.sqlmodel_dump(exclude_unset=True)
|
||||
for key, value in vo_data.items():
|
||||
if hasattr(analysis, key) and key != 'id': # id는 변경하지 않음
|
||||
setattr(analysis, key, value)
|
||||
|
||||
analysis.updated_at = datetime.now()
|
||||
self.session.add(analysis)
|
||||
self.session.flush()
|
||||
self.session.refresh(analysis)
|
||||
|
||||
return AnalysisVO(**row_to_dict(analysis))
|
||||
from analysis.domain.repository.analysis_repo import IAnalysisRepository
|
||||
from sqlmodel import Session, select
|
||||
from analysis.domain.analysis import Analysis as AnalysisVO
|
||||
from analysis.infra.db_models.analysis import Analysis, AnalysisStatus
|
||||
from analysis.interface.dto import TradingAnalysisRequest
|
||||
from utils.db_utils import row_to_dict
|
||||
from sqlalchemy.orm import selectinload
|
||||
from datetime import datetime, date
|
||||
|
||||
class AnalysisRepository(IAnalysisRepository):
|
||||
def __init__(self, session: Session):
|
||||
self.session = session
|
||||
|
||||
def find_by_member_id(self, member_id: str) -> list[AnalysisVO] | None:
|
||||
query = select(Analysis).where(Analysis.member_id == member_id)
|
||||
analyses = self.session.exec(query).all()
|
||||
|
||||
if not analyses:
|
||||
return None
|
||||
|
||||
return [AnalysisVO(**row_to_dict(analysis)) for analysis in analyses]
|
||||
|
||||
def find_by_id(self, analysis_id: str) -> AnalysisVO | None:
|
||||
analysis = self.session.get(Analysis, analysis_id)
|
||||
if not analysis:
|
||||
return None
|
||||
return AnalysisVO(**row_to_dict(analysis))
|
||||
|
||||
def save(self, analysis: AnalysisVO) -> AnalysisVO:
|
||||
new_analysis = Analysis(
|
||||
**analysis.model_dump()
|
||||
)
|
||||
|
||||
self.session.add(new_analysis)
|
||||
self.session.flush()
|
||||
self.session.refresh(new_analysis)
|
||||
|
||||
analysis.id = new_analysis.id
|
||||
return analysis
|
||||
|
||||
def update(self, analysis_vo: AnalysisVO) -> AnalysisVO | None:
|
||||
analysis = self.session.get(Analysis, analysis_vo.id)
|
||||
if not analysis:
|
||||
return None
|
||||
|
||||
# AnalysisVO의 데이터를 SQLModel 객체에 업데이트
|
||||
analysis_data = analysis_vo.model_dump(exclude_unset=True)
|
||||
|
||||
analysis.updated_at = datetime.now()
|
||||
analysis.sqlmodel_update(analysis_data)
|
||||
|
||||
self.session.flush()
|
||||
|
||||
|
||||
return AnalysisVO(**row_to_dict(analysis))
|
||||
|
|
|
|||
|
|
@ -1,108 +1,137 @@
|
|||
from typing import Annotated
|
||||
from fastapi import APIRouter, Depends, BackgroundTasks, HTTPException, status
|
||||
from analysis.interface.dto import (
|
||||
AnalysisSessionResponse,
|
||||
TradingAnalysisRequest,
|
||||
AnalysisResultResponse
|
||||
)
|
||||
from utils.auth import get_current_member, CurrentMember
|
||||
from dependency_injector.wiring import inject, Provide
|
||||
from analysis.application.analysis_service import AnalysisService
|
||||
from utils.containers import Container
|
||||
|
||||
router = APIRouter(prefix="/analysis", tags=["analysis"])
|
||||
|
||||
@router.get("/", response_model=list[AnalysisSessionResponse])
|
||||
@inject
|
||||
def get_analysis_list_for_member(
|
||||
current_member: Annotated[CurrentMember, Depends(get_current_member)],
|
||||
analysis_service: Annotated[AnalysisService, Depends(Provide[Container.analysis_service])]
|
||||
):
|
||||
"""
|
||||
현재 로그인한 사용자의 모든 분석 세션 목록을 조회합니다.
|
||||
"""
|
||||
analyses = analysis_service.get_analysis_list(current_member.id)
|
||||
return [
|
||||
AnalysisSessionResponse(
|
||||
id=analysis.id,
|
||||
ticker=analysis.ticker,
|
||||
status=analysis.status
|
||||
) for analysis in analyses
|
||||
]
|
||||
|
||||
@router.post("/start", status_code=201, response_model=AnalysisSessionResponse)
|
||||
@inject
|
||||
def start_analysis_session(
|
||||
request: TradingAnalysisRequest,
|
||||
current_member: Annotated[CurrentMember, Depends(get_current_member)],
|
||||
analysis_service: Annotated[AnalysisService, Depends(Provide[Container.analysis_service])],
|
||||
background_tasks: BackgroundTasks
|
||||
):
|
||||
"""
|
||||
새로운 분석 세션을 시작합니다.
|
||||
"""
|
||||
try:
|
||||
new_analysis = analysis_service.create_analysis(current_member.id, request, background_tasks)
|
||||
return AnalysisSessionResponse(
|
||||
id=new_analysis.id,
|
||||
ticker=new_analysis.ticker,
|
||||
status=new_analysis.status
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to start analysis: {str(e)}"
|
||||
)
|
||||
|
||||
@router.get("/{analysis_id}", response_model=AnalysisResultResponse)
|
||||
@inject
|
||||
def get_analysis_result(
|
||||
analysis_id: str,
|
||||
current_member: Annotated[CurrentMember, Depends(get_current_member)],
|
||||
analysis_service: Annotated[AnalysisService, Depends(Provide[Container.analysis_service])]
|
||||
):
|
||||
"""
|
||||
특정 분석 세션의 결과를 조회합니다.
|
||||
"""
|
||||
analysis = analysis_service.get_analysis_by_id(analysis_id, current_member.id)
|
||||
|
||||
return AnalysisResultResponse(
|
||||
id=analysis.id,
|
||||
ticker=analysis.ticker,
|
||||
analysis_date=analysis.analysis_date,
|
||||
status=analysis.status,
|
||||
market_report=analysis.market_report,
|
||||
sentiment_report=analysis.sentiment_report,
|
||||
news_report=analysis.news_report,
|
||||
fundamentals_report=analysis.fundamentals_report,
|
||||
investment_debate_state=analysis.investment_debate_state,
|
||||
trader_investment_plan=analysis.trader_investment_plan,
|
||||
risk_debate_state=analysis.risk_debate_state,
|
||||
final_trade_decision=analysis.final_trade_decision,
|
||||
final_report=analysis.final_report,
|
||||
created_at=analysis.created_at.isoformat(),
|
||||
completed_at=analysis.completed_at.isoformat() if analysis.completed_at else None,
|
||||
error_message=analysis.error_message
|
||||
)
|
||||
|
||||
@router.get("/{analysis_id}/status")
|
||||
@inject
|
||||
def get_analysis_status(
|
||||
analysis_id: str,
|
||||
current_member: Annotated[CurrentMember, Depends(get_current_member)],
|
||||
analysis_service: Annotated[AnalysisService, Depends(Provide[Container.analysis_service])]
|
||||
):
|
||||
"""
|
||||
분석 진행 상황을 조회합니다.
|
||||
"""
|
||||
analysis = analysis_service.get_analysis_by_id(analysis_id, current_member.id)
|
||||
|
||||
return {
|
||||
"analysis_id": analysis.id,
|
||||
"status": analysis.status,
|
||||
"ticker": analysis.ticker,
|
||||
"analysis_date": analysis.analysis_date,
|
||||
"created_at": analysis.created_at.isoformat(),
|
||||
"updated_at": analysis.updated_at.isoformat(),
|
||||
"error_message": analysis.error_message
|
||||
}
|
||||
from typing import Annotated
|
||||
from fastapi import APIRouter, Depends, BackgroundTasks, HTTPException, status, WebSocket, WebSocketDisconnect
|
||||
from analysis.interface.dto import (
|
||||
AnalysisSessionResponse,
|
||||
TradingAnalysisRequest,
|
||||
AnalysisResultResponse
|
||||
)
|
||||
from utils.auth import get_current_member, CurrentMember
|
||||
from dependency_injector.wiring import inject, Provide
|
||||
from analysis.application.analysis_service import AnalysisService
|
||||
from utils.containers import Container
|
||||
from analysis.application.websocket_manager import WebSocketManager
|
||||
|
||||
router = APIRouter(prefix="/analysis", tags=["analysis"])
|
||||
|
||||
@router.get("/", response_model=list[AnalysisSessionResponse])
|
||||
@inject
|
||||
def get_analysis_list_for_member(
|
||||
current_member: Annotated[CurrentMember, Depends(get_current_member)],
|
||||
analysis_service: Annotated[AnalysisService, Depends(Provide[Container.analysis_service])]
|
||||
):
|
||||
"""
|
||||
현재 로그인한 사용자의 모든 분석 세션 목록을 조회합니다.
|
||||
"""
|
||||
analyses = analysis_service.get_analysis_list(current_member.id)
|
||||
return [
|
||||
AnalysisSessionResponse(
|
||||
id=analysis.id,
|
||||
ticker=analysis.ticker,
|
||||
status=analysis.status
|
||||
) for analysis in analyses
|
||||
]
|
||||
|
||||
@router.post("/start", status_code=201, response_model=AnalysisSessionResponse)
|
||||
@inject
|
||||
def start_analysis_session(
|
||||
request: TradingAnalysisRequest,
|
||||
current_member: Annotated[CurrentMember, Depends(get_current_member)],
|
||||
analysis_service: Annotated[AnalysisService, Depends(Provide[Container.analysis_service])],
|
||||
background_tasks: BackgroundTasks
|
||||
):
|
||||
"""
|
||||
새로운 분석 세션을 시작합니다.
|
||||
|
||||
"""
|
||||
try:
|
||||
new_analysis = analysis_service.create_analysis(current_member.id, request, background_tasks)
|
||||
return AnalysisSessionResponse(
|
||||
id=new_analysis.id,
|
||||
ticker=new_analysis.ticker,
|
||||
status=new_analysis.status
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to start analysis: {str(e)}"
|
||||
)
|
||||
|
||||
@router.get("/{analysis_id}", response_model=AnalysisResultResponse)
|
||||
@inject
|
||||
def get_analysis_result(
|
||||
analysis_id: str,
|
||||
current_member: Annotated[CurrentMember, Depends(get_current_member)],
|
||||
analysis_service: Annotated[AnalysisService, Depends(Provide[Container.analysis_service])]
|
||||
):
|
||||
"""
|
||||
특정 분석 세션의 결과를 조회합니다.
|
||||
"""
|
||||
analysis = analysis_service.get_analysis_by_id(analysis_id, current_member.id)
|
||||
|
||||
return AnalysisResultResponse(
|
||||
id=analysis.id,
|
||||
ticker=analysis.ticker,
|
||||
analysis_date=analysis.analysis_date.isoformat() if hasattr(analysis.analysis_date, 'isoformat') else str(analysis.analysis_date),
|
||||
status=analysis.status,
|
||||
market_report=analysis.market_report,
|
||||
sentiment_report=analysis.sentiment_report,
|
||||
news_report=analysis.news_report,
|
||||
fundamentals_report=analysis.fundamentals_report,
|
||||
investment_debate_state=analysis.investment_debate_state,
|
||||
trader_investment_plan=analysis.trader_investment_plan,
|
||||
risk_debate_state=analysis.risk_debate_state,
|
||||
final_trade_decision=analysis.final_trade_decision,
|
||||
final_report=analysis.final_report,
|
||||
created_at=analysis.created_at.isoformat(),
|
||||
completed_at=analysis.completed_at.isoformat() if analysis.completed_at else None,
|
||||
error_message=analysis.error_message
|
||||
)
|
||||
|
||||
@router.get("/{analysis_id}/status")
|
||||
@inject
|
||||
def get_analysis_status(
|
||||
analysis_id: str,
|
||||
current_member: Annotated[CurrentMember, Depends(get_current_member)],
|
||||
analysis_service: Annotated[AnalysisService, Depends(Provide[Container.analysis_service])]
|
||||
):
|
||||
"""
|
||||
분석 진행 상황을 조회합니다.
|
||||
"""
|
||||
analysis = analysis_service.get_analysis_by_id(analysis_id, current_member.id)
|
||||
|
||||
return {
|
||||
"analysis_id": analysis.id,
|
||||
"status": analysis.status,
|
||||
"ticker": analysis.ticker,
|
||||
"analysis_date": analysis.analysis_date,
|
||||
"created_at": analysis.created_at.isoformat(),
|
||||
"updated_at": analysis.updated_at.isoformat(),
|
||||
"error_message": analysis.error_message
|
||||
}
|
||||
|
||||
@router.websocket("/ws")
|
||||
@inject
|
||||
async def websocket_endpoint(
|
||||
websocket: WebSocket,
|
||||
current_member: Annotated[CurrentMember, Depends(get_current_member)],
|
||||
websocket_manager: Annotated[WebSocketManager, Depends(Provide[Container.websocket_manager])]
|
||||
):
|
||||
"""
|
||||
WebSocket endpoint for real-time analysis updates
|
||||
"""
|
||||
try:
|
||||
# Connect the websocket
|
||||
await websocket_manager.connect(websocket, current_member.id)
|
||||
|
||||
try:
|
||||
# Keep connection alive
|
||||
while True:
|
||||
# Wait for messages from client (like ping/pong)
|
||||
data = await websocket.receive_text()
|
||||
# Echo back for heartbeat
|
||||
if data == "ping":
|
||||
await websocket.send_text("pong")
|
||||
except WebSocketDisconnect:
|
||||
websocket_manager.disconnect(websocket, current_member.id)
|
||||
except Exception as e:
|
||||
await websocket.close(code=1011, reason=str(e))
|
||||
|
|
|
|||
|
|
@ -1,52 +1,52 @@
|
|||
from pydantic import BaseModel
|
||||
from datetime import date
|
||||
from typing import List
|
||||
from analysis.infra.db_models.analysis import AnalysisStatus
|
||||
from enum import Enum
|
||||
|
||||
class AnalystType(str, Enum):
|
||||
MARKET = "market"
|
||||
SOCIAL = "social"
|
||||
NEWS = "news"
|
||||
FUNDAMENTALS = "fundamentals"
|
||||
|
||||
class TradingAnalysisRequest(BaseModel):
|
||||
ticker: str
|
||||
analysis_date: str
|
||||
analysts: List[AnalystType]
|
||||
research_depth: int = 3
|
||||
llm_provider: str = "openai"
|
||||
backend_url: str = "https://api.openai.com/v1"
|
||||
shallow_thinker: str = "gpt-4o-mini"
|
||||
deep_thinker: str = "gpt-4o"
|
||||
|
||||
class AnalysisSessionResponse(BaseModel):
|
||||
id : str
|
||||
ticker : str
|
||||
status : AnalysisStatus
|
||||
|
||||
class AnalysisProgressUpdate(BaseModel):
|
||||
analysis_id: str
|
||||
current_agent: str
|
||||
status: str
|
||||
progress_percentage: float
|
||||
current_report_section: str | None = None
|
||||
message: str | None = None
|
||||
|
||||
class AnalysisResultResponse(BaseModel):
|
||||
id: str
|
||||
ticker: str
|
||||
analysis_date: str
|
||||
status: AnalysisStatus
|
||||
market_report: str | None = None
|
||||
sentiment_report: str | None = None
|
||||
news_report: str | None = None
|
||||
fundamentals_report: str | None = None
|
||||
investment_debate_state: dict | None = None
|
||||
trader_investment_plan: str | None = None
|
||||
risk_debate_state: dict | None = None
|
||||
final_trade_decision: str | None = None
|
||||
final_report: str | None = None
|
||||
created_at: str
|
||||
completed_at: str | None = None
|
||||
from pydantic import BaseModel
|
||||
from datetime import date
|
||||
from typing import List
|
||||
from analysis.infra.db_models.analysis import AnalysisStatus
|
||||
from enum import Enum
|
||||
|
||||
class AnalystType(str, Enum):
|
||||
MARKET = "market"
|
||||
SOCIAL = "social"
|
||||
NEWS = "news"
|
||||
FUNDAMENTALS = "fundamentals"
|
||||
|
||||
class TradingAnalysisRequest(BaseModel):
|
||||
ticker: str = "NVDA"
|
||||
analysis_date: str = "2025-07-07"
|
||||
analysts: List[AnalystType] = [AnalystType.MARKET, AnalystType.SOCIAL, AnalystType.NEWS, AnalystType.FUNDAMENTALS]
|
||||
research_depth: int = 3
|
||||
llm_provider: str = "openai"
|
||||
backend_url: str = "https://api.openai.com/v1"
|
||||
shallow_thinker: str = "gpt-4o-mini"
|
||||
deep_thinker: str = "gpt-4o-mini"
|
||||
|
||||
class AnalysisSessionResponse(BaseModel):
|
||||
id : str
|
||||
ticker : str
|
||||
status : AnalysisStatus
|
||||
|
||||
class AnalysisProgressUpdate(BaseModel):
|
||||
analysis_id: str
|
||||
current_agent: str
|
||||
status: str
|
||||
progress_percentage: float
|
||||
current_report_section: str | None = None
|
||||
message: str | None = None
|
||||
|
||||
class AnalysisResultResponse(BaseModel):
|
||||
id: str
|
||||
ticker: str
|
||||
analysis_date: str
|
||||
status: AnalysisStatus
|
||||
market_report: str | None = None
|
||||
sentiment_report: str | None = None
|
||||
news_report: str | None = None
|
||||
fundamentals_report: str | None = None
|
||||
investment_debate_state: dict | None = None
|
||||
trader_investment_plan: str | None = None
|
||||
risk_debate_state: dict | None = None
|
||||
final_trade_decision: str | None = None
|
||||
final_report: str | None = None
|
||||
created_at: str
|
||||
completed_at: str | None = None
|
||||
error_message: str | None = None
|
||||
|
|
@ -1,20 +1,20 @@
|
|||
from functools import lru_cache
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=".env",
|
||||
env_file_encoding="utf-8",
|
||||
)
|
||||
|
||||
# MySQL 데이터베이스 설정
|
||||
DB_HOST: str
|
||||
DB_PORT: int
|
||||
DB_USER: str
|
||||
DB_PASSWORD: str
|
||||
DB_NAME: str
|
||||
SECRET_KEY: str
|
||||
|
||||
@lru_cache
|
||||
def get_settings():
|
||||
from functools import lru_cache
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=".env",
|
||||
env_file_encoding="utf-8",
|
||||
)
|
||||
|
||||
# MySQL 데이터베이스 설정
|
||||
DB_HOST: str
|
||||
DB_PORT: int
|
||||
DB_USER: str
|
||||
DB_PASSWORD: str
|
||||
DB_NAME: str
|
||||
SECRET_KEY: str
|
||||
|
||||
@lru_cache
|
||||
def get_settings():
|
||||
return Settings()
|
||||
|
|
@ -1,20 +1,30 @@
|
|||
from fastapi import FastAPI
|
||||
from utils.database import create_db_and_tables
|
||||
from utils.containers import Container
|
||||
|
||||
|
||||
from analysis.interface.controller.analysis_controller import router as analysis_router
|
||||
from member.interface.controller.member_controller import router as member_router
|
||||
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
app.container = Container()
|
||||
|
||||
app.include_router(analysis_router)
|
||||
app.include_router(member_router)
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
def startup_db_client():
|
||||
from fastapi import FastAPI
|
||||
from utils.database import create_db_and_tables
|
||||
from utils.containers import Container
|
||||
|
||||
|
||||
from analysis.interface.controller.analysis_controller import router as analysis_router
|
||||
from member.interface.controller.member_controller import router as member_router
|
||||
import logging
|
||||
|
||||
# 로깅 설정
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.StreamHandler(), # 콘솔 출력
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
app.container = Container()
|
||||
|
||||
app.include_router(analysis_router)
|
||||
app.include_router(member_router)
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
def startup_db_client():
|
||||
create_db_and_tables()
|
||||
|
|
@ -1,97 +1,97 @@
|
|||
from sqlmodel import Session
|
||||
from utils.crypto import Crypto
|
||||
from member.domain.repository.member_repo import IMemberRepository
|
||||
from utils.auth import Role
|
||||
from member.domain.member import Member as MemberVO
|
||||
from fastapi import HTTPException, status
|
||||
from datetime import datetime
|
||||
from utils.auth import create_access_token
|
||||
from ulid import ULID
|
||||
from analysis.domain.analysis import Analysis as AnalysisVO
|
||||
|
||||
class MemberService:
|
||||
def __init__(
|
||||
self,
|
||||
member_repo: IMemberRepository,
|
||||
crypto: Crypto,
|
||||
db_session: Session,
|
||||
ulid: ULID
|
||||
):
|
||||
self.member_repo = member_repo
|
||||
self.crypto = crypto
|
||||
self.db_session = db_session
|
||||
self.ulid = ulid
|
||||
|
||||
def create_member(
|
||||
self,
|
||||
name: str,
|
||||
email: str,
|
||||
password: str,
|
||||
role: Role
|
||||
):
|
||||
try:
|
||||
if self.member_repo.find_by_email(email):
|
||||
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Email already exists")
|
||||
except Exception as e:
|
||||
self.db_session.rollback()
|
||||
raise e
|
||||
|
||||
now = datetime.now()
|
||||
member_vo = MemberVO(
|
||||
id=self.ulid.generate(),
|
||||
name=name,
|
||||
email=email,
|
||||
password=self.crypto.encrypt(password),
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
role=role
|
||||
)
|
||||
|
||||
saved_member = self.member_repo.save(member_vo)
|
||||
self.db_session.commit()
|
||||
|
||||
return saved_member
|
||||
|
||||
|
||||
def get_members(
|
||||
self,
|
||||
page: int,
|
||||
items_per_page: int
|
||||
)->tuple[int, list[MemberVO]] :
|
||||
return self.member_repo.get_members(page, items_per_page)
|
||||
|
||||
def get_member(
|
||||
self,
|
||||
id: str
|
||||
)->MemberVO | None:
|
||||
member = self.member_repo.find_by_id(id)
|
||||
if not member:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Member not found")
|
||||
return member
|
||||
|
||||
def login(
|
||||
self,
|
||||
email: str,
|
||||
password: str
|
||||
):
|
||||
member = self.member_repo.find_by_email(email)
|
||||
if not member:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials")
|
||||
|
||||
if not self.crypto.verify(password, member.password):
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials")
|
||||
|
||||
access_token = create_access_token(
|
||||
payload={"member_id": member.id, "role": member.role},
|
||||
role=member.role,
|
||||
)
|
||||
|
||||
return access_token
|
||||
|
||||
def get_analysis_sessions_by_member(
|
||||
self,
|
||||
member_id: str
|
||||
)->list[AnalysisVO]:
|
||||
analysis_sessions = self.member_repo.find_analysis_sessions_by_member(member_id)
|
||||
return analysis_sessions
|
||||
from sqlmodel import Session
|
||||
from utils.crypto import Crypto
|
||||
from member.domain.repository.member_repo import IMemberRepository
|
||||
from utils.auth import Role
|
||||
from member.domain.member import Member as MemberVO
|
||||
from fastapi import HTTPException, status
|
||||
from datetime import datetime
|
||||
from utils.auth import create_access_token
|
||||
from ulid import ULID
|
||||
from analysis.domain.analysis import Analysis as AnalysisVO
|
||||
|
||||
class MemberService:
|
||||
def __init__(
|
||||
self,
|
||||
member_repo: IMemberRepository,
|
||||
crypto: Crypto,
|
||||
session: Session,
|
||||
ulid: ULID
|
||||
):
|
||||
self.member_repo = member_repo
|
||||
self.crypto = crypto
|
||||
self.db_session = session
|
||||
self.ulid = ulid
|
||||
|
||||
def create_member(
|
||||
self,
|
||||
name: str,
|
||||
email: str,
|
||||
password: str,
|
||||
role: Role
|
||||
):
|
||||
try:
|
||||
if self.member_repo.find_by_email(email):
|
||||
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Email already exists")
|
||||
except Exception as e:
|
||||
self.db_session.rollback()
|
||||
raise e
|
||||
|
||||
now = datetime.now()
|
||||
member_vo = MemberVO(
|
||||
id=self.ulid.generate(),
|
||||
name=name,
|
||||
email=email,
|
||||
password=self.crypto.encrypt(password),
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
role=role
|
||||
)
|
||||
|
||||
saved_member = self.member_repo.save(member_vo)
|
||||
self.db_session.commit()
|
||||
|
||||
return saved_member
|
||||
|
||||
|
||||
def get_members(
|
||||
self,
|
||||
page: int,
|
||||
items_per_page: int
|
||||
)->tuple[int, list[MemberVO]] :
|
||||
return self.member_repo.get_members(page, items_per_page)
|
||||
|
||||
def get_member(
|
||||
self,
|
||||
id: str
|
||||
)->MemberVO | None:
|
||||
member = self.member_repo.find_by_id(id)
|
||||
if not member:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Member not found")
|
||||
return member
|
||||
|
||||
def login(
|
||||
self,
|
||||
email: str,
|
||||
password: str
|
||||
):
|
||||
member = self.member_repo.find_by_email(email)
|
||||
if not member:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials")
|
||||
|
||||
if not self.crypto.verify(password, member.password):
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials")
|
||||
|
||||
access_token = create_access_token(
|
||||
payload={"member_id": member.id, "role": member.role},
|
||||
role=member.role,
|
||||
)
|
||||
|
||||
return access_token
|
||||
|
||||
def get_analysis_sessions_by_member(
|
||||
self,
|
||||
member_id: str
|
||||
)->list[AnalysisVO]:
|
||||
analysis_sessions = self.member_repo.find_analysis_sessions_by_member(member_id)
|
||||
return analysis_sessions
|
||||
|
||||
|
|
@ -1,12 +1,12 @@
|
|||
from pydantic import BaseModel
|
||||
from utils.auth import Role
|
||||
from datetime import datetime
|
||||
|
||||
class Member(BaseModel):
|
||||
id: str | None = None
|
||||
name: str
|
||||
email: str
|
||||
password: str
|
||||
role: Role
|
||||
created_at: datetime
|
||||
from pydantic import BaseModel
|
||||
from utils.auth import Role
|
||||
from datetime import datetime
|
||||
|
||||
class Member(BaseModel):
|
||||
id: str | None = None
|
||||
name: str
|
||||
email: str
|
||||
password: str
|
||||
role: Role
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
from abc import ABC, abstractmethod
|
||||
|
||||
class IMemberRepository(ABC):
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
class IMemberRepository(ABC):
|
||||
pass
|
||||
|
|
@ -1,24 +1,24 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from member.domain.member import Member as MemberVO
|
||||
from analysis.domain.analysis import Analysis as AnalysisVO
|
||||
|
||||
class IMemberRepository(ABC):
|
||||
@abstractmethod
|
||||
def find_by_email(self, email: str) -> MemberVO | None:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def save(self, member: MemberVO) -> MemberVO:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def find_by_id(self, id: str) -> MemberVO | None:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def get_members(self, page: int, items_per_page: int) -> tuple[int, list[MemberVO]]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def find_analysis_sessions_by_member(self, member_id: str) -> list[AnalysisVO]:
|
||||
from abc import ABC, abstractmethod
|
||||
from member.domain.member import Member as MemberVO
|
||||
from analysis.domain.analysis import Analysis as AnalysisVO
|
||||
|
||||
class IMemberRepository(ABC):
|
||||
@abstractmethod
|
||||
def find_by_email(self, email: str) -> MemberVO | None:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def save(self, member: MemberVO) -> MemberVO:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def find_by_id(self, id: str) -> MemberVO | None:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def get_members(self, page: int, items_per_page: int) -> tuple[int, list[MemberVO]]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def find_analysis_sessions_by_member(self, member_id: str) -> list[AnalysisVO]:
|
||||
raise NotImplementedError()
|
||||
|
|
@ -1,66 +1,66 @@
|
|||
from member.domain.repository import IMemberRepository
|
||||
from sqlmodel import Session, select
|
||||
from member.domain.member import Member as MemberVO
|
||||
from member.infra.db_models.member import Member
|
||||
from utils.db_utils import row_to_dict
|
||||
from sqlalchemy import func
|
||||
|
||||
class MemberRepository(IMemberRepository):
|
||||
def __init__(self, session: Session):
|
||||
self.session = session
|
||||
|
||||
def find_by_email(self, email: str) -> MemberVO | None:
|
||||
query = select(Member).where(Member.email == email)
|
||||
member = self.session.exec(query).first()
|
||||
|
||||
if not member:
|
||||
return None
|
||||
|
||||
return MemberVO(**row_to_dict(member))
|
||||
|
||||
def save(self, member: MemberVO) -> MemberVO:
|
||||
new_member = Member(
|
||||
id=member.id,
|
||||
email=member.email,
|
||||
name=member.name,
|
||||
password=member.password,
|
||||
role=member.role,
|
||||
created_at=member.created_at,
|
||||
updated_at=member.updated_at
|
||||
)
|
||||
|
||||
self.session.add(new_member)
|
||||
self.session.flush()
|
||||
self.session.refresh(new_member)
|
||||
|
||||
member.id = new_member.id
|
||||
return member
|
||||
|
||||
|
||||
def get_members(self, page: int, items_per_page: int) -> tuple[int, list[MemberVO]]:
|
||||
offset = (page - 1) * items_per_page
|
||||
total_count_query = select(func.count(Member.id))
|
||||
total_count = self.session.exec(total_count_query).one()
|
||||
|
||||
if total_count == 0:
|
||||
return 0, []
|
||||
|
||||
query = (
|
||||
select(Member)
|
||||
.order_by(Member.created_at.desc())
|
||||
.offset(offset)
|
||||
.limit(items_per_page)
|
||||
)
|
||||
|
||||
members = self.session.exec(query).all()
|
||||
|
||||
return total_count, [MemberVO(**row_to_dict(member)) for member in members]
|
||||
|
||||
def find_by_id(self, id: str) -> MemberVO | None:
|
||||
query = select(Member).where(Member.id == id)
|
||||
member = self.session.exec(query).first()
|
||||
|
||||
if not member:
|
||||
return None
|
||||
|
||||
return MemberVO(**row_to_dict(member))
|
||||
from member.domain.repository import IMemberRepository
|
||||
from sqlmodel import Session, select
|
||||
from member.domain.member import Member as MemberVO
|
||||
from member.infra.db_models.member import Member
|
||||
from utils.db_utils import row_to_dict
|
||||
from sqlalchemy import func
|
||||
|
||||
class MemberRepository(IMemberRepository):
|
||||
def __init__(self, session: Session):
|
||||
self.session = session
|
||||
|
||||
def find_by_email(self, email: str) -> MemberVO | None:
|
||||
query = select(Member).where(Member.email == email)
|
||||
member = self.session.exec(query).first()
|
||||
|
||||
if not member:
|
||||
return None
|
||||
|
||||
return MemberVO(**row_to_dict(member))
|
||||
|
||||
def save(self, member: MemberVO) -> MemberVO:
|
||||
new_member = Member(
|
||||
id=member.id,
|
||||
email=member.email,
|
||||
name=member.name,
|
||||
password=member.password,
|
||||
role=member.role,
|
||||
created_at=member.created_at,
|
||||
updated_at=member.updated_at
|
||||
)
|
||||
|
||||
self.session.add(new_member)
|
||||
self.session.flush()
|
||||
self.session.refresh(new_member)
|
||||
|
||||
member.id = new_member.id
|
||||
return member
|
||||
|
||||
|
||||
def get_members(self, page: int, items_per_page: int) -> tuple[int, list[MemberVO]]:
|
||||
offset = (page - 1) * items_per_page
|
||||
total_count_query = select(func.count(Member.id))
|
||||
total_count = self.session.exec(total_count_query).one()
|
||||
|
||||
if total_count == 0:
|
||||
return 0, []
|
||||
|
||||
query = (
|
||||
select(Member)
|
||||
.order_by(Member.created_at.desc())
|
||||
.offset(offset)
|
||||
.limit(items_per_page)
|
||||
)
|
||||
|
||||
members = self.session.exec(query).all()
|
||||
|
||||
return total_count, [MemberVO(**row_to_dict(member)) for member in members]
|
||||
|
||||
def find_by_id(self, id: str) -> MemberVO | None:
|
||||
query = select(Member).where(Member.id == id)
|
||||
member = self.session.exec(query).first()
|
||||
|
||||
if not member:
|
||||
return None
|
||||
|
||||
return MemberVO(**row_to_dict(member))
|
||||
|
|
|
|||
|
|
@ -1,78 +1,71 @@
|
|||
from fastapi import APIRouter, status, Depends,HTTPException
|
||||
from member.interface.dto import CreateUserBody, MemberResponse
|
||||
from member.application.member_service import MemberService
|
||||
from typing import Annotated
|
||||
from utils.containers import Container
|
||||
from dependency_injector.wiring import inject, Provide
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from utils.auth import get_current_member, CurrentMember, get_admin_member
|
||||
|
||||
router = APIRouter(prefix="/members", tags=["members"])
|
||||
|
||||
@router.post("", status_code=status.HTTP_201_CREATED, response_model=MemberResponse)
|
||||
@inject
|
||||
async def create_user(
|
||||
member: CreateUserBody,
|
||||
member_service: MemberService = Depends(Provide[Container.member_service])
|
||||
):
|
||||
created_member = member_service.create_member(
|
||||
member.name,
|
||||
member.email,
|
||||
member.password,
|
||||
member.role
|
||||
)
|
||||
|
||||
return created_member
|
||||
|
||||
@router.post("/login")
|
||||
@inject
|
||||
def login(
|
||||
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
||||
member_service: MemberService = Depends(Provide[Container.member_service])
|
||||
):
|
||||
access_token = member_service.login(
|
||||
email=form_data.username,
|
||||
password=form_data.password
|
||||
)
|
||||
|
||||
return {
|
||||
"access_token" : access_token,
|
||||
"token_type" : "Bearer"
|
||||
}
|
||||
|
||||
@router.get("/me", response_model=dict)
|
||||
def get_current_user_info(
|
||||
current_user: CurrentMember = Depends(get_current_member)
|
||||
):
|
||||
"""
|
||||
현재 로그인한 사용자 정보를 조회합니다.
|
||||
이 엔드포인트는 JWT 토큰이 필요하며, Swagger UI에서 Authorize 버튼을 활성화합니다.
|
||||
"""
|
||||
return {
|
||||
"user_id": current_user.id,
|
||||
"role": current_user.role,
|
||||
"message": "Successfully authenticated"
|
||||
}
|
||||
|
||||
@router.get("/{member_id}", response_model=MemberResponse)
|
||||
@inject
|
||||
def get_member(
|
||||
member_id: str,
|
||||
current_member: Annotated[CurrentMember | None, Depends(get_current_member)] = None,
|
||||
member_service: Annotated[MemberService | None, Depends(Provide[Container.member_service])] = None
|
||||
):
|
||||
|
||||
member = member_service.get_member(member_id)
|
||||
if not member:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Member not found")
|
||||
return member
|
||||
|
||||
# @router.get("/analysis-sessions", response_model=list[AnalysisSessionResponse])
|
||||
# @inject
|
||||
# def get_member_analysis_sessions(
|
||||
# current_member: Annotated[CurrentMember | None, Depends(get_current_member)] = None,
|
||||
# member_service: Annotated[MemberService | None, Depends(Provide[Container.member_service])] = None
|
||||
# ):
|
||||
|
||||
# result = member_service.get_analysis_sessions_by_member(current_member.id)
|
||||
# return result
|
||||
from fastapi import APIRouter, status, Depends,HTTPException
|
||||
from member.interface.dto import CreateUserBody, MemberResponse
|
||||
from member.application.member_service import MemberService
|
||||
from typing import Annotated
|
||||
from utils.containers import Container
|
||||
from dependency_injector.wiring import inject, Provide
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from utils.auth import get_current_member, CurrentMember, get_admin_member
|
||||
from analysis.interface.dto import AnalysisSessionResponse
|
||||
from analysis.application.analysis_service import AnalysisService
|
||||
|
||||
router = APIRouter(prefix="/members", tags=["members"])
|
||||
|
||||
@router.post("", status_code=status.HTTP_201_CREATED, response_model=MemberResponse)
|
||||
@inject
|
||||
async def create_user(
|
||||
member: CreateUserBody,
|
||||
member_service: MemberService = Depends(Provide[Container.member_service])
|
||||
):
|
||||
created_member = member_service.create_member(
|
||||
member.name,
|
||||
member.email,
|
||||
member.password,
|
||||
member.role
|
||||
)
|
||||
|
||||
return created_member
|
||||
|
||||
@router.post("/login")
|
||||
@inject
|
||||
def login(
|
||||
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
||||
member_service: MemberService = Depends(Provide[Container.member_service])
|
||||
):
|
||||
access_token = member_service.login(
|
||||
email=form_data.username,
|
||||
password=form_data.password
|
||||
)
|
||||
|
||||
return {
|
||||
"access_token" : access_token,
|
||||
"token_type" : "Bearer"
|
||||
}
|
||||
|
||||
@router.get("/me", response_model=dict)
|
||||
def get_current_user_info(
|
||||
current_user: CurrentMember = Depends(get_current_member)
|
||||
):
|
||||
"""
|
||||
현재 로그인한 사용자 정보를 조회합니다.
|
||||
이 엔드포인트는 JWT 토큰이 필요하며, Swagger UI에서 Authorize 버튼을 활성화합니다.
|
||||
"""
|
||||
return {
|
||||
"user_id": current_user.id,
|
||||
"role": current_user.role,
|
||||
"message": "Successfully authenticated"
|
||||
}
|
||||
|
||||
@router.get("/{member_id}", response_model=MemberResponse)
|
||||
@inject
|
||||
def get_member(
|
||||
member_id: str,
|
||||
current_member: Annotated[CurrentMember | None, Depends(get_current_member)] = None,
|
||||
member_service: Annotated[MemberService | None, Depends(Provide[Container.member_service])] = None
|
||||
):
|
||||
|
||||
member = member_service.get_member(member_id)
|
||||
if not member:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Member not found")
|
||||
return member
|
||||
|
||||
|
|
|
|||
|
|
@ -1,18 +1,18 @@
|
|||
from typing import Annotated
|
||||
from pydantic import BaseModel, Field, EmailStr
|
||||
from utils.auth import Role
|
||||
from datetime import datetime
|
||||
|
||||
class CreateUserBody(BaseModel):
|
||||
name : Annotated[str, Field(min_length=1, max_length=32)]
|
||||
email : Annotated[EmailStr, Field(max_length=32)]
|
||||
password : Annotated[str, Field(max_length=32)]
|
||||
role : Annotated[Role, Field(default=Role.USER)]
|
||||
|
||||
class MemberResponse(BaseModel):
|
||||
id : str
|
||||
name : str | None = None
|
||||
email : str
|
||||
created_at : datetime
|
||||
updated_at : datetime
|
||||
role : Role
|
||||
from typing import Annotated
|
||||
from pydantic import BaseModel, Field, EmailStr
|
||||
from utils.auth import Role
|
||||
from datetime import datetime
|
||||
|
||||
class CreateUserBody(BaseModel):
|
||||
name : Annotated[str, Field(min_length=1, max_length=32)]
|
||||
email : Annotated[EmailStr, Field(max_length=32)]
|
||||
password : Annotated[str, Field(max_length=32)]
|
||||
role : Annotated[Role, Field(default=Role.USER)]
|
||||
|
||||
class MemberResponse(BaseModel):
|
||||
id : str
|
||||
name : str | None = None
|
||||
email : str
|
||||
created_at : datetime
|
||||
updated_at : datetime
|
||||
role : Role
|
||||
|
|
|
|||
|
|
@ -1,69 +1,69 @@
|
|||
from datetime import datetime, timedelta
|
||||
from fastapi import HTTPException, status, Depends
|
||||
from jose import jwt, JWTError
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
from enum import StrEnum
|
||||
from pydantic import BaseModel
|
||||
from typing import Annotated
|
||||
|
||||
from config import get_settings
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
SECRET_KEY = settings.SECRET_KEY
|
||||
ALGORITHM = "HS256"
|
||||
|
||||
class Role(StrEnum):
|
||||
ADMIN = "ADMIN"
|
||||
USER = "USER"
|
||||
|
||||
class CurrentMember(BaseModel):
|
||||
id : str
|
||||
role : Role
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.id}({self.role})"
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/members/login")
|
||||
|
||||
|
||||
|
||||
def create_access_token(
|
||||
payload: dict,
|
||||
role: Role,
|
||||
expires_delta: timedelta = timedelta(hours=6)
|
||||
):
|
||||
expire = datetime.utcnow() + expires_delta
|
||||
payload.update({"exp": expire, "role": role})
|
||||
encoded_jwt = jwt.encode(payload, SECRET_KEY, algorithm=ALGORITHM)
|
||||
|
||||
return encoded_jwt
|
||||
|
||||
def decode_access_token(token: str):
|
||||
try:
|
||||
return jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
except JWTError:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token")
|
||||
|
||||
|
||||
# ✅ 수정된 부분: Annotated 올바른 사용법
|
||||
def get_current_member(token: Annotated[str, Depends(oauth2_scheme)]):
|
||||
payload = decode_access_token(token)
|
||||
member_id = payload.get("member_id")
|
||||
role = payload.get("role")
|
||||
if not member_id or not role:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Invalid token")
|
||||
|
||||
return CurrentMember(id=member_id, role=Role(role))
|
||||
|
||||
def get_admin_member(token: Annotated[str, Depends(oauth2_scheme)]):
|
||||
payload = decode_access_token(token)
|
||||
member_id = payload.get("member_id")
|
||||
role = payload.get("role")
|
||||
|
||||
if not role or role != Role.ADMIN:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Invalid token")
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from fastapi import HTTPException, status, Depends
|
||||
from jose import jwt, JWTError
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
from enum import StrEnum
|
||||
from pydantic import BaseModel
|
||||
from typing import Annotated
|
||||
|
||||
from config import get_settings
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
SECRET_KEY = settings.SECRET_KEY
|
||||
ALGORITHM = "HS256"
|
||||
|
||||
class Role(StrEnum):
|
||||
ADMIN = "ADMIN"
|
||||
USER = "USER"
|
||||
|
||||
class CurrentMember(BaseModel):
|
||||
id : str
|
||||
role : Role
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.id}({self.role})"
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/members/login")
|
||||
|
||||
|
||||
|
||||
def create_access_token(
|
||||
payload: dict,
|
||||
role: Role,
|
||||
expires_delta: timedelta = timedelta(hours=6)
|
||||
):
|
||||
expire = datetime.utcnow() + expires_delta
|
||||
payload.update({"exp": expire, "role": role})
|
||||
encoded_jwt = jwt.encode(payload, SECRET_KEY, algorithm=ALGORITHM)
|
||||
|
||||
return encoded_jwt
|
||||
|
||||
def decode_access_token(token: str):
|
||||
try:
|
||||
return jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
except JWTError:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token")
|
||||
|
||||
|
||||
# ✅ 수정된 부분: Annotated 올바른 사용법
|
||||
def get_current_member(token: Annotated[str, Depends(oauth2_scheme)]):
|
||||
payload = decode_access_token(token)
|
||||
member_id = payload.get("member_id")
|
||||
role = payload.get("role")
|
||||
if not member_id or not role:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Invalid token")
|
||||
|
||||
return CurrentMember(id=member_id, role=Role(role))
|
||||
|
||||
def get_admin_member(token: Annotated[str, Depends(oauth2_scheme)]):
|
||||
payload = decode_access_token(token)
|
||||
member_id = payload.get("member_id")
|
||||
role = payload.get("role")
|
||||
|
||||
if not role or role != Role.ADMIN:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Invalid token")
|
||||
|
||||
return CurrentMember(id=member_id, role=Role(role))
|
||||
|
|
@ -1,43 +1,49 @@
|
|||
from dependency_injector import containers, providers
|
||||
from utils.database import get_session
|
||||
from utils.crypto import Crypto
|
||||
from member.infra.repository.member_repo import MemberRepository
|
||||
from member.application.member_service import MemberService
|
||||
from analysis.application.analysis_service import AnalysisService
|
||||
from analysis.infra.repository.analysis_repo import AnalysisRepository
|
||||
from ulid import ULID
|
||||
|
||||
class Container(containers.DeclarativeContainer):
|
||||
wiring_config = containers.WiringConfiguration(
|
||||
packages=["member", "analysis"]
|
||||
)
|
||||
|
||||
db_session = providers.Resource(get_session)
|
||||
crypto = providers.Factory(Crypto)
|
||||
ulid = providers.Factory(ULID)
|
||||
|
||||
member_repo = providers.Factory(
|
||||
MemberRepository,
|
||||
session=db_session
|
||||
)
|
||||
|
||||
member_service = providers.Factory(
|
||||
MemberService,
|
||||
member_repo=member_repo,
|
||||
crypto=crypto,
|
||||
db_session=db_session,
|
||||
ulid=ulid
|
||||
)
|
||||
|
||||
analysis_repo = providers.Factory(
|
||||
AnalysisRepository,
|
||||
session=db_session
|
||||
)
|
||||
|
||||
analysis_service = providers.Factory(
|
||||
AnalysisService,
|
||||
analysis_repo=analysis_repo,
|
||||
db_session=db_session,
|
||||
ulid=ulid
|
||||
)
|
||||
|
||||
from dependency_injector import containers, providers
|
||||
from utils.database import get_session
|
||||
from utils.crypto import Crypto
|
||||
from member.infra.repository.member_repo import MemberRepository
|
||||
from member.application.member_service import MemberService
|
||||
from analysis.application.analysis_service import AnalysisService
|
||||
from analysis.infra.repository.analysis_repo import AnalysisRepository
|
||||
from analysis.application.websocket_manager import WebSocketManager
|
||||
from ulid import ULID
|
||||
|
||||
class Container(containers.DeclarativeContainer):
|
||||
wiring_config = containers.WiringConfiguration(
|
||||
packages=["member", "analysis"]
|
||||
)
|
||||
|
||||
session = providers.Resource(get_session)
|
||||
crypto = providers.Factory(Crypto)
|
||||
ulid = providers.Factory(ULID)
|
||||
|
||||
member_repo = providers.Factory(
|
||||
MemberRepository,
|
||||
session=session
|
||||
)
|
||||
|
||||
member_service = providers.Factory(
|
||||
MemberService,
|
||||
member_repo=member_repo,
|
||||
crypto=crypto,
|
||||
session=session,
|
||||
ulid=ulid
|
||||
)
|
||||
|
||||
analysis_repo = providers.Factory(
|
||||
AnalysisRepository,
|
||||
session=session
|
||||
)
|
||||
|
||||
websocket_manager = providers.Singleton(
|
||||
WebSocketManager
|
||||
)
|
||||
|
||||
analysis_service = providers.Factory(
|
||||
AnalysisService,
|
||||
analysis_repo=analysis_repo,
|
||||
session=session,
|
||||
ulid=ulid,
|
||||
websocket_manager=websocket_manager
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,12 +1,12 @@
|
|||
from passlib.context import CryptContext
|
||||
|
||||
class Crypto:
|
||||
def __init__(self):
|
||||
self.pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
def encrypt(self, secret):
|
||||
return self.pwd_context.hash(secret)
|
||||
|
||||
def verify(self, secret, hash):
|
||||
return self.pwd_context.verify(secret, hash)
|
||||
from passlib.context import CryptContext
|
||||
|
||||
class Crypto:
|
||||
def __init__(self):
|
||||
self.pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
def encrypt(self, secret):
|
||||
return self.pwd_context.hash(secret)
|
||||
|
||||
def verify(self, secret, hash):
|
||||
return self.pwd_context.verify(secret, hash)
|
||||
|
||||
|
|
@ -1,32 +1,32 @@
|
|||
import os
|
||||
from pathlib import Path
|
||||
from sqlmodel import SQLModel, create_engine, Session
|
||||
from config.config import get_settings
|
||||
from member.infra.db_models.member import Member
|
||||
from analysis.infra.db_models.analysis import Analysis
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
BASE_DIR = Path(__file__).resolve().parent.parent
|
||||
|
||||
# MySQL 데이터베이스 URL 구성
|
||||
DATABASE_URL = f"mysql+pymysql://{settings.DB_USER}:{settings.DB_PASSWORD}@{settings.DB_HOST}:{settings.DB_PORT}/{settings.DB_NAME}?charset=utf8mb4"
|
||||
|
||||
# MySQL 엔진 생성
|
||||
engine = create_engine(
|
||||
DATABASE_URL,
|
||||
echo=True
|
||||
)
|
||||
|
||||
def get_session():
|
||||
with Session(engine) as session:
|
||||
yield session
|
||||
|
||||
def create_db_and_tables():
|
||||
# 테이블 생성
|
||||
# SQLModel.metadata.drop_all(engine)
|
||||
SQLModel.metadata.create_all(engine)
|
||||
|
||||
if __name__ == "__main__":
|
||||
create_db_and_tables()
|
||||
import os
|
||||
from pathlib import Path
|
||||
from sqlmodel import SQLModel, create_engine, Session
|
||||
from config.config import get_settings
|
||||
from member.infra.db_models.member import Member
|
||||
from analysis.infra.db_models.analysis import Analysis
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
BASE_DIR = Path(__file__).resolve().parent.parent
|
||||
|
||||
# MySQL 데이터베이스 URL 구성
|
||||
DATABASE_URL = f"mysql+pymysql://{settings.DB_USER}:{settings.DB_PASSWORD}@{settings.DB_HOST}:{settings.DB_PORT}/{settings.DB_NAME}?charset=utf8mb4"
|
||||
|
||||
# MySQL 엔진 생성
|
||||
engine = create_engine(
|
||||
DATABASE_URL,
|
||||
echo=True
|
||||
)
|
||||
|
||||
def get_session():
|
||||
with Session(engine) as session:
|
||||
yield session
|
||||
|
||||
def create_db_and_tables():
|
||||
# 테이블 생성
|
||||
# SQLModel.metadata.drop_all(engine)
|
||||
SQLModel.metadata.create_all(engine)
|
||||
|
||||
if __name__ == "__main__":
|
||||
create_db_and_tables()
|
||||
print(DATABASE_URL)
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
from sqlalchemy import inspect
|
||||
|
||||
def row_to_dict(row)->dict:
|
||||
from sqlalchemy import inspect
|
||||
|
||||
def row_to_dict(row)->dict:
|
||||
return {key : getattr(row, key) for key in inspect(row).attrs.keys()}
|
||||
|
|
@ -1,59 +1,59 @@
|
|||
version: '3.8'
|
||||
|
||||
services:
|
||||
mysql:
|
||||
image: mysql:8.0
|
||||
container_name: tradingagents_mysql
|
||||
restart: unless-stopped
|
||||
environment:
|
||||
MYSQL_ROOT_PASSWORD: ${DB_PASSWORD:-password}
|
||||
MYSQL_DATABASE: ${DB_NAME:-tradingagents_db}
|
||||
MYSQL_USER: ${DB_USER:-tradinguser}
|
||||
MYSQL_PASSWORD: ${DB_PASSWORD:-password}
|
||||
ports:
|
||||
- "3306:3306"
|
||||
volumes:
|
||||
- /home/hskim/mysql_data:/var/lib/mysql
|
||||
- /home/hskim/docker/mysql/init:/docker-entrypoint-initdb.d
|
||||
command: --default-authentication-plugin=mysql_native_password --character-set-server=utf8mb4 --collation-server=utf8mb4_unicode_ci
|
||||
networks:
|
||||
- tradingagents_network
|
||||
|
||||
redis:
|
||||
image: redis:7-alpine
|
||||
container_name: tradingagents_redis
|
||||
restart: unless-stopped
|
||||
ports:
|
||||
- "6379:6379"
|
||||
volumes:
|
||||
- redis_data:/data
|
||||
command: redis-server --appendonly yes
|
||||
networks:
|
||||
- tradingagents_network
|
||||
|
||||
# 개발용 phpMyAdmin (선택사항)
|
||||
# phpmyadmin:
|
||||
# image: phpmyadmin/phpmyadmin
|
||||
# container_name: tradingagents_phpmyadmin
|
||||
# restart: unless-stopped
|
||||
# environment:
|
||||
# PMA_HOST: mysql
|
||||
# PMA_PORT: 3306
|
||||
# PMA_USER: root
|
||||
# PMA_PASSWORD: ${DB_PASSWORD:-password}
|
||||
# ports:
|
||||
# - "8080:80"
|
||||
# depends_on:
|
||||
# - mysql
|
||||
# networks:
|
||||
# - tradingagents_network
|
||||
|
||||
volumes:
|
||||
mysql_data:
|
||||
driver: local
|
||||
redis_data:
|
||||
driver: local
|
||||
|
||||
networks:
|
||||
tradingagents_network:
|
||||
version: '3.8'
|
||||
|
||||
services:
|
||||
mysql:
|
||||
image: mysql:8.0
|
||||
container_name: tradingagents_mysql
|
||||
restart: unless-stopped
|
||||
environment:
|
||||
MYSQL_ROOT_PASSWORD: ${DB_PASSWORD:-password}
|
||||
MYSQL_DATABASE: ${DB_NAME:-tradingagents_db}
|
||||
MYSQL_USER: ${DB_USER:-tradinguser}
|
||||
MYSQL_PASSWORD: ${DB_PASSWORD:-password}
|
||||
ports:
|
||||
- "3306:3306"
|
||||
volumes:
|
||||
- /home/hskim/mysql_data:/var/lib/mysql
|
||||
- /home/hskim/docker/mysql/init:/docker-entrypoint-initdb.d
|
||||
command: --default-authentication-plugin=mysql_native_password --character-set-server=utf8mb4 --collation-server=utf8mb4_unicode_ci
|
||||
networks:
|
||||
- tradingagents_network
|
||||
|
||||
redis:
|
||||
image: redis:7-alpine
|
||||
container_name: tradingagents_redis
|
||||
restart: unless-stopped
|
||||
ports:
|
||||
- "6379:6379"
|
||||
volumes:
|
||||
- redis_data:/data
|
||||
command: redis-server --appendonly yes
|
||||
networks:
|
||||
- tradingagents_network
|
||||
|
||||
# 개발용 phpMyAdmin (선택사항)
|
||||
# phpmyadmin:
|
||||
# image: phpmyadmin/phpmyadmin
|
||||
# container_name: tradingagents_phpmyadmin
|
||||
# restart: unless-stopped
|
||||
# environment:
|
||||
# PMA_HOST: mysql
|
||||
# PMA_PORT: 3306
|
||||
# PMA_USER: root
|
||||
# PMA_PASSWORD: ${DB_PASSWORD:-password}
|
||||
# ports:
|
||||
# - "8080:80"
|
||||
# depends_on:
|
||||
# - mysql
|
||||
# networks:
|
||||
# - tradingagents_network
|
||||
|
||||
volumes:
|
||||
mysql_data:
|
||||
driver: local
|
||||
redis_data:
|
||||
driver: local
|
||||
|
||||
networks:
|
||||
tradingagents_network:
|
||||
driver: bridge
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,48 +1,48 @@
|
|||
{
|
||||
"name": "tradingagents-web-frontend",
|
||||
"version": "0.1.0",
|
||||
"private": true,
|
||||
"dependencies": {
|
||||
"@ant-design/icons": "^5.2.6",
|
||||
"@testing-library/jest-dom": "^5.16.4",
|
||||
"@testing-library/react": "^13.3.0",
|
||||
"@testing-library/user-event": "^13.5.0",
|
||||
"antd": "^5.10.0",
|
||||
"axios": "^1.5.0",
|
||||
"dayjs": "^1.11.9",
|
||||
"react": "^18.2.0",
|
||||
"react-dom": "^18.2.0",
|
||||
"react-markdown": "^8.0.7",
|
||||
"react-router-dom": "^6.4.0",
|
||||
"react-scripts": "5.0.1",
|
||||
"recharts": "^2.8.0",
|
||||
"remark-gfm": "^4.0.1",
|
||||
"styled-components": "^6.0.8",
|
||||
"websocket": "^1.0.34"
|
||||
},
|
||||
"scripts": {
|
||||
"start": "react-scripts start",
|
||||
"build": "react-scripts build",
|
||||
"test": "react-scripts test",
|
||||
"eject": "react-scripts eject"
|
||||
},
|
||||
"eslintConfig": {
|
||||
"extends": [
|
||||
"react-app",
|
||||
"react-app/jest"
|
||||
]
|
||||
},
|
||||
"browserslist": {
|
||||
"production": [
|
||||
">0.2%",
|
||||
"not dead",
|
||||
"not op_mini all"
|
||||
],
|
||||
"development": [
|
||||
"last 1 chrome version",
|
||||
"last 1 firefox version",
|
||||
"last 1 safari version"
|
||||
]
|
||||
},
|
||||
"proxy": "http://localhost:8000"
|
||||
}
|
||||
{
|
||||
"name": "tradingagents-web-frontend",
|
||||
"version": "0.1.0",
|
||||
"private": true,
|
||||
"dependencies": {
|
||||
"@ant-design/icons": "^5.2.6",
|
||||
"@testing-library/jest-dom": "^5.16.4",
|
||||
"@testing-library/react": "^13.3.0",
|
||||
"@testing-library/user-event": "^13.5.0",
|
||||
"antd": "^5.10.0",
|
||||
"axios": "^1.5.0",
|
||||
"dayjs": "^1.11.9",
|
||||
"react": "^18.2.0",
|
||||
"react-dom": "^18.2.0",
|
||||
"react-markdown": "^8.0.7",
|
||||
"react-router-dom": "^6.4.0",
|
||||
"react-scripts": "5.0.1",
|
||||
"recharts": "^2.8.0",
|
||||
"remark-gfm": "^4.0.1",
|
||||
"styled-components": "^6.0.8",
|
||||
"websocket": "^1.0.34"
|
||||
},
|
||||
"scripts": {
|
||||
"start": "react-scripts start",
|
||||
"build": "react-scripts build",
|
||||
"test": "react-scripts test",
|
||||
"eject": "react-scripts eject"
|
||||
},
|
||||
"eslintConfig": {
|
||||
"extends": [
|
||||
"react-app",
|
||||
"react-app/jest"
|
||||
]
|
||||
},
|
||||
"browserslist": {
|
||||
"production": [
|
||||
">0.2%",
|
||||
"not dead",
|
||||
"not op_mini all"
|
||||
],
|
||||
"development": [
|
||||
"last 1 chrome version",
|
||||
"last 1 firefox version",
|
||||
"last 1 safari version"
|
||||
]
|
||||
},
|
||||
"proxy": "http://localhost:8000"
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,20 +1,20 @@
|
|||
<!DOCTYPE html>
|
||||
<html lang="ko">
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
<link rel="icon" href="%PUBLIC_URL%/favicon.ico" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
||||
<meta name="theme-color" content="#000000" />
|
||||
<meta
|
||||
name="description"
|
||||
content="TradingAgents - Multi-Agents LLM Financial Trading Framework"
|
||||
/>
|
||||
<link rel="apple-touch-icon" href="%PUBLIC_URL%/logo192.png" />
|
||||
<link rel="manifest" href="%PUBLIC_URL%/manifest.json" />
|
||||
<title>TradingAgents - AI 거래 분석 플랫폼</title>
|
||||
</head>
|
||||
<body>
|
||||
<noscript>JavaScript를 활성화해야 이 앱을 실행할 수 있습니다.</noscript>
|
||||
<div id="root"></div>
|
||||
</body>
|
||||
<!DOCTYPE html>
|
||||
<html lang="ko">
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
<link rel="icon" href="%PUBLIC_URL%/favicon.ico" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
||||
<meta name="theme-color" content="#000000" />
|
||||
<meta
|
||||
name="description"
|
||||
content="TradingAgents - Multi-Agents LLM Financial Trading Framework"
|
||||
/>
|
||||
<link rel="apple-touch-icon" href="%PUBLIC_URL%/logo192.png" />
|
||||
<link rel="manifest" href="%PUBLIC_URL%/manifest.json" />
|
||||
<title>TradingAgents - AI 거래 분석 플랫폼</title>
|
||||
</head>
|
||||
<body>
|
||||
<noscript>JavaScript를 활성화해야 이 앱을 실행할 수 있습니다.</noscript>
|
||||
<div id="root"></div>
|
||||
</body>
|
||||
</html>
|
||||
42
main.py
42
main.py
|
|
@ -1,21 +1,21 @@
|
|||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||
from tradingagents.default_config import DEFAULT_CONFIG
|
||||
|
||||
# Create a custom config
|
||||
config = DEFAULT_CONFIG.copy()
|
||||
config["llm_provider"] = "google" # Use a different model
|
||||
config["backend_url"] = "https://generativelanguage.googleapis.com/v1" # Use a different backend
|
||||
config["deep_think_llm"] = "gemini-2.5-pro" # Use a different model
|
||||
config["quick_think_llm"] = "gemini-2.5-flash-lite-preview-06-17" # Use a different model
|
||||
config["max_debate_rounds"] = 1 # Increase debate rounds
|
||||
config["online_tools"] = True # Increase debate rounds
|
||||
|
||||
# Initialize with custom config
|
||||
ta = TradingAgentsGraph(debug=True, config=config)
|
||||
|
||||
# forward propagate
|
||||
_, decision = ta.propagate("NVDA", "2024-05-10")
|
||||
print(decision)
|
||||
|
||||
# Memorize mistakes and reflect
|
||||
# ta.reflect_and_remember(1000) # parameter is the position returns
|
||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||
from tradingagents.default_config import DEFAULT_CONFIG
|
||||
|
||||
# Create a custom config
|
||||
config = DEFAULT_CONFIG.copy()
|
||||
config["llm_provider"] = "google" # Use a different model
|
||||
config["backend_url"] = "https://generativelanguage.googleapis.com/v1" # Use a different backend
|
||||
config["deep_think_llm"] = "gemini-2.5-pro" # Use a different model
|
||||
config["quick_think_llm"] = "gemini-2.5-flash-lite-preview-06-17" # Use a different model
|
||||
config["max_debate_rounds"] = 1 # Increase debate rounds
|
||||
config["online_tools"] = True # Increase debate rounds
|
||||
|
||||
# Initialize with custom config
|
||||
ta = TradingAgentsGraph(debug=True, config=config)
|
||||
|
||||
# forward propagate
|
||||
_, decision = ta.propagate("NVDA", "2024-05-10")
|
||||
print(decision)
|
||||
|
||||
# Memorize mistakes and reflect
|
||||
# ta.reflect_and_remember(1000) # parameter is the position returns
|
||||
|
|
|
|||
|
|
@ -1,34 +1,34 @@
|
|||
[project]
|
||||
name = "tradingagents"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"akshare>=1.16.98",
|
||||
"backtrader>=1.9.78.123",
|
||||
"chainlit>=2.5.5",
|
||||
"chromadb>=1.0.12",
|
||||
"eodhd>=1.0.32",
|
||||
"feedparser>=6.0.11",
|
||||
"finnhub-python>=2.4.23",
|
||||
"langchain-anthropic>=0.3.15",
|
||||
"langchain-experimental>=0.3.4",
|
||||
"langchain-google-genai>=2.1.5",
|
||||
"langchain-openai>=0.3.23",
|
||||
"langgraph>=0.4.8",
|
||||
"pandas>=2.3.0",
|
||||
"parsel>=1.10.0",
|
||||
"praw>=7.8.1",
|
||||
"pytz>=2025.2",
|
||||
"questionary>=2.1.0",
|
||||
"redis>=6.2.0",
|
||||
"requests>=2.32.4",
|
||||
"rich>=14.0.0",
|
||||
"setuptools>=80.9.0",
|
||||
"stockstats>=0.6.5",
|
||||
"tqdm>=4.67.1",
|
||||
"tushare>=1.4.21",
|
||||
"typing-extensions>=4.14.0",
|
||||
"yfinance>=0.2.63",
|
||||
]
|
||||
[project]
|
||||
name = "tradingagents"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"akshare>=1.16.98",
|
||||
"backtrader>=1.9.78.123",
|
||||
"chainlit>=2.5.5",
|
||||
"chromadb>=1.0.12",
|
||||
"eodhd>=1.0.32",
|
||||
"feedparser>=6.0.11",
|
||||
"finnhub-python>=2.4.23",
|
||||
"langchain-anthropic>=0.3.15",
|
||||
"langchain-experimental>=0.3.4",
|
||||
"langchain-google-genai>=2.1.5",
|
||||
"langchain-openai>=0.3.23",
|
||||
"langgraph>=0.4.8",
|
||||
"pandas>=2.3.0",
|
||||
"parsel>=1.10.0",
|
||||
"praw>=7.8.1",
|
||||
"pytz>=2025.2",
|
||||
"questionary>=2.1.0",
|
||||
"redis>=6.2.0",
|
||||
"requests>=2.32.4",
|
||||
"rich>=14.0.0",
|
||||
"setuptools>=80.9.0",
|
||||
"stockstats>=0.6.5",
|
||||
"tqdm>=4.67.1",
|
||||
"tushare>=1.4.21",
|
||||
"typing-extensions>=4.14.0",
|
||||
"yfinance>=0.2.63",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,60 +1,60 @@
|
|||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
import time
|
||||
import json
|
||||
|
||||
|
||||
def create_news_analyst(llm, toolkit):
|
||||
def news_analyst_node(state):
|
||||
current_date = state["trade_date"]
|
||||
ticker = state["company_of_interest"]
|
||||
|
||||
if toolkit.config["online_tools"]:
|
||||
tools = [toolkit.get_global_news, toolkit.get_google_news]
|
||||
else:
|
||||
tools = [
|
||||
toolkit.get_finnhub_news,
|
||||
toolkit.get_reddit_news,
|
||||
toolkit.get_google_news,
|
||||
]
|
||||
|
||||
system_message = (
|
||||
"**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)\n\nYou are a news researcher tasked with analyzing recent news and trends over the past week. Please write a comprehensive report of the current state of the world that is relevant for trading and macroeconomics. Look at news from EODHD, and finnhub to be comprehensive. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."
|
||||
+ """ Make sure to append a Makrdown table at the end of the report to organize key points in the report, organized and easy to read."""
|
||||
)
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
(
|
||||
"system",
|
||||
"You are a helpful AI assistant, collaborating with other assistants."
|
||||
" Use the provided tools to progress towards answering the question."
|
||||
" If you are unable to fully answer, that's OK; another assistant with different tools"
|
||||
" will help where you left off. Execute what you can to make progress."
|
||||
" If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable,"
|
||||
" prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop."
|
||||
" You have access to the following tools: {tool_names}.\n{system_message}"
|
||||
"For your reference, the current date is {current_date}. We are looking at the company {ticker}",
|
||||
),
|
||||
MessagesPlaceholder(variable_name="messages"),
|
||||
]
|
||||
)
|
||||
|
||||
prompt = prompt.partial(system_message=system_message)
|
||||
prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools]))
|
||||
prompt = prompt.partial(current_date=current_date)
|
||||
prompt = prompt.partial(ticker=ticker)
|
||||
|
||||
chain = prompt | llm.bind_tools(tools)
|
||||
result = chain.invoke(state["messages"])
|
||||
|
||||
report = ""
|
||||
|
||||
if len(result.tool_calls) == 0:
|
||||
report = result.content
|
||||
|
||||
return {
|
||||
"messages": [result],
|
||||
"news_report": report,
|
||||
}
|
||||
|
||||
return news_analyst_node
|
||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
import time
|
||||
import json
|
||||
|
||||
|
||||
def create_news_analyst(llm, toolkit):
|
||||
def news_analyst_node(state):
|
||||
current_date = state["trade_date"]
|
||||
ticker = state["company_of_interest"]
|
||||
|
||||
if toolkit.config["online_tools"]:
|
||||
tools = [toolkit.get_global_news, toolkit.get_google_news]
|
||||
else:
|
||||
tools = [
|
||||
toolkit.get_finnhub_news,
|
||||
toolkit.get_reddit_news,
|
||||
toolkit.get_google_news,
|
||||
]
|
||||
|
||||
system_message = (
|
||||
"**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)\n\nYou are a news researcher tasked with analyzing recent news and trends over the past week. Please write a comprehensive report of the current state of the world that is relevant for trading and macroeconomics. Look at news from EODHD, and finnhub to be comprehensive. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."
|
||||
+ """ Make sure to append a Makrdown table at the end of the report to organize key points in the report, organized and easy to read."""
|
||||
)
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
(
|
||||
"system",
|
||||
"You are a helpful AI assistant, collaborating with other assistants."
|
||||
" Use the provided tools to progress towards answering the question."
|
||||
" If you are unable to fully answer, that's OK; another assistant with different tools"
|
||||
" will help where you left off. Execute what you can to make progress."
|
||||
" If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable,"
|
||||
" prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop."
|
||||
" You have access to the following tools: {tool_names}.\n{system_message}"
|
||||
"For your reference, the current date is {current_date}. We are looking at the company {ticker}",
|
||||
),
|
||||
MessagesPlaceholder(variable_name="messages"),
|
||||
]
|
||||
)
|
||||
|
||||
prompt = prompt.partial(system_message=system_message)
|
||||
prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools]))
|
||||
prompt = prompt.partial(current_date=current_date)
|
||||
prompt = prompt.partial(ticker=ticker)
|
||||
|
||||
chain = prompt | llm.bind_tools(tools)
|
||||
result = chain.invoke(state["messages"])
|
||||
|
||||
report = ""
|
||||
|
||||
if len(result.tool_calls) == 0:
|
||||
report = result.content
|
||||
|
||||
return {
|
||||
"messages": [result],
|
||||
"news_report": report,
|
||||
}
|
||||
|
||||
return news_analyst_node
|
||||
|
|
|
|||
|
|
@ -1,60 +1,60 @@
|
|||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
import time
|
||||
import json
|
||||
|
||||
|
||||
def create_social_media_analyst(llm, toolkit):
|
||||
def social_media_analyst_node(state):
|
||||
current_date = state["trade_date"]
|
||||
ticker = state["company_of_interest"]
|
||||
company_name = state["company_of_interest"]
|
||||
|
||||
if toolkit.config["online_tools"]:
|
||||
tools = [toolkit.get_stock_news]
|
||||
else:
|
||||
tools = [
|
||||
toolkit.get_reddit_stock_info,
|
||||
]
|
||||
|
||||
system_message = (
|
||||
"**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)\n\nYou are a social media and company specific news researcher/analyst tasked with analyzing social media posts, recent company news, and public sentiment for a specific company over the past week. You will be given a company's name your objective is to write a comprehensive long report detailing your analysis, insights, and implications for traders and investors on this company's current state after looking at social media and what people are saying about that company, analyzing sentiment data of what people feel each day about the company, and looking at recent company news. Try to look at all sources possible from social media to sentiment to news. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."
|
||||
+ """ Make sure to append a Makrdown table at the end of the report to organize key points in the report, organized and easy to read.""",
|
||||
)
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
(
|
||||
"system",
|
||||
"You are a helpful AI assistant, collaborating with other assistants."
|
||||
" Use the provided tools to progress towards answering the question."
|
||||
" If you are unable to fully answer, that's OK; another assistant with different tools"
|
||||
" will help where you left off. Execute what you can to make progress."
|
||||
" If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable,"
|
||||
" prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop."
|
||||
" You have access to the following tools: {tool_names}.\n{system_message}"
|
||||
"For your reference, the current date is {current_date}. The current company we want to analyze is {ticker}",
|
||||
),
|
||||
MessagesPlaceholder(variable_name="messages"),
|
||||
]
|
||||
)
|
||||
|
||||
prompt = prompt.partial(system_message=system_message)
|
||||
prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools]))
|
||||
prompt = prompt.partial(current_date=current_date)
|
||||
prompt = prompt.partial(ticker=ticker)
|
||||
|
||||
chain = prompt | llm.bind_tools(tools)
|
||||
|
||||
result = chain.invoke(state["messages"])
|
||||
|
||||
report = ""
|
||||
|
||||
if len(result.tool_calls) == 0:
|
||||
report = result.content
|
||||
|
||||
return {
|
||||
"messages": [result],
|
||||
"sentiment_report": report,
|
||||
}
|
||||
|
||||
return social_media_analyst_node
|
||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
import time
|
||||
import json
|
||||
|
||||
|
||||
def create_social_media_analyst(llm, toolkit):
|
||||
def social_media_analyst_node(state):
|
||||
current_date = state["trade_date"]
|
||||
ticker = state["company_of_interest"]
|
||||
company_name = state["company_of_interest"]
|
||||
|
||||
if toolkit.config["online_tools"]:
|
||||
tools = [toolkit.get_stock_news]
|
||||
else:
|
||||
tools = [
|
||||
toolkit.get_reddit_stock_info,
|
||||
]
|
||||
|
||||
system_message = (
|
||||
"**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)\n\nYou are a social media and company specific news researcher/analyst tasked with analyzing social media posts, recent company news, and public sentiment for a specific company over the past week. You will be given a company's name your objective is to write a comprehensive long report detailing your analysis, insights, and implications for traders and investors on this company's current state after looking at social media and what people are saying about that company, analyzing sentiment data of what people feel each day about the company, and looking at recent company news. Try to look at all sources possible from social media to sentiment to news. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."
|
||||
+ """ Make sure to append a Makrdown table at the end of the report to organize key points in the report, organized and easy to read.""",
|
||||
)
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
(
|
||||
"system",
|
||||
"You are a helpful AI assistant, collaborating with other assistants."
|
||||
" Use the provided tools to progress towards answering the question."
|
||||
" If you are unable to fully answer, that's OK; another assistant with different tools"
|
||||
" will help where you left off. Execute what you can to make progress."
|
||||
" If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable,"
|
||||
" prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop."
|
||||
" You have access to the following tools: {tool_names}.\n{system_message}"
|
||||
"For your reference, the current date is {current_date}. The current company we want to analyze is {ticker}",
|
||||
),
|
||||
MessagesPlaceholder(variable_name="messages"),
|
||||
]
|
||||
)
|
||||
|
||||
prompt = prompt.partial(system_message=system_message)
|
||||
prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools]))
|
||||
prompt = prompt.partial(current_date=current_date)
|
||||
prompt = prompt.partial(ticker=ticker)
|
||||
|
||||
chain = prompt | llm.bind_tools(tools)
|
||||
|
||||
result = chain.invoke(state["messages"])
|
||||
|
||||
report = ""
|
||||
|
||||
if len(result.tool_calls) == 0:
|
||||
report = result.content
|
||||
|
||||
return {
|
||||
"messages": [result],
|
||||
"sentiment_report": report,
|
||||
}
|
||||
|
||||
return social_media_analyst_node
|
||||
|
|
|
|||
|
|
@ -1,57 +1,57 @@
|
|||
import time
|
||||
import json
|
||||
|
||||
|
||||
def create_research_manager(llm, memory):
|
||||
def research_manager_node(state) -> dict:
|
||||
history = state["investment_debate_state"].get("history", "")
|
||||
market_research_report = state["market_report"]
|
||||
sentiment_report = state["sentiment_report"]
|
||||
news_report = state["news_report"]
|
||||
fundamentals_report = state["fundamentals_report"]
|
||||
|
||||
investment_debate_state = state["investment_debate_state"]
|
||||
|
||||
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
|
||||
past_memories = memory.get_memories(curr_situation, n_matches=2)
|
||||
|
||||
past_memory_str = ""
|
||||
for i, rec in enumerate(past_memories, 1):
|
||||
past_memory_str += rec["recommendation"] + "\n\n"
|
||||
|
||||
prompt = f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)
|
||||
|
||||
As the portfolio manager and debate facilitator, your role is to critically evaluate this round of debate and make a definitive decision: align with the bear analyst, the bull analyst, or choose Hold only if it is strongly justified based on the arguments presented.
|
||||
|
||||
Summarize the key points from both sides concisely, focusing on the most compelling evidence or reasoning. Your recommendation—Buy, Sell, or Hold—must be clear and actionable. Avoid defaulting to Hold simply because both sides have valid points; commit to a stance grounded in the debate's strongest arguments.
|
||||
|
||||
Additionally, develop a detailed investment plan for the trader. This should include:
|
||||
|
||||
Your Recommendation: A decisive stance supported by the most convincing arguments.
|
||||
Rationale: An explanation of why these arguments lead to your conclusion.
|
||||
Strategic Actions: Concrete steps for implementing the recommendation.
|
||||
Take into account your past mistakes on similar situations. Use these insights to refine your decision-making and ensure you are learning and improving. Present your analysis conversationally, as if speaking naturally, without special formatting.
|
||||
|
||||
Here are your past reflections on mistakes:
|
||||
\"{past_memory_str}\"
|
||||
|
||||
Here is the debate:
|
||||
Debate History:
|
||||
{history}"""
|
||||
response = llm.invoke(prompt)
|
||||
|
||||
new_investment_debate_state = {
|
||||
"judge_decision": response.content,
|
||||
"history": investment_debate_state.get("history", ""),
|
||||
"bear_history": investment_debate_state.get("bear_history", ""),
|
||||
"bull_history": investment_debate_state.get("bull_history", ""),
|
||||
"current_response": response.content,
|
||||
"count": investment_debate_state["count"],
|
||||
}
|
||||
|
||||
return {
|
||||
"investment_debate_state": new_investment_debate_state,
|
||||
"investment_plan": response.content,
|
||||
}
|
||||
|
||||
return research_manager_node
|
||||
import time
|
||||
import json
|
||||
|
||||
|
||||
def create_research_manager(llm, memory):
|
||||
def research_manager_node(state) -> dict:
|
||||
history = state["investment_debate_state"].get("history", "")
|
||||
market_research_report = state["market_report"]
|
||||
sentiment_report = state["sentiment_report"]
|
||||
news_report = state["news_report"]
|
||||
fundamentals_report = state["fundamentals_report"]
|
||||
|
||||
investment_debate_state = state["investment_debate_state"]
|
||||
|
||||
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
|
||||
past_memories = memory.get_memories(curr_situation, n_matches=2)
|
||||
|
||||
past_memory_str = ""
|
||||
for i, rec in enumerate(past_memories, 1):
|
||||
past_memory_str += rec["recommendation"] + "\n\n"
|
||||
|
||||
prompt = f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)
|
||||
|
||||
As the portfolio manager and debate facilitator, your role is to critically evaluate this round of debate and make a definitive decision: align with the bear analyst, the bull analyst, or choose Hold only if it is strongly justified based on the arguments presented.
|
||||
|
||||
Summarize the key points from both sides concisely, focusing on the most compelling evidence or reasoning. Your recommendation—Buy, Sell, or Hold—must be clear and actionable. Avoid defaulting to Hold simply because both sides have valid points; commit to a stance grounded in the debate's strongest arguments.
|
||||
|
||||
Additionally, develop a detailed investment plan for the trader. This should include:
|
||||
|
||||
Your Recommendation: A decisive stance supported by the most convincing arguments.
|
||||
Rationale: An explanation of why these arguments lead to your conclusion.
|
||||
Strategic Actions: Concrete steps for implementing the recommendation.
|
||||
Take into account your past mistakes on similar situations. Use these insights to refine your decision-making and ensure you are learning and improving. Present your analysis conversationally, as if speaking naturally, without special formatting.
|
||||
|
||||
Here are your past reflections on mistakes:
|
||||
\"{past_memory_str}\"
|
||||
|
||||
Here is the debate:
|
||||
Debate History:
|
||||
{history}"""
|
||||
response = llm.invoke(prompt)
|
||||
|
||||
new_investment_debate_state = {
|
||||
"judge_decision": response.content,
|
||||
"history": investment_debate_state.get("history", ""),
|
||||
"bear_history": investment_debate_state.get("bear_history", ""),
|
||||
"bull_history": investment_debate_state.get("bull_history", ""),
|
||||
"current_response": response.content,
|
||||
"count": investment_debate_state["count"],
|
||||
}
|
||||
|
||||
return {
|
||||
"investment_debate_state": new_investment_debate_state,
|
||||
"investment_plan": response.content,
|
||||
}
|
||||
|
||||
return research_manager_node
|
||||
|
|
|
|||
|
|
@ -1,68 +1,68 @@
|
|||
import time
|
||||
import json
|
||||
|
||||
|
||||
def create_risk_manager(llm, memory):
|
||||
def risk_manager_node(state) -> dict:
|
||||
|
||||
company_name = state["company_of_interest"]
|
||||
|
||||
history = state["risk_debate_state"]["history"]
|
||||
risk_debate_state = state["risk_debate_state"]
|
||||
market_research_report = state["market_report"]
|
||||
news_report = state["news_report"]
|
||||
fundamentals_report = state["news_report"]
|
||||
sentiment_report = state["sentiment_report"]
|
||||
trader_plan = state["investment_plan"]
|
||||
|
||||
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
|
||||
past_memories = memory.get_memories(curr_situation, n_matches=2)
|
||||
|
||||
past_memory_str = ""
|
||||
for i, rec in enumerate(past_memories, 1):
|
||||
past_memory_str += rec["recommendation"] + "\n\n"
|
||||
|
||||
prompt = f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)
|
||||
|
||||
As the Risk Management Judge and Debate Facilitator, your goal is to evaluate the debate between three risk analysts—Risky, Neutral, and Safe/Conservative—and determine the best course of action for the trader. Your decision must result in a clear recommendation: Buy, Sell, or Hold. Choose Hold only if strongly justified by specific arguments, not as a fallback when all sides seem valid. Strive for clarity and decisiveness.
|
||||
|
||||
Guidelines for Decision-Making:
|
||||
1. **Summarize Key Arguments**: Extract the strongest points from each analyst, focusing on relevance to the context.
|
||||
2. **Provide Rationale**: Support your recommendation with direct quotes and counterarguments from the debate.
|
||||
3. **Refine the Trader's Plan**: Start with the trader's original plan, **{trader_plan}**, and adjust it based on the analysts' insights.
|
||||
4. **Learn from Past Mistakes**: Use lessons from **{past_memory_str}** to address prior misjudgments and improve the decision you are making now to make sure you don't make a wrong BUY/SELL/HOLD call that loses money.
|
||||
|
||||
Deliverables:
|
||||
- A clear and actionable recommendation: Buy, Sell, or Hold.
|
||||
- Detailed reasoning anchored in the debate and past reflections.
|
||||
|
||||
---
|
||||
|
||||
**Analysts Debate History:**
|
||||
{history}
|
||||
|
||||
---
|
||||
|
||||
Focus on actionable insights and continuous improvement. Build on past lessons, critically evaluate all perspectives, and ensure each decision advances better outcomes."""
|
||||
|
||||
response = llm.invoke(prompt)
|
||||
|
||||
new_risk_debate_state = {
|
||||
"judge_decision": response.content,
|
||||
"history": risk_debate_state["history"],
|
||||
"risky_history": risk_debate_state["risky_history"],
|
||||
"safe_history": risk_debate_state["safe_history"],
|
||||
"neutral_history": risk_debate_state["neutral_history"],
|
||||
"latest_speaker": "Judge",
|
||||
"current_risky_response": risk_debate_state["current_risky_response"],
|
||||
"current_safe_response": risk_debate_state["current_safe_response"],
|
||||
"current_neutral_response": risk_debate_state["current_neutral_response"],
|
||||
"count": risk_debate_state["count"],
|
||||
}
|
||||
|
||||
return {
|
||||
"risk_debate_state": new_risk_debate_state,
|
||||
"final_trade_decision": response.content,
|
||||
}
|
||||
|
||||
return risk_manager_node
|
||||
import time
|
||||
import json
|
||||
|
||||
|
||||
def create_risk_manager(llm, memory):
|
||||
def risk_manager_node(state) -> dict:
|
||||
|
||||
company_name = state["company_of_interest"]
|
||||
|
||||
history = state["risk_debate_state"]["history"]
|
||||
risk_debate_state = state["risk_debate_state"]
|
||||
market_research_report = state["market_report"]
|
||||
news_report = state["news_report"]
|
||||
fundamentals_report = state["news_report"]
|
||||
sentiment_report = state["sentiment_report"]
|
||||
trader_plan = state["investment_plan"]
|
||||
|
||||
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
|
||||
past_memories = memory.get_memories(curr_situation, n_matches=2)
|
||||
|
||||
past_memory_str = ""
|
||||
for i, rec in enumerate(past_memories, 1):
|
||||
past_memory_str += rec["recommendation"] + "\n\n"
|
||||
|
||||
prompt = f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)
|
||||
|
||||
As the Risk Management Judge and Debate Facilitator, your goal is to evaluate the debate between three risk analysts—Risky, Neutral, and Safe/Conservative—and determine the best course of action for the trader. Your decision must result in a clear recommendation: Buy, Sell, or Hold. Choose Hold only if strongly justified by specific arguments, not as a fallback when all sides seem valid. Strive for clarity and decisiveness.
|
||||
|
||||
Guidelines for Decision-Making:
|
||||
1. **Summarize Key Arguments**: Extract the strongest points from each analyst, focusing on relevance to the context.
|
||||
2. **Provide Rationale**: Support your recommendation with direct quotes and counterarguments from the debate.
|
||||
3. **Refine the Trader's Plan**: Start with the trader's original plan, **{trader_plan}**, and adjust it based on the analysts' insights.
|
||||
4. **Learn from Past Mistakes**: Use lessons from **{past_memory_str}** to address prior misjudgments and improve the decision you are making now to make sure you don't make a wrong BUY/SELL/HOLD call that loses money.
|
||||
|
||||
Deliverables:
|
||||
- A clear and actionable recommendation: Buy, Sell, or Hold.
|
||||
- Detailed reasoning anchored in the debate and past reflections.
|
||||
|
||||
---
|
||||
|
||||
**Analysts Debate History:**
|
||||
{history}
|
||||
|
||||
---
|
||||
|
||||
Focus on actionable insights and continuous improvement. Build on past lessons, critically evaluate all perspectives, and ensure each decision advances better outcomes."""
|
||||
|
||||
response = llm.invoke(prompt)
|
||||
|
||||
new_risk_debate_state = {
|
||||
"judge_decision": response.content,
|
||||
"history": risk_debate_state["history"],
|
||||
"risky_history": risk_debate_state["risky_history"],
|
||||
"safe_history": risk_debate_state["safe_history"],
|
||||
"neutral_history": risk_debate_state["neutral_history"],
|
||||
"latest_speaker": "Judge",
|
||||
"current_risky_response": risk_debate_state["current_risky_response"],
|
||||
"current_safe_response": risk_debate_state["current_safe_response"],
|
||||
"current_neutral_response": risk_debate_state["current_neutral_response"],
|
||||
"count": risk_debate_state["count"],
|
||||
}
|
||||
|
||||
return {
|
||||
"risk_debate_state": new_risk_debate_state,
|
||||
"final_trade_decision": response.content,
|
||||
}
|
||||
|
||||
return risk_manager_node
|
||||
|
|
|
|||
|
|
@ -1,63 +1,63 @@
|
|||
from langchain_core.messages import AIMessage
|
||||
import time
|
||||
import json
|
||||
|
||||
|
||||
def create_bear_researcher(llm, memory):
|
||||
def bear_node(state) -> dict:
|
||||
investment_debate_state = state["investment_debate_state"]
|
||||
history = investment_debate_state.get("history", "")
|
||||
bear_history = investment_debate_state.get("bear_history", "")
|
||||
|
||||
current_response = investment_debate_state.get("current_response", "")
|
||||
market_research_report = state["market_report"]
|
||||
sentiment_report = state["sentiment_report"]
|
||||
news_report = state["news_report"]
|
||||
fundamentals_report = state["fundamentals_report"]
|
||||
|
||||
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
|
||||
past_memories = memory.get_memories(curr_situation, n_matches=2)
|
||||
|
||||
past_memory_str = ""
|
||||
for i, rec in enumerate(past_memories, 1):
|
||||
past_memory_str += rec["recommendation"] + "\n\n"
|
||||
|
||||
prompt = f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)
|
||||
|
||||
You are a Bear Analyst making the case against investing in the stock. Your goal is to present a well-reasoned argument emphasizing risks, challenges, and negative indicators. Leverage the provided research and data to highlight potential downsides and counter bullish arguments effectively.
|
||||
|
||||
Key points to focus on:
|
||||
|
||||
- Risks and Challenges: Highlight factors like market saturation, financial instability, or macroeconomic threats that could hinder the stock's performance.
|
||||
- Competitive Weaknesses: Emphasize vulnerabilities such as weaker market positioning, declining innovation, or threats from competitors.
|
||||
- Negative Indicators: Use evidence from financial data, market trends, or recent adverse news to support your position.
|
||||
- Bull Counterpoints: Critically analyze the bull argument with specific data and sound reasoning, exposing weaknesses or over-optimistic assumptions.
|
||||
- Engagement: Present your argument in a conversational style, directly engaging with the bull analyst's points and debating effectively rather than simply listing facts.
|
||||
|
||||
Resources available:
|
||||
|
||||
Market research report: {market_research_report}
|
||||
Social media sentiment report: {sentiment_report}
|
||||
Latest world affairs news: {news_report}
|
||||
Company fundamentals report: {fundamentals_report}
|
||||
Conversation history of the debate: {history}
|
||||
Last bull argument: {current_response}
|
||||
Reflections from similar situations and lessons learned: {past_memory_str}
|
||||
Use this information to deliver a compelling bear argument, refute the bull's claims, and engage in a dynamic debate that demonstrates the risks and weaknesses of investing in the stock. You must also address reflections and learn from lessons and mistakes you made in the past.
|
||||
"""
|
||||
|
||||
response = llm.invoke(prompt)
|
||||
|
||||
argument = f"Bear Analyst: {response.content}"
|
||||
|
||||
new_investment_debate_state = {
|
||||
"history": history + "\n" + argument,
|
||||
"bear_history": bear_history + "\n" + argument,
|
||||
"bull_history": investment_debate_state.get("bull_history", ""),
|
||||
"current_response": argument,
|
||||
"count": investment_debate_state["count"] + 1,
|
||||
}
|
||||
|
||||
return {"investment_debate_state": new_investment_debate_state}
|
||||
|
||||
return bear_node
|
||||
from langchain_core.messages import AIMessage
|
||||
import time
|
||||
import json
|
||||
|
||||
|
||||
def create_bear_researcher(llm, memory):
|
||||
def bear_node(state) -> dict:
|
||||
investment_debate_state = state["investment_debate_state"]
|
||||
history = investment_debate_state.get("history", "")
|
||||
bear_history = investment_debate_state.get("bear_history", "")
|
||||
|
||||
current_response = investment_debate_state.get("current_response", "")
|
||||
market_research_report = state["market_report"]
|
||||
sentiment_report = state["sentiment_report"]
|
||||
news_report = state["news_report"]
|
||||
fundamentals_report = state["fundamentals_report"]
|
||||
|
||||
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
|
||||
past_memories = memory.get_memories(curr_situation, n_matches=2)
|
||||
|
||||
past_memory_str = ""
|
||||
for i, rec in enumerate(past_memories, 1):
|
||||
past_memory_str += rec["recommendation"] + "\n\n"
|
||||
|
||||
prompt = f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)
|
||||
|
||||
You are a Bear Analyst making the case against investing in the stock. Your goal is to present a well-reasoned argument emphasizing risks, challenges, and negative indicators. Leverage the provided research and data to highlight potential downsides and counter bullish arguments effectively.
|
||||
|
||||
Key points to focus on:
|
||||
|
||||
- Risks and Challenges: Highlight factors like market saturation, financial instability, or macroeconomic threats that could hinder the stock's performance.
|
||||
- Competitive Weaknesses: Emphasize vulnerabilities such as weaker market positioning, declining innovation, or threats from competitors.
|
||||
- Negative Indicators: Use evidence from financial data, market trends, or recent adverse news to support your position.
|
||||
- Bull Counterpoints: Critically analyze the bull argument with specific data and sound reasoning, exposing weaknesses or over-optimistic assumptions.
|
||||
- Engagement: Present your argument in a conversational style, directly engaging with the bull analyst's points and debating effectively rather than simply listing facts.
|
||||
|
||||
Resources available:
|
||||
|
||||
Market research report: {market_research_report}
|
||||
Social media sentiment report: {sentiment_report}
|
||||
Latest world affairs news: {news_report}
|
||||
Company fundamentals report: {fundamentals_report}
|
||||
Conversation history of the debate: {history}
|
||||
Last bull argument: {current_response}
|
||||
Reflections from similar situations and lessons learned: {past_memory_str}
|
||||
Use this information to deliver a compelling bear argument, refute the bull's claims, and engage in a dynamic debate that demonstrates the risks and weaknesses of investing in the stock. You must also address reflections and learn from lessons and mistakes you made in the past.
|
||||
"""
|
||||
|
||||
response = llm.invoke(prompt)
|
||||
|
||||
argument = f"Bear Analyst: {response.content}"
|
||||
|
||||
new_investment_debate_state = {
|
||||
"history": history + "\n" + argument,
|
||||
"bear_history": bear_history + "\n" + argument,
|
||||
"bull_history": investment_debate_state.get("bull_history", ""),
|
||||
"current_response": argument,
|
||||
"count": investment_debate_state["count"] + 1,
|
||||
}
|
||||
|
||||
return {"investment_debate_state": new_investment_debate_state}
|
||||
|
||||
return bear_node
|
||||
|
|
|
|||
|
|
@ -1,61 +1,61 @@
|
|||
from langchain_core.messages import AIMessage
|
||||
import time
|
||||
import json
|
||||
|
||||
|
||||
def create_bull_researcher(llm, memory):
|
||||
def bull_node(state) -> dict:
|
||||
investment_debate_state = state["investment_debate_state"]
|
||||
history = investment_debate_state.get("history", "")
|
||||
bull_history = investment_debate_state.get("bull_history", "")
|
||||
|
||||
current_response = investment_debate_state.get("current_response", "")
|
||||
market_research_report = state["market_report"]
|
||||
sentiment_report = state["sentiment_report"]
|
||||
news_report = state["news_report"]
|
||||
fundamentals_report = state["fundamentals_report"]
|
||||
|
||||
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
|
||||
past_memories = memory.get_memories(curr_situation, n_matches=2)
|
||||
|
||||
past_memory_str = ""
|
||||
for i, rec in enumerate(past_memories, 1):
|
||||
past_memory_str += rec["recommendation"] + "\n\n"
|
||||
|
||||
prompt = f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)
|
||||
|
||||
You are a Bull Analyst advocating for investing in the stock. Your task is to build a strong, evidence-based case emphasizing growth potential, competitive advantages, and positive market indicators. Leverage the provided research and data to address concerns and counter bearish arguments effectively.
|
||||
|
||||
Key points to focus on:
|
||||
- Growth Potential: Highlight the company's market opportunities, revenue projections, and scalability.
|
||||
- Competitive Advantages: Emphasize factors like unique products, strong branding, or dominant market positioning.
|
||||
- Positive Indicators: Use financial health, industry trends, and recent positive news as evidence.
|
||||
- Bear Counterpoints: Critically analyze the bear argument with specific data and sound reasoning, addressing concerns thoroughly and showing why the bull perspective holds stronger merit.
|
||||
- Engagement: Present your argument in a conversational style, engaging directly with the bear analyst's points and debating effectively rather than just listing data.
|
||||
|
||||
Resources available:
|
||||
Market research report: {market_research_report}
|
||||
Social media sentiment report: {sentiment_report}
|
||||
Latest world affairs news: {news_report}
|
||||
Company fundamentals report: {fundamentals_report}
|
||||
Conversation history of the debate: {history}
|
||||
Last bear argument: {current_response}
|
||||
Reflections from similar situations and lessons learned: {past_memory_str}
|
||||
Use this information to deliver a compelling bull argument, refute the bear's concerns, and engage in a dynamic debate that demonstrates the strengths of the bull position. You must also address reflections and learn from lessons and mistakes you made in the past.
|
||||
"""
|
||||
|
||||
response = llm.invoke(prompt)
|
||||
|
||||
argument = f"Bull Analyst: {response.content}"
|
||||
|
||||
new_investment_debate_state = {
|
||||
"history": history + "\n" + argument,
|
||||
"bull_history": bull_history + "\n" + argument,
|
||||
"bear_history": investment_debate_state.get("bear_history", ""),
|
||||
"current_response": argument,
|
||||
"count": investment_debate_state["count"] + 1,
|
||||
}
|
||||
|
||||
return {"investment_debate_state": new_investment_debate_state}
|
||||
|
||||
return bull_node
|
||||
from langchain_core.messages import AIMessage
|
||||
import time
|
||||
import json
|
||||
|
||||
|
||||
def create_bull_researcher(llm, memory):
|
||||
def bull_node(state) -> dict:
|
||||
investment_debate_state = state["investment_debate_state"]
|
||||
history = investment_debate_state.get("history", "")
|
||||
bull_history = investment_debate_state.get("bull_history", "")
|
||||
|
||||
current_response = investment_debate_state.get("current_response", "")
|
||||
market_research_report = state["market_report"]
|
||||
sentiment_report = state["sentiment_report"]
|
||||
news_report = state["news_report"]
|
||||
fundamentals_report = state["fundamentals_report"]
|
||||
|
||||
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
|
||||
past_memories = memory.get_memories(curr_situation, n_matches=2)
|
||||
|
||||
past_memory_str = ""
|
||||
for i, rec in enumerate(past_memories, 1):
|
||||
past_memory_str += rec["recommendation"] + "\n\n"
|
||||
|
||||
prompt = f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)
|
||||
|
||||
You are a Bull Analyst advocating for investing in the stock. Your task is to build a strong, evidence-based case emphasizing growth potential, competitive advantages, and positive market indicators. Leverage the provided research and data to address concerns and counter bearish arguments effectively.
|
||||
|
||||
Key points to focus on:
|
||||
- Growth Potential: Highlight the company's market opportunities, revenue projections, and scalability.
|
||||
- Competitive Advantages: Emphasize factors like unique products, strong branding, or dominant market positioning.
|
||||
- Positive Indicators: Use financial health, industry trends, and recent positive news as evidence.
|
||||
- Bear Counterpoints: Critically analyze the bear argument with specific data and sound reasoning, addressing concerns thoroughly and showing why the bull perspective holds stronger merit.
|
||||
- Engagement: Present your argument in a conversational style, engaging directly with the bear analyst's points and debating effectively rather than just listing data.
|
||||
|
||||
Resources available:
|
||||
Market research report: {market_research_report}
|
||||
Social media sentiment report: {sentiment_report}
|
||||
Latest world affairs news: {news_report}
|
||||
Company fundamentals report: {fundamentals_report}
|
||||
Conversation history of the debate: {history}
|
||||
Last bear argument: {current_response}
|
||||
Reflections from similar situations and lessons learned: {past_memory_str}
|
||||
Use this information to deliver a compelling bull argument, refute the bear's concerns, and engage in a dynamic debate that demonstrates the strengths of the bull position. You must also address reflections and learn from lessons and mistakes you made in the past.
|
||||
"""
|
||||
|
||||
response = llm.invoke(prompt)
|
||||
|
||||
argument = f"Bull Analyst: {response.content}"
|
||||
|
||||
new_investment_debate_state = {
|
||||
"history": history + "\n" + argument,
|
||||
"bull_history": bull_history + "\n" + argument,
|
||||
"bear_history": investment_debate_state.get("bear_history", ""),
|
||||
"current_response": argument,
|
||||
"count": investment_debate_state["count"] + 1,
|
||||
}
|
||||
|
||||
return {"investment_debate_state": new_investment_debate_state}
|
||||
|
||||
return bull_node
|
||||
|
|
|
|||
|
|
@ -1,57 +1,57 @@
|
|||
import time
|
||||
import json
|
||||
|
||||
|
||||
def create_risky_debator(llm):
|
||||
def risky_node(state) -> dict:
|
||||
risk_debate_state = state["risk_debate_state"]
|
||||
history = risk_debate_state.get("history", "")
|
||||
risky_history = risk_debate_state.get("risky_history", "")
|
||||
|
||||
current_safe_response = risk_debate_state.get("current_safe_response", "")
|
||||
current_neutral_response = risk_debate_state.get("current_neutral_response", "")
|
||||
|
||||
market_research_report = state["market_report"]
|
||||
sentiment_report = state["sentiment_report"]
|
||||
news_report = state["news_report"]
|
||||
fundamentals_report = state["fundamentals_report"]
|
||||
|
||||
trader_decision = state["trader_investment_plan"]
|
||||
|
||||
prompt = f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)
|
||||
|
||||
As the Risky Risk Analyst, your role is to actively champion high-reward, high-risk opportunities, emphasizing bold strategies and competitive advantages. When evaluating the trader's decision or plan, focus intently on the potential upside, growth potential, and innovative benefits—even when these come with elevated risk. Use the provided market data and sentiment analysis to strengthen your arguments and challenge the opposing views. Specifically, respond directly to each point made by the conservative and neutral analysts, countering with data-driven rebuttals and persuasive reasoning. Highlight where their caution might miss critical opportunities or where their assumptions may be overly conservative. Here is the trader's decision:
|
||||
|
||||
{trader_decision}
|
||||
|
||||
Your task is to create a compelling case for the trader's decision by questioning and critiquing the conservative and neutral stances to demonstrate why your high-reward perspective offers the best path forward. Incorporate insights from the following sources into your arguments:
|
||||
|
||||
Market Research Report: {market_research_report}
|
||||
Social Media Sentiment Report: {sentiment_report}
|
||||
Latest World Affairs Report: {news_report}
|
||||
Company Fundamentals Report: {fundamentals_report}
|
||||
Here is the current conversation history: {history} Here are the last arguments from the conservative analyst: {current_safe_response} Here are the last arguments from the neutral analyst: {current_neutral_response}. If there are no responses from the other viewpoints, do not halluncinate and just present your point.
|
||||
|
||||
Engage actively by addressing any specific concerns raised, refuting the weaknesses in their logic, and asserting the benefits of risk-taking to outpace market norms. Maintain a focus on debating and persuading, not just presenting data. Challenge each counterpoint to underscore why a high-risk approach is optimal. Output conversationally as if you are speaking without any special formatting."""
|
||||
|
||||
response = llm.invoke(prompt)
|
||||
|
||||
argument = f"Risky Analyst: {response.content}"
|
||||
|
||||
new_risk_debate_state = {
|
||||
"history": history + "\n" + argument,
|
||||
"risky_history": risky_history + "\n" + argument,
|
||||
"safe_history": risk_debate_state.get("safe_history", ""),
|
||||
"neutral_history": risk_debate_state.get("neutral_history", ""),
|
||||
"latest_speaker": "Risky",
|
||||
"current_risky_response": argument,
|
||||
"current_safe_response": risk_debate_state.get("current_safe_response", ""),
|
||||
"current_neutral_response": risk_debate_state.get(
|
||||
"current_neutral_response", ""
|
||||
),
|
||||
"count": risk_debate_state["count"] + 1,
|
||||
}
|
||||
|
||||
return {"risk_debate_state": new_risk_debate_state}
|
||||
|
||||
return risky_node
|
||||
import time
|
||||
import json
|
||||
|
||||
|
||||
def create_risky_debator(llm):
|
||||
def risky_node(state) -> dict:
|
||||
risk_debate_state = state["risk_debate_state"]
|
||||
history = risk_debate_state.get("history", "")
|
||||
risky_history = risk_debate_state.get("risky_history", "")
|
||||
|
||||
current_safe_response = risk_debate_state.get("current_safe_response", "")
|
||||
current_neutral_response = risk_debate_state.get("current_neutral_response", "")
|
||||
|
||||
market_research_report = state["market_report"]
|
||||
sentiment_report = state["sentiment_report"]
|
||||
news_report = state["news_report"]
|
||||
fundamentals_report = state["fundamentals_report"]
|
||||
|
||||
trader_decision = state["trader_investment_plan"]
|
||||
|
||||
prompt = f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)
|
||||
|
||||
As the Risky Risk Analyst, your role is to actively champion high-reward, high-risk opportunities, emphasizing bold strategies and competitive advantages. When evaluating the trader's decision or plan, focus intently on the potential upside, growth potential, and innovative benefits—even when these come with elevated risk. Use the provided market data and sentiment analysis to strengthen your arguments and challenge the opposing views. Specifically, respond directly to each point made by the conservative and neutral analysts, countering with data-driven rebuttals and persuasive reasoning. Highlight where their caution might miss critical opportunities or where their assumptions may be overly conservative. Here is the trader's decision:
|
||||
|
||||
{trader_decision}
|
||||
|
||||
Your task is to create a compelling case for the trader's decision by questioning and critiquing the conservative and neutral stances to demonstrate why your high-reward perspective offers the best path forward. Incorporate insights from the following sources into your arguments:
|
||||
|
||||
Market Research Report: {market_research_report}
|
||||
Social Media Sentiment Report: {sentiment_report}
|
||||
Latest World Affairs Report: {news_report}
|
||||
Company Fundamentals Report: {fundamentals_report}
|
||||
Here is the current conversation history: {history} Here are the last arguments from the conservative analyst: {current_safe_response} Here are the last arguments from the neutral analyst: {current_neutral_response}. If there are no responses from the other viewpoints, do not halluncinate and just present your point.
|
||||
|
||||
Engage actively by addressing any specific concerns raised, refuting the weaknesses in their logic, and asserting the benefits of risk-taking to outpace market norms. Maintain a focus on debating and persuading, not just presenting data. Challenge each counterpoint to underscore why a high-risk approach is optimal. Output conversationally as if you are speaking without any special formatting."""
|
||||
|
||||
response = llm.invoke(prompt)
|
||||
|
||||
argument = f"Risky Analyst: {response.content}"
|
||||
|
||||
new_risk_debate_state = {
|
||||
"history": history + "\n" + argument,
|
||||
"risky_history": risky_history + "\n" + argument,
|
||||
"safe_history": risk_debate_state.get("safe_history", ""),
|
||||
"neutral_history": risk_debate_state.get("neutral_history", ""),
|
||||
"latest_speaker": "Risky",
|
||||
"current_risky_response": argument,
|
||||
"current_safe_response": risk_debate_state.get("current_safe_response", ""),
|
||||
"current_neutral_response": risk_debate_state.get(
|
||||
"current_neutral_response", ""
|
||||
),
|
||||
"count": risk_debate_state["count"] + 1,
|
||||
}
|
||||
|
||||
return {"risk_debate_state": new_risk_debate_state}
|
||||
|
||||
return risky_node
|
||||
|
|
|
|||
|
|
@ -1,60 +1,60 @@
|
|||
from langchain_core.messages import AIMessage
|
||||
import time
|
||||
import json
|
||||
|
||||
|
||||
def create_safe_debator(llm):
|
||||
def safe_node(state) -> dict:
|
||||
risk_debate_state = state["risk_debate_state"]
|
||||
history = risk_debate_state.get("history", "")
|
||||
safe_history = risk_debate_state.get("safe_history", "")
|
||||
|
||||
current_risky_response = risk_debate_state.get("current_risky_response", "")
|
||||
current_neutral_response = risk_debate_state.get("current_neutral_response", "")
|
||||
|
||||
market_research_report = state["market_report"]
|
||||
sentiment_report = state["sentiment_report"]
|
||||
news_report = state["news_report"]
|
||||
fundamentals_report = state["fundamentals_report"]
|
||||
|
||||
trader_decision = state["trader_investment_plan"]
|
||||
|
||||
prompt = f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)
|
||||
|
||||
As the Safe/Conservative Risk Analyst, your primary objective is to protect assets, minimize volatility, and ensure steady, reliable growth. You prioritize stability, security, and risk mitigation, carefully assessing potential losses, economic downturns, and market volatility. When evaluating the trader's decision or plan, critically examine high-risk elements, pointing out where the decision may expose the firm to undue risk and where more cautious alternatives could secure long-term gains. Here is the trader's decision:
|
||||
|
||||
{trader_decision}
|
||||
|
||||
Your task is to actively counter the arguments of the Risky and Neutral Analysts, highlighting where their views may overlook potential threats or fail to prioritize sustainability. Respond directly to their points, drawing from the following data sources to build a convincing case for a low-risk approach adjustment to the trader's decision:
|
||||
|
||||
Market Research Report: {market_research_report}
|
||||
Social Media Sentiment Report: {sentiment_report}
|
||||
Latest World Affairs Report: {news_report}
|
||||
Company Fundamentals Report: {fundamentals_report}
|
||||
Here is the current conversation history: {history} Here is the last response from the risky analyst: {current_risky_response} Here is the last response from the neutral analyst: {current_neutral_response}. If there are no responses from the other viewpoints, do not halluncinate and just present your point.
|
||||
|
||||
Engage by questioning their optimism and emphasizing the potential downsides they may have overlooked. Address each of their counterpoints to showcase why a conservative stance is ultimately the safest path for the firm's assets. Focus on debating and critiquing their arguments to demonstrate the strength of a low-risk strategy over their approaches. Output conversationally as if you are speaking without any special formatting."""
|
||||
|
||||
response = llm.invoke(prompt)
|
||||
|
||||
argument = f"Safe Analyst: {response.content}"
|
||||
|
||||
new_risk_debate_state = {
|
||||
"history": history + "\n" + argument,
|
||||
"risky_history": risk_debate_state.get("risky_history", ""),
|
||||
"safe_history": safe_history + "\n" + argument,
|
||||
"neutral_history": risk_debate_state.get("neutral_history", ""),
|
||||
"latest_speaker": "Safe",
|
||||
"current_risky_response": risk_debate_state.get(
|
||||
"current_risky_response", ""
|
||||
),
|
||||
"current_safe_response": argument,
|
||||
"current_neutral_response": risk_debate_state.get(
|
||||
"current_neutral_response", ""
|
||||
),
|
||||
"count": risk_debate_state["count"] + 1,
|
||||
}
|
||||
|
||||
return {"risk_debate_state": new_risk_debate_state}
|
||||
|
||||
return safe_node
|
||||
from langchain_core.messages import AIMessage
|
||||
import time
|
||||
import json
|
||||
|
||||
|
||||
def create_safe_debator(llm):
|
||||
def safe_node(state) -> dict:
|
||||
risk_debate_state = state["risk_debate_state"]
|
||||
history = risk_debate_state.get("history", "")
|
||||
safe_history = risk_debate_state.get("safe_history", "")
|
||||
|
||||
current_risky_response = risk_debate_state.get("current_risky_response", "")
|
||||
current_neutral_response = risk_debate_state.get("current_neutral_response", "")
|
||||
|
||||
market_research_report = state["market_report"]
|
||||
sentiment_report = state["sentiment_report"]
|
||||
news_report = state["news_report"]
|
||||
fundamentals_report = state["fundamentals_report"]
|
||||
|
||||
trader_decision = state["trader_investment_plan"]
|
||||
|
||||
prompt = f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)
|
||||
|
||||
As the Safe/Conservative Risk Analyst, your primary objective is to protect assets, minimize volatility, and ensure steady, reliable growth. You prioritize stability, security, and risk mitigation, carefully assessing potential losses, economic downturns, and market volatility. When evaluating the trader's decision or plan, critically examine high-risk elements, pointing out where the decision may expose the firm to undue risk and where more cautious alternatives could secure long-term gains. Here is the trader's decision:
|
||||
|
||||
{trader_decision}
|
||||
|
||||
Your task is to actively counter the arguments of the Risky and Neutral Analysts, highlighting where their views may overlook potential threats or fail to prioritize sustainability. Respond directly to their points, drawing from the following data sources to build a convincing case for a low-risk approach adjustment to the trader's decision:
|
||||
|
||||
Market Research Report: {market_research_report}
|
||||
Social Media Sentiment Report: {sentiment_report}
|
||||
Latest World Affairs Report: {news_report}
|
||||
Company Fundamentals Report: {fundamentals_report}
|
||||
Here is the current conversation history: {history} Here is the last response from the risky analyst: {current_risky_response} Here is the last response from the neutral analyst: {current_neutral_response}. If there are no responses from the other viewpoints, do not halluncinate and just present your point.
|
||||
|
||||
Engage by questioning their optimism and emphasizing the potential downsides they may have overlooked. Address each of their counterpoints to showcase why a conservative stance is ultimately the safest path for the firm's assets. Focus on debating and critiquing their arguments to demonstrate the strength of a low-risk strategy over their approaches. Output conversationally as if you are speaking without any special formatting."""
|
||||
|
||||
response = llm.invoke(prompt)
|
||||
|
||||
argument = f"Safe Analyst: {response.content}"
|
||||
|
||||
new_risk_debate_state = {
|
||||
"history": history + "\n" + argument,
|
||||
"risky_history": risk_debate_state.get("risky_history", ""),
|
||||
"safe_history": safe_history + "\n" + argument,
|
||||
"neutral_history": risk_debate_state.get("neutral_history", ""),
|
||||
"latest_speaker": "Safe",
|
||||
"current_risky_response": risk_debate_state.get(
|
||||
"current_risky_response", ""
|
||||
),
|
||||
"current_safe_response": argument,
|
||||
"current_neutral_response": risk_debate_state.get(
|
||||
"current_neutral_response", ""
|
||||
),
|
||||
"count": risk_debate_state["count"] + 1,
|
||||
}
|
||||
|
||||
return {"risk_debate_state": new_risk_debate_state}
|
||||
|
||||
return safe_node
|
||||
|
|
|
|||
|
|
@ -1,57 +1,57 @@
|
|||
import time
|
||||
import json
|
||||
|
||||
|
||||
def create_neutral_debator(llm):
|
||||
def neutral_node(state) -> dict:
|
||||
risk_debate_state = state["risk_debate_state"]
|
||||
history = risk_debate_state.get("history", "")
|
||||
neutral_history = risk_debate_state.get("neutral_history", "")
|
||||
|
||||
current_risky_response = risk_debate_state.get("current_risky_response", "")
|
||||
current_safe_response = risk_debate_state.get("current_safe_response", "")
|
||||
|
||||
market_research_report = state["market_report"]
|
||||
sentiment_report = state["sentiment_report"]
|
||||
news_report = state["news_report"]
|
||||
fundamentals_report = state["fundamentals_report"]
|
||||
|
||||
trader_decision = state["trader_investment_plan"]
|
||||
|
||||
prompt = f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)
|
||||
|
||||
As the Neutral Risk Analyst, your role is to provide a balanced perspective, weighing both the potential benefits and risks of the trader's decision or plan. You prioritize a well-rounded approach, evaluating the upsides and downsides while factoring in broader market trends, potential economic shifts, and diversification strategies.Here is the trader's decision:
|
||||
|
||||
{trader_decision}
|
||||
|
||||
Your task is to challenge both the Risky and Safe Analysts, pointing out where each perspective may be overly optimistic or overly cautious. Use insights from the following data sources to support a moderate, sustainable strategy to adjust the trader's decision:
|
||||
|
||||
Market Research Report: {market_research_report}
|
||||
Social Media Sentiment Report: {sentiment_report}
|
||||
Latest World Affairs Report: {news_report}
|
||||
Company Fundamentals Report: {fundamentals_report}
|
||||
Here is the current conversation history: {history} Here is the last response from the risky analyst: {current_risky_response} Here is the last response from the safe analyst: {current_safe_response}. If there are no responses from the other viewpoints, do not halluncinate and just present your point.
|
||||
|
||||
Engage actively by analyzing both sides critically, addressing weaknesses in the risky and conservative arguments to advocate for a more balanced approach. Challenge each of their points to illustrate why a moderate risk strategy might offer the best of both worlds, providing growth potential while safeguarding against extreme volatility. Focus on debating rather than simply presenting data, aiming to show that a balanced view can lead to the most reliable outcomes. Output conversationally as if you are speaking without any special formatting."""
|
||||
|
||||
response = llm.invoke(prompt)
|
||||
|
||||
argument = f"Neutral Analyst: {response.content}"
|
||||
|
||||
new_risk_debate_state = {
|
||||
"history": history + "\n" + argument,
|
||||
"risky_history": risk_debate_state.get("risky_history", ""),
|
||||
"safe_history": risk_debate_state.get("safe_history", ""),
|
||||
"neutral_history": neutral_history + "\n" + argument,
|
||||
"latest_speaker": "Neutral",
|
||||
"current_risky_response": risk_debate_state.get(
|
||||
"current_risky_response", ""
|
||||
),
|
||||
"current_safe_response": risk_debate_state.get("current_safe_response", ""),
|
||||
"current_neutral_response": argument,
|
||||
"count": risk_debate_state["count"] + 1,
|
||||
}
|
||||
|
||||
return {"risk_debate_state": new_risk_debate_state}
|
||||
|
||||
return neutral_node
|
||||
import time
|
||||
import json
|
||||
|
||||
|
||||
def create_neutral_debator(llm):
|
||||
def neutral_node(state) -> dict:
|
||||
risk_debate_state = state["risk_debate_state"]
|
||||
history = risk_debate_state.get("history", "")
|
||||
neutral_history = risk_debate_state.get("neutral_history", "")
|
||||
|
||||
current_risky_response = risk_debate_state.get("current_risky_response", "")
|
||||
current_safe_response = risk_debate_state.get("current_safe_response", "")
|
||||
|
||||
market_research_report = state["market_report"]
|
||||
sentiment_report = state["sentiment_report"]
|
||||
news_report = state["news_report"]
|
||||
fundamentals_report = state["fundamentals_report"]
|
||||
|
||||
trader_decision = state["trader_investment_plan"]
|
||||
|
||||
prompt = f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)
|
||||
|
||||
As the Neutral Risk Analyst, your role is to provide a balanced perspective, weighing both the potential benefits and risks of the trader's decision or plan. You prioritize a well-rounded approach, evaluating the upsides and downsides while factoring in broader market trends, potential economic shifts, and diversification strategies.Here is the trader's decision:
|
||||
|
||||
{trader_decision}
|
||||
|
||||
Your task is to challenge both the Risky and Safe Analysts, pointing out where each perspective may be overly optimistic or overly cautious. Use insights from the following data sources to support a moderate, sustainable strategy to adjust the trader's decision:
|
||||
|
||||
Market Research Report: {market_research_report}
|
||||
Social Media Sentiment Report: {sentiment_report}
|
||||
Latest World Affairs Report: {news_report}
|
||||
Company Fundamentals Report: {fundamentals_report}
|
||||
Here is the current conversation history: {history} Here is the last response from the risky analyst: {current_risky_response} Here is the last response from the safe analyst: {current_safe_response}. If there are no responses from the other viewpoints, do not halluncinate and just present your point.
|
||||
|
||||
Engage actively by analyzing both sides critically, addressing weaknesses in the risky and conservative arguments to advocate for a more balanced approach. Challenge each of their points to illustrate why a moderate risk strategy might offer the best of both worlds, providing growth potential while safeguarding against extreme volatility. Focus on debating rather than simply presenting data, aiming to show that a balanced view can lead to the most reliable outcomes. Output conversationally as if you are speaking without any special formatting."""
|
||||
|
||||
response = llm.invoke(prompt)
|
||||
|
||||
argument = f"Neutral Analyst: {response.content}"
|
||||
|
||||
new_risk_debate_state = {
|
||||
"history": history + "\n" + argument,
|
||||
"risky_history": risk_debate_state.get("risky_history", ""),
|
||||
"safe_history": risk_debate_state.get("safe_history", ""),
|
||||
"neutral_history": neutral_history + "\n" + argument,
|
||||
"latest_speaker": "Neutral",
|
||||
"current_risky_response": risk_debate_state.get(
|
||||
"current_risky_response", ""
|
||||
),
|
||||
"current_safe_response": risk_debate_state.get("current_safe_response", ""),
|
||||
"current_neutral_response": argument,
|
||||
"count": risk_debate_state["count"] + 1,
|
||||
}
|
||||
|
||||
return {"risk_debate_state": new_risk_debate_state}
|
||||
|
||||
return neutral_node
|
||||
|
|
|
|||
|
|
@ -1,45 +1,45 @@
|
|||
import functools
|
||||
import time
|
||||
import json
|
||||
|
||||
|
||||
def create_trader(llm, memory):
|
||||
def trader_node(state, name):
|
||||
company_name = state["company_of_interest"]
|
||||
investment_plan = state["investment_plan"]
|
||||
market_research_report = state["market_report"]
|
||||
sentiment_report = state["sentiment_report"]
|
||||
news_report = state["news_report"]
|
||||
fundamentals_report = state["fundamentals_report"]
|
||||
|
||||
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
|
||||
past_memories = memory.get_memories(curr_situation, n_matches=2)
|
||||
|
||||
past_memory_str = ""
|
||||
for i, rec in enumerate(past_memories, 1):
|
||||
past_memory_str += rec["recommendation"] + "\n\n"
|
||||
|
||||
context = {
|
||||
"role": "user",
|
||||
"content": f"Based on a comprehensive analysis by a team of analysts, here is an investment plan tailored for {company_name}. This plan incorporates insights from current technical market trends, macroeconomic indicators, and social media sentiment. Use this plan as a foundation for evaluating your next trading decision.\n\nProposed Investment Plan: {investment_plan}\n\nLeverage these insights to make an informed and strategic decision.",
|
||||
}
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)
|
||||
|
||||
You are a trading agent analyzing market data to make investment decisions. Based on your analysis, provide a specific recommendation to buy, sell, or hold. End with a firm decision and always conclude your response with 'FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL**' to confirm your recommendation. Do not forget to utilize lessons from past decisions to learn from your mistakes. Here is some reflections from similar situatiosn you traded in and the lessons learned: {past_memory_str}""",
|
||||
},
|
||||
context,
|
||||
]
|
||||
|
||||
result = llm.invoke(messages)
|
||||
|
||||
return {
|
||||
"messages": [result],
|
||||
"trader_investment_plan": result.content,
|
||||
"sender": name,
|
||||
}
|
||||
|
||||
return functools.partial(trader_node, name="Trader")
|
||||
import functools
|
||||
import time
|
||||
import json
|
||||
|
||||
|
||||
def create_trader(llm, memory):
|
||||
def trader_node(state, name):
|
||||
company_name = state["company_of_interest"]
|
||||
investment_plan = state["investment_plan"]
|
||||
market_research_report = state["market_report"]
|
||||
sentiment_report = state["sentiment_report"]
|
||||
news_report = state["news_report"]
|
||||
fundamentals_report = state["fundamentals_report"]
|
||||
|
||||
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
|
||||
past_memories = memory.get_memories(curr_situation, n_matches=2)
|
||||
|
||||
past_memory_str = ""
|
||||
for i, rec in enumerate(past_memories, 1):
|
||||
past_memory_str += rec["recommendation"] + "\n\n"
|
||||
|
||||
context = {
|
||||
"role": "user",
|
||||
"content": f"Based on a comprehensive analysis by a team of analysts, here is an investment plan tailored for {company_name}. This plan incorporates insights from current technical market trends, macroeconomic indicators, and social media sentiment. Use this plan as a foundation for evaluating your next trading decision.\n\nProposed Investment Plan: {investment_plan}\n\nLeverage these insights to make an informed and strategic decision.",
|
||||
}
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)
|
||||
|
||||
You are a trading agent analyzing market data to make investment decisions. Based on your analysis, provide a specific recommendation to buy, sell, or hold. End with a firm decision and always conclude your response with 'FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL**' to confirm your recommendation. Do not forget to utilize lessons from past decisions to learn from your mistakes. Here is some reflections from similar situatiosn you traded in and the lessons learned: {past_memory_str}""",
|
||||
},
|
||||
context,
|
||||
]
|
||||
|
||||
result = llm.invoke(messages)
|
||||
|
||||
return {
|
||||
"messages": [result],
|
||||
"trader_investment_plan": result.content,
|
||||
"sender": name,
|
||||
}
|
||||
|
||||
return functools.partial(trader_node, name="Trader")
|
||||
|
|
|
|||
|
|
@ -1,20 +1,20 @@
|
|||
from .embedding_providers import (
|
||||
EmbeddingProvider,
|
||||
OpenAIEmbeddingProvider,
|
||||
GeminiEmbeddingProvider,
|
||||
OllamaEmbeddingProvider
|
||||
)
|
||||
from typing import Any
|
||||
|
||||
class EmbeddingProviderFactory:
|
||||
@staticmethod
|
||||
def create_provider(config : dict[str, Any])->EmbeddingProvider:
|
||||
backend_url = config["backend_url"]
|
||||
|
||||
if "generativelanguage.googleapis.com" in backend_url:
|
||||
return GeminiEmbeddingProvider(backend_url)
|
||||
elif "localhost:11434" in backend_url:
|
||||
return OllamaEmbeddingProvider(backend_url)
|
||||
else:
|
||||
return OpenAIEmbeddingProvider(backend_url)
|
||||
from .embedding_providers import (
|
||||
EmbeddingProvider,
|
||||
OpenAIEmbeddingProvider,
|
||||
GeminiEmbeddingProvider,
|
||||
OllamaEmbeddingProvider
|
||||
)
|
||||
from typing import Any
|
||||
|
||||
class EmbeddingProviderFactory:
|
||||
@staticmethod
|
||||
def create_provider(config : dict[str, Any])->EmbeddingProvider:
|
||||
backend_url = config["backend_url"]
|
||||
|
||||
if "generativelanguage.googleapis.com" in backend_url:
|
||||
return GeminiEmbeddingProvider(backend_url)
|
||||
elif "localhost:11434" in backend_url:
|
||||
return OllamaEmbeddingProvider(backend_url)
|
||||
else:
|
||||
return OpenAIEmbeddingProvider(backend_url)
|
||||
|
||||
|
|
@ -1,66 +1,66 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from openai import OpenAI
|
||||
from google import genai
|
||||
|
||||
|
||||
class EmbeddingProvider(ABC):
|
||||
@abstractmethod
|
||||
def get_embedding(self, text: str)->list[float]:
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def model_name(self)->str:
|
||||
pass
|
||||
|
||||
|
||||
class OpenAIEmbeddingProvider(EmbeddingProvider):
|
||||
def __init__(self, backend_url: str, embedding_model: str = "text-embedding-3-small"):
|
||||
self.client = OpenAI(base_url=backend_url)
|
||||
self._embedding_model = embedding_model
|
||||
|
||||
|
||||
def get_embedding(self, text: str)->list[float]:
|
||||
response = self.client.embeddings.create(
|
||||
model=self._embedding_model,
|
||||
input=text
|
||||
)
|
||||
return response.data[0].embedding
|
||||
|
||||
@property
|
||||
def model_name(self)->str:
|
||||
return self._embedding_model
|
||||
|
||||
|
||||
class GeminiEmbeddingProvider(EmbeddingProvider):
|
||||
def __init__(self, backend_url: str, embedding_model: str = "gemini-embedding-exp-03-07"):
|
||||
self.client = genai.Client()
|
||||
self._embedding_model = embedding_model
|
||||
|
||||
def get_embedding(self, text: str)->list[float]:
|
||||
response = self.client.models.embed_content(
|
||||
model=self._embedding_model,
|
||||
contents=text
|
||||
)
|
||||
return response.embeddings[0].values
|
||||
|
||||
@property
|
||||
def model_name(self)->str:
|
||||
return self._embedding_model
|
||||
|
||||
class OllamaEmbeddingProvider(EmbeddingProvider):
|
||||
def __init__(self, backend_url: str, embedding_model: str = "nomic-embed-text"):
|
||||
self.client = OpenAI(base_url=backend_url)
|
||||
self._embedding_model = embedding_model
|
||||
|
||||
def get_embedding(self, text: str)->list[float]:
|
||||
response = self.client.embeddings.create(
|
||||
model=self._embedding_model,
|
||||
input=text
|
||||
)
|
||||
return response.data[0].embedding
|
||||
|
||||
@property
|
||||
def model_name(self)->str:
|
||||
return self._embedding_model
|
||||
from abc import ABC, abstractmethod
|
||||
from openai import OpenAI
|
||||
from google import genai
|
||||
|
||||
|
||||
class EmbeddingProvider(ABC):
|
||||
@abstractmethod
|
||||
def get_embedding(self, text: str)->list[float]:
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def model_name(self)->str:
|
||||
pass
|
||||
|
||||
|
||||
class OpenAIEmbeddingProvider(EmbeddingProvider):
|
||||
def __init__(self, backend_url: str, embedding_model: str = "text-embedding-3-small"):
|
||||
self.client = OpenAI(base_url=backend_url)
|
||||
self._embedding_model = embedding_model
|
||||
|
||||
|
||||
def get_embedding(self, text: str)->list[float]:
|
||||
response = self.client.embeddings.create(
|
||||
model=self._embedding_model,
|
||||
input=text
|
||||
)
|
||||
return response.data[0].embedding
|
||||
|
||||
@property
|
||||
def model_name(self)->str:
|
||||
return self._embedding_model
|
||||
|
||||
|
||||
class GeminiEmbeddingProvider(EmbeddingProvider):
|
||||
def __init__(self, backend_url: str, embedding_model: str = "gemini-embedding-exp-03-07"):
|
||||
self.client = genai.Client()
|
||||
self._embedding_model = embedding_model
|
||||
|
||||
def get_embedding(self, text: str)->list[float]:
|
||||
response = self.client.models.embed_content(
|
||||
model=self._embedding_model,
|
||||
contents=text
|
||||
)
|
||||
return response.embeddings[0].values
|
||||
|
||||
@property
|
||||
def model_name(self)->str:
|
||||
return self._embedding_model
|
||||
|
||||
class OllamaEmbeddingProvider(EmbeddingProvider):
|
||||
def __init__(self, backend_url: str, embedding_model: str = "nomic-embed-text"):
|
||||
self.client = OpenAI(base_url=backend_url)
|
||||
self._embedding_model = embedding_model
|
||||
|
||||
def get_embedding(self, text: str)->list[float]:
|
||||
response = self.client.embeddings.create(
|
||||
model=self._embedding_model,
|
||||
input=text
|
||||
)
|
||||
return response.data[0].embedding
|
||||
|
||||
@property
|
||||
def model_name(self)->str:
|
||||
return self._embedding_model
|
||||
|
||||
|
|
@ -1,112 +1,112 @@
|
|||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
from openai import OpenAI
|
||||
import os
|
||||
from .embedding_provider_factory import EmbeddingProviderFactory
|
||||
from google import genai
|
||||
|
||||
class FinancialSituationMemory:
|
||||
def __init__(self, name, config):
|
||||
self.config = config
|
||||
self.backend_url = config["backend_url"]
|
||||
|
||||
self.embedding_provider = EmbeddingProviderFactory.create_provider(config)
|
||||
|
||||
self.chroma_client = chromadb.Client(Settings(allow_reset=True))
|
||||
self.situation_collection = self.chroma_client.create_collection(name=name)
|
||||
|
||||
def get_embedding(self, text):
|
||||
"""Get embedding for a text using the appropriate API"""
|
||||
|
||||
return self.embedding_provider.get_embedding(text)
|
||||
|
||||
def add_situations(self, situations_and_advice):
|
||||
"""Add financial situations and their corresponding advice. Parameter is a list of tuples (situation, rec)"""
|
||||
|
||||
situations = []
|
||||
advice = []
|
||||
ids = []
|
||||
embeddings = []
|
||||
|
||||
offset = self.situation_collection.count()
|
||||
|
||||
for i, (situation, recommendation) in enumerate(situations_and_advice):
|
||||
situations.append(situation)
|
||||
advice.append(recommendation)
|
||||
ids.append(str(offset + i))
|
||||
embeddings.append(self.get_embedding(situation))
|
||||
|
||||
self.situation_collection.add(
|
||||
documents=situations,
|
||||
metadatas=[{"recommendation": rec} for rec in advice],
|
||||
embeddings=embeddings,
|
||||
ids=ids,
|
||||
)
|
||||
|
||||
def get_memories(self, current_situation, n_matches=1):
|
||||
"""Find matching recommendations using embeddings"""
|
||||
query_embedding = self.get_embedding(current_situation)
|
||||
|
||||
results = self.situation_collection.query(
|
||||
query_embeddings=[query_embedding],
|
||||
n_results=n_matches,
|
||||
include=["metadatas", "documents", "distances"],
|
||||
)
|
||||
|
||||
matched_results = []
|
||||
for i in range(len(results["documents"][0])):
|
||||
matched_results.append(
|
||||
{
|
||||
"matched_situation": results["documents"][0][i],
|
||||
"recommendation": results["metadatas"][0][i]["recommendation"],
|
||||
"similarity_score": 1 - results["distances"][0][i],
|
||||
}
|
||||
)
|
||||
|
||||
return matched_results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Example usage
|
||||
matcher = FinancialSituationMemory()
|
||||
|
||||
# Example data
|
||||
example_data = [
|
||||
(
|
||||
"High inflation rate with rising interest rates and declining consumer spending",
|
||||
"Consider defensive sectors like consumer staples and utilities. Review fixed-income portfolio duration.",
|
||||
),
|
||||
(
|
||||
"Tech sector showing high volatility with increasing institutional selling pressure",
|
||||
"Reduce exposure to high-growth tech stocks. Look for value opportunities in established tech companies with strong cash flows.",
|
||||
),
|
||||
(
|
||||
"Strong dollar affecting emerging markets with increasing forex volatility",
|
||||
"Hedge currency exposure in international positions. Consider reducing allocation to emerging market debt.",
|
||||
),
|
||||
(
|
||||
"Market showing signs of sector rotation with rising yields",
|
||||
"Rebalance portfolio to maintain target allocations. Consider increasing exposure to sectors benefiting from higher rates.",
|
||||
),
|
||||
]
|
||||
|
||||
# Add the example situations and recommendations
|
||||
matcher.add_situations(example_data)
|
||||
|
||||
# Example query
|
||||
current_situation = """
|
||||
Market showing increased volatility in tech sector, with institutional investors
|
||||
reducing positions and rising interest rates affecting growth stock valuations
|
||||
"""
|
||||
|
||||
try:
|
||||
recommendations = matcher.get_memories(current_situation, n_matches=2)
|
||||
|
||||
for i, rec in enumerate(recommendations, 1):
|
||||
print(f"\nMatch {i}:")
|
||||
print(f"Similarity Score: {rec['similarity_score']:.2f}")
|
||||
print(f"Matched Situation: {rec['matched_situation']}")
|
||||
print(f"Recommendation: {rec['recommendation']}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error during recommendation: {str(e)}")
|
||||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
from openai import OpenAI
|
||||
import os
|
||||
from .embedding_provider_factory import EmbeddingProviderFactory
|
||||
from google import genai
|
||||
|
||||
class FinancialSituationMemory:
|
||||
def __init__(self, name, config):
|
||||
self.config = config
|
||||
self.backend_url = config["backend_url"]
|
||||
|
||||
self.embedding_provider = EmbeddingProviderFactory.create_provider(config)
|
||||
|
||||
self.chroma_client = chromadb.Client(Settings(allow_reset=True))
|
||||
self.situation_collection = self.chroma_client.create_collection(name=name)
|
||||
|
||||
def get_embedding(self, text):
|
||||
"""Get embedding for a text using the appropriate API"""
|
||||
|
||||
return self.embedding_provider.get_embedding(text)
|
||||
|
||||
def add_situations(self, situations_and_advice):
|
||||
"""Add financial situations and their corresponding advice. Parameter is a list of tuples (situation, rec)"""
|
||||
|
||||
situations = []
|
||||
advice = []
|
||||
ids = []
|
||||
embeddings = []
|
||||
|
||||
offset = self.situation_collection.count()
|
||||
|
||||
for i, (situation, recommendation) in enumerate(situations_and_advice):
|
||||
situations.append(situation)
|
||||
advice.append(recommendation)
|
||||
ids.append(str(offset + i))
|
||||
embeddings.append(self.get_embedding(situation))
|
||||
|
||||
self.situation_collection.add(
|
||||
documents=situations,
|
||||
metadatas=[{"recommendation": rec} for rec in advice],
|
||||
embeddings=embeddings,
|
||||
ids=ids,
|
||||
)
|
||||
|
||||
def get_memories(self, current_situation, n_matches=1):
|
||||
"""Find matching recommendations using embeddings"""
|
||||
query_embedding = self.get_embedding(current_situation)
|
||||
|
||||
results = self.situation_collection.query(
|
||||
query_embeddings=[query_embedding],
|
||||
n_results=n_matches,
|
||||
include=["metadatas", "documents", "distances"],
|
||||
)
|
||||
|
||||
matched_results = []
|
||||
for i in range(len(results["documents"][0])):
|
||||
matched_results.append(
|
||||
{
|
||||
"matched_situation": results["documents"][0][i],
|
||||
"recommendation": results["metadatas"][0][i]["recommendation"],
|
||||
"similarity_score": 1 - results["distances"][0][i],
|
||||
}
|
||||
)
|
||||
|
||||
return matched_results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Example usage
|
||||
matcher = FinancialSituationMemory()
|
||||
|
||||
# Example data
|
||||
example_data = [
|
||||
(
|
||||
"High inflation rate with rising interest rates and declining consumer spending",
|
||||
"Consider defensive sectors like consumer staples and utilities. Review fixed-income portfolio duration.",
|
||||
),
|
||||
(
|
||||
"Tech sector showing high volatility with increasing institutional selling pressure",
|
||||
"Reduce exposure to high-growth tech stocks. Look for value opportunities in established tech companies with strong cash flows.",
|
||||
),
|
||||
(
|
||||
"Strong dollar affecting emerging markets with increasing forex volatility",
|
||||
"Hedge currency exposure in international positions. Consider reducing allocation to emerging market debt.",
|
||||
),
|
||||
(
|
||||
"Market showing signs of sector rotation with rising yields",
|
||||
"Rebalance portfolio to maintain target allocations. Consider increasing exposure to sectors benefiting from higher rates.",
|
||||
),
|
||||
]
|
||||
|
||||
# Add the example situations and recommendations
|
||||
matcher.add_situations(example_data)
|
||||
|
||||
# Example query
|
||||
current_situation = """
|
||||
Market showing increased volatility in tech sector, with institutional investors
|
||||
reducing positions and rising interest rates affecting growth stock valuations
|
||||
"""
|
||||
|
||||
try:
|
||||
recommendations = matcher.get_memories(current_situation, n_matches=2)
|
||||
|
||||
for i, rec in enumerate(recommendations, 1):
|
||||
print(f"\nMatch {i}:")
|
||||
print(f"Similarity Score: {rec['similarity_score']:.2f}")
|
||||
print(f"Matched Situation: {rec['matched_situation']}")
|
||||
print(f"Recommendation: {rec['recommendation']}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error during recommendation: {str(e)}")
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
|
@ -1,76 +1,76 @@
|
|||
from google import genai
|
||||
from google.genai.types import Tool, GenerateContentConfig, GoogleSearch
|
||||
from openai import OpenAI
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
|
||||
class SearchProvider(ABC):
|
||||
@abstractmethod
|
||||
def search(self, query: str, ticker: str, curr_date: str) -> str:
|
||||
pass
|
||||
|
||||
|
||||
class GoogleSearchProvider(SearchProvider):
|
||||
def __init__(self, model: str):
|
||||
self.client = genai.Client()
|
||||
self.model = model
|
||||
|
||||
def search(self, query: str) -> str:
|
||||
google_search_tool = Tool(
|
||||
google_search=GoogleSearch()
|
||||
)
|
||||
|
||||
response = self.client.models.generate_content(
|
||||
model=self.model,
|
||||
contents=query,
|
||||
config=GenerateContentConfig(
|
||||
tools=[google_search_tool],
|
||||
response_modalities=["TEXT"]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
result_text = ""
|
||||
for part in response.candidates[0].content.parts:
|
||||
if hasattr(part, 'text'):
|
||||
result_text += part.text
|
||||
|
||||
return result_text
|
||||
|
||||
|
||||
class OpenAISearchProvider(SearchProvider):
|
||||
def __init__(self, model: str, backend_url: str):
|
||||
self.client = OpenAI(base_url=backend_url)
|
||||
self.model = model
|
||||
|
||||
def search(self, query: str) -> str:
|
||||
response = self.client.responses.create(
|
||||
model=self.model,
|
||||
input=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{
|
||||
"type": "input_text",
|
||||
"text": query
|
||||
}
|
||||
],
|
||||
}
|
||||
],
|
||||
text={"format": {"type": "text"}},
|
||||
reasoning={},
|
||||
tools=[
|
||||
{
|
||||
"type": "web_search_preview",
|
||||
"user_location": {"type": "approximate"},
|
||||
"search_context_size": "low",
|
||||
}
|
||||
],
|
||||
temperature=1,
|
||||
max_output_tokens=4096,
|
||||
top_p=1,
|
||||
store=True,
|
||||
)
|
||||
|
||||
from google import genai
|
||||
from google.genai.types import Tool, GenerateContentConfig, GoogleSearch
|
||||
from openai import OpenAI
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
|
||||
class SearchProvider(ABC):
|
||||
@abstractmethod
|
||||
def search(self, query: str, ticker: str, curr_date: str) -> str:
|
||||
pass
|
||||
|
||||
|
||||
class GoogleSearchProvider(SearchProvider):
|
||||
def __init__(self, model: str):
|
||||
self.client = genai.Client()
|
||||
self.model = model
|
||||
|
||||
def search(self, query: str) -> str:
|
||||
google_search_tool = Tool(
|
||||
google_search=GoogleSearch()
|
||||
)
|
||||
|
||||
response = self.client.models.generate_content(
|
||||
model=self.model,
|
||||
contents=query,
|
||||
config=GenerateContentConfig(
|
||||
tools=[google_search_tool],
|
||||
response_modalities=["TEXT"]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
result_text = ""
|
||||
for part in response.candidates[0].content.parts:
|
||||
if hasattr(part, 'text'):
|
||||
result_text += part.text
|
||||
|
||||
return result_text
|
||||
|
||||
|
||||
class OpenAISearchProvider(SearchProvider):
|
||||
def __init__(self, model: str, backend_url: str):
|
||||
self.client = OpenAI(base_url=backend_url)
|
||||
self.model = model
|
||||
|
||||
def search(self, query: str) -> str:
|
||||
response = self.client.responses.create(
|
||||
model=self.model,
|
||||
input=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{
|
||||
"type": "input_text",
|
||||
"text": query
|
||||
}
|
||||
],
|
||||
}
|
||||
],
|
||||
text={"format": {"type": "text"}},
|
||||
reasoning={},
|
||||
tools=[
|
||||
{
|
||||
"type": "web_search_preview",
|
||||
"user_location": {"type": "approximate"},
|
||||
"search_context_size": "low",
|
||||
}
|
||||
],
|
||||
temperature=1,
|
||||
max_output_tokens=4096,
|
||||
top_p=1,
|
||||
store=True,
|
||||
)
|
||||
|
||||
return response.output[1].content[0].text
|
||||
|
|
@ -1,47 +1,47 @@
|
|||
from .search_provider import (
|
||||
SearchProvider,
|
||||
GoogleSearchProvider,
|
||||
OpenAISearchProvider
|
||||
)
|
||||
import hashlib
|
||||
import json
|
||||
|
||||
|
||||
class SearchProviderFactory:
|
||||
_cache = {} # 클래스 레벨 캐시
|
||||
|
||||
@staticmethod
|
||||
def create_provider(config: dict[str, any]) -> SearchProvider:
|
||||
"""
|
||||
Create a SearchProvider with caching to avoid creating new instances.
|
||||
Uses config hash as cache key for efficient reuse.
|
||||
"""
|
||||
# Create cache key from relevant config values
|
||||
cache_key_data = {
|
||||
"backend_url": config["backend_url"],
|
||||
"model": config["quick_think_llm"]
|
||||
}
|
||||
cache_key = hashlib.md5(json.dumps(cache_key_data, sort_keys=True).encode()).hexdigest()
|
||||
|
||||
# Return cached instance if exists
|
||||
if cache_key in SearchProviderFactory._cache:
|
||||
return SearchProviderFactory._cache[cache_key]
|
||||
|
||||
# Create new instance
|
||||
backend_url = config["backend_url"]
|
||||
model = config["quick_think_llm"]
|
||||
|
||||
if "generativelanguage.googleapis.com" in backend_url:
|
||||
provider = GoogleSearchProvider(model)
|
||||
else:
|
||||
provider = OpenAISearchProvider(model, backend_url)
|
||||
|
||||
# Cache and return
|
||||
SearchProviderFactory._cache[cache_key] = provider
|
||||
return provider
|
||||
|
||||
@staticmethod
|
||||
def clear_cache():
|
||||
"""Clear the provider cache (useful for testing or config changes)."""
|
||||
SearchProviderFactory._cache.clear()
|
||||
|
||||
from .search_provider import (
|
||||
SearchProvider,
|
||||
GoogleSearchProvider,
|
||||
OpenAISearchProvider
|
||||
)
|
||||
import hashlib
|
||||
import json
|
||||
|
||||
|
||||
class SearchProviderFactory:
|
||||
_cache = {} # 클래스 레벨 캐시
|
||||
|
||||
@staticmethod
|
||||
def create_provider(config: dict[str, any]) -> SearchProvider:
|
||||
"""
|
||||
Create a SearchProvider with caching to avoid creating new instances.
|
||||
Uses config hash as cache key for efficient reuse.
|
||||
"""
|
||||
# Create cache key from relevant config values
|
||||
cache_key_data = {
|
||||
"backend_url": config["backend_url"],
|
||||
"model": config["quick_think_llm"]
|
||||
}
|
||||
cache_key = hashlib.md5(json.dumps(cache_key_data, sort_keys=True).encode()).hexdigest()
|
||||
|
||||
# Return cached instance if exists
|
||||
if cache_key in SearchProviderFactory._cache:
|
||||
return SearchProviderFactory._cache[cache_key]
|
||||
|
||||
# Create new instance
|
||||
backend_url = config["backend_url"]
|
||||
model = config["quick_think_llm"]
|
||||
|
||||
if "generativelanguage.googleapis.com" in backend_url:
|
||||
provider = GoogleSearchProvider(model)
|
||||
else:
|
||||
provider = OpenAISearchProvider(model, backend_url)
|
||||
|
||||
# Cache and return
|
||||
SearchProviderFactory._cache[cache_key] = provider
|
||||
return provider
|
||||
|
||||
@staticmethod
|
||||
def clear_cache():
|
||||
"""Clear the provider cache (useful for testing or config changes)."""
|
||||
SearchProviderFactory._cache.clear()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,31 +1,31 @@
|
|||
# TradingAgents/graph/signal_processing.py
|
||||
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
|
||||
class SignalProcessor:
|
||||
"""Processes trading signals to extract actionable decisions."""
|
||||
|
||||
def __init__(self, quick_thinking_llm: ChatOpenAI):
|
||||
"""Initialize with an LLM for processing."""
|
||||
self.quick_thinking_llm = quick_thinking_llm
|
||||
|
||||
def process_signal(self, full_signal: str) -> str:
|
||||
"""
|
||||
Process a full trading signal to extract the core decision.
|
||||
|
||||
Args:
|
||||
full_signal: Complete trading signal text
|
||||
|
||||
Returns:
|
||||
Extracted decision (BUY, SELL, or HOLD)
|
||||
"""
|
||||
messages = [
|
||||
(
|
||||
"system",
|
||||
"**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)\n\nYou are an efficient assistant designed to analyze paragraphs or financial reports provided by a group of analysts. Your task is to extract the investment decision: SELL, BUY, or HOLD. Provide only the extracted decision (SELL, BUY, or HOLD) as your output, without adding any additional text or information.",
|
||||
),
|
||||
("human", full_signal),
|
||||
]
|
||||
|
||||
return self.quick_thinking_llm.invoke(messages).content
|
||||
# TradingAgents/graph/signal_processing.py
|
||||
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
|
||||
class SignalProcessor:
|
||||
"""Processes trading signals to extract actionable decisions."""
|
||||
|
||||
def __init__(self, quick_thinking_llm: ChatOpenAI):
|
||||
"""Initialize with an LLM for processing."""
|
||||
self.quick_thinking_llm = quick_thinking_llm
|
||||
|
||||
def process_signal(self, full_signal: str) -> str:
|
||||
"""
|
||||
Process a full trading signal to extract the core decision.
|
||||
|
||||
Args:
|
||||
full_signal: Complete trading signal text
|
||||
|
||||
Returns:
|
||||
Extracted decision (BUY, SELL, or HOLD)
|
||||
"""
|
||||
messages = [
|
||||
(
|
||||
"system",
|
||||
"**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)\n\nYou are an efficient assistant designed to analyze paragraphs or financial reports provided by a group of analysts. Your task is to extract the investment decision: SELL, BUY, or HOLD. Provide only the extracted decision (SELL, BUY, or HOLD) as your output, without adding any additional text or information.",
|
||||
),
|
||||
("human", full_signal),
|
||||
]
|
||||
|
||||
return self.quick_thinking_llm.invoke(messages).content
|
||||
|
|
|
|||
Loading…
Reference in New Issue