Feat: 모델 생성/학습된 모델 생성 API 구현

This commit is contained in:
김진현 2024-09-18 00:13:01 +09:00
parent aae8faf11e
commit 4e39dccd55
5 changed files with 82 additions and 0 deletions

39
ai/app/api/yolo/model.py Normal file
View File

@ -0,0 +1,39 @@
from fastapi import APIRouter, HTTPException
from schemas.model_create_request import ModelCreateRequest
from services.init_model import create_pretrained_model, create_default_model
router = APIRouter()
# modelType(detection/segmentation/classification), (default, pretrained), labelCategories
@router.get("/info")
def get_model_info(project_id:int, model_path:str):
pass
#
@router.get("/list")
def get_model_list(project_id:int):
pass
@router.post("/create", status_code=201)
def model_create(request: ModelCreateRequest):
if request.type not in ["seg", "det", "cls"]:
raise HTTPException(status_code=400,
detail= f"Invalid type '{request.type}'. Must be one of \"seg\", \"det\", \"cls\".")
if request.pretrained:
model_path = create_pretrained_model(request.project_id, request.type)
else:
labels = list(map(lambda x:x.label, request.labelCategories))
model_path = create_default_model(request.project_id, request.type, labels)
return {"model_path": model_path}
@router.delete("/delete", status_code=204)
def model_delete():
pass
@router.post("/upload")
def model_upload():
pass
@router.get("/download")
def model_download():
pass

View File

@ -1,12 +1,14 @@
from fastapi import FastAPI from fastapi import FastAPI
from api.yolo.detection import router as yolo_detection_router from api.yolo.detection import router as yolo_detection_router
from api.yolo.segmentation import router as yolo_segmentation_router from api.yolo.segmentation import router as yolo_segmentation_router
from api.yolo.model import router as yolo_model_router
app = FastAPI() app = FastAPI()
# 각 기능별 라우터를 애플리케이션에 등록 # 각 기능별 라우터를 애플리케이션에 등록
app.include_router(yolo_detection_router, prefix="/api/detection", tags=["Detection"]) app.include_router(yolo_detection_router, prefix="/api/detection", tags=["Detection"])
app.include_router(yolo_segmentation_router, prefix="/api/segmentation", tags=["Segmentation"]) app.include_router(yolo_segmentation_router, prefix="/api/segmentation", tags=["Segmentation"])
app.include_router(yolo_model_router, prefix="/api/model", tags=["Model"])
# 애플리케이션 실행 # 애플리케이션 실행
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -0,0 +1,13 @@
from pydantic import BaseModel
from typing import Optional
class LabelCategory(BaseModel):
id:int
label:str
class ModelCreateRequest(BaseModel):
project_id: int
type: str
pretrained:bool = True
labelCategories:Optional[list[LabelCategory]] = None

View File

@ -0,0 +1,28 @@
from ultralytics import YOLO # Ultralytics YOLO 모델을 가져오기
import os
import uuid
def create_pretrained_model(project_id: id, type:str):
suffix = ""
if type in ["seg", "cls"]:
suffix = "-"+type
# 학습된 기본 모델 로드
model = YOLO(os.path.join("resources", "models" ,f"yolov8n{suffix}.pt"))
# 모델을 저장할 폴더 경로
base_path = os.path.join("resources","projects",str(project_id))
os.makedirs(base_path, exist_ok=True)
# 고유값 id 생성
unique_id = uuid.uuid4()
while os.path.exists(os.path.join(base_path, f"{unique_id}.pt")):
unique_id = uuid.uuid4()
model_path = os.path.join(base_path, f"{unique_id}.pt")
# 기본 모델 저장
model.save(filename=model_path)
return model_path
def create_default_model(project_id: id, type:str, labels:list[str]):
pass