[add] remove app
This commit is contained in:
parent
d35b62e999
commit
c3e609730b
|
|
@ -8,4 +8,5 @@ eval_data/
|
||||||
*.egg-info/
|
*.egg-info/
|
||||||
results/
|
results/
|
||||||
.env
|
.env
|
||||||
tradingagents/dataflows/data_cache/
|
tradingagents/dataflows/data_cache/
|
||||||
|
CLAUDE.md
|
||||||
|
|
|
||||||
|
|
@ -1,67 +0,0 @@
|
||||||
from typing import Generator, Optional
|
|
||||||
from fastapi import Depends, HTTPException, status
|
|
||||||
from fastapi.security import OAuth2PasswordBearer
|
|
||||||
from jose import jwt, JWTError
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from sqlmodel import Session
|
|
||||||
|
|
||||||
from app.core.config import settings
|
|
||||||
from app.infrastructure.database import get_db
|
|
||||||
from app.domain.models import User
|
|
||||||
from app.infrastructure.repositories.user import UserRepository
|
|
||||||
from app.core.services.trading_analysis import TradingAnalysisService
|
|
||||||
|
|
||||||
reusable_oauth2 = OAuth2PasswordBearer(tokenUrl=f"{settings.API_V1_STR}/login/access-token")
|
|
||||||
|
|
||||||
class TokenData(BaseModel):
|
|
||||||
username: Optional[str] = None
|
|
||||||
|
|
||||||
def get_user_repository(db: Session = Depends(get_db)) -> UserRepository:
|
|
||||||
return UserRepository(db)
|
|
||||||
|
|
||||||
def get_user_from_token(token: str, db: Session) -> Optional[User]:
|
|
||||||
try:
|
|
||||||
payload = jwt.decode(
|
|
||||||
token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
|
|
||||||
)
|
|
||||||
token_data = TokenData(username=payload.get("sub"))
|
|
||||||
except JWTError:
|
|
||||||
return None
|
|
||||||
|
|
||||||
user_repo = UserRepository(db)
|
|
||||||
user = user_repo.get_by_email(email=token_data.username)
|
|
||||||
return user
|
|
||||||
|
|
||||||
def get_current_user(
|
|
||||||
db: Session = Depends(get_db), token: str = Depends(reusable_oauth2)
|
|
||||||
) -> User:
|
|
||||||
user = get_user_from_token(token=token, db=db)
|
|
||||||
if not user:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
||||||
detail="Could not validate credentials",
|
|
||||||
headers={"WWW-Authenticate": "Bearer"},
|
|
||||||
)
|
|
||||||
return user
|
|
||||||
|
|
||||||
def get_current_active_user(
|
|
||||||
current_user: User = Depends(get_current_user),
|
|
||||||
) -> User:
|
|
||||||
if not current_user.is_active:
|
|
||||||
raise HTTPException(status_code=400, detail="Inactive user")
|
|
||||||
return current_user
|
|
||||||
|
|
||||||
def get_current_active_superuser(
|
|
||||||
current_user: User = Depends(get_current_active_user),
|
|
||||||
) -> User:
|
|
||||||
if not current_user.is_superuser:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=403, detail="The user doesn't have enough privileges"
|
|
||||||
)
|
|
||||||
return current_user
|
|
||||||
|
|
||||||
def get_analysis_service(
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
user: User = Depends(get_current_active_user)
|
|
||||||
) -> TradingAnalysisService:
|
|
||||||
return TradingAnalysisService(user=user, db=db)
|
|
||||||
|
|
@ -1,94 +0,0 @@
|
||||||
from typing import Any, List
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, WebSocket, WebSocketDisconnect, BackgroundTasks
|
|
||||||
from app.api import deps
|
|
||||||
from app.core.schemas.analysis import AnalysisSession, AnalysisSessionCreate
|
|
||||||
from app.domain.models import User as UserModel
|
|
||||||
from app.core.services.trading_analysis import TradingAnalysisService
|
|
||||||
from app.core.websocket_manager import WebSocketManager
|
|
||||||
from sqlmodel import Session
|
|
||||||
from cli.utils import SHALLOW_AGENT_OPTIONS, DEEP_AGENT_OPTIONS, BASE_URLS
|
|
||||||
|
|
||||||
router = APIRouter()
|
|
||||||
manager = WebSocketManager()
|
|
||||||
|
|
||||||
@router.post("/start", response_model=AnalysisSession)
|
|
||||||
def start_analysis(
|
|
||||||
*,
|
|
||||||
analysis_in: AnalysisSessionCreate,
|
|
||||||
background_tasks: BackgroundTasks,
|
|
||||||
service: TradingAnalysisService = Depends(deps.get_analysis_service),
|
|
||||||
) -> Any:
|
|
||||||
"""
|
|
||||||
Start a new analysis session.
|
|
||||||
"""
|
|
||||||
session = service.create_session(analysis_in=analysis_in)
|
|
||||||
background_tasks.add_task(service.run_analysis, session_id=session.id)
|
|
||||||
return session
|
|
||||||
|
|
||||||
@router.get("/history", response_model=List[AnalysisSession])
|
|
||||||
def get_analysis_history(
|
|
||||||
service: TradingAnalysisService = Depends(deps.get_analysis_service),
|
|
||||||
skip: int = 0,
|
|
||||||
limit: int = 100,
|
|
||||||
) -> Any:
|
|
||||||
"""
|
|
||||||
Get analysis history for the current user.
|
|
||||||
"""
|
|
||||||
return service.get_user_sessions(skip=skip, limit=limit)
|
|
||||||
|
|
||||||
@router.get("/options")
|
|
||||||
def get_analysis_options():
|
|
||||||
"""
|
|
||||||
Get available options for analysis.
|
|
||||||
"""
|
|
||||||
return {
|
|
||||||
'analysts': [
|
|
||||||
{'value': 'market', 'label': 'Market Analyst'},
|
|
||||||
{'value': 'social', 'label': 'Social Analyst'},
|
|
||||||
{'value': 'news', 'label': 'News Analyst'},
|
|
||||||
{'value': 'fundamentals', 'label': 'Fundamentals Analyst'},
|
|
||||||
],
|
|
||||||
'research_depths': [
|
|
||||||
{'value': 1, 'label': 'Shallow'},
|
|
||||||
{'value': 3, 'label': 'Medium'},
|
|
||||||
{'value': 5, 'label': 'Deep'},
|
|
||||||
],
|
|
||||||
'llm_providers': [{'name': p[0], 'url': p[1]} for p in BASE_URLS],
|
|
||||||
'shallow_thinkers': SHALLOW_AGENT_OPTIONS,
|
|
||||||
'deep_thinkers': DEEP_AGENT_OPTIONS,
|
|
||||||
}
|
|
||||||
|
|
||||||
@router.get("/{session_id}", response_model=AnalysisSession)
|
|
||||||
def get_analysis_session(
|
|
||||||
session_id: int,
|
|
||||||
service: TradingAnalysisService = Depends(deps.get_analysis_service),
|
|
||||||
) -> Any:
|
|
||||||
"""
|
|
||||||
Get a specific analysis session by ID.
|
|
||||||
"""
|
|
||||||
session = service.get_session(session_id=session_id)
|
|
||||||
if not session:
|
|
||||||
raise HTTPException(status_code=404, detail="Analysis session not found")
|
|
||||||
return session
|
|
||||||
|
|
||||||
@router.websocket("/ws")
|
|
||||||
async def websocket_endpoint(
|
|
||||||
websocket: WebSocket,
|
|
||||||
token: str,
|
|
||||||
db: Session = Depends(deps.get_db)
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
WebSocket endpoint for real-time analysis updates.
|
|
||||||
"""
|
|
||||||
user = deps.get_user_from_token(token=token, db=db)
|
|
||||||
if not user or not user.is_active:
|
|
||||||
await websocket.close(code=1008)
|
|
||||||
return
|
|
||||||
|
|
||||||
await manager.connect(user.id, websocket)
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
# Keep the connection alive
|
|
||||||
await websocket.receive_text()
|
|
||||||
except WebSocketDisconnect:
|
|
||||||
manager.disconnect(user.id, websocket)
|
|
||||||
|
|
@ -1,35 +0,0 @@
|
||||||
from datetime import timedelta
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
|
||||||
from fastapi.security import OAuth2PasswordRequestForm
|
|
||||||
from sqlmodel import Session
|
|
||||||
|
|
||||||
from app.api import deps
|
|
||||||
from app.core.config import settings
|
|
||||||
from app.core.schemas.token import Token
|
|
||||||
from app.core import security
|
|
||||||
from app.infrastructure.repositories.user import UserRepository
|
|
||||||
|
|
||||||
router = APIRouter()
|
|
||||||
|
|
||||||
@router.post("/login/access-token", response_model=Token)
|
|
||||||
def login_access_token(
|
|
||||||
db: Session = Depends(deps.get_db), form_data: OAuth2PasswordRequestForm = Depends()
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
OAuth2 compatible token login, get an access token for future requests
|
|
||||||
"""
|
|
||||||
user_repo = UserRepository(db)
|
|
||||||
user = user_repo.get_by_email(email=form_data.username)
|
|
||||||
|
|
||||||
if not user or not security.verify_password(form_data.password, user.hashed_password):
|
|
||||||
raise HTTPException(status_code=400, detail="Incorrect email or password")
|
|
||||||
elif not user.is_active:
|
|
||||||
raise HTTPException(status_code=400, detail="Inactive user")
|
|
||||||
|
|
||||||
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
|
||||||
return {
|
|
||||||
"access_token": security.create_access_token(
|
|
||||||
user.email, expires_delta=access_token_expires
|
|
||||||
),
|
|
||||||
"token_type": "bearer",
|
|
||||||
}
|
|
||||||
|
|
@ -1,89 +0,0 @@
|
||||||
from typing import Any, List
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
|
||||||
from app.api import deps
|
|
||||||
from app.core.schemas.user import User, UserCreate, UserUpdate
|
|
||||||
from app.domain.models import User as UserModel
|
|
||||||
from app.domain.repositories import IUserRepository
|
|
||||||
|
|
||||||
router = APIRouter()
|
|
||||||
|
|
||||||
@router.get("/", response_model=List[User])
|
|
||||||
def read_users(
|
|
||||||
repo: IUserRepository = Depends(deps.get_user_repository),
|
|
||||||
skip: int = 0,
|
|
||||||
limit: int = 100,
|
|
||||||
current_user: UserModel = Depends(deps.get_current_active_superuser),
|
|
||||||
) -> Any:
|
|
||||||
"""
|
|
||||||
Retrieve users.
|
|
||||||
"""
|
|
||||||
users = repo.get_multi(skip=skip, limit=limit)
|
|
||||||
return users
|
|
||||||
|
|
||||||
@router.post("/", response_model=User)
|
|
||||||
def create_user(
|
|
||||||
*,
|
|
||||||
repo: IUserRepository = Depends(deps.get_user_repository),
|
|
||||||
user_in: UserCreate,
|
|
||||||
current_user: UserModel = Depends(deps.get_current_active_superuser),
|
|
||||||
) -> Any:
|
|
||||||
"""
|
|
||||||
Create new user.
|
|
||||||
"""
|
|
||||||
user = repo.get_by_email(email=user_in.email)
|
|
||||||
if user:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail="The user with this username already exists in the system.",
|
|
||||||
)
|
|
||||||
user = repo.create(obj_in=user_in)
|
|
||||||
return user
|
|
||||||
|
|
||||||
@router.get("/me", response_model=User)
|
|
||||||
def read_user_me(
|
|
||||||
current_user: UserModel = Depends(deps.get_current_active_user),
|
|
||||||
) -> Any:
|
|
||||||
"""
|
|
||||||
Get current user.
|
|
||||||
"""
|
|
||||||
return current_user
|
|
||||||
|
|
||||||
@router.get("/{user_id}", response_model=User)
|
|
||||||
def read_user_by_id(
|
|
||||||
user_id: int,
|
|
||||||
repo: IUserRepository = Depends(deps.get_user_repository),
|
|
||||||
current_user: UserModel = Depends(deps.get_current_active_user),
|
|
||||||
) -> Any:
|
|
||||||
"""
|
|
||||||
Get a specific user by id.
|
|
||||||
"""
|
|
||||||
user = repo.get(id=user_id)
|
|
||||||
if not user:
|
|
||||||
raise HTTPException(status_code=404, detail="User not found")
|
|
||||||
if user == current_user:
|
|
||||||
return user
|
|
||||||
if not repo.is_superuser(user=current_user):
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=403, detail="The user doesn't have enough privileges"
|
|
||||||
)
|
|
||||||
return user
|
|
||||||
|
|
||||||
@router.put("/{user_id}", response_model=User)
|
|
||||||
def update_user(
|
|
||||||
*,
|
|
||||||
repo: IUserRepository = Depends(deps.get_user_repository),
|
|
||||||
user_id: int,
|
|
||||||
user_in: UserUpdate,
|
|
||||||
current_user: UserModel = Depends(deps.get_current_active_superuser),
|
|
||||||
) -> Any:
|
|
||||||
"""
|
|
||||||
Update a user.
|
|
||||||
"""
|
|
||||||
user = repo.get(id=user_id)
|
|
||||||
if not user:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=404,
|
|
||||||
detail="The user with this username does not exist in the system",
|
|
||||||
)
|
|
||||||
user = repo.update(db_obj=user, obj_in=user_in)
|
|
||||||
return user
|
|
||||||
|
|
@ -1,7 +0,0 @@
|
||||||
from fastapi import APIRouter
|
|
||||||
from app.api.endpoints import login, users, analysis
|
|
||||||
|
|
||||||
api_router = APIRouter()
|
|
||||||
api_router.include_router(login.router, tags=["login"])
|
|
||||||
api_router.include_router(users.router, prefix="/users", tags=["users"])
|
|
||||||
api_router.include_router(analysis.router, prefix="/analysis", tags=["analysis"])
|
|
||||||
|
|
@ -1,26 +0,0 @@
|
||||||
import os
|
|
||||||
from pydantic import BaseSettings
|
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
|
||||||
PROJECT_NAME: str = "TradingAgents Backend"
|
|
||||||
API_V1_STR: str = "/api/v1"
|
|
||||||
|
|
||||||
# Security
|
|
||||||
SECRET_KEY: str = os.getenv("SECRET_KEY", "a_very_secret_key")
|
|
||||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = 60 * 24 * 8 # 8 days
|
|
||||||
ALGORITHM: str = "HS256"
|
|
||||||
|
|
||||||
# Database
|
|
||||||
DATABASE_URL: str = os.getenv("DATABASE_URL", "sqlite:///./tradingagents.db")
|
|
||||||
|
|
||||||
# OpenAI
|
|
||||||
OPENAI_API_KEY: str = os.getenv("OPENAI_API_KEY", "")
|
|
||||||
|
|
||||||
# CORS
|
|
||||||
CORS_ALLOWED_ORIGINS: List[str] = os.getenv("CORS_ALLOWED_ORIGINS", "http://localhost:3000,http://127.0.0.1:3000").split(',')
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
case_sensitive = True
|
|
||||||
|
|
||||||
settings = Settings()
|
|
||||||
|
|
@ -1,4 +0,0 @@
|
||||||
from .user import User, UserCreate, UserUpdate
|
|
||||||
from .token import Token, TokenPayload
|
|
||||||
from .profile import Profile, ProfileCreate, ProfileUpdate
|
|
||||||
from .analysis import AnalysisSession, AnalysisSessionCreate, AnalysisSessionUpdate
|
|
||||||
|
|
@ -1,38 +0,0 @@
|
||||||
from pydantic import BaseModel
|
|
||||||
from typing import List, Optional
|
|
||||||
from datetime import date, datetime
|
|
||||||
from app.domain.models import AnalysisStatus
|
|
||||||
|
|
||||||
class AnalysisSessionBase(BaseModel):
|
|
||||||
ticker: str
|
|
||||||
analysts_selected: List[str]
|
|
||||||
research_depth: int
|
|
||||||
llm_provider: str
|
|
||||||
backend_url: str
|
|
||||||
shallow_thinker: str
|
|
||||||
deep_thinker: str
|
|
||||||
|
|
||||||
class AnalysisSessionCreate(AnalysisSessionBase):
|
|
||||||
pass
|
|
||||||
|
|
||||||
class AnalysisSessionUpdate(BaseModel):
|
|
||||||
status: Optional[AnalysisStatus] = None
|
|
||||||
final_report: Optional[str] = None
|
|
||||||
error_message: Optional[str] = None
|
|
||||||
|
|
||||||
class AnalysisSessionInDBBase(AnalysisSessionBase):
|
|
||||||
id: int
|
|
||||||
user_id: int
|
|
||||||
analysis_date: date
|
|
||||||
status: AnalysisStatus
|
|
||||||
final_report: Optional[str] = None
|
|
||||||
error_message: Optional[str] = None
|
|
||||||
created_at: datetime
|
|
||||||
started_at: Optional[datetime] = None
|
|
||||||
completed_at: Optional[datetime] = None
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
orm_mode = True
|
|
||||||
|
|
||||||
class AnalysisSession(AnalysisSessionInDBBase):
|
|
||||||
pass
|
|
||||||
|
|
@ -1,20 +0,0 @@
|
||||||
from pydantic import BaseModel
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
class ProfileBase(BaseModel):
|
|
||||||
default_ticker: str = "SPY"
|
|
||||||
preferred_research_depth: int = 3
|
|
||||||
preferred_shallow_thinker: str = "gpt-4o-mini"
|
|
||||||
preferred_deep_thinker: str = "gpt-4o"
|
|
||||||
|
|
||||||
class ProfileCreate(ProfileBase):
|
|
||||||
pass
|
|
||||||
|
|
||||||
class ProfileUpdate(ProfileBase):
|
|
||||||
openai_api_key: Optional[str] = None
|
|
||||||
|
|
||||||
class Profile(ProfileBase):
|
|
||||||
has_openai_api_key: bool
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
orm_mode = True
|
|
||||||
|
|
@ -1,9 +0,0 @@
|
||||||
from pydantic import BaseModel
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
class Token(BaseModel):
|
|
||||||
access_token: str
|
|
||||||
token_type: str
|
|
||||||
|
|
||||||
class TokenPayload(BaseModel):
|
|
||||||
sub: Optional[int] = None
|
|
||||||
|
|
@ -1,28 +0,0 @@
|
||||||
from pydantic import BaseModel, EmailStr
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
class UserBase(BaseModel):
|
|
||||||
email: EmailStr
|
|
||||||
username: str
|
|
||||||
first_name: Optional[str] = None
|
|
||||||
last_name: Optional[str] = None
|
|
||||||
|
|
||||||
class UserCreate(UserBase):
|
|
||||||
password: str
|
|
||||||
|
|
||||||
class UserUpdate(UserBase):
|
|
||||||
pass
|
|
||||||
|
|
||||||
class UserInDBBase(UserBase):
|
|
||||||
id: int
|
|
||||||
is_active: bool
|
|
||||||
is_superuser: bool
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
orm_mode = True
|
|
||||||
|
|
||||||
class User(UserInDBBase):
|
|
||||||
pass
|
|
||||||
|
|
||||||
class UserInDB(UserInDBBase):
|
|
||||||
hashed_password: str
|
|
||||||
|
|
@ -1,23 +0,0 @@
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from typing import Any, Union, Optional
|
|
||||||
from jose import jwt
|
|
||||||
from passlib.context import CryptContext
|
|
||||||
from app.core.config import settings
|
|
||||||
|
|
||||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
|
||||||
|
|
||||||
def create_access_token(subject: Union[str, Any], expires_delta: Optional[timedelta] = None) -> str:
|
|
||||||
if expires_delta:
|
|
||||||
expire = datetime.utcnow() + expires_delta
|
|
||||||
else:
|
|
||||||
expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
|
||||||
|
|
||||||
to_encode = {"exp": expire, "sub": str(subject)}
|
|
||||||
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
|
||||||
return encoded_jwt
|
|
||||||
|
|
||||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
|
||||||
return pwd_context.verify(plain_password, hashed_password)
|
|
||||||
|
|
||||||
def get_password_hash(password: str) -> str:
|
|
||||||
return pwd_context.hash(password)
|
|
||||||
|
|
@ -1,128 +0,0 @@
|
||||||
import asyncio
|
|
||||||
import datetime
|
|
||||||
import json
|
|
||||||
from typing import Dict, List, Optional
|
|
||||||
from sqlmodel import Session, select
|
|
||||||
from app.domain.models import User, AnalysisSession, AnalysisStatus
|
|
||||||
from app.core.schemas.analysis import AnalysisSessionCreate
|
|
||||||
from app.core.config import settings
|
|
||||||
from cli.models import AnalystType
|
|
||||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
|
||||||
from tradingagents.default_config import DEFAULT_CONFIG
|
|
||||||
from app.api.deps import get_db
|
|
||||||
from app.core.websocket_manager import WebSocketManager
|
|
||||||
|
|
||||||
class TradingAnalysisService:
|
|
||||||
def __init__(self, user: User, db: Session):
|
|
||||||
self.user = user
|
|
||||||
self.db = db
|
|
||||||
self.websocket_manager = WebSocketManager()
|
|
||||||
|
|
||||||
async def run_analysis(self, session_id: int):
|
|
||||||
"""분석 실행"""
|
|
||||||
session = self.get_session(session_id=session_id)
|
|
||||||
if not session:
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
session.status = AnalysisStatus.RUNNING
|
|
||||||
session.started_at = datetime.datetime.utcnow()
|
|
||||||
self.db.add(session)
|
|
||||||
self.db.commit()
|
|
||||||
self.db.refresh(session)
|
|
||||||
|
|
||||||
await self.websocket_manager.send_to_user(
|
|
||||||
self.user.id,
|
|
||||||
{
|
|
||||||
'type': 'analysis_started',
|
|
||||||
'session_id': session.id,
|
|
||||||
'message': '분석을 시작합니다...'
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Prepare config for TradingAgentsGraph
|
|
||||||
config = DEFAULT_CONFIG.copy()
|
|
||||||
config.update({
|
|
||||||
'openai_api_key': settings.OPENAI_API_KEY,
|
|
||||||
'llm_provider': session.llm_provider,
|
|
||||||
'backend_url': session.backend_url,
|
|
||||||
'shallow_thinking_model': session.shallow_thinker,
|
|
||||||
'deep_thinking_model': session.deep_thinker,
|
|
||||||
})
|
|
||||||
|
|
||||||
# Progress callback for websocket
|
|
||||||
async def progress_callback(message_type: str, content: str, agent: str = None, step: int = 0, total: int = 0):
|
|
||||||
progress_percent = int((step / total) * 99) if total > 0 else 0
|
|
||||||
await self.websocket_manager.send_to_user(self.user.id, {
|
|
||||||
'type': 'analysis_progress',
|
|
||||||
'session_id': session.id,
|
|
||||||
'message_type': message_type,
|
|
||||||
'content': content,
|
|
||||||
'agent': agent,
|
|
||||||
'progress': progress_percent,
|
|
||||||
})
|
|
||||||
|
|
||||||
trading_graph = TradingAgentsGraph(
|
|
||||||
config=config,
|
|
||||||
selected_analysts=session.analysts_selected,
|
|
||||||
)
|
|
||||||
|
|
||||||
input_data = {
|
|
||||||
'company_of_interest': session.ticker,
|
|
||||||
'trade_date': session.analysis_date.strftime('%Y-%m-%d'),
|
|
||||||
}
|
|
||||||
|
|
||||||
final_state, result = await asyncio.to_thread(
|
|
||||||
trading_graph.propagate,
|
|
||||||
input_data['company_of_interest'],
|
|
||||||
input_data['trade_date']
|
|
||||||
)
|
|
||||||
|
|
||||||
session.status = AnalysisStatus.COMPLETED
|
|
||||||
session.completed_at = datetime.datetime.utcnow()
|
|
||||||
session.final_report = json.dumps(final_state) # Store full state as JSON
|
|
||||||
self.db.add(session)
|
|
||||||
self.db.commit()
|
|
||||||
|
|
||||||
await self.websocket_manager.send_to_user(
|
|
||||||
self.user.id,
|
|
||||||
{
|
|
||||||
'type': 'analysis_completed',
|
|
||||||
'session_id': session.id,
|
|
||||||
'message': '분석이 완료되었습니다.',
|
|
||||||
'result': result
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
session.status = AnalysisStatus.FAILED
|
|
||||||
session.error_message = str(e)
|
|
||||||
self.db.add(session)
|
|
||||||
self.db.commit()
|
|
||||||
await self.websocket_manager.send_to_user(
|
|
||||||
self.user.id,
|
|
||||||
{
|
|
||||||
'type': 'analysis_failed',
|
|
||||||
'session_id': session.id,
|
|
||||||
'message': f'분석 중 오류가 발생했습니다: {str(e)}'
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
def create_session(self, *, analysis_in: AnalysisSessionCreate) -> AnalysisSession:
|
|
||||||
session = AnalysisSession(
|
|
||||||
**analysis_in.dict(),
|
|
||||||
user_id=self.user.id,
|
|
||||||
analysis_date=datetime.date.today()
|
|
||||||
)
|
|
||||||
self.db.add(session)
|
|
||||||
self.db.commit()
|
|
||||||
self.db.refresh(session)
|
|
||||||
return session
|
|
||||||
|
|
||||||
def get_session(self, *, session_id: int) -> Optional[AnalysisSession]:
|
|
||||||
statement = select(AnalysisSession).where(AnalysisSession.id == session_id, AnalysisSession.user_id == self.user.id)
|
|
||||||
return self.db.exec(statement).first()
|
|
||||||
|
|
||||||
def get_user_sessions(self, *, skip: int = 0, limit: int = 100) -> List[AnalysisSession]:
|
|
||||||
statement = select(AnalysisSession).where(AnalysisSession.user_id == self.user.id).order_by(AnalysisSession.created_at.desc()).offset(skip).limit(limit)
|
|
||||||
return self.db.exec(statement).all()
|
|
||||||
|
|
@ -1,23 +0,0 @@
|
||||||
from typing import Dict, List
|
|
||||||
from fastapi import WebSocket
|
|
||||||
|
|
||||||
class WebSocketManager:
|
|
||||||
def __init__(self):
|
|
||||||
self.active_connections: Dict[int, List[WebSocket]] = {}
|
|
||||||
|
|
||||||
async def connect(self, user_id: int, websocket: WebSocket):
|
|
||||||
await websocket.accept()
|
|
||||||
if user_id not in self.active_connections:
|
|
||||||
self.active_connections[user_id] = []
|
|
||||||
self.active_connections[user_id].append(websocket)
|
|
||||||
|
|
||||||
def disconnect(self, user_id: int, websocket: WebSocket):
|
|
||||||
if user_id in self.active_connections:
|
|
||||||
self.active_connections[user_id].remove(websocket)
|
|
||||||
if not self.active_connections[user_id]:
|
|
||||||
del self.active_connections[user_id]
|
|
||||||
|
|
||||||
async def send_to_user(self, user_id: int, message: dict):
|
|
||||||
if user_id in self.active_connections:
|
|
||||||
for connection in self.active_connections[user_id]:
|
|
||||||
await connection.send_json(message)
|
|
||||||
|
|
@ -1,56 +0,0 @@
|
||||||
from datetime import date, datetime
|
|
||||||
from typing import List, Optional
|
|
||||||
from sqlmodel import Field, SQLModel, JSON, Column
|
|
||||||
import enum
|
|
||||||
|
|
||||||
|
|
||||||
class User(SQLModel, table=True):
|
|
||||||
id: Optional[int] = Field(default=None, primary_key=True)
|
|
||||||
email: str = Field(unique=True, index=True)
|
|
||||||
username: str = Field(unique=True, index=True)
|
|
||||||
hashed_password: str
|
|
||||||
first_name: Optional[str] = None
|
|
||||||
last_name: Optional[str] = None
|
|
||||||
is_active: bool = Field(default=True)
|
|
||||||
is_superuser: bool = Field(default=False)
|
|
||||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
|
||||||
updated_at: datetime = Field(default_factory=datetime.utcnow, sa_column_kwargs={"onupdate": datetime.utcnow})
|
|
||||||
|
|
||||||
|
|
||||||
class UserProfile(SQLModel, table=True):
|
|
||||||
id: Optional[int] = Field(default=None, primary_key=True)
|
|
||||||
user_id: int = Field(foreign_key="user.id", unique=True)
|
|
||||||
encrypted_openai_api_key: Optional[str] = None
|
|
||||||
default_ticker: str = Field(default="SPY")
|
|
||||||
preferred_research_depth: int = Field(default=3)
|
|
||||||
preferred_shallow_thinker: str = Field(default="gpt-4o-mini")
|
|
||||||
preferred_deep_thinker: str = Field(default="gpt-4o")
|
|
||||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
|
||||||
updated_at: datetime = Field(default_factory=datetime.utcnow, sa_column_kwargs={"onupdate": datetime.utcnow})
|
|
||||||
|
|
||||||
|
|
||||||
class AnalysisStatus(str, enum.Enum):
|
|
||||||
PENDING = "pending"
|
|
||||||
RUNNING = "running"
|
|
||||||
COMPLETED = "completed"
|
|
||||||
FAILED = "failed"
|
|
||||||
CANCELLED = "cancelled"
|
|
||||||
|
|
||||||
|
|
||||||
class AnalysisSession(SQLModel, table=True):
|
|
||||||
id: Optional[int] = Field(default=None, primary_key=True)
|
|
||||||
user_id: int = Field(foreign_key="user.id")
|
|
||||||
ticker: str
|
|
||||||
analysis_date: date
|
|
||||||
analysts_selected: List[str] = Field(sa_column=Column(JSON))
|
|
||||||
research_depth: int
|
|
||||||
llm_provider: str
|
|
||||||
backend_url: str
|
|
||||||
shallow_thinker: str
|
|
||||||
deep_thinker: str
|
|
||||||
status: AnalysisStatus = Field(default=AnalysisStatus.PENDING)
|
|
||||||
final_report: Optional[str] = None
|
|
||||||
error_message: Optional[str] = None
|
|
||||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
|
||||||
started_at: Optional[datetime] = None
|
|
||||||
completed_at: Optional[datetime] = None
|
|
||||||
|
|
@ -1,48 +0,0 @@
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Generic, TypeVar, Optional, List
|
|
||||||
from sqlmodel import SQLModel
|
|
||||||
from app.core.schemas.user import UserCreate, UserUpdate
|
|
||||||
from app.domain.models import User
|
|
||||||
|
|
||||||
ModelType = TypeVar("ModelType", bound=SQLModel)
|
|
||||||
CreateSchemaType = TypeVar("CreateSchemaType", bound=SQLModel)
|
|
||||||
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=SQLModel)
|
|
||||||
|
|
||||||
class IRepository(Generic[ModelType], ABC):
|
|
||||||
@abstractmethod
|
|
||||||
def get(self, id: int) -> Optional[ModelType]:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_multi(self, *, skip: int = 0, limit: int = 100) -> List[ModelType]:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def create(self, *, obj_in: CreateSchemaType) -> ModelType:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def update(self, *, db_obj: ModelType, obj_in: UpdateSchemaType) -> ModelType:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def remove(self, *, id: int) -> ModelType:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class IUserRepository(IRepository[User], ABC):
|
|
||||||
@abstractmethod
|
|
||||||
def get_by_email(self, *, email: str) -> Optional[User]:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def create(self, *, obj_in: UserCreate) -> User:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def update(self, *, db_obj: User, obj_in: UserUpdate) -> User:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def is_superuser(self, *, user: User) -> bool:
|
|
||||||
pass
|
|
||||||
|
|
@ -1,9 +0,0 @@
|
||||||
from sqlmodel import create_engine, Session
|
|
||||||
|
|
||||||
from app.core.config import settings
|
|
||||||
|
|
||||||
engine = create_engine(settings.DATABASE_URL, echo=True, connect_args={"check_same_thread": False})
|
|
||||||
|
|
||||||
def get_db():
|
|
||||||
with Session(engine) as session:
|
|
||||||
yield session
|
|
||||||
|
|
@ -1,53 +0,0 @@
|
||||||
from typing import Optional
|
|
||||||
from sqlmodel import Session, select
|
|
||||||
from app.domain.models import User
|
|
||||||
from app.core.schemas.user import UserCreate, UserUpdate
|
|
||||||
from app.domain.repositories import IUserRepository
|
|
||||||
from app.core.security import get_password_hash
|
|
||||||
|
|
||||||
class UserRepository(IUserRepository):
|
|
||||||
def __init__(self, db: Session):
|
|
||||||
self.db = db
|
|
||||||
|
|
||||||
def get(self, id: int) -> Optional[User]:
|
|
||||||
return self.db.get(User, id)
|
|
||||||
|
|
||||||
def get_by_email(self, *, email: str) -> Optional[User]:
|
|
||||||
statement = select(User).where(User.email == email)
|
|
||||||
return self.db.exec(statement).first()
|
|
||||||
|
|
||||||
def get_multi(self, *, skip: int = 0, limit: int = 100) -> list[User]:
|
|
||||||
statement = select(User).offset(skip).limit(limit)
|
|
||||||
return self.db.exec(statement).all()
|
|
||||||
|
|
||||||
def create(self, *, obj_in: UserCreate) -> User:
|
|
||||||
db_obj = User(
|
|
||||||
email=obj_in.email,
|
|
||||||
username=obj_in.username,
|
|
||||||
hashed_password=get_password_hash(obj_in.password),
|
|
||||||
first_name=obj_in.first_name,
|
|
||||||
last_name=obj_in.last_name,
|
|
||||||
)
|
|
||||||
self.db.add(db_obj)
|
|
||||||
self.db.commit()
|
|
||||||
self.db.refresh(db_obj)
|
|
||||||
return db_obj
|
|
||||||
|
|
||||||
def update(self, *, db_obj: User, obj_in: UserUpdate) -> User:
|
|
||||||
update_data = obj_in.dict(exclude_unset=True)
|
|
||||||
for field, value in update_data.items():
|
|
||||||
setattr(db_obj, field, value)
|
|
||||||
|
|
||||||
self.db.add(db_obj)
|
|
||||||
self.db.commit()
|
|
||||||
self.db.refresh(db_obj)
|
|
||||||
return db_obj
|
|
||||||
|
|
||||||
def remove(self, *, id: int) -> User:
|
|
||||||
db_obj = self.db.get(User, id)
|
|
||||||
self.db.delete(db_obj)
|
|
||||||
self.db.commit()
|
|
||||||
return db_obj
|
|
||||||
|
|
||||||
def is_superuser(self, *, user: User) -> bool:
|
|
||||||
return user.is_superuser
|
|
||||||
36
app/main.py
36
app/main.py
|
|
@ -1,36 +0,0 @@
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
from fastapi import FastAPI
|
|
||||||
from starlette.middleware.cors import CORSMiddleware
|
|
||||||
|
|
||||||
# Add project root to path to allow importing tradingagents
|
|
||||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")))
|
|
||||||
|
|
||||||
from app.api.router import api_router
|
|
||||||
from app.core.config import settings
|
|
||||||
from app.infrastructure.database import engine
|
|
||||||
from sqlmodel import SQLModel
|
|
||||||
|
|
||||||
def create_tables():
|
|
||||||
SQLModel.metadata.create_all(engine)
|
|
||||||
|
|
||||||
app = FastAPI(
|
|
||||||
title=settings.PROJECT_NAME,
|
|
||||||
openapi_url=f"{settings.API_V1_STR}/openapi.json"
|
|
||||||
)
|
|
||||||
|
|
||||||
@app.on_event("startup")
|
|
||||||
def on_startup():
|
|
||||||
create_tables()
|
|
||||||
|
|
||||||
# Set all CORS enabled origins
|
|
||||||
if settings.CORS_ALLOWED_ORIGINS:
|
|
||||||
app.add_middleware(
|
|
||||||
CORSMiddleware,
|
|
||||||
allow_origins=[str(origin) for origin in settings.CORS_ALLOWED_ORIGINS],
|
|
||||||
allow_credentials=True,
|
|
||||||
allow_methods=["*"],
|
|
||||||
allow_headers=["*"],
|
|
||||||
)
|
|
||||||
|
|
||||||
app.include_router(api_router, prefix=settings.API_V1_STR)
|
|
||||||
|
|
@ -1,13 +1,13 @@
|
||||||
from pydantic import BaseModel, field_validator
|
from pydantic import BaseModel, field_validator
|
||||||
from datetime import datetime, date
|
from datetime import datetime, date
|
||||||
from typing import List, Dict, Union
|
from typing import Dict
|
||||||
from analysis.infra.db_models.analysis import AnalysisStatus
|
from analysis.infra.db_models.analysis import AnalysisStatus
|
||||||
|
|
||||||
class Analysis(BaseModel):
|
class Analysis(BaseModel):
|
||||||
id: str | None = None
|
id: str | None = None
|
||||||
member_id: str | None = None
|
member_id: str
|
||||||
ticker: str | None = None
|
ticker: str
|
||||||
analysis_date: date | None = None
|
analysis_date: date
|
||||||
analysts_selected: list[str] = []
|
analysts_selected: list[str] = []
|
||||||
research_depth: int = 1
|
research_depth: int = 1
|
||||||
llm_provider: str = "google"
|
llm_provider: str = "google"
|
||||||
|
|
@ -34,5 +34,5 @@ class Analysis(BaseModel):
|
||||||
# 실행 결과 정보
|
# 실행 결과 정보
|
||||||
error_message: str | None = None
|
error_message: str | None = None
|
||||||
completed_at: datetime | None = None
|
completed_at: datetime | None = None
|
||||||
created_at: datetime | None = None
|
created_at: datetime
|
||||||
updated_at: datetime | None = None
|
updated_at: datetime
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
from pydantic import validator, Field
|
from pydantic import field_validator, Field
|
||||||
import secrets
|
import secrets
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
@ -42,21 +42,24 @@ class Settings(BaseSettings):
|
||||||
ENVIRONMENT: str = Field(default="development", description="Environment (development/staging/production)")
|
ENVIRONMENT: str = Field(default="development", description="Environment (development/staging/production)")
|
||||||
DEBUG: bool = Field(default=True, description="Debug mode")
|
DEBUG: bool = Field(default=True, description="Debug mode")
|
||||||
|
|
||||||
@validator('ENVIRONMENT')
|
@field_validator('ENVIRONMENT')
|
||||||
|
@classmethod
|
||||||
def validate_environment(cls, v):
|
def validate_environment(cls, v):
|
||||||
allowed_envs = ['development', 'staging', 'production']
|
allowed_envs = ['development', 'staging', 'production']
|
||||||
if v not in allowed_envs:
|
if v not in allowed_envs:
|
||||||
raise ValueError(f'Environment must be one of {allowed_envs}')
|
raise ValueError(f'Environment must be one of {allowed_envs}')
|
||||||
return v
|
return v
|
||||||
|
|
||||||
@validator('LOG_LEVEL')
|
@field_validator('LOG_LEVEL')
|
||||||
|
@classmethod
|
||||||
def validate_log_level(cls, v):
|
def validate_log_level(cls, v):
|
||||||
allowed_levels = ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL']
|
allowed_levels = ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL']
|
||||||
if v.upper() not in allowed_levels:
|
if v.upper() not in allowed_levels:
|
||||||
raise ValueError(f'Log level must be one of {allowed_levels}')
|
raise ValueError(f'Log level must be one of {allowed_levels}')
|
||||||
return v.upper()
|
return v.upper()
|
||||||
|
|
||||||
@validator('SECRET_KEY')
|
@field_validator('SECRET_KEY')
|
||||||
|
@classmethod
|
||||||
def validate_secret_key(cls, v):
|
def validate_secret_key(cls, v):
|
||||||
if len(v) < 32:
|
if len(v) < 32:
|
||||||
raise ValueError('SECRET_KEY must be at least 32 characters long')
|
raise ValueError('SECRET_KEY must be at least 32 characters long')
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ from utils.containers import Container
|
||||||
from utils.middlewares import RateLimitMiddleware, LoggingMiddleware, SecurityHeadersMiddleware
|
from utils.middlewares import RateLimitMiddleware, LoggingMiddleware, SecurityHeadersMiddleware
|
||||||
from utils.exceptions import BaseAPIException
|
from utils.exceptions import BaseAPIException
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
from analysis.interface.controller.analysis_controller import router as analysis_router
|
from analysis.interface.controller.analysis_controller import router as analysis_router
|
||||||
from member.interface.controller.member_controller import router as member_router
|
from member.interface.controller.member_controller import router as member_router
|
||||||
|
|
@ -114,7 +115,7 @@ async def health_check():
|
||||||
return {
|
return {
|
||||||
"status": "healthy",
|
"status": "healthy",
|
||||||
"environment": settings.ENVIRONMENT,
|
"environment": settings.ENVIRONMENT,
|
||||||
"timestamp": "2024-01-01T00:00:00Z"
|
"timestamp": datetime.utcnow().isoformat() + "Z"
|
||||||
}
|
}
|
||||||
|
|
||||||
@app.get("/")
|
@app.get("/")
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@ from tqdm import tqdm
|
||||||
import yfinance as yf
|
import yfinance as yf
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
from .config import get_config, set_config, DATA_DIR
|
from .config import get_config, set_config, DATA_DIR
|
||||||
from .search_provider_factory import SearchProviderFactory
|
from .search_provider_factory import SearchProviderFactory, create_search_provider_factory
|
||||||
|
|
||||||
|
|
||||||
def parse_date_range(curr_date: str, look_back_days: int) -> Tuple[str, str]:
|
def parse_date_range(curr_date: str, look_back_days: int) -> Tuple[str, str]:
|
||||||
|
|
@ -709,9 +709,13 @@ def get_YFin_data(
|
||||||
return filtered_data
|
return filtered_data
|
||||||
|
|
||||||
|
|
||||||
|
# Enhanced search provider factory instance (singleton)
|
||||||
|
_search_factory = create_search_provider_factory()
|
||||||
|
|
||||||
|
|
||||||
def get_stock_news(ticker, curr_date):
|
def get_stock_news(ticker, curr_date):
|
||||||
config = get_config()
|
config = get_config()
|
||||||
search_provider = SearchProviderFactory.create_provider(config)
|
search_provider = _search_factory.create_provider(config)
|
||||||
query = f"Can you search Social Media for {ticker} from 7 days before {curr_date} to {curr_date}? Make sure you only get the data posted during that period."
|
query = f"Can you search Social Media for {ticker} from 7 days before {curr_date} to {curr_date}? Make sure you only get the data posted during that period."
|
||||||
return search_provider.search(query)
|
return search_provider.search(query)
|
||||||
|
|
||||||
|
|
@ -719,7 +723,7 @@ def get_stock_news(ticker, curr_date):
|
||||||
|
|
||||||
def get_global_news(curr_date):
|
def get_global_news(curr_date):
|
||||||
config = get_config()
|
config = get_config()
|
||||||
search_provider = SearchProviderFactory.create_provider(config)
|
search_provider = _search_factory.create_provider(config)
|
||||||
query = f"Search for global macroeconomic news and financial market updates from 7 days before {curr_date} to {curr_date}. Focus on central bank decisions, economic indicators, geopolitical events, and market-moving news that would be important for trading decisions."
|
query = f"Search for global macroeconomic news and financial market updates from 7 days before {curr_date} to {curr_date}. Focus on central bank decisions, economic indicators, geopolitical events, and market-moving news that would be important for trading decisions."
|
||||||
return search_provider.search(query)
|
return search_provider.search(query)
|
||||||
|
|
||||||
|
|
@ -727,7 +731,7 @@ def get_global_news(curr_date):
|
||||||
|
|
||||||
def get_fundamentals(ticker, curr_date):
|
def get_fundamentals(ticker, curr_date):
|
||||||
config = get_config()
|
config = get_config()
|
||||||
search_provider = SearchProviderFactory.create_provider(config)
|
search_provider = _search_factory.create_provider(config)
|
||||||
query = f"Search for fundamental analysis data and financial metrics for {ticker} stock from the month before {curr_date} to the month of {curr_date}. Look for earnings reports, financial ratios like PE, PS, cash flow, revenue growth, analyst ratings, and any fundamental analysis discussions. Please present key metrics in a structured format."
|
query = f"Search for fundamental analysis data and financial metrics for {ticker} stock from the month before {curr_date} to the month of {curr_date}. Look for earnings reports, financial ratios like PE, PS, cash flow, revenue growth, analyst ratings, and any fundamental analysis discussions. Please present key metrics in a structured format."
|
||||||
return search_provider.search(query)
|
return search_provider.search(query)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,47 +1,133 @@
|
||||||
from .search_provider import (
|
from .search_provider import SearchProvider
|
||||||
SearchProvider,
|
|
||||||
GoogleSearchProvider,
|
|
||||||
OpenAISearchProvider
|
|
||||||
)
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
|
from typing import Dict, Callable, Any
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
|
||||||
class SearchProviderFactory:
|
class ProviderSelector(ABC):
|
||||||
_cache = {} # 클래스 레벨 캐시
|
"""Abstract base class for provider selection strategies."""
|
||||||
|
|
||||||
@staticmethod
|
@abstractmethod
|
||||||
def create_provider(config: dict[str, any]) -> SearchProvider:
|
def select_provider_type(self, config: Dict[str, Any]) -> str:
|
||||||
|
"""Select provider type based on configuration."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class MappingBasedProviderSelector(ProviderSelector):
|
||||||
|
"""Selects provider based on URL pattern mapping table."""
|
||||||
|
|
||||||
|
def __init__(self, mappings: Dict[str, str], default_provider: str = "openai"):
|
||||||
|
self._mappings = mappings
|
||||||
|
self._default_provider = default_provider
|
||||||
|
|
||||||
|
def select_provider_type(self, config: Dict[str, Any]) -> str:
|
||||||
|
backend_url = config.get("backend_url", "")
|
||||||
|
for pattern, provider_type in self._mappings.items():
|
||||||
|
if pattern in backend_url:
|
||||||
|
return provider_type
|
||||||
|
return self._default_provider
|
||||||
|
|
||||||
|
|
||||||
|
class SearchProviderRegistry:
|
||||||
|
"""Registry for search provider creation functions."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._providers: Dict[str, Callable[[Dict[str, Any]], SearchProvider]] = {}
|
||||||
|
|
||||||
|
def register(self, provider_type: str, creator: Callable[[Dict[str, Any]], SearchProvider]):
|
||||||
|
"""Register a provider creator function."""
|
||||||
|
self._providers[provider_type] = creator
|
||||||
|
|
||||||
|
def create(self, provider_type: str, config: Dict[str, Any]) -> SearchProvider:
|
||||||
|
"""Create a provider instance using registered creator."""
|
||||||
|
if provider_type not in self._providers:
|
||||||
|
raise ValueError(f"Unknown provider type: {provider_type}")
|
||||||
|
return self._providers[provider_type](config)
|
||||||
|
|
||||||
|
def get_available_types(self) -> list[str]:
|
||||||
|
"""Get list of available provider types."""
|
||||||
|
return list(self._providers.keys())
|
||||||
|
|
||||||
|
|
||||||
|
class SearchProviderFactoryImpl:
|
||||||
|
"""Enhanced factory for creating SearchProvider instances with caching and extensibility."""
|
||||||
|
|
||||||
|
def __init__(self, registry: SearchProviderRegistry, selector: ProviderSelector):
|
||||||
|
self._registry = registry
|
||||||
|
self._selector = selector
|
||||||
|
self._cache: Dict[str, SearchProvider] = {}
|
||||||
|
|
||||||
|
def create_provider(self, config: Dict[str, Any]) -> SearchProvider:
|
||||||
"""
|
"""
|
||||||
Create a SearchProvider with caching to avoid creating new instances.
|
Create a SearchProvider with caching to avoid creating new instances.
|
||||||
Uses config hash as cache key for efficient reuse.
|
Uses config hash as cache key for efficient reuse.
|
||||||
"""
|
"""
|
||||||
# Create cache key from relevant config values
|
# Create cache key from relevant config values
|
||||||
cache_key_data = {
|
cache_key_data = {
|
||||||
"backend_url": config["backend_url"],
|
"backend_url": config.get("backend_url", ""),
|
||||||
"model": config["quick_think_llm"]
|
"model": config.get("quick_think_llm", "")
|
||||||
}
|
}
|
||||||
cache_key = hashlib.md5(json.dumps(cache_key_data, sort_keys=True).encode()).hexdigest()
|
cache_key = hashlib.md5(json.dumps(cache_key_data, sort_keys=True).encode()).hexdigest()
|
||||||
|
|
||||||
# Return cached instance if exists
|
# Return cached instance if exists
|
||||||
if cache_key in SearchProviderFactory._cache:
|
if cache_key in self._cache:
|
||||||
return SearchProviderFactory._cache[cache_key]
|
return self._cache[cache_key]
|
||||||
|
|
||||||
# Create new instance
|
# Select and create provider
|
||||||
backend_url = config["backend_url"]
|
provider_type = self._selector.select_provider_type(config)
|
||||||
model = config["quick_think_llm"]
|
provider = self._registry.create(provider_type, config)
|
||||||
|
|
||||||
if "generativelanguage.googleapis.com" in backend_url:
|
|
||||||
provider = GoogleSearchProvider(model)
|
|
||||||
else:
|
|
||||||
provider = OpenAISearchProvider(model, backend_url)
|
|
||||||
|
|
||||||
# Cache and return
|
# Cache and return
|
||||||
SearchProviderFactory._cache[cache_key] = provider
|
self._cache[cache_key] = provider
|
||||||
return provider
|
return provider
|
||||||
|
|
||||||
|
def clear_cache(self):
|
||||||
|
"""Clear the provider cache (useful for testing or config changes)."""
|
||||||
|
self._cache.clear()
|
||||||
|
|
||||||
|
def get_available_provider_types(self) -> list[str]:
|
||||||
|
"""Get list of available provider types."""
|
||||||
|
return self._registry.get_available_types()
|
||||||
|
|
||||||
|
|
||||||
|
def create_search_provider_factory() -> SearchProviderFactoryImpl:
|
||||||
|
"""Create a configured SearchProviderFactory with default providers."""
|
||||||
|
registry = SearchProviderRegistry()
|
||||||
|
|
||||||
|
# Register default providers
|
||||||
|
def create_google_provider(config: Dict[str, Any]) -> SearchProvider:
|
||||||
|
from .search_provider import GoogleSearchProvider
|
||||||
|
return GoogleSearchProvider(config["quick_think_llm"])
|
||||||
|
|
||||||
|
def create_openai_provider(config: Dict[str, Any]) -> SearchProvider:
|
||||||
|
from .search_provider import OpenAISearchProvider
|
||||||
|
return OpenAISearchProvider(config["quick_think_llm"], config["backend_url"])
|
||||||
|
|
||||||
|
registry.register("google", create_google_provider)
|
||||||
|
registry.register("openai", create_openai_provider)
|
||||||
|
|
||||||
|
# Create URL pattern mappings (easily extensible)
|
||||||
|
url_mappings = {
|
||||||
|
"generativelanguage.googleapis.com": "google",
|
||||||
|
"api.openai.com": "openai",
|
||||||
|
}
|
||||||
|
|
||||||
|
selector = MappingBasedProviderSelector(url_mappings, default_provider="openai")
|
||||||
|
return SearchProviderFactoryImpl(registry, selector)
|
||||||
|
|
||||||
|
|
||||||
|
# Backward compatibility - singleton instance
|
||||||
|
_default_factory = create_search_provider_factory()
|
||||||
|
|
||||||
|
|
||||||
|
class SearchProviderFactory:
|
||||||
|
"""Backward compatibility wrapper for the old static factory."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_provider(config: Dict[str, Any]) -> SearchProvider:
|
||||||
|
return _default_factory.create_provider(config)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def clear_cache():
|
def clear_cache():
|
||||||
"""Clear the provider cache (useful for testing or config changes)."""
|
_default_factory.clear_cache()
|
||||||
SearchProviderFactory._cache.clear()
|
|
||||||
|
|
||||||
|
|
@ -3,7 +3,7 @@ import os
|
||||||
DEFAULT_CONFIG = {
|
DEFAULT_CONFIG = {
|
||||||
"project_dir": os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
|
"project_dir": os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
|
||||||
"results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", "./results"),
|
"results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", "./results"),
|
||||||
"data_dir": "/Users/yluo/Documents/Code/ScAI/FR1-data",
|
"data_dir": os.getenv("TRADINGAGENTS_DATA_DIR", "./data"),
|
||||||
"data_cache_dir": os.path.join(
|
"data_cache_dir": os.path.join(
|
||||||
os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
|
os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
|
||||||
"dataflows/data_cache",
|
"dataflows/data_cache",
|
||||||
|
|
|
||||||
|
|
@ -127,7 +127,7 @@ class TradingAgentsGraph:
|
||||||
# online tools
|
# online tools
|
||||||
self.toolkit.get_stock_news,
|
self.toolkit.get_stock_news,
|
||||||
# offline tools
|
# offline tools
|
||||||
self.toolkit.get_reddit_stock_info,
|
# self.toolkit.get_reddit_stock_info,
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
"news": ToolNode(
|
"news": ToolNode(
|
||||||
|
|
@ -136,8 +136,8 @@ class TradingAgentsGraph:
|
||||||
self.toolkit.get_global_news,
|
self.toolkit.get_global_news,
|
||||||
self.toolkit.get_google_news,
|
self.toolkit.get_google_news,
|
||||||
# offline tools
|
# offline tools
|
||||||
self.toolkit.get_finnhub_news,
|
# self.toolkit.get_finnhub_news,
|
||||||
self.toolkit.get_reddit_news,
|
# self.toolkit.get_reddit_news,
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
"fundamentals": ToolNode(
|
"fundamentals": ToolNode(
|
||||||
|
|
@ -145,11 +145,11 @@ class TradingAgentsGraph:
|
||||||
# online tools
|
# online tools
|
||||||
self.toolkit.get_fundamentals,
|
self.toolkit.get_fundamentals,
|
||||||
# offline tools
|
# offline tools
|
||||||
self.toolkit.get_finnhub_company_insider_sentiment,
|
# self.toolkit.get_finnhub_company_insider_sentiment,
|
||||||
self.toolkit.get_finnhub_company_insider_transactions,
|
# self.toolkit.get_finnhub_company_insider_transactions,
|
||||||
self.toolkit.get_simfin_balance_sheet,
|
# self.toolkit.get_simfin_balance_sheet,
|
||||||
self.toolkit.get_simfin_cashflow,
|
# self.toolkit.get_simfin_cashflow,
|
||||||
self.toolkit.get_simfin_income_stmt,
|
# self.toolkit.get_simfin_income_stmt,
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue