[add] remove app

This commit is contained in:
kimheesu 2025-07-08 14:27:34 +09:00
parent d35b62e999
commit c3e609730b
35 changed files with 145 additions and 843 deletions

3
.gitignore vendored
View File

@ -8,4 +8,5 @@ eval_data/
*.egg-info/
results/
.env
tradingagents/dataflows/data_cache/
tradingagents/dataflows/data_cache/
CLAUDE.md

View File

View File

View File

@ -1,67 +0,0 @@
from typing import Generator, Optional
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from jose import jwt, JWTError
from pydantic import BaseModel
from sqlmodel import Session
from app.core.config import settings
from app.infrastructure.database import get_db
from app.domain.models import User
from app.infrastructure.repositories.user import UserRepository
from app.core.services.trading_analysis import TradingAnalysisService
reusable_oauth2 = OAuth2PasswordBearer(tokenUrl=f"{settings.API_V1_STR}/login/access-token")
class TokenData(BaseModel):
username: Optional[str] = None
def get_user_repository(db: Session = Depends(get_db)) -> UserRepository:
return UserRepository(db)
def get_user_from_token(token: str, db: Session) -> Optional[User]:
try:
payload = jwt.decode(
token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
)
token_data = TokenData(username=payload.get("sub"))
except JWTError:
return None
user_repo = UserRepository(db)
user = user_repo.get_by_email(email=token_data.username)
return user
def get_current_user(
db: Session = Depends(get_db), token: str = Depends(reusable_oauth2)
) -> User:
user = get_user_from_token(token=token, db=db)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
return user
def get_current_active_user(
current_user: User = Depends(get_current_user),
) -> User:
if not current_user.is_active:
raise HTTPException(status_code=400, detail="Inactive user")
return current_user
def get_current_active_superuser(
current_user: User = Depends(get_current_active_user),
) -> User:
if not current_user.is_superuser:
raise HTTPException(
status_code=403, detail="The user doesn't have enough privileges"
)
return current_user
def get_analysis_service(
db: Session = Depends(get_db),
user: User = Depends(get_current_active_user)
) -> TradingAnalysisService:
return TradingAnalysisService(user=user, db=db)

View File

@ -1,94 +0,0 @@
from typing import Any, List
from fastapi import APIRouter, Depends, HTTPException, WebSocket, WebSocketDisconnect, BackgroundTasks
from app.api import deps
from app.core.schemas.analysis import AnalysisSession, AnalysisSessionCreate
from app.domain.models import User as UserModel
from app.core.services.trading_analysis import TradingAnalysisService
from app.core.websocket_manager import WebSocketManager
from sqlmodel import Session
from cli.utils import SHALLOW_AGENT_OPTIONS, DEEP_AGENT_OPTIONS, BASE_URLS
router = APIRouter()
manager = WebSocketManager()
@router.post("/start", response_model=AnalysisSession)
def start_analysis(
*,
analysis_in: AnalysisSessionCreate,
background_tasks: BackgroundTasks,
service: TradingAnalysisService = Depends(deps.get_analysis_service),
) -> Any:
"""
Start a new analysis session.
"""
session = service.create_session(analysis_in=analysis_in)
background_tasks.add_task(service.run_analysis, session_id=session.id)
return session
@router.get("/history", response_model=List[AnalysisSession])
def get_analysis_history(
service: TradingAnalysisService = Depends(deps.get_analysis_service),
skip: int = 0,
limit: int = 100,
) -> Any:
"""
Get analysis history for the current user.
"""
return service.get_user_sessions(skip=skip, limit=limit)
@router.get("/options")
def get_analysis_options():
"""
Get available options for analysis.
"""
return {
'analysts': [
{'value': 'market', 'label': 'Market Analyst'},
{'value': 'social', 'label': 'Social Analyst'},
{'value': 'news', 'label': 'News Analyst'},
{'value': 'fundamentals', 'label': 'Fundamentals Analyst'},
],
'research_depths': [
{'value': 1, 'label': 'Shallow'},
{'value': 3, 'label': 'Medium'},
{'value': 5, 'label': 'Deep'},
],
'llm_providers': [{'name': p[0], 'url': p[1]} for p in BASE_URLS],
'shallow_thinkers': SHALLOW_AGENT_OPTIONS,
'deep_thinkers': DEEP_AGENT_OPTIONS,
}
@router.get("/{session_id}", response_model=AnalysisSession)
def get_analysis_session(
session_id: int,
service: TradingAnalysisService = Depends(deps.get_analysis_service),
) -> Any:
"""
Get a specific analysis session by ID.
"""
session = service.get_session(session_id=session_id)
if not session:
raise HTTPException(status_code=404, detail="Analysis session not found")
return session
@router.websocket("/ws")
async def websocket_endpoint(
websocket: WebSocket,
token: str,
db: Session = Depends(deps.get_db)
):
"""
WebSocket endpoint for real-time analysis updates.
"""
user = deps.get_user_from_token(token=token, db=db)
if not user or not user.is_active:
await websocket.close(code=1008)
return
await manager.connect(user.id, websocket)
try:
while True:
# Keep the connection alive
await websocket.receive_text()
except WebSocketDisconnect:
manager.disconnect(user.id, websocket)

View File

@ -1,35 +0,0 @@
from datetime import timedelta
from fastapi import APIRouter, Depends, HTTPException
from fastapi.security import OAuth2PasswordRequestForm
from sqlmodel import Session
from app.api import deps
from app.core.config import settings
from app.core.schemas.token import Token
from app.core import security
from app.infrastructure.repositories.user import UserRepository
router = APIRouter()
@router.post("/login/access-token", response_model=Token)
def login_access_token(
db: Session = Depends(deps.get_db), form_data: OAuth2PasswordRequestForm = Depends()
):
"""
OAuth2 compatible token login, get an access token for future requests
"""
user_repo = UserRepository(db)
user = user_repo.get_by_email(email=form_data.username)
if not user or not security.verify_password(form_data.password, user.hashed_password):
raise HTTPException(status_code=400, detail="Incorrect email or password")
elif not user.is_active:
raise HTTPException(status_code=400, detail="Inactive user")
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
return {
"access_token": security.create_access_token(
user.email, expires_delta=access_token_expires
),
"token_type": "bearer",
}

View File

@ -1,89 +0,0 @@
from typing import Any, List
from fastapi import APIRouter, Depends, HTTPException
from app.api import deps
from app.core.schemas.user import User, UserCreate, UserUpdate
from app.domain.models import User as UserModel
from app.domain.repositories import IUserRepository
router = APIRouter()
@router.get("/", response_model=List[User])
def read_users(
repo: IUserRepository = Depends(deps.get_user_repository),
skip: int = 0,
limit: int = 100,
current_user: UserModel = Depends(deps.get_current_active_superuser),
) -> Any:
"""
Retrieve users.
"""
users = repo.get_multi(skip=skip, limit=limit)
return users
@router.post("/", response_model=User)
def create_user(
*,
repo: IUserRepository = Depends(deps.get_user_repository),
user_in: UserCreate,
current_user: UserModel = Depends(deps.get_current_active_superuser),
) -> Any:
"""
Create new user.
"""
user = repo.get_by_email(email=user_in.email)
if user:
raise HTTPException(
status_code=400,
detail="The user with this username already exists in the system.",
)
user = repo.create(obj_in=user_in)
return user
@router.get("/me", response_model=User)
def read_user_me(
current_user: UserModel = Depends(deps.get_current_active_user),
) -> Any:
"""
Get current user.
"""
return current_user
@router.get("/{user_id}", response_model=User)
def read_user_by_id(
user_id: int,
repo: IUserRepository = Depends(deps.get_user_repository),
current_user: UserModel = Depends(deps.get_current_active_user),
) -> Any:
"""
Get a specific user by id.
"""
user = repo.get(id=user_id)
if not user:
raise HTTPException(status_code=404, detail="User not found")
if user == current_user:
return user
if not repo.is_superuser(user=current_user):
raise HTTPException(
status_code=403, detail="The user doesn't have enough privileges"
)
return user
@router.put("/{user_id}", response_model=User)
def update_user(
*,
repo: IUserRepository = Depends(deps.get_user_repository),
user_id: int,
user_in: UserUpdate,
current_user: UserModel = Depends(deps.get_current_active_superuser),
) -> Any:
"""
Update a user.
"""
user = repo.get(id=user_id)
if not user:
raise HTTPException(
status_code=404,
detail="The user with this username does not exist in the system",
)
user = repo.update(db_obj=user, obj_in=user_in)
return user

View File

@ -1,7 +0,0 @@
from fastapi import APIRouter
from app.api.endpoints import login, users, analysis
api_router = APIRouter()
api_router.include_router(login.router, tags=["login"])
api_router.include_router(users.router, prefix="/users", tags=["users"])
api_router.include_router(analysis.router, prefix="/analysis", tags=["analysis"])

View File

View File

@ -1,26 +0,0 @@
import os
from pydantic import BaseSettings
from typing import List, Optional
class Settings(BaseSettings):
PROJECT_NAME: str = "TradingAgents Backend"
API_V1_STR: str = "/api/v1"
# Security
SECRET_KEY: str = os.getenv("SECRET_KEY", "a_very_secret_key")
ACCESS_TOKEN_EXPIRE_MINUTES: int = 60 * 24 * 8 # 8 days
ALGORITHM: str = "HS256"
# Database
DATABASE_URL: str = os.getenv("DATABASE_URL", "sqlite:///./tradingagents.db")
# OpenAI
OPENAI_API_KEY: str = os.getenv("OPENAI_API_KEY", "")
# CORS
CORS_ALLOWED_ORIGINS: List[str] = os.getenv("CORS_ALLOWED_ORIGINS", "http://localhost:3000,http://127.0.0.1:3000").split(',')
class Config:
case_sensitive = True
settings = Settings()

View File

@ -1,4 +0,0 @@
from .user import User, UserCreate, UserUpdate
from .token import Token, TokenPayload
from .profile import Profile, ProfileCreate, ProfileUpdate
from .analysis import AnalysisSession, AnalysisSessionCreate, AnalysisSessionUpdate

View File

@ -1,38 +0,0 @@
from pydantic import BaseModel
from typing import List, Optional
from datetime import date, datetime
from app.domain.models import AnalysisStatus
class AnalysisSessionBase(BaseModel):
ticker: str
analysts_selected: List[str]
research_depth: int
llm_provider: str
backend_url: str
shallow_thinker: str
deep_thinker: str
class AnalysisSessionCreate(AnalysisSessionBase):
pass
class AnalysisSessionUpdate(BaseModel):
status: Optional[AnalysisStatus] = None
final_report: Optional[str] = None
error_message: Optional[str] = None
class AnalysisSessionInDBBase(AnalysisSessionBase):
id: int
user_id: int
analysis_date: date
status: AnalysisStatus
final_report: Optional[str] = None
error_message: Optional[str] = None
created_at: datetime
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
class Config:
orm_mode = True
class AnalysisSession(AnalysisSessionInDBBase):
pass

View File

@ -1,20 +0,0 @@
from pydantic import BaseModel
from typing import Optional
class ProfileBase(BaseModel):
default_ticker: str = "SPY"
preferred_research_depth: int = 3
preferred_shallow_thinker: str = "gpt-4o-mini"
preferred_deep_thinker: str = "gpt-4o"
class ProfileCreate(ProfileBase):
pass
class ProfileUpdate(ProfileBase):
openai_api_key: Optional[str] = None
class Profile(ProfileBase):
has_openai_api_key: bool
class Config:
orm_mode = True

View File

@ -1,9 +0,0 @@
from pydantic import BaseModel
from typing import Optional
class Token(BaseModel):
access_token: str
token_type: str
class TokenPayload(BaseModel):
sub: Optional[int] = None

View File

@ -1,28 +0,0 @@
from pydantic import BaseModel, EmailStr
from typing import Optional
class UserBase(BaseModel):
email: EmailStr
username: str
first_name: Optional[str] = None
last_name: Optional[str] = None
class UserCreate(UserBase):
password: str
class UserUpdate(UserBase):
pass
class UserInDBBase(UserBase):
id: int
is_active: bool
is_superuser: bool
class Config:
orm_mode = True
class User(UserInDBBase):
pass
class UserInDB(UserInDBBase):
hashed_password: str

View File

@ -1,23 +0,0 @@
from datetime import datetime, timedelta
from typing import Any, Union, Optional
from jose import jwt
from passlib.context import CryptContext
from app.core.config import settings
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def create_access_token(subject: Union[str, Any], expires_delta: Optional[timedelta] = None) -> str:
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
to_encode = {"exp": expire, "sub": str(subject)}
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
return encoded_jwt
def verify_password(plain_password: str, hashed_password: str) -> bool:
return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password: str) -> str:
return pwd_context.hash(password)

View File

@ -1,128 +0,0 @@
import asyncio
import datetime
import json
from typing import Dict, List, Optional
from sqlmodel import Session, select
from app.domain.models import User, AnalysisSession, AnalysisStatus
from app.core.schemas.analysis import AnalysisSessionCreate
from app.core.config import settings
from cli.models import AnalystType
from tradingagents.graph.trading_graph import TradingAgentsGraph
from tradingagents.default_config import DEFAULT_CONFIG
from app.api.deps import get_db
from app.core.websocket_manager import WebSocketManager
class TradingAnalysisService:
def __init__(self, user: User, db: Session):
self.user = user
self.db = db
self.websocket_manager = WebSocketManager()
async def run_analysis(self, session_id: int):
"""분석 실행"""
session = self.get_session(session_id=session_id)
if not session:
return
try:
session.status = AnalysisStatus.RUNNING
session.started_at = datetime.datetime.utcnow()
self.db.add(session)
self.db.commit()
self.db.refresh(session)
await self.websocket_manager.send_to_user(
self.user.id,
{
'type': 'analysis_started',
'session_id': session.id,
'message': '분석을 시작합니다...'
}
)
# Prepare config for TradingAgentsGraph
config = DEFAULT_CONFIG.copy()
config.update({
'openai_api_key': settings.OPENAI_API_KEY,
'llm_provider': session.llm_provider,
'backend_url': session.backend_url,
'shallow_thinking_model': session.shallow_thinker,
'deep_thinking_model': session.deep_thinker,
})
# Progress callback for websocket
async def progress_callback(message_type: str, content: str, agent: str = None, step: int = 0, total: int = 0):
progress_percent = int((step / total) * 99) if total > 0 else 0
await self.websocket_manager.send_to_user(self.user.id, {
'type': 'analysis_progress',
'session_id': session.id,
'message_type': message_type,
'content': content,
'agent': agent,
'progress': progress_percent,
})
trading_graph = TradingAgentsGraph(
config=config,
selected_analysts=session.analysts_selected,
)
input_data = {
'company_of_interest': session.ticker,
'trade_date': session.analysis_date.strftime('%Y-%m-%d'),
}
final_state, result = await asyncio.to_thread(
trading_graph.propagate,
input_data['company_of_interest'],
input_data['trade_date']
)
session.status = AnalysisStatus.COMPLETED
session.completed_at = datetime.datetime.utcnow()
session.final_report = json.dumps(final_state) # Store full state as JSON
self.db.add(session)
self.db.commit()
await self.websocket_manager.send_to_user(
self.user.id,
{
'type': 'analysis_completed',
'session_id': session.id,
'message': '분석이 완료되었습니다.',
'result': result
}
)
except Exception as e:
session.status = AnalysisStatus.FAILED
session.error_message = str(e)
self.db.add(session)
self.db.commit()
await self.websocket_manager.send_to_user(
self.user.id,
{
'type': 'analysis_failed',
'session_id': session.id,
'message': f'분석 중 오류가 발생했습니다: {str(e)}'
}
)
def create_session(self, *, analysis_in: AnalysisSessionCreate) -> AnalysisSession:
session = AnalysisSession(
**analysis_in.dict(),
user_id=self.user.id,
analysis_date=datetime.date.today()
)
self.db.add(session)
self.db.commit()
self.db.refresh(session)
return session
def get_session(self, *, session_id: int) -> Optional[AnalysisSession]:
statement = select(AnalysisSession).where(AnalysisSession.id == session_id, AnalysisSession.user_id == self.user.id)
return self.db.exec(statement).first()
def get_user_sessions(self, *, skip: int = 0, limit: int = 100) -> List[AnalysisSession]:
statement = select(AnalysisSession).where(AnalysisSession.user_id == self.user.id).order_by(AnalysisSession.created_at.desc()).offset(skip).limit(limit)
return self.db.exec(statement).all()

View File

@ -1,23 +0,0 @@
from typing import Dict, List
from fastapi import WebSocket
class WebSocketManager:
def __init__(self):
self.active_connections: Dict[int, List[WebSocket]] = {}
async def connect(self, user_id: int, websocket: WebSocket):
await websocket.accept()
if user_id not in self.active_connections:
self.active_connections[user_id] = []
self.active_connections[user_id].append(websocket)
def disconnect(self, user_id: int, websocket: WebSocket):
if user_id in self.active_connections:
self.active_connections[user_id].remove(websocket)
if not self.active_connections[user_id]:
del self.active_connections[user_id]
async def send_to_user(self, user_id: int, message: dict):
if user_id in self.active_connections:
for connection in self.active_connections[user_id]:
await connection.send_json(message)

