Refactor: 요청 응답 객체 구조 변경

This commit is contained in:
김진현 2024-09-03 15:45:05 +09:00
parent 8e77c225da
commit 84adb79c5e
9 changed files with 64 additions and 40 deletions

3
ai/.gitignore vendored
View File

@ -30,3 +30,6 @@ dist/
# MacOS 관련 파일 # MacOS 관련 파일
.DS_Store .DS_Store
# 테스트 파일
test-data/

View File

@ -1,9 +1,8 @@
from fastapi import APIRouter, HTTPException from fastapi import APIRouter, HTTPException
from schemas.predict_request import PredictRequest from schemas.predict_request import PredictRequest
from schemas.predict_response import PredictResponse from schemas.predict_response import PredictResponse, LabelData
from services.ai_service import load_detection_model from services.ai_service import load_detection_model
from typing import List from typing import List
import os
router = APIRouter() router = APIRouter()
@ -11,21 +10,30 @@ router = APIRouter()
def predict(request: PredictRequest): def predict(request: PredictRequest):
version = "0.1.0" version = "0.1.0"
# 모델 로드
try: try:
model = load_detection_model() model = load_detection_model()
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail="load model exception: "+str(e)) raise HTTPException(status_code=500, detail="load model exception: "+str(e))
print(model)
# 추론
results = []
try: try:
results = model.predict( for image in request.image_list:
source=request.image_path, predict_results = model.predict(
source=image.image_url,
iou=request.iou_threshold, iou=request.iou_threshold,
conf=request.conf_threshold, conf=request.conf_threshold,
classes=request.classes) classes=request.classes)
results.append(predict_results[0])
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail="model predict exception: "+str(e)) raise HTTPException(status_code=500, detail="model predict exception: "+str(e))
# 추론 결과 -> 레이블 객체 파싱
response = []
try: try:
response = [{ for (image, result) in zip(request.image_list, results):
label_data:LabelData = {
"version": version, "version": version,
"task_type": "det", "task_type": "det",
"shapes": [ "shapes": [
@ -40,15 +48,17 @@ def predict(request: PredictRequest):
"shape_type": "rectangle", "shape_type": "rectangle",
"flags": {} "flags": {}
} }
for summary in result.summary() for summary in result.summary()
], ],
"split": "none", "split": "none",
"imageHeight": result.orig_shape[0], "imageHeight": result.orig_shape[0],
"imageWidth": result.orig_shape[1], "imageWidth": result.orig_shape[1],
"imageDepth": 1 "imageDepth": 1
} for result in results }
] response.append({
"image_id":image.image_id,
"data":label_data
})
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail="label parsing exception: "+str(e)) raise HTTPException(status_code=500, detail="label parsing exception: "+str(e))
return response return response

View File

@ -1,9 +1,13 @@
from pydantic import BaseModel from pydantic import BaseModel
from typing import List, Optional from typing import List, Optional
class ImageInfo(BaseModel):
image_id: int
image_url: str
class PredictRequest(BaseModel): class PredictRequest(BaseModel):
projectId: int project_id: int
image_path: str image_list: List[ImageInfo]
version: Optional[str] = "latest" version: Optional[str] = "latest"
conf_threshold: Optional[float] = 0.25 conf_threshold: Optional[float] = 0.25
iou_threshold: Optional[float] = 0.45 iou_threshold: Optional[float] = 0.45

View File

@ -7,9 +7,9 @@ class Shape(BaseModel):
points: List[Tuple[float, float]] points: List[Tuple[float, float]]
group_id: Optional[int] = None group_id: Optional[int] = None
shape_type: str shape_type: str
flags: Dict[str, Optional[bool]] = {} # key는 문자열, value는 boolean 또는 None flags: Dict[str, Optional[bool]] = {}
class PredictResponse(BaseModel): class LabelData(BaseModel):
version: str version: str
task_type: str task_type: str
shapes: List[Shape] shapes: List[Shape]
@ -17,3 +17,7 @@ class PredictResponse(BaseModel):
imageHeight: int imageHeight: int
imageWidth: int imageWidth: int
imageDepth: int imageDepth: int
class PredictResponse(BaseModel):
image_id: int
data: LabelData

View File

@ -1,9 +1,10 @@
# ai_service.py # ai_service.py
from ultralytics import YOLO # Ultralytics YOLO 모델을 가져오기 from ultralytics import YOLO # Ultralytics YOLO 모델을 가져오기
from typing import List
import os import os
def load_detection_model(model_path: str = "test/model/initial.pt", device:str ="cpu"): def load_detection_model(model_path: str = "test-data/model/yolov8n.pt", device:str ="cpu"):
""" """
지정된 경로에서 YOLO 모델을 로드합니다. 지정된 경로에서 YOLO 모델을 로드합니다.
@ -16,12 +17,14 @@ def load_detection_model(model_path: str = "test/model/initial.pt", device:str =
YOLO: 로드된 YOLO 모델 인스턴스 YOLO: 로드된 YOLO 모델 인스턴스
""" """
if not os.path.exists(model_path): if not os.path.exists(model_path) and model_path != "test-data/model/yolov8n.pt":
raise FileNotFoundError(f"Model file not found at path: {model_path}") raise FileNotFoundError(f"Model file not found at path: {model_path}")
try: try:
model = YOLO(model_path) model = YOLO(model_path)
model.to(device) model.to(device)
# Detection 모델인지 검증
# 코드 추가
return model return model
except Exception as e: except Exception as e:
raise RuntimeError(f"Failed to load the model from {model_path}. Error: {str(e)}") raise RuntimeError(f"Failed to load the model from {model_path}. Error: {str(e)}")

Binary file not shown.

Before

Width:  |  Height:  |  Size: 43 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 31 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 40 KiB

Binary file not shown.