From c3e609730b7df83a54b1e4f3111411b17a068598 Mon Sep 17 00:00:00 2001 From: kimheesu Date: Tue, 8 Jul 2025 14:27:34 +0900 Subject: [PATCH] [add] remove app --- .gitignore | 3 +- app/__init__.py | 0 app/api/__init__.py | 0 app/api/deps.py | 67 --------- app/api/endpoints/__init__.py | 0 app/api/endpoints/analysis.py | 94 ------------ app/api/endpoints/login.py | 35 ----- app/api/endpoints/users.py | 89 ------------ app/api/router.py | 7 - app/core/__init__.py | 0 app/core/config.py | 26 ---- app/core/schemas/__init__.py | 4 - app/core/schemas/analysis.py | 38 ----- app/core/schemas/profile.py | 20 --- app/core/schemas/token.py | 9 -- app/core/schemas/user.py | 28 ---- app/core/security.py | 23 --- app/core/services/__init__.py | 0 app/core/services/trading_analysis.py | 128 ----------------- app/core/websocket_manager.py | 23 --- app/domain/__init__.py | 0 app/domain/models.py | 56 -------- app/domain/repositories.py | 48 ------- app/infrastructure/__init__.py | 0 app/infrastructure/database.py | 9 -- app/infrastructure/repositories/__init__.py | 0 app/infrastructure/repositories/user.py | 53 ------- app/main.py | 36 ----- backend/analysis/domain/analysis.py | 12 +- backend/config/config.py | 11 +- backend/main.py | 3 +- backend/tradingagents/dataflows/interface.py | 12 +- .../dataflows/search_provider_factory.py | 136 ++++++++++++++---- backend/tradingagents/default_config.py | 2 +- backend/tradingagents/graph/trading_graph.py | 16 +-- 35 files changed, 145 insertions(+), 843 deletions(-) delete mode 100644 app/__init__.py delete mode 100644 app/api/__init__.py delete mode 100644 app/api/deps.py delete mode 100644 app/api/endpoints/__init__.py delete mode 100644 app/api/endpoints/analysis.py delete mode 100644 app/api/endpoints/login.py delete mode 100644 app/api/endpoints/users.py delete mode 100644 app/api/router.py delete mode 100644 app/core/__init__.py delete mode 100644 app/core/config.py delete mode 100644 app/core/schemas/__init__.py delete mode 100644 app/core/schemas/analysis.py delete mode 100644 app/core/schemas/profile.py delete mode 100644 app/core/schemas/token.py delete mode 100644 app/core/schemas/user.py delete mode 100644 app/core/security.py delete mode 100644 app/core/services/__init__.py delete mode 100644 app/core/services/trading_analysis.py delete mode 100644 app/core/websocket_manager.py delete mode 100644 app/domain/__init__.py delete mode 100644 app/domain/models.py delete mode 100644 app/domain/repositories.py delete mode 100644 app/infrastructure/__init__.py delete mode 100644 app/infrastructure/database.py delete mode 100644 app/infrastructure/repositories/__init__.py delete mode 100644 app/infrastructure/repositories/user.py delete mode 100644 app/main.py diff --git a/.gitignore b/.gitignore index c4641246..4dae4be2 100644 --- a/.gitignore +++ b/.gitignore @@ -8,4 +8,5 @@ eval_data/ *.egg-info/ results/ .env -tradingagents/dataflows/data_cache/ \ No newline at end of file +tradingagents/dataflows/data_cache/ +CLAUDE.md diff --git a/app/__init__.py b/app/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/app/api/__init__.py b/app/api/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/app/api/deps.py b/app/api/deps.py deleted file mode 100644 index 7e17ea41..00000000 --- a/app/api/deps.py +++ /dev/null @@ -1,67 +0,0 @@ -from typing import Generator, Optional -from fastapi import Depends, HTTPException, status -from fastapi.security import OAuth2PasswordBearer -from jose import jwt, JWTError -from pydantic import BaseModel -from sqlmodel import Session - -from app.core.config import settings -from app.infrastructure.database import get_db -from app.domain.models import User -from app.infrastructure.repositories.user import UserRepository -from app.core.services.trading_analysis import TradingAnalysisService - -reusable_oauth2 = OAuth2PasswordBearer(tokenUrl=f"{settings.API_V1_STR}/login/access-token") - -class TokenData(BaseModel): - username: Optional[str] = None - -def get_user_repository(db: Session = Depends(get_db)) -> UserRepository: - return UserRepository(db) - -def get_user_from_token(token: str, db: Session) -> Optional[User]: - try: - payload = jwt.decode( - token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM] - ) - token_data = TokenData(username=payload.get("sub")) - except JWTError: - return None - - user_repo = UserRepository(db) - user = user_repo.get_by_email(email=token_data.username) - return user - -def get_current_user( - db: Session = Depends(get_db), token: str = Depends(reusable_oauth2) -) -> User: - user = get_user_from_token(token=token, db=db) - if not user: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Could not validate credentials", - headers={"WWW-Authenticate": "Bearer"}, - ) - return user - -def get_current_active_user( - current_user: User = Depends(get_current_user), -) -> User: - if not current_user.is_active: - raise HTTPException(status_code=400, detail="Inactive user") - return current_user - -def get_current_active_superuser( - current_user: User = Depends(get_current_active_user), -) -> User: - if not current_user.is_superuser: - raise HTTPException( - status_code=403, detail="The user doesn't have enough privileges" - ) - return current_user - -def get_analysis_service( - db: Session = Depends(get_db), - user: User = Depends(get_current_active_user) -) -> TradingAnalysisService: - return TradingAnalysisService(user=user, db=db) diff --git a/app/api/endpoints/__init__.py b/app/api/endpoints/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/app/api/endpoints/analysis.py b/app/api/endpoints/analysis.py deleted file mode 100644 index ba3a3694..00000000 --- a/app/api/endpoints/analysis.py +++ /dev/null @@ -1,94 +0,0 @@ -from typing import Any, List -from fastapi import APIRouter, Depends, HTTPException, WebSocket, WebSocketDisconnect, BackgroundTasks -from app.api import deps -from app.core.schemas.analysis import AnalysisSession, AnalysisSessionCreate -from app.domain.models import User as UserModel -from app.core.services.trading_analysis import TradingAnalysisService -from app.core.websocket_manager import WebSocketManager -from sqlmodel import Session -from cli.utils import SHALLOW_AGENT_OPTIONS, DEEP_AGENT_OPTIONS, BASE_URLS - -router = APIRouter() -manager = WebSocketManager() - -@router.post("/start", response_model=AnalysisSession) -def start_analysis( - *, - analysis_in: AnalysisSessionCreate, - background_tasks: BackgroundTasks, - service: TradingAnalysisService = Depends(deps.get_analysis_service), -) -> Any: - """ - Start a new analysis session. - """ - session = service.create_session(analysis_in=analysis_in) - background_tasks.add_task(service.run_analysis, session_id=session.id) - return session - -@router.get("/history", response_model=List[AnalysisSession]) -def get_analysis_history( - service: TradingAnalysisService = Depends(deps.get_analysis_service), - skip: int = 0, - limit: int = 100, -) -> Any: - """ - Get analysis history for the current user. - """ - return service.get_user_sessions(skip=skip, limit=limit) - -@router.get("/options") -def get_analysis_options(): - """ - Get available options for analysis. - """ - return { - 'analysts': [ - {'value': 'market', 'label': 'Market Analyst'}, - {'value': 'social', 'label': 'Social Analyst'}, - {'value': 'news', 'label': 'News Analyst'}, - {'value': 'fundamentals', 'label': 'Fundamentals Analyst'}, - ], - 'research_depths': [ - {'value': 1, 'label': 'Shallow'}, - {'value': 3, 'label': 'Medium'}, - {'value': 5, 'label': 'Deep'}, - ], - 'llm_providers': [{'name': p[0], 'url': p[1]} for p in BASE_URLS], - 'shallow_thinkers': SHALLOW_AGENT_OPTIONS, - 'deep_thinkers': DEEP_AGENT_OPTIONS, - } - -@router.get("/{session_id}", response_model=AnalysisSession) -def get_analysis_session( - session_id: int, - service: TradingAnalysisService = Depends(deps.get_analysis_service), -) -> Any: - """ - Get a specific analysis session by ID. - """ - session = service.get_session(session_id=session_id) - if not session: - raise HTTPException(status_code=404, detail="Analysis session not found") - return session - -@router.websocket("/ws") -async def websocket_endpoint( - websocket: WebSocket, - token: str, - db: Session = Depends(deps.get_db) -): - """ - WebSocket endpoint for real-time analysis updates. - """ - user = deps.get_user_from_token(token=token, db=db) - if not user or not user.is_active: - await websocket.close(code=1008) - return - - await manager.connect(user.id, websocket) - try: - while True: - # Keep the connection alive - await websocket.receive_text() - except WebSocketDisconnect: - manager.disconnect(user.id, websocket) \ No newline at end of file diff --git a/app/api/endpoints/login.py b/app/api/endpoints/login.py deleted file mode 100644 index 700ffd74..00000000 --- a/app/api/endpoints/login.py +++ /dev/null @@ -1,35 +0,0 @@ -from datetime import timedelta -from fastapi import APIRouter, Depends, HTTPException -from fastapi.security import OAuth2PasswordRequestForm -from sqlmodel import Session - -from app.api import deps -from app.core.config import settings -from app.core.schemas.token import Token -from app.core import security -from app.infrastructure.repositories.user import UserRepository - -router = APIRouter() - -@router.post("/login/access-token", response_model=Token) -def login_access_token( - db: Session = Depends(deps.get_db), form_data: OAuth2PasswordRequestForm = Depends() -): - """ - OAuth2 compatible token login, get an access token for future requests - """ - user_repo = UserRepository(db) - user = user_repo.get_by_email(email=form_data.username) - - if not user or not security.verify_password(form_data.password, user.hashed_password): - raise HTTPException(status_code=400, detail="Incorrect email or password") - elif not user.is_active: - raise HTTPException(status_code=400, detail="Inactive user") - - access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) - return { - "access_token": security.create_access_token( - user.email, expires_delta=access_token_expires - ), - "token_type": "bearer", - } diff --git a/app/api/endpoints/users.py b/app/api/endpoints/users.py deleted file mode 100644 index 868821f1..00000000 --- a/app/api/endpoints/users.py +++ /dev/null @@ -1,89 +0,0 @@ -from typing import Any, List -from fastapi import APIRouter, Depends, HTTPException -from app.api import deps -from app.core.schemas.user import User, UserCreate, UserUpdate -from app.domain.models import User as UserModel -from app.domain.repositories import IUserRepository - -router = APIRouter() - -@router.get("/", response_model=List[User]) -def read_users( - repo: IUserRepository = Depends(deps.get_user_repository), - skip: int = 0, - limit: int = 100, - current_user: UserModel = Depends(deps.get_current_active_superuser), -) -> Any: - """ - Retrieve users. - """ - users = repo.get_multi(skip=skip, limit=limit) - return users - -@router.post("/", response_model=User) -def create_user( - *, - repo: IUserRepository = Depends(deps.get_user_repository), - user_in: UserCreate, - current_user: UserModel = Depends(deps.get_current_active_superuser), -) -> Any: - """ - Create new user. - """ - user = repo.get_by_email(email=user_in.email) - if user: - raise HTTPException( - status_code=400, - detail="The user with this username already exists in the system.", - ) - user = repo.create(obj_in=user_in) - return user - -@router.get("/me", response_model=User) -def read_user_me( - current_user: UserModel = Depends(deps.get_current_active_user), -) -> Any: - """ - Get current user. - """ - return current_user - -@router.get("/{user_id}", response_model=User) -def read_user_by_id( - user_id: int, - repo: IUserRepository = Depends(deps.get_user_repository), - current_user: UserModel = Depends(deps.get_current_active_user), -) -> Any: - """ - Get a specific user by id. - """ - user = repo.get(id=user_id) - if not user: - raise HTTPException(status_code=404, detail="User not found") - if user == current_user: - return user - if not repo.is_superuser(user=current_user): - raise HTTPException( - status_code=403, detail="The user doesn't have enough privileges" - ) - return user - -@router.put("/{user_id}", response_model=User) -def update_user( - *, - repo: IUserRepository = Depends(deps.get_user_repository), - user_id: int, - user_in: UserUpdate, - current_user: UserModel = Depends(deps.get_current_active_superuser), -) -> Any: - """ - Update a user. - """ - user = repo.get(id=user_id) - if not user: - raise HTTPException( - status_code=404, - detail="The user with this username does not exist in the system", - ) - user = repo.update(db_obj=user, obj_in=user_in) - return user diff --git a/app/api/router.py b/app/api/router.py deleted file mode 100644 index 2c273ffe..00000000 --- a/app/api/router.py +++ /dev/null @@ -1,7 +0,0 @@ -from fastapi import APIRouter -from app.api.endpoints import login, users, analysis - -api_router = APIRouter() -api_router.include_router(login.router, tags=["login"]) -api_router.include_router(users.router, prefix="/users", tags=["users"]) -api_router.include_router(analysis.router, prefix="/analysis", tags=["analysis"]) diff --git a/app/core/__init__.py b/app/core/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/app/core/config.py b/app/core/config.py deleted file mode 100644 index 0ec85091..00000000 --- a/app/core/config.py +++ /dev/null @@ -1,26 +0,0 @@ -import os -from pydantic import BaseSettings -from typing import List, Optional - -class Settings(BaseSettings): - PROJECT_NAME: str = "TradingAgents Backend" - API_V1_STR: str = "/api/v1" - - # Security - SECRET_KEY: str = os.getenv("SECRET_KEY", "a_very_secret_key") - ACCESS_TOKEN_EXPIRE_MINUTES: int = 60 * 24 * 8 # 8 days - ALGORITHM: str = "HS256" - - # Database - DATABASE_URL: str = os.getenv("DATABASE_URL", "sqlite:///./tradingagents.db") - - # OpenAI - OPENAI_API_KEY: str = os.getenv("OPENAI_API_KEY", "") - - # CORS - CORS_ALLOWED_ORIGINS: List[str] = os.getenv("CORS_ALLOWED_ORIGINS", "http://localhost:3000,http://127.0.0.1:3000").split(',') - - class Config: - case_sensitive = True - -settings = Settings() \ No newline at end of file diff --git a/app/core/schemas/__init__.py b/app/core/schemas/__init__.py deleted file mode 100644 index 1cb491e2..00000000 --- a/app/core/schemas/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .user import User, UserCreate, UserUpdate -from .token import Token, TokenPayload -from .profile import Profile, ProfileCreate, ProfileUpdate -from .analysis import AnalysisSession, AnalysisSessionCreate, AnalysisSessionUpdate diff --git a/app/core/schemas/analysis.py b/app/core/schemas/analysis.py deleted file mode 100644 index 966e7df6..00000000 --- a/app/core/schemas/analysis.py +++ /dev/null @@ -1,38 +0,0 @@ -from pydantic import BaseModel -from typing import List, Optional -from datetime import date, datetime -from app.domain.models import AnalysisStatus - -class AnalysisSessionBase(BaseModel): - ticker: str - analysts_selected: List[str] - research_depth: int - llm_provider: str - backend_url: str - shallow_thinker: str - deep_thinker: str - -class AnalysisSessionCreate(AnalysisSessionBase): - pass - -class AnalysisSessionUpdate(BaseModel): - status: Optional[AnalysisStatus] = None - final_report: Optional[str] = None - error_message: Optional[str] = None - -class AnalysisSessionInDBBase(AnalysisSessionBase): - id: int - user_id: int - analysis_date: date - status: AnalysisStatus - final_report: Optional[str] = None - error_message: Optional[str] = None - created_at: datetime - started_at: Optional[datetime] = None - completed_at: Optional[datetime] = None - - class Config: - orm_mode = True - -class AnalysisSession(AnalysisSessionInDBBase): - pass \ No newline at end of file diff --git a/app/core/schemas/profile.py b/app/core/schemas/profile.py deleted file mode 100644 index 95c145bd..00000000 --- a/app/core/schemas/profile.py +++ /dev/null @@ -1,20 +0,0 @@ -from pydantic import BaseModel -from typing import Optional - -class ProfileBase(BaseModel): - default_ticker: str = "SPY" - preferred_research_depth: int = 3 - preferred_shallow_thinker: str = "gpt-4o-mini" - preferred_deep_thinker: str = "gpt-4o" - -class ProfileCreate(ProfileBase): - pass - -class ProfileUpdate(ProfileBase): - openai_api_key: Optional[str] = None - -class Profile(ProfileBase): - has_openai_api_key: bool - - class Config: - orm_mode = True diff --git a/app/core/schemas/token.py b/app/core/schemas/token.py deleted file mode 100644 index 53eb44d1..00000000 --- a/app/core/schemas/token.py +++ /dev/null @@ -1,9 +0,0 @@ -from pydantic import BaseModel -from typing import Optional - -class Token(BaseModel): - access_token: str - token_type: str - -class TokenPayload(BaseModel): - sub: Optional[int] = None diff --git a/app/core/schemas/user.py b/app/core/schemas/user.py deleted file mode 100644 index d2a28a93..00000000 --- a/app/core/schemas/user.py +++ /dev/null @@ -1,28 +0,0 @@ -from pydantic import BaseModel, EmailStr -from typing import Optional - -class UserBase(BaseModel): - email: EmailStr - username: str - first_name: Optional[str] = None - last_name: Optional[str] = None - -class UserCreate(UserBase): - password: str - -class UserUpdate(UserBase): - pass - -class UserInDBBase(UserBase): - id: int - is_active: bool - is_superuser: bool - - class Config: - orm_mode = True - -class User(UserInDBBase): - pass - -class UserInDB(UserInDBBase): - hashed_password: str diff --git a/app/core/security.py b/app/core/security.py deleted file mode 100644 index 15e5503f..00000000 --- a/app/core/security.py +++ /dev/null @@ -1,23 +0,0 @@ -from datetime import datetime, timedelta -from typing import Any, Union, Optional -from jose import jwt -from passlib.context import CryptContext -from app.core.config import settings - -pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") - -def create_access_token(subject: Union[str, Any], expires_delta: Optional[timedelta] = None) -> str: - if expires_delta: - expire = datetime.utcnow() + expires_delta - else: - expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) - - to_encode = {"exp": expire, "sub": str(subject)} - encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM) - return encoded_jwt - -def verify_password(plain_password: str, hashed_password: str) -> bool: - return pwd_context.verify(plain_password, hashed_password) - -def get_password_hash(password: str) -> str: - return pwd_context.hash(password) diff --git a/app/core/services/__init__.py b/app/core/services/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/app/core/services/trading_analysis.py b/app/core/services/trading_analysis.py deleted file mode 100644 index fb70d28a..00000000 --- a/app/core/services/trading_analysis.py +++ /dev/null @@ -1,128 +0,0 @@ -import asyncio -import datetime -import json -from typing import Dict, List, Optional -from sqlmodel import Session, select -from app.domain.models import User, AnalysisSession, AnalysisStatus -from app.core.schemas.analysis import AnalysisSessionCreate -from app.core.config import settings -from cli.models import AnalystType -from tradingagents.graph.trading_graph import TradingAgentsGraph -from tradingagents.default_config import DEFAULT_CONFIG -from app.api.deps import get_db -from app.core.websocket_manager import WebSocketManager - -class TradingAnalysisService: - def __init__(self, user: User, db: Session): - self.user = user - self.db = db - self.websocket_manager = WebSocketManager() - - async def run_analysis(self, session_id: int): - """분석 실행""" - session = self.get_session(session_id=session_id) - if not session: - return - - try: - session.status = AnalysisStatus.RUNNING - session.started_at = datetime.datetime.utcnow() - self.db.add(session) - self.db.commit() - self.db.refresh(session) - - await self.websocket_manager.send_to_user( - self.user.id, - { - 'type': 'analysis_started', - 'session_id': session.id, - 'message': '분석을 시작합니다...' - } - ) - - # Prepare config for TradingAgentsGraph - config = DEFAULT_CONFIG.copy() - config.update({ - 'openai_api_key': settings.OPENAI_API_KEY, - 'llm_provider': session.llm_provider, - 'backend_url': session.backend_url, - 'shallow_thinking_model': session.shallow_thinker, - 'deep_thinking_model': session.deep_thinker, - }) - - # Progress callback for websocket - async def progress_callback(message_type: str, content: str, agent: str = None, step: int = 0, total: int = 0): - progress_percent = int((step / total) * 99) if total > 0 else 0 - await self.websocket_manager.send_to_user(self.user.id, { - 'type': 'analysis_progress', - 'session_id': session.id, - 'message_type': message_type, - 'content': content, - 'agent': agent, - 'progress': progress_percent, - }) - - trading_graph = TradingAgentsGraph( - config=config, - selected_analysts=session.analysts_selected, - ) - - input_data = { - 'company_of_interest': session.ticker, - 'trade_date': session.analysis_date.strftime('%Y-%m-%d'), - } - - final_state, result = await asyncio.to_thread( - trading_graph.propagate, - input_data['company_of_interest'], - input_data['trade_date'] - ) - - session.status = AnalysisStatus.COMPLETED - session.completed_at = datetime.datetime.utcnow() - session.final_report = json.dumps(final_state) # Store full state as JSON - self.db.add(session) - self.db.commit() - - await self.websocket_manager.send_to_user( - self.user.id, - { - 'type': 'analysis_completed', - 'session_id': session.id, - 'message': '분석이 완료되었습니다.', - 'result': result - } - ) - - except Exception as e: - session.status = AnalysisStatus.FAILED - session.error_message = str(e) - self.db.add(session) - self.db.commit() - await self.websocket_manager.send_to_user( - self.user.id, - { - 'type': 'analysis_failed', - 'session_id': session.id, - 'message': f'분석 중 오류가 발생했습니다: {str(e)}' - } - ) - - def create_session(self, *, analysis_in: AnalysisSessionCreate) -> AnalysisSession: - session = AnalysisSession( - **analysis_in.dict(), - user_id=self.user.id, - analysis_date=datetime.date.today() - ) - self.db.add(session) - self.db.commit() - self.db.refresh(session) - return session - - def get_session(self, *, session_id: int) -> Optional[AnalysisSession]: - statement = select(AnalysisSession).where(AnalysisSession.id == session_id, AnalysisSession.user_id == self.user.id) - return self.db.exec(statement).first() - - def get_user_sessions(self, *, skip: int = 0, limit: int = 100) -> List[AnalysisSession]: - statement = select(AnalysisSession).where(AnalysisSession.user_id == self.user.id).order_by(AnalysisSession.created_at.desc()).offset(skip).limit(limit) - return self.db.exec(statement).all() diff --git a/app/core/websocket_manager.py b/app/core/websocket_manager.py deleted file mode 100644 index fe5c9b5c..00000000 --- a/app/core/websocket_manager.py +++ /dev/null @@ -1,23 +0,0 @@ -from typing import Dict, List -from fastapi import WebSocket - -class WebSocketManager: - def __init__(self): - self.active_connections: Dict[int, List[WebSocket]] = {} - - async def connect(self, user_id: int, websocket: WebSocket): - await websocket.accept() - if user_id not in self.active_connections: - self.active_connections[user_id] = [] - self.active_connections[user_id].append(websocket) - - def disconnect(self, user_id: int, websocket: WebSocket): - if user_id in self.active_connections: - self.active_connections[user_id].remove(websocket) - if not self.active_connections[user_id]: - del self.active_connections[user_id] - - async def send_to_user(self, user_id: int, message: dict): - if user_id in self.active_connections: - for connection in self.active_connections[user_id]: - await connection.send_json(message) diff --git a/app/domain/__init__.py b/app/domain/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/app/domain/models.py b/app/domain/models.py deleted file mode 100644 index 298c4329..00000000 --- a/app/domain/models.py +++ /dev/null @@ -1,56 +0,0 @@ -from datetime import date, datetime -from typing import List, Optional -from sqlmodel import Field, SQLModel, JSON, Column -import enum - - -class User(SQLModel, table=True): - id: Optional[int] = Field(default=None, primary_key=True) - email: str = Field(unique=True, index=True) - username: str = Field(unique=True, index=True) - hashed_password: str - first_name: Optional[str] = None - last_name: Optional[str] = None - is_active: bool = Field(default=True) - is_superuser: bool = Field(default=False) - created_at: datetime = Field(default_factory=datetime.utcnow) - updated_at: datetime = Field(default_factory=datetime.utcnow, sa_column_kwargs={"onupdate": datetime.utcnow}) - - -class UserProfile(SQLModel, table=True): - id: Optional[int] = Field(default=None, primary_key=True) - user_id: int = Field(foreign_key="user.id", unique=True) - encrypted_openai_api_key: Optional[str] = None - default_ticker: str = Field(default="SPY") - preferred_research_depth: int = Field(default=3) - preferred_shallow_thinker: str = Field(default="gpt-4o-mini") - preferred_deep_thinker: str = Field(default="gpt-4o") - created_at: datetime = Field(default_factory=datetime.utcnow) - updated_at: datetime = Field(default_factory=datetime.utcnow, sa_column_kwargs={"onupdate": datetime.utcnow}) - - -class AnalysisStatus(str, enum.Enum): - PENDING = "pending" - RUNNING = "running" - COMPLETED = "completed" - FAILED = "failed" - CANCELLED = "cancelled" - - -class AnalysisSession(SQLModel, table=True): - id: Optional[int] = Field(default=None, primary_key=True) - user_id: int = Field(foreign_key="user.id") - ticker: str - analysis_date: date - analysts_selected: List[str] = Field(sa_column=Column(JSON)) - research_depth: int - llm_provider: str - backend_url: str - shallow_thinker: str - deep_thinker: str - status: AnalysisStatus = Field(default=AnalysisStatus.PENDING) - final_report: Optional[str] = None - error_message: Optional[str] = None - created_at: datetime = Field(default_factory=datetime.utcnow) - started_at: Optional[datetime] = None - completed_at: Optional[datetime] = None \ No newline at end of file diff --git a/app/domain/repositories.py b/app/domain/repositories.py deleted file mode 100644 index 84964bea..00000000 --- a/app/domain/repositories.py +++ /dev/null @@ -1,48 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Generic, TypeVar, Optional, List -from sqlmodel import SQLModel -from app.core.schemas.user import UserCreate, UserUpdate -from app.domain.models import User - -ModelType = TypeVar("ModelType", bound=SQLModel) -CreateSchemaType = TypeVar("CreateSchemaType", bound=SQLModel) -UpdateSchemaType = TypeVar("UpdateSchemaType", bound=SQLModel) - -class IRepository(Generic[ModelType], ABC): - @abstractmethod - def get(self, id: int) -> Optional[ModelType]: - pass - - @abstractmethod - def get_multi(self, *, skip: int = 0, limit: int = 100) -> List[ModelType]: - pass - - @abstractmethod - def create(self, *, obj_in: CreateSchemaType) -> ModelType: - pass - - @abstractmethod - def update(self, *, db_obj: ModelType, obj_in: UpdateSchemaType) -> ModelType: - pass - - @abstractmethod - def remove(self, *, id: int) -> ModelType: - pass - - -class IUserRepository(IRepository[User], ABC): - @abstractmethod - def get_by_email(self, *, email: str) -> Optional[User]: - pass - - @abstractmethod - def create(self, *, obj_in: UserCreate) -> User: - pass - - @abstractmethod - def update(self, *, db_obj: User, obj_in: UserUpdate) -> User: - pass - - @abstractmethod - def is_superuser(self, *, user: User) -> bool: - pass diff --git a/app/infrastructure/__init__.py b/app/infrastructure/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/app/infrastructure/database.py b/app/infrastructure/database.py deleted file mode 100644 index c0d2ef3b..00000000 --- a/app/infrastructure/database.py +++ /dev/null @@ -1,9 +0,0 @@ -from sqlmodel import create_engine, Session - -from app.core.config import settings - -engine = create_engine(settings.DATABASE_URL, echo=True, connect_args={"check_same_thread": False}) - -def get_db(): - with Session(engine) as session: - yield session diff --git a/app/infrastructure/repositories/__init__.py b/app/infrastructure/repositories/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/app/infrastructure/repositories/user.py b/app/infrastructure/repositories/user.py deleted file mode 100644 index e8840254..00000000 --- a/app/infrastructure/repositories/user.py +++ /dev/null @@ -1,53 +0,0 @@ -from typing import Optional -from sqlmodel import Session, select -from app.domain.models import User -from app.core.schemas.user import UserCreate, UserUpdate -from app.domain.repositories import IUserRepository -from app.core.security import get_password_hash - -class UserRepository(IUserRepository): - def __init__(self, db: Session): - self.db = db - - def get(self, id: int) -> Optional[User]: - return self.db.get(User, id) - - def get_by_email(self, *, email: str) -> Optional[User]: - statement = select(User).where(User.email == email) - return self.db.exec(statement).first() - - def get_multi(self, *, skip: int = 0, limit: int = 100) -> list[User]: - statement = select(User).offset(skip).limit(limit) - return self.db.exec(statement).all() - - def create(self, *, obj_in: UserCreate) -> User: - db_obj = User( - email=obj_in.email, - username=obj_in.username, - hashed_password=get_password_hash(obj_in.password), - first_name=obj_in.first_name, - last_name=obj_in.last_name, - ) - self.db.add(db_obj) - self.db.commit() - self.db.refresh(db_obj) - return db_obj - - def update(self, *, db_obj: User, obj_in: UserUpdate) -> User: - update_data = obj_in.dict(exclude_unset=True) - for field, value in update_data.items(): - setattr(db_obj, field, value) - - self.db.add(db_obj) - self.db.commit() - self.db.refresh(db_obj) - return db_obj - - def remove(self, *, id: int) -> User: - db_obj = self.db.get(User, id) - self.db.delete(db_obj) - self.db.commit() - return db_obj - - def is_superuser(self, *, user: User) -> bool: - return user.is_superuser diff --git a/app/main.py b/app/main.py deleted file mode 100644 index 239c0437..00000000 --- a/app/main.py +++ /dev/null @@ -1,36 +0,0 @@ -import sys -import os -from fastapi import FastAPI -from starlette.middleware.cors import CORSMiddleware - -# Add project root to path to allow importing tradingagents -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))) - -from app.api.router import api_router -from app.core.config import settings -from app.infrastructure.database import engine -from sqlmodel import SQLModel - -def create_tables(): - SQLModel.metadata.create_all(engine) - -app = FastAPI( - title=settings.PROJECT_NAME, - openapi_url=f"{settings.API_V1_STR}/openapi.json" -) - -@app.on_event("startup") -def on_startup(): - create_tables() - -# Set all CORS enabled origins -if settings.CORS_ALLOWED_ORIGINS: - app.add_middleware( - CORSMiddleware, - allow_origins=[str(origin) for origin in settings.CORS_ALLOWED_ORIGINS], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], - ) - -app.include_router(api_router, prefix=settings.API_V1_STR) \ No newline at end of file diff --git a/backend/analysis/domain/analysis.py b/backend/analysis/domain/analysis.py index e15f3c04..f3ba6d4a 100644 --- a/backend/analysis/domain/analysis.py +++ b/backend/analysis/domain/analysis.py @@ -1,13 +1,13 @@ from pydantic import BaseModel, field_validator from datetime import datetime, date -from typing import List, Dict, Union +from typing import Dict from analysis.infra.db_models.analysis import AnalysisStatus class Analysis(BaseModel): id: str | None = None - member_id: str | None = None - ticker: str | None = None - analysis_date: date | None = None + member_id: str + ticker: str + analysis_date: date analysts_selected: list[str] = [] research_depth: int = 1 llm_provider: str = "google" @@ -34,5 +34,5 @@ class Analysis(BaseModel): # 실행 결과 정보 error_message: str | None = None completed_at: datetime | None = None - created_at: datetime | None = None - updated_at: datetime | None = None \ No newline at end of file + created_at: datetime + updated_at: datetime \ No newline at end of file diff --git a/backend/config/config.py b/backend/config/config.py index f53e7de3..2b55e924 100644 --- a/backend/config/config.py +++ b/backend/config/config.py @@ -1,6 +1,6 @@ from functools import lru_cache from pydantic_settings import BaseSettings, SettingsConfigDict -from pydantic import validator, Field +from pydantic import field_validator, Field import secrets import os @@ -42,21 +42,24 @@ class Settings(BaseSettings): ENVIRONMENT: str = Field(default="development", description="Environment (development/staging/production)") DEBUG: bool = Field(default=True, description="Debug mode") - @validator('ENVIRONMENT') + @field_validator('ENVIRONMENT') + @classmethod def validate_environment(cls, v): allowed_envs = ['development', 'staging', 'production'] if v not in allowed_envs: raise ValueError(f'Environment must be one of {allowed_envs}') return v - @validator('LOG_LEVEL') + @field_validator('LOG_LEVEL') + @classmethod def validate_log_level(cls, v): allowed_levels = ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'] if v.upper() not in allowed_levels: raise ValueError(f'Log level must be one of {allowed_levels}') return v.upper() - @validator('SECRET_KEY') + @field_validator('SECRET_KEY') + @classmethod def validate_secret_key(cls, v): if len(v) < 32: raise ValueError('SECRET_KEY must be at least 32 characters long') diff --git a/backend/main.py b/backend/main.py index 0c5ba1cb..59d06dd9 100644 --- a/backend/main.py +++ b/backend/main.py @@ -7,6 +7,7 @@ from utils.containers import Container from utils.middlewares import RateLimitMiddleware, LoggingMiddleware, SecurityHeadersMiddleware from utils.exceptions import BaseAPIException from contextlib import asynccontextmanager +from datetime import datetime from analysis.interface.controller.analysis_controller import router as analysis_router from member.interface.controller.member_controller import router as member_router @@ -114,7 +115,7 @@ async def health_check(): return { "status": "healthy", "environment": settings.ENVIRONMENT, - "timestamp": "2024-01-01T00:00:00Z" + "timestamp": datetime.utcnow().isoformat() + "Z" } @app.get("/") diff --git a/backend/tradingagents/dataflows/interface.py b/backend/tradingagents/dataflows/interface.py index a829008d..e049efff 100644 --- a/backend/tradingagents/dataflows/interface.py +++ b/backend/tradingagents/dataflows/interface.py @@ -14,7 +14,7 @@ from tqdm import tqdm import yfinance as yf from openai import OpenAI from .config import get_config, set_config, DATA_DIR -from .search_provider_factory import SearchProviderFactory +from .search_provider_factory import SearchProviderFactory, create_search_provider_factory def parse_date_range(curr_date: str, look_back_days: int) -> Tuple[str, str]: @@ -709,9 +709,13 @@ def get_YFin_data( return filtered_data +# Enhanced search provider factory instance (singleton) +_search_factory = create_search_provider_factory() + + def get_stock_news(ticker, curr_date): config = get_config() - search_provider = SearchProviderFactory.create_provider(config) + search_provider = _search_factory.create_provider(config) query = f"Can you search Social Media for {ticker} from 7 days before {curr_date} to {curr_date}? Make sure you only get the data posted during that period." return search_provider.search(query) @@ -719,7 +723,7 @@ def get_stock_news(ticker, curr_date): def get_global_news(curr_date): config = get_config() - search_provider = SearchProviderFactory.create_provider(config) + search_provider = _search_factory.create_provider(config) query = f"Search for global macroeconomic news and financial market updates from 7 days before {curr_date} to {curr_date}. Focus on central bank decisions, economic indicators, geopolitical events, and market-moving news that would be important for trading decisions." return search_provider.search(query) @@ -727,7 +731,7 @@ def get_global_news(curr_date): def get_fundamentals(ticker, curr_date): config = get_config() - search_provider = SearchProviderFactory.create_provider(config) + search_provider = _search_factory.create_provider(config) query = f"Search for fundamental analysis data and financial metrics for {ticker} stock from the month before {curr_date} to the month of {curr_date}. Look for earnings reports, financial ratios like PE, PS, cash flow, revenue growth, analyst ratings, and any fundamental analysis discussions. Please present key metrics in a structured format." return search_provider.search(query) diff --git a/backend/tradingagents/dataflows/search_provider_factory.py b/backend/tradingagents/dataflows/search_provider_factory.py index 5508a6f0..86a1dbf8 100644 --- a/backend/tradingagents/dataflows/search_provider_factory.py +++ b/backend/tradingagents/dataflows/search_provider_factory.py @@ -1,47 +1,133 @@ -from .search_provider import ( - SearchProvider, - GoogleSearchProvider, - OpenAISearchProvider -) +from .search_provider import SearchProvider import hashlib import json +from typing import Dict, Callable, Any +from abc import ABC, abstractmethod -class SearchProviderFactory: - _cache = {} # 클래스 레벨 캐시 +class ProviderSelector(ABC): + """Abstract base class for provider selection strategies.""" - @staticmethod - def create_provider(config: dict[str, any]) -> SearchProvider: + @abstractmethod + def select_provider_type(self, config: Dict[str, Any]) -> str: + """Select provider type based on configuration.""" + pass + + +class MappingBasedProviderSelector(ProviderSelector): + """Selects provider based on URL pattern mapping table.""" + + def __init__(self, mappings: Dict[str, str], default_provider: str = "openai"): + self._mappings = mappings + self._default_provider = default_provider + + def select_provider_type(self, config: Dict[str, Any]) -> str: + backend_url = config.get("backend_url", "") + for pattern, provider_type in self._mappings.items(): + if pattern in backend_url: + return provider_type + return self._default_provider + + +class SearchProviderRegistry: + """Registry for search provider creation functions.""" + + def __init__(self): + self._providers: Dict[str, Callable[[Dict[str, Any]], SearchProvider]] = {} + + def register(self, provider_type: str, creator: Callable[[Dict[str, Any]], SearchProvider]): + """Register a provider creator function.""" + self._providers[provider_type] = creator + + def create(self, provider_type: str, config: Dict[str, Any]) -> SearchProvider: + """Create a provider instance using registered creator.""" + if provider_type not in self._providers: + raise ValueError(f"Unknown provider type: {provider_type}") + return self._providers[provider_type](config) + + def get_available_types(self) -> list[str]: + """Get list of available provider types.""" + return list(self._providers.keys()) + + +class SearchProviderFactoryImpl: + """Enhanced factory for creating SearchProvider instances with caching and extensibility.""" + + def __init__(self, registry: SearchProviderRegistry, selector: ProviderSelector): + self._registry = registry + self._selector = selector + self._cache: Dict[str, SearchProvider] = {} + + def create_provider(self, config: Dict[str, Any]) -> SearchProvider: """ Create a SearchProvider with caching to avoid creating new instances. Uses config hash as cache key for efficient reuse. """ # Create cache key from relevant config values cache_key_data = { - "backend_url": config["backend_url"], - "model": config["quick_think_llm"] + "backend_url": config.get("backend_url", ""), + "model": config.get("quick_think_llm", "") } cache_key = hashlib.md5(json.dumps(cache_key_data, sort_keys=True).encode()).hexdigest() # Return cached instance if exists - if cache_key in SearchProviderFactory._cache: - return SearchProviderFactory._cache[cache_key] + if cache_key in self._cache: + return self._cache[cache_key] - # Create new instance - backend_url = config["backend_url"] - model = config["quick_think_llm"] - - if "generativelanguage.googleapis.com" in backend_url: - provider = GoogleSearchProvider(model) - else: - provider = OpenAISearchProvider(model, backend_url) + # Select and create provider + provider_type = self._selector.select_provider_type(config) + provider = self._registry.create(provider_type, config) # Cache and return - SearchProviderFactory._cache[cache_key] = provider + self._cache[cache_key] = provider return provider + def clear_cache(self): + """Clear the provider cache (useful for testing or config changes).""" + self._cache.clear() + + def get_available_provider_types(self) -> list[str]: + """Get list of available provider types.""" + return self._registry.get_available_types() + + +def create_search_provider_factory() -> SearchProviderFactoryImpl: + """Create a configured SearchProviderFactory with default providers.""" + registry = SearchProviderRegistry() + + # Register default providers + def create_google_provider(config: Dict[str, Any]) -> SearchProvider: + from .search_provider import GoogleSearchProvider + return GoogleSearchProvider(config["quick_think_llm"]) + + def create_openai_provider(config: Dict[str, Any]) -> SearchProvider: + from .search_provider import OpenAISearchProvider + return OpenAISearchProvider(config["quick_think_llm"], config["backend_url"]) + + registry.register("google", create_google_provider) + registry.register("openai", create_openai_provider) + + # Create URL pattern mappings (easily extensible) + url_mappings = { + "generativelanguage.googleapis.com": "google", + "api.openai.com": "openai", + } + + selector = MappingBasedProviderSelector(url_mappings, default_provider="openai") + return SearchProviderFactoryImpl(registry, selector) + + +# Backward compatibility - singleton instance +_default_factory = create_search_provider_factory() + + +class SearchProviderFactory: + """Backward compatibility wrapper for the old static factory.""" + + @staticmethod + def create_provider(config: Dict[str, Any]) -> SearchProvider: + return _default_factory.create_provider(config) + @staticmethod def clear_cache(): - """Clear the provider cache (useful for testing or config changes).""" - SearchProviderFactory._cache.clear() - + _default_factory.clear_cache() \ No newline at end of file diff --git a/backend/tradingagents/default_config.py b/backend/tradingagents/default_config.py index 089e9c24..3489dced 100644 --- a/backend/tradingagents/default_config.py +++ b/backend/tradingagents/default_config.py @@ -3,7 +3,7 @@ import os DEFAULT_CONFIG = { "project_dir": os.path.abspath(os.path.join(os.path.dirname(__file__), ".")), "results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", "./results"), - "data_dir": "/Users/yluo/Documents/Code/ScAI/FR1-data", + "data_dir": os.getenv("TRADINGAGENTS_DATA_DIR", "./data"), "data_cache_dir": os.path.join( os.path.abspath(os.path.join(os.path.dirname(__file__), ".")), "dataflows/data_cache", diff --git a/backend/tradingagents/graph/trading_graph.py b/backend/tradingagents/graph/trading_graph.py index 6795dfe0..ab7d50d0 100644 --- a/backend/tradingagents/graph/trading_graph.py +++ b/backend/tradingagents/graph/trading_graph.py @@ -127,7 +127,7 @@ class TradingAgentsGraph: # online tools self.toolkit.get_stock_news, # offline tools - self.toolkit.get_reddit_stock_info, + # self.toolkit.get_reddit_stock_info, ] ), "news": ToolNode( @@ -136,8 +136,8 @@ class TradingAgentsGraph: self.toolkit.get_global_news, self.toolkit.get_google_news, # offline tools - self.toolkit.get_finnhub_news, - self.toolkit.get_reddit_news, + # self.toolkit.get_finnhub_news, + # self.toolkit.get_reddit_news, ] ), "fundamentals": ToolNode( @@ -145,11 +145,11 @@ class TradingAgentsGraph: # online tools self.toolkit.get_fundamentals, # offline tools - self.toolkit.get_finnhub_company_insider_sentiment, - self.toolkit.get_finnhub_company_insider_transactions, - self.toolkit.get_simfin_balance_sheet, - self.toolkit.get_simfin_cashflow, - self.toolkit.get_simfin_income_stmt, + # self.toolkit.get_finnhub_company_insider_sentiment, + # self.toolkit.get_finnhub_company_insider_transactions, + # self.toolkit.get_simfin_balance_sheet, + # self.toolkit.get_simfin_cashflow, + # self.toolkit.get_simfin_income_stmt, ] ), }