View File

View File

@ -1,56 +0,0 @@
from datetime import date, datetime
from typing import List, Optional
from sqlmodel import Field, SQLModel, JSON, Column
import enum
class User(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
email: str = Field(unique=True, index=True)
username: str = Field(unique=True, index=True)
hashed_password: str
first_name: Optional[str] = None
last_name: Optional[str] = None
is_active: bool = Field(default=True)
is_superuser: bool = Field(default=False)
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow, sa_column_kwargs={"onupdate": datetime.utcnow})
class UserProfile(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
user_id: int = Field(foreign_key="user.id", unique=True)
encrypted_openai_api_key: Optional[str] = None
default_ticker: str = Field(default="SPY")
preferred_research_depth: int = Field(default=3)
preferred_shallow_thinker: str = Field(default="gpt-4o-mini")
preferred_deep_thinker: str = Field(default="gpt-4o")
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow, sa_column_kwargs={"onupdate": datetime.utcnow})
class AnalysisStatus(str, enum.Enum):
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
class AnalysisSession(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
user_id: int = Field(foreign_key="user.id")
ticker: str
analysis_date: date
analysts_selected: List[str] = Field(sa_column=Column(JSON))
research_depth: int
llm_provider: str
backend_url: str
shallow_thinker: str
deep_thinker: str
status: AnalysisStatus = Field(default=AnalysisStatus.PENDING)
final_report: Optional[str] = None
error_message: Optional[str] = None
created_at: datetime = Field(default_factory=datetime.utcnow)
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None

View File

@ -1,48 +0,0 @@
from abc import ABC, abstractmethod
from typing import Generic, TypeVar, Optional, List
from sqlmodel import SQLModel
from app.core.schemas.user import UserCreate, UserUpdate
from app.domain.models import User
ModelType = TypeVar("ModelType", bound=SQLModel)
CreateSchemaType = TypeVar("CreateSchemaType", bound=SQLModel)
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=SQLModel)
class IRepository(Generic[ModelType], ABC):
@abstractmethod
def get(self, id: int) -> Optional[ModelType]:
pass
@abstractmethod
def get_multi(self, *, skip: int = 0, limit: int = 100) -> List[ModelType]:
pass
@abstractmethod
def create(self, *, obj_in: CreateSchemaType) -> ModelType:
pass
@abstractmethod
def update(self, *, db_obj: ModelType, obj_in: UpdateSchemaType) -> ModelType:
pass
@abstractmethod
def remove(self, *, id: int) -> ModelType:
pass
class IUserRepository(IRepository[User], ABC):
@abstractmethod
def get_by_email(self, *, email: str) -> Optional[User]:
pass
@abstractmethod
def create(self, *, obj_in: UserCreate) -> User:
pass
@abstractmethod
def update(self, *, db_obj: User, obj_in: UserUpdate) -> User:
pass
@abstractmethod
def is_superuser(self, *, user: User) -> bool:
pass

View File

@ -1,9 +0,0 @@
from sqlmodel import create_engine, Session
from app.core.config import settings
engine = create_engine(settings.DATABASE_URL, echo=True, connect_args={"check_same_thread": False})
def get_db():
with Session(engine) as session:
yield session

View File

