Refactor: 요청 응답 객체 구조 변경
This commit is contained in:
parent
8e77c225da
commit
84adb79c5e
3
ai/.gitignore
vendored
3
ai/.gitignore
vendored
@ -30,3 +30,6 @@ dist/
|
|||||||
|
|
||||||
# MacOS 관련 파일
|
# MacOS 관련 파일
|
||||||
.DS_Store
|
.DS_Store
|
||||||
|
|
||||||
|
# 테스트 파일
|
||||||
|
test-data/
|
@ -1,9 +1,8 @@
|
|||||||
from fastapi import APIRouter, HTTPException
|
from fastapi import APIRouter, HTTPException
|
||||||
from schemas.predict_request import PredictRequest
|
from schemas.predict_request import PredictRequest
|
||||||
from schemas.predict_response import PredictResponse
|
from schemas.predict_response import PredictResponse, LabelData
|
||||||
from services.ai_service import load_detection_model
|
from services.ai_service import load_detection_model
|
||||||
from typing import List
|
from typing import List
|
||||||
import os
|
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
@ -11,21 +10,30 @@ router = APIRouter()
|
|||||||
def predict(request: PredictRequest):
|
def predict(request: PredictRequest):
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
|
|
||||||
|
# 모델 로드
|
||||||
try:
|
try:
|
||||||
model = load_detection_model()
|
model = load_detection_model()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail="load model exception: "+str(e))
|
raise HTTPException(status_code=500, detail="load model exception: "+str(e))
|
||||||
print(model)
|
|
||||||
|
# 추론
|
||||||
|
results = []
|
||||||
try:
|
try:
|
||||||
results = model.predict(
|
for image in request.image_list:
|
||||||
source=request.image_path,
|
predict_results = model.predict(
|
||||||
|
source=image.image_url,
|
||||||
iou=request.iou_threshold,
|
iou=request.iou_threshold,
|
||||||
conf=request.conf_threshold,
|
conf=request.conf_threshold,
|
||||||
classes=request.classes)
|
classes=request.classes)
|
||||||
|
results.append(predict_results[0])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail="model predict exception: "+str(e))
|
raise HTTPException(status_code=500, detail="model predict exception: "+str(e))
|
||||||
|
|
||||||
|
# 추론 결과 -> 레이블 객체 파싱
|
||||||
|
response = []
|
||||||
try:
|
try:
|
||||||
response = [{
|
for (image, result) in zip(request.image_list, results):
|
||||||
|
label_data:LabelData = {
|
||||||
"version": version,
|
"version": version,
|
||||||
"task_type": "det",
|
"task_type": "det",
|
||||||
"shapes": [
|
"shapes": [
|
||||||
@ -40,15 +48,17 @@ def predict(request: PredictRequest):
|
|||||||
"shape_type": "rectangle",
|
"shape_type": "rectangle",
|
||||||
"flags": {}
|
"flags": {}
|
||||||
}
|
}
|
||||||
|
|
||||||
for summary in result.summary()
|
for summary in result.summary()
|
||||||
],
|
],
|
||||||
"split": "none",
|
"split": "none",
|
||||||
"imageHeight": result.orig_shape[0],
|
"imageHeight": result.orig_shape[0],
|
||||||
"imageWidth": result.orig_shape[1],
|
"imageWidth": result.orig_shape[1],
|
||||||
"imageDepth": 1
|
"imageDepth": 1
|
||||||
} for result in results
|
}
|
||||||
]
|
response.append({
|
||||||
|
"image_id":image.image_id,
|
||||||
|
"data":label_data
|
||||||
|
})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail="label parsing exception: "+str(e))
|
raise HTTPException(status_code=500, detail="label parsing exception: "+str(e))
|
||||||
return response
|
return response
|
@ -1,9 +1,13 @@
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
|
class ImageInfo(BaseModel):
|
||||||
|
image_id: int
|
||||||
|
image_url: str
|
||||||
|
|
||||||
class PredictRequest(BaseModel):
|
class PredictRequest(BaseModel):
|
||||||
projectId: int
|
project_id: int
|
||||||
image_path: str
|
image_list: List[ImageInfo]
|
||||||
version: Optional[str] = "latest"
|
version: Optional[str] = "latest"
|
||||||
conf_threshold: Optional[float] = 0.25
|
conf_threshold: Optional[float] = 0.25
|
||||||
iou_threshold: Optional[float] = 0.45
|
iou_threshold: Optional[float] = 0.45
|
||||||
|
@ -7,9 +7,9 @@ class Shape(BaseModel):
|
|||||||
points: List[Tuple[float, float]]
|
points: List[Tuple[float, float]]
|
||||||
group_id: Optional[int] = None
|
group_id: Optional[int] = None
|
||||||
shape_type: str
|
shape_type: str
|
||||||
flags: Dict[str, Optional[bool]] = {} # key는 문자열, value는 boolean 또는 None
|
flags: Dict[str, Optional[bool]] = {}
|
||||||
|
|
||||||
class PredictResponse(BaseModel):
|
class LabelData(BaseModel):
|
||||||
version: str
|
version: str
|
||||||
task_type: str
|
task_type: str
|
||||||
shapes: List[Shape]
|
shapes: List[Shape]
|
||||||
@ -17,3 +17,7 @@ class PredictResponse(BaseModel):
|
|||||||
imageHeight: int
|
imageHeight: int
|
||||||
imageWidth: int
|
imageWidth: int
|
||||||
imageDepth: int
|
imageDepth: int
|
||||||
|
|
||||||
|
class PredictResponse(BaseModel):
|
||||||
|
image_id: int
|
||||||
|
data: LabelData
|
@ -1,9 +1,10 @@
|
|||||||
# ai_service.py
|
# ai_service.py
|
||||||
|
|
||||||
from ultralytics import YOLO # Ultralytics YOLO 모델을 가져오기
|
from ultralytics import YOLO # Ultralytics YOLO 모델을 가져오기
|
||||||
|
from typing import List
|
||||||
import os
|
import os
|
||||||
|
|
||||||
def load_detection_model(model_path: str = "test/model/initial.pt", device:str ="cpu"):
|
def load_detection_model(model_path: str = "test-data/model/yolov8n.pt", device:str ="cpu"):
|
||||||
"""
|
"""
|
||||||
지정된 경로에서 YOLO 모델을 로드합니다.
|
지정된 경로에서 YOLO 모델을 로드합니다.
|
||||||
|
|
||||||
@ -16,12 +17,14 @@ def load_detection_model(model_path: str = "test/model/initial.pt", device:str =
|
|||||||
YOLO: 로드된 YOLO 모델 인스턴스
|
YOLO: 로드된 YOLO 모델 인스턴스
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not os.path.exists(model_path):
|
if not os.path.exists(model_path) and model_path != "test-data/model/yolov8n.pt":
|
||||||
raise FileNotFoundError(f"Model file not found at path: {model_path}")
|
raise FileNotFoundError(f"Model file not found at path: {model_path}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
model = YOLO(model_path)
|
model = YOLO(model_path)
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
# Detection 모델인지 검증
|
||||||
|
# 코드 추가
|
||||||
return model
|
return model
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Failed to load the model from {model_path}. Error: {str(e)}")
|
raise RuntimeError(f"Failed to load the model from {model_path}. Error: {str(e)}")
|
||||||
|
Binary file not shown.
Before Width: | Height: | Size: 43 KiB |
Binary file not shown.
Before Width: | Height: | Size: 31 KiB |
Binary file not shown.
Before Width: | Height: | Size: 40 KiB |
Binary file not shown.
Loading…
Reference in New Issue
Block a user