Feat: 모델 생성/학습된 모델 생성 API 구현
This commit is contained in:
parent
aae8faf11e
commit
4e39dccd55
39
ai/app/api/yolo/model.py
Normal file
39
ai/app/api/yolo/model.py
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
from fastapi import APIRouter, HTTPException
|
||||||
|
from schemas.model_create_request import ModelCreateRequest
|
||||||
|
from services.init_model import create_pretrained_model, create_default_model
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
# modelType(detection/segmentation/classification), (default, pretrained), labelCategories
|
||||||
|
@router.get("/info")
|
||||||
|
def get_model_info(project_id:int, model_path:str):
|
||||||
|
pass
|
||||||
|
|
||||||
|
#
|
||||||
|
@router.get("/list")
|
||||||
|
def get_model_list(project_id:int):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@router.post("/create", status_code=201)
|
||||||
|
def model_create(request: ModelCreateRequest):
|
||||||
|
if request.type not in ["seg", "det", "cls"]:
|
||||||
|
raise HTTPException(status_code=400,
|
||||||
|
detail= f"Invalid type '{request.type}'. Must be one of \"seg\", \"det\", \"cls\".")
|
||||||
|
if request.pretrained:
|
||||||
|
model_path = create_pretrained_model(request.project_id, request.type)
|
||||||
|
else:
|
||||||
|
labels = list(map(lambda x:x.label, request.labelCategories))
|
||||||
|
model_path = create_default_model(request.project_id, request.type, labels)
|
||||||
|
return {"model_path": model_path}
|
||||||
|
|
||||||
|
@router.delete("/delete", status_code=204)
|
||||||
|
def model_delete():
|
||||||
|
pass
|
||||||
|
|
||||||
|
@router.post("/upload")
|
||||||
|
def model_upload():
|
||||||
|
pass
|
||||||
|
|
||||||
|
@router.get("/download")
|
||||||
|
def model_download():
|
||||||
|
pass
|
@ -1,12 +1,14 @@
|
|||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from api.yolo.detection import router as yolo_detection_router
|
from api.yolo.detection import router as yolo_detection_router
|
||||||
from api.yolo.segmentation import router as yolo_segmentation_router
|
from api.yolo.segmentation import router as yolo_segmentation_router
|
||||||
|
from api.yolo.model import router as yolo_model_router
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
# 각 기능별 라우터를 애플리케이션에 등록
|
# 각 기능별 라우터를 애플리케이션에 등록
|
||||||
app.include_router(yolo_detection_router, prefix="/api/detection", tags=["Detection"])
|
app.include_router(yolo_detection_router, prefix="/api/detection", tags=["Detection"])
|
||||||
app.include_router(yolo_segmentation_router, prefix="/api/segmentation", tags=["Segmentation"])
|
app.include_router(yolo_segmentation_router, prefix="/api/segmentation", tags=["Segmentation"])
|
||||||
|
app.include_router(yolo_model_router, prefix="/api/model", tags=["Model"])
|
||||||
|
|
||||||
# 애플리케이션 실행
|
# 애플리케이션 실행
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
13
ai/app/schemas/model_create_request.py
Normal file
13
ai/app/schemas/model_create_request.py
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
class LabelCategory(BaseModel):
|
||||||
|
id:int
|
||||||
|
label:str
|
||||||
|
|
||||||
|
|
||||||
|
class ModelCreateRequest(BaseModel):
|
||||||
|
project_id: int
|
||||||
|
type: str
|
||||||
|
pretrained:bool = True
|
||||||
|
labelCategories:Optional[list[LabelCategory]] = None
|
28
ai/app/services/init_model.py
Normal file
28
ai/app/services/init_model.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
from ultralytics import YOLO # Ultralytics YOLO 모델을 가져오기
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
def create_pretrained_model(project_id: id, type:str):
|
||||||
|
suffix = ""
|
||||||
|
if type in ["seg", "cls"]:
|
||||||
|
suffix = "-"+type
|
||||||
|
# 학습된 기본 모델 로드
|
||||||
|
model = YOLO(os.path.join("resources", "models" ,f"yolov8n{suffix}.pt"))
|
||||||
|
|
||||||
|
# 모델을 저장할 폴더 경로
|
||||||
|
base_path = os.path.join("resources","projects",str(project_id))
|
||||||
|
os.makedirs(base_path, exist_ok=True)
|
||||||
|
|
||||||
|
# 고유값 id 생성
|
||||||
|
unique_id = uuid.uuid4()
|
||||||
|
while os.path.exists(os.path.join(base_path, f"{unique_id}.pt")):
|
||||||
|
unique_id = uuid.uuid4()
|
||||||
|
model_path = os.path.join(base_path, f"{unique_id}.pt")
|
||||||
|
|
||||||
|
# 기본 모델 저장
|
||||||
|
model.save(filename=model_path)
|
||||||
|
|
||||||
|
return model_path
|
||||||
|
|
||||||
|
def create_default_model(project_id: id, type:str, labels:list[str]):
|
||||||
|
pass
|
Loading…
Reference in New Issue
Block a user