@ -1,53 +0,0 @@
from typing import Optional
from sqlmodel import Session, select
from app.domain.models import User
from app.core.schemas.user import UserCreate, UserUpdate
from app.domain.repositories import IUserRepository
from app.core.security import get_password_hash
class UserRepository(IUserRepository):
def __init__(self, db: Session):
self.db = db
def get(self, id: int) -> Optional[User]:
return self.db.get(User, id)
def get_by_email(self, *, email: str) -> Optional[User]:
statement = select(User).where(User.email == email)
return self.db.exec(statement).first()
def get_multi(self, *, skip: int = 0, limit: int = 100) -> list[User]:
statement = select(User).offset(skip).limit(limit)
return self.db.exec(statement).all()
def create(self, *, obj_in: UserCreate) -> User:
db_obj = User(
email=obj_in.email,
username=obj_in.username,
hashed_password=get_password_hash(obj_in.password),
first_name=obj_in.first_name,
last_name=obj_in.last_name,
)
self.db.add(db_obj)
self.db.commit()
self.db.refresh(db_obj)
return db_obj
def update(self, *, db_obj: User, obj_in: UserUpdate) -> User:
update_data = obj_in.dict(exclude_unset=True)
for field, value in update_data.items():
setattr(db_obj, field, value)
self.db.add(db_obj)
self.db.commit()
self.db.refresh(db_obj)
return db_obj
def remove(self, *, id: int) -> User:
db_obj = self.db.get(User, id)
self.db.delete(db_obj)
self.db.commit()
return db_obj
def is_superuser(self, *, user: User) -> bool:
return user.is_superuser

View File

