Merge branch 'ai/feat/53-detection-autolabel' into 'ai/develop'
Feat: Detection 오토 레이블링 API 구현 - S11P21S002-53 See merge request s11-s-project/S11P21S002!37
This commit is contained in:
commit
c7379d52ed
4
ai/.gitignore
vendored
4
ai/.gitignore
vendored
@ -9,6 +9,7 @@ __pycache__/
|
||||
*.env
|
||||
|
||||
# 패키지 디렉토리
|
||||
.venv/
|
||||
venv/
|
||||
env/
|
||||
|
||||
@ -29,3 +30,6 @@ dist/
|
||||
|
||||
# MacOS 관련 파일
|
||||
.DS_Store
|
||||
|
||||
# 테스트 파일
|
||||
test-data/
|
64
ai/app/api/yolo/detection.py
Normal file
64
ai/app/api/yolo/detection.py
Normal file
@ -0,0 +1,64 @@
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from schemas.predict_request import PredictRequest
|
||||
from schemas.predict_response import PredictResponse, LabelData
|
||||
from services.ai_service import load_detection_model
|
||||
from typing import List
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("/predict", response_model=List[PredictResponse])
|
||||
def predict(request: PredictRequest):
|
||||
version = "0.1.0"
|
||||
|
||||
# 모델 로드
|
||||
try:
|
||||
model = load_detection_model()
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail="load model exception: "+str(e))
|
||||
|
||||
# 추론
|
||||
results = []
|
||||
try:
|
||||
for image in request.image_list:
|
||||
predict_results = model.predict(
|
||||
source=image.image_url,
|
||||
iou=request.iou_threshold,
|
||||
conf=request.conf_threshold,
|
||||
classes=request.classes)
|
||||
results.append(predict_results[0])
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail="model predict exception: "+str(e))
|
||||
|
||||
# 추론 결과 -> 레이블 객체 파싱
|
||||
response = []
|
||||
try:
|
||||
for (image, result) in zip(request.image_list, results):
|
||||
label_data:LabelData = {
|
||||
"version": version,
|
||||
"task_type": "det",
|
||||
"shapes": [
|
||||
{
|
||||
"label": summary['name'],
|
||||
"color": "#ff0000",
|
||||
"points": [
|
||||
[summary['box']['x1'], summary['box']['y1']],
|
||||
[summary['box']['x2'], summary['box']['y2']]
|
||||
],
|
||||
"group_id": summary['class'],
|
||||
"shape_type": "rectangle",
|
||||
"flags": {}
|
||||
}
|
||||
for summary in result.summary()
|
||||
],
|
||||
"split": "none",
|
||||
"imageHeight": result.orig_shape[0],
|
||||
"imageWidth": result.orig_shape[1],
|
||||
"imageDepth": 1
|
||||
}
|
||||
response.append({
|
||||
"image_id":image.image_id,
|
||||
"data":label_data
|
||||
})
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail="label parsing exception: "+str(e))
|
||||
return response
|
@ -1,6 +1,12 @@
|
||||
from fastapi import FastAPI
|
||||
from app.api.endpoints import router
|
||||
from api.yolo.detection import router as yolo_detection_router
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
app.include_router(router)
|
||||
# 각 기능별 라우터를 애플리케이션에 등록
|
||||
app.include_router(yolo_detection_router, prefix="/api/yolo/detection")
|
||||
|
||||
# 애플리케이션 실행
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run("main:app", reload=True)
|
||||
|
14
ai/app/schemas/predict_request.py
Normal file
14
ai/app/schemas/predict_request.py
Normal file
@ -0,0 +1,14 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional
|
||||
|
||||
class ImageInfo(BaseModel):
|
||||
image_id: int
|
||||
image_url: str
|
||||
|
||||
class PredictRequest(BaseModel):
|
||||
project_id: int
|
||||
image_list: List[ImageInfo]
|
||||
version: Optional[str] = "latest"
|
||||
conf_threshold: Optional[float] = 0.25
|
||||
iou_threshold: Optional[float] = 0.45
|
||||
classes: Optional[List[int]] = None
|
23
ai/app/schemas/predict_response.py
Normal file
23
ai/app/schemas/predict_response.py
Normal file
@ -0,0 +1,23 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional, Tuple, Dict
|
||||
|
||||
class Shape(BaseModel):
|
||||
label: str
|
||||
color: str
|
||||
points: List[Tuple[float, float]]
|
||||
group_id: Optional[int] = None
|
||||
shape_type: str
|
||||
flags: Dict[str, Optional[bool]] = {}
|
||||
|
||||
class LabelData(BaseModel):
|
||||
version: str
|
||||
task_type: str
|
||||
shapes: List[Shape]
|
||||
split: str
|
||||
imageHeight: int
|
||||
imageWidth: int
|
||||
imageDepth: int
|
||||
|
||||
class PredictResponse(BaseModel):
|
||||
image_id: int
|
||||
data: LabelData
|
@ -0,0 +1,31 @@
|
||||
# ai_service.py
|
||||
|
||||
from ultralytics import YOLO # Ultralytics YOLO 모델을 가져오기
|
||||
from typing import List
|
||||
import os
|
||||
|
||||
def load_detection_model(model_path: str = "test-data/model/yolov8n.pt", device:str ="cpu"):
|
||||
"""
|
||||
지정된 경로에서 YOLO 모델을 로드합니다.
|
||||
|
||||
Args:
|
||||
model_path (str): 모델 파일 경로.
|
||||
device (str): 모델을 로드할 장치. 기본값은 'cpu'.
|
||||
'cpu' 또는 'cuda'와 같은 장치를 지정할 수 있습니다.
|
||||
|
||||
Returns:
|
||||
YOLO: 로드된 YOLO 모델 인스턴스
|
||||
"""
|
||||
|
||||
if not os.path.exists(model_path) and model_path != "test-data/model/yolov8n.pt":
|
||||
raise FileNotFoundError(f"Model file not found at path: {model_path}")
|
||||
|
||||
try:
|
||||
model = YOLO(model_path)
|
||||
model.to(device)
|
||||
# Detection 모델인지 검증
|
||||
# 코드 추가
|
||||
return model
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load the model from {model_path}. Error: {str(e)}")
|
||||
|
@ -1,16 +1,7 @@
|
||||
# FastAPI 웹 프레임워크
|
||||
fastapi
|
||||
|
||||
# ASGI 서버를 위한 Uvicorn
|
||||
uvicorn
|
||||
|
||||
# YOLOv8 모델을 위한 ultralytics
|
||||
ultralytics
|
||||
|
||||
# 테스트 도구
|
||||
# pytest
|
||||
# pytest-asyncio # 비동기 테스트 지원
|
||||
|
||||
# 환경 변수 로드
|
||||
python-dotenv
|
||||
|
||||
fastapi==0.104.1
|
||||
uvicorn==0.30.6
|
||||
torch==2.4.0 -f https://download.pytorch.org/whl/cpu
|
||||
torchaudio==2.4.0 -f https://download.pytorch.org/whl/cpu
|
||||
torchvision==0.19.0 -f https://download.pytorch.org/whl/cpu
|
||||
ultralytics==8.2.82
|
||||
ultralytics-thop==2.0.5
|
||||
|
Loading…
Reference in New Issue
Block a user