Feat: 모델 생성/학습된 모델 생성 API 구현
This commit is contained in:
parent
aae8faf11e
commit
4e39dccd55
39
ai/app/api/yolo/model.py
Normal file
39
ai/app/api/yolo/model.py
Normal file
@ -0,0 +1,39 @@
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from schemas.model_create_request import ModelCreateRequest
|
||||
from services.init_model import create_pretrained_model, create_default_model
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# modelType(detection/segmentation/classification), (default, pretrained), labelCategories
|
||||
@router.get("/info")
|
||||
def get_model_info(project_id:int, model_path:str):
|
||||
pass
|
||||
|
||||
#
|
||||
@router.get("/list")
|
||||
def get_model_list(project_id:int):
|
||||
pass
|
||||
|
||||
@router.post("/create", status_code=201)
|
||||
def model_create(request: ModelCreateRequest):
|
||||
if request.type not in ["seg", "det", "cls"]:
|
||||
raise HTTPException(status_code=400,
|
||||
detail= f"Invalid type '{request.type}'. Must be one of \"seg\", \"det\", \"cls\".")
|
||||
if request.pretrained:
|
||||
model_path = create_pretrained_model(request.project_id, request.type)
|
||||
else:
|
||||
labels = list(map(lambda x:x.label, request.labelCategories))
|
||||
model_path = create_default_model(request.project_id, request.type, labels)
|
||||
return {"model_path": model_path}
|
||||
|
||||
@router.delete("/delete", status_code=204)
|
||||
def model_delete():
|
||||
pass
|
||||
|
||||
@router.post("/upload")
|
||||
def model_upload():
|
||||
pass
|
||||
|
||||
@router.get("/download")
|
||||
def model_download():
|
||||
pass
|
@ -1,12 +1,14 @@
|
||||
from fastapi import FastAPI
|
||||
from api.yolo.detection import router as yolo_detection_router
|
||||
from api.yolo.segmentation import router as yolo_segmentation_router
|
||||
from api.yolo.model import router as yolo_model_router
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# 각 기능별 라우터를 애플리케이션에 등록
|
||||
app.include_router(yolo_detection_router, prefix="/api/detection", tags=["Detection"])
|
||||
app.include_router(yolo_segmentation_router, prefix="/api/segmentation", tags=["Segmentation"])
|
||||
app.include_router(yolo_model_router, prefix="/api/model", tags=["Model"])
|
||||
|
||||
# 애플리케이션 실행
|
||||
if __name__ == "__main__":
|
||||
|
13
ai/app/schemas/model_create_request.py
Normal file
13
ai/app/schemas/model_create_request.py
Normal file
@ -0,0 +1,13 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
|
||||
class LabelCategory(BaseModel):
|
||||
id:int
|
||||
label:str
|
||||
|
||||
|
||||
class ModelCreateRequest(BaseModel):
|
||||
project_id: int
|
||||
type: str
|
||||
pretrained:bool = True
|
||||
labelCategories:Optional[list[LabelCategory]] = None
|
28
ai/app/services/init_model.py
Normal file
28
ai/app/services/init_model.py
Normal file
@ -0,0 +1,28 @@
|
||||
from ultralytics import YOLO # Ultralytics YOLO 모델을 가져오기
|
||||
import os
|
||||
import uuid
|
||||
|
||||
def create_pretrained_model(project_id: id, type:str):
|
||||
suffix = ""
|
||||
if type in ["seg", "cls"]:
|
||||
suffix = "-"+type
|
||||
# 학습된 기본 모델 로드
|
||||
model = YOLO(os.path.join("resources", "models" ,f"yolov8n{suffix}.pt"))
|
||||
|
||||
# 모델을 저장할 폴더 경로
|
||||
base_path = os.path.join("resources","projects",str(project_id))
|
||||
os.makedirs(base_path, exist_ok=True)
|
||||
|
||||
# 고유값 id 생성
|
||||
unique_id = uuid.uuid4()
|
||||
while os.path.exists(os.path.join(base_path, f"{unique_id}.pt")):
|
||||
unique_id = uuid.uuid4()
|
||||
model_path = os.path.join(base_path, f"{unique_id}.pt")
|
||||
|
||||
# 기본 모델 저장
|
||||
model.save(filename=model_path)
|
||||
|
||||
return model_path
|
||||
|
||||
def create_default_model(project_id: id, type:str, labels:list[str]):
|
||||
pass
|
Loading…
Reference in New Issue
Block a user