@ -1,36 +0,0 @@
import sys
import os
from fastapi import FastAPI
from starlette.middleware.cors import CORSMiddleware
# Add project root to path to allow importing tradingagents
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")))
from app.api.router import api_router
from app.core.config import settings
from app.infrastructure.database import engine
from sqlmodel import SQLModel
def create_tables():
SQLModel.metadata.create_all(engine)
app = FastAPI(
title=settings.PROJECT_NAME,
openapi_url=f"{settings.API_V1_STR}/openapi.json"
)
@app.on_event("startup")
def on_startup():
create_tables()
# Set all CORS enabled origins
if settings.CORS_ALLOWED_ORIGINS:
app.add_middleware(
CORSMiddleware,
allow_origins=[str(origin) for origin in settings.CORS_ALLOWED_ORIGINS],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.include_router(api_router, prefix=settings.API_V1_STR)

View File

@ -1,13 +1,13 @@
from pydantic import BaseModel, field_validator
from datetime import datetime, date
from typing import List, Dict, Union
from typing import Dict
from analysis.infra.db_models.analysis import AnalysisStatus
class Analysis(BaseModel):
id: str | None = None
member_id: str | None = None
ticker: str | None = None
analysis_date: date | None = None
member_id: str
ticker: str
analysis_date: date
analysts_selected: list[str] = []
research_depth: int = 1
llm_provider: str = "google"
@ -34,5 +34,5 @@ class Analysis(BaseModel):
# 실행 결과 정보
error_message: str | None = None
completed_at: datetime | None = None
created_at: datetime | None = None
updated_at: datetime | None = None
created_at: datetime
updated_at: datetime

View File

@ -1,6 +1,6 @@
from functools import lru_cache
from pydantic_settings import BaseSettings, SettingsConfigDict
from pydantic import validator, Field
from pydantic import field_validator, Field
import secrets
import os
@ -42,21 +42,24 @@ class Settings(BaseSettings):
ENVIRONMENT: str = Field(default="development", description="Environment (development/staging/production)")
DEBUG: bool = Field(default=True, description="Debug mode")
@validator('ENVIRONMENT')
@field_validator('ENVIRONMENT')
@classmethod
def validate_environment(cls, v):
allowed_envs = ['development', 'staging', 'production']
if v not in allowed_envs:
raise ValueError(f'Environment must be one of {allowed_envs}')
return v
@validator('LOG_LEVEL')
@field_validator('LOG_LEVEL')
@classmethod
def validate_log_level(cls, v):
allowed_levels = ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL']
if v.upper() not in allowed_levels:
raise ValueError(f'Log level must be one of {allowed_levels}')
return v.upper()
@validator('SECRET_KEY')
@field_validator('SECRET_KEY')
@classmethod
def validate_secret_key(cls, v):
if len(v) < 32:
raise ValueError('SECRET_KEY must be at least 32 characters long')

View File

@ -7,6 +7,7 @@ from utils.containers import Container
from utils.middlewares import RateLimitMiddleware, LoggingMiddleware, SecurityHeadersMiddleware
from utils.exceptions import BaseAPIException
from contextlib import asynccontextmanager
from datetime import datetime
from analysis.interface.controller.analysis_controller import router as analysis_router
from member.interface.controller.member_controller import router as member_router
@ -114,7 +115,7 @@ async def health_check():
return {
"status": "healthy",
"environment": settings.ENVIRONMENT,
"timestamp": "2024-01-01T00:00:00Z"
"timestamp": datetime.utcnow().isoformat() + "Z"
}
@app.get("/")

View File

@ -14,7 +14,7 @@ from tqdm import tqdm
import yfinance as yf
from openai import OpenAI
from .config import get_config, set_config, DATA_DIR
from .search_provider_factory import SearchProviderFactory
from .search_provider_factory import SearchProviderFactory, create_search_provider_factory
def parse_date_range(curr_date: str, look_back_days: int) -> Tuple[str, str]:
@ -709,9 +709,13 @@ def get_YFin_data(
return filtered_data
# Enhanced search provider factory instance (singleton)
_search_factory = create_search_provider_factory()
def get_stock_news(ticker, curr_date):
config = get_config()
search_provider = SearchProviderFactory.create_provider(config)
search_provider = _search_factory.create_provider(config)
query = f"Can you search Social Media for {ticker} from 7 days before {curr_date} to {curr_date}? Make sure you only get the data posted during that period."
return search_provider.search(query)
@ -719,7 +723,7 @@ def get_stock_news(ticker, curr_date):
def get_global_news(curr_date):
config = get_config()
search_provider = SearchProviderFactory.create_provider(config)
search_provider = _search_factory.create_provider(config)
query = f"Search for global macroeconomic news and financial market updates from 7 days before {curr_date} to {curr_date}. Focus on central bank decisions, economic indicators, geopolitical events, and market-moving news that would be important for trading decisions."
return search_provider.search(query)
@ -727,7 +731,7 @@ def get_global_news(curr_date):
def get_fundamentals(ticker, curr_date):
config = get_config()
search_provider = SearchProviderFactory.create_provider(config)
search_provider = _search_factory.create_provider(config)
query = f"Search for fundamental analysis data and financial metrics for {ticker} stock from the month before {curr_date} to the month of {curr_date}. Look for earnings reports, financial ratios like PE, PS, cash flow, revenue growth, analyst ratings, and any fundamental analysis discussions. Please present key metrics in a structured format."
return search_provider.search(query)

View File

@ -1,47 +1,133 @@
from .search_provider import (
SearchProvider,
GoogleSearchProvider,
OpenAISearchProvider
)
from .search_provider import SearchProvider
import hashlib
import json
from typing import Dict, Callable, Any
from abc import ABC, abstractmethod
class SearchProviderFactory:
_cache = {} # 클래스 레벨 캐시
class ProviderSelector(ABC):
"""Abstract base class for provider selection strategies."""
@staticmethod
def create_provider(config: dict[str, any]) -> SearchProvider:
@abstractmethod
def select_provider_type(self, config: Dict[str, Any]) -> str:
"""Select provider type based on configuration."""
pass
class MappingBasedProviderSelector(ProviderSelector):
"""Selects provider based on URL pattern mapping table."""
def __init__(self, mappings: Dict[str, str], default_provider: str = "openai"):
self._mappings = mappings
self._default_provider = default_provider
def select_provider_type(self, config: Dict[str, Any]) -> str:
backend_url = config.get("backend_url", "")
for pattern, provider_type in self._mappings.items():
if pattern in backend_url:
return provider_type
return self._default_provider
class SearchProviderRegistry:
"""Registry for search provider creation functions."""
def __init__(self):
self._providers: Dict[str, Callable[[Dict[str, Any]], SearchProvider]] = {}
def register(self, provider_type: str, creator: Callable[[Dict[str, Any]], SearchProvider]):
"""Register a provider creator function."""
self._providers[provider_type] = creator
def create(self, provider_type: str, config: Dict[str, Any]) -> SearchProvider:
"""Create a provider instance using registered creator."""
if provider_type not in self._providers:
raise ValueError(f"Unknown provider type: {provider_type}")
return self._providers[provider_type](config)
def get_available_types(self) -> list[str]:
"""Get list of available provider types."""
return list(self._providers.keys())
class SearchProviderFactoryImpl:
"""Enhanced factory for creating SearchProvider instances with caching and extensibility."""
def __init__(self, registry: SearchProviderRegistry, selector: ProviderSelector):
self._registry = registry
self._selector = selector
self._cache: Dict[str, SearchProvider] = {}
def create_provider(self, config: Dict[str, Any]) -> SearchProvider:
"""
Create a SearchProvider with caching to avoid creating new instances.
Uses config hash as cache key for efficient reuse.
"""
# Create cache key from relevant config values
cache_key_data = {
"backend_url": config["backend_url"],
"model": config["quick_think_llm"]
"backend_url": config.get("backend_url", ""),
"model": config.get("quick_think_llm", "")
}
cache_key = hashlib.md5(json.dumps(cache_key_data, sort_keys=True).encode()).hexdigest()
# Return cached instance if exists
if cache_key in SearchProviderFactory._cache:
return SearchProviderFactory._cache[cache_key]
if cache_key in self._cache:
return self._cache[cache_key]
# Create new instance
backend_url = config["backend_url"]
model = config["quick_think_llm"]
if "generativelanguage.googleapis.com" in backend_url:
provider = GoogleSearchProvider(model)
else:
provider = OpenAISearchProvider(model, backend_url)
# Select and create provider
provider_type = self._selector.select_provider_type(config)
provider = self._registry.create(provider_type, config)
# Cache and return
SearchProviderFactory._cache[cache_key] = provider
self._cache[cache_key] = provider
return provider
def clear_cache(self):
"""Clear the provider cache (useful for testing or config changes)."""
self._cache.clear()
def get_available_provider_types(self) -> list[str]:
"""Get list of available provider types."""
return self._registry.get_available_types()
def create_search_provider_factory() -> SearchProviderFactoryImpl:
"""Create a configured SearchProviderFactory with default providers."""
registry = SearchProviderRegistry()
# Register default providers
def create_google_provider(config: Dict[str, Any]) -> SearchProvider:
from .search_provider import GoogleSearchProvider
return GoogleSearchProvider(config["quick_think_llm"])
def create_openai_provider(config: Dict[str, Any]) -> SearchProvider:
from .search_provider import OpenAISearchProvider
return OpenAISearchProvider(config["quick_think_llm"], config["backend_url"])
registry.register("google", create_google_provider)
registry.register("openai", create_openai_provider)
# Create URL pattern mappings (easily extensible)
url_mappings = {
"generativelanguage.googleapis.com": "google",
"api.openai.com": "openai",
}
selector = MappingBasedProviderSelector(url_mappings, default_provider="openai")
return SearchProviderFactoryImpl(registry, selector)
# Backward compatibility - singleton instance
_default_factory = create_search_provider_factory()
class SearchProviderFactory:
"""Backward compatibility wrapper for the old static factory."""
@staticmethod
def create_provider(config: Dict[str, Any]) -> SearchProvider:
return _default_factory.create_provider(config)
@staticmethod
def clear_cache():
"""Clear the provider cache (useful for testing or config changes)."""
SearchProviderFactory._cache.clear()
_default_factory.clear_cache()

View File

@ -3,7 +3,7 @@ import os
DEFAULT_CONFIG = {
"project_dir": os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
"results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", "./results"),
"data_dir": "/Users/yluo/Documents/Code/ScAI/FR1-data",
"data_dir": os.getenv("TRADINGAGENTS_DATA_DIR", "./data"),
"data_cache_dir": os.path.join(
os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
"dataflows/data_cache",

View File

@ -127,7 +127,7 @@ class TradingAgentsGraph:
# online tools
self.toolkit.get_stock_news,
# offline tools
self.toolkit.get_reddit_stock_info,
# self.toolkit.get_reddit_stock_info,
]
),
"news": ToolNode(
@ -136,8 +136,8 @@ class TradingAgentsGraph:
self.toolkit.get_global_news,
self.toolkit.get_google_news,
# offline tools
self.toolkit.get_finnhub_news,
self.toolkit.get_reddit_news,
# self.toolkit.get_finnhub_news,
# self.toolkit.get_reddit_news,
]
),
"fundamentals": ToolNode(
@ -145,11 +145,11 @@ class TradingAgentsGraph:
# online tools
self.toolkit.get_fundamentals,
# offline tools
self.toolkit.get_finnhub_company_insider_sentiment,
self.toolkit.get_finnhub_company_insider_transactions,
self.toolkit.get_simfin_balance_sheet,
self.toolkit.get_simfin_cashflow,
self.toolkit.get_simfin_income_stmt,
# self.toolkit.get_finnhub_company_insider_sentiment,
# self.toolkit.get_finnhub_company_insider_transactions,
# self.toolkit.get_simfin_balance_sheet,
# self.toolkit.get_simfin_cashflow,
# self.toolkit.get_simfin_income_stmt,
]
),
}