Feat: Detection 오토레이블링 API 구현 - S11P21S002-53
This commit is contained in:
parent
03bba85028
commit
8e77c225da
54
ai/app/api/yolo/detection.py
Normal file
54
ai/app/api/yolo/detection.py
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
from fastapi import APIRouter, HTTPException
|
||||||
|
from schemas.predict_request import PredictRequest
|
||||||
|
from schemas.predict_response import PredictResponse
|
||||||
|
from services.ai_service import load_detection_model
|
||||||
|
from typing import List
|
||||||
|
import os
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
@router.post("/predict", response_model=List[PredictResponse])
|
||||||
|
def predict(request: PredictRequest):
|
||||||
|
version = "0.1.0"
|
||||||
|
|
||||||
|
try:
|
||||||
|
model = load_detection_model()
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail="load model exception: "+str(e))
|
||||||
|
print(model)
|
||||||
|
try:
|
||||||
|
results = model.predict(
|
||||||
|
source=request.image_path,
|
||||||
|
iou=request.iou_threshold,
|
||||||
|
conf=request.conf_threshold,
|
||||||
|
classes=request.classes)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail="model predict exception: "+str(e))
|
||||||
|
try:
|
||||||
|
response = [{
|
||||||
|
"version": version,
|
||||||
|
"task_type": "det",
|
||||||
|
"shapes": [
|
||||||
|
{
|
||||||
|
"label": summary['name'],
|
||||||
|
"color": "#ff0000",
|
||||||
|
"points": [
|
||||||
|
[summary['box']['x1'], summary['box']['y1']],
|
||||||
|
[summary['box']['x2'], summary['box']['y2']]
|
||||||
|
],
|
||||||
|
"group_id": summary['class'],
|
||||||
|
"shape_type": "rectangle",
|
||||||
|
"flags": {}
|
||||||
|
}
|
||||||
|
|
||||||
|
for summary in result.summary()
|
||||||
|
],
|
||||||
|
"split": "none",
|
||||||
|
"imageHeight": result.orig_shape[0],
|
||||||
|
"imageWidth": result.orig_shape[1],
|
||||||
|
"imageDepth": 1
|
||||||
|
} for result in results
|
||||||
|
]
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail="label parsing exception: "+str(e))
|
||||||
|
return response
|
@ -1,6 +1,12 @@
|
|||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from app.api.endpoints import router
|
from api.yolo.detection import router as yolo_detection_router
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
app.include_router(router)
|
# 각 기능별 라우터를 애플리케이션에 등록
|
||||||
|
app.include_router(yolo_detection_router, prefix="/api/yolo/detection")
|
||||||
|
|
||||||
|
# 애플리케이션 실행
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import uvicorn
|
||||||
|
uvicorn.run("main:app", reload=True)
|
||||||
|
10
ai/app/schemas/predict_request.py
Normal file
10
ai/app/schemas/predict_request.py
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
class PredictRequest(BaseModel):
|
||||||
|
projectId: int
|
||||||
|
image_path: str
|
||||||
|
version: Optional[str] = "latest"
|
||||||
|
conf_threshold: Optional[float] = 0.25
|
||||||
|
iou_threshold: Optional[float] = 0.45
|
||||||
|
classes: Optional[List[int]] = None
|
19
ai/app/schemas/predict_response.py
Normal file
19
ai/app/schemas/predict_response.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import List, Optional, Tuple, Dict
|
||||||
|
|
||||||
|
class Shape(BaseModel):
|
||||||
|
label: str
|
||||||
|
color: str
|
||||||
|
points: List[Tuple[float, float]]
|
||||||
|
group_id: Optional[int] = None
|
||||||
|
shape_type: str
|
||||||
|
flags: Dict[str, Optional[bool]] = {} # key는 문자열, value는 boolean 또는 None
|
||||||
|
|
||||||
|
class PredictResponse(BaseModel):
|
||||||
|
version: str
|
||||||
|
task_type: str
|
||||||
|
shapes: List[Shape]
|
||||||
|
split: str
|
||||||
|
imageHeight: int
|
||||||
|
imageWidth: int
|
||||||
|
imageDepth: int
|
@ -0,0 +1,28 @@
|
|||||||
|
# ai_service.py
|
||||||
|
|
||||||
|
from ultralytics import YOLO # Ultralytics YOLO 모델을 가져오기
|
||||||
|
import os
|
||||||
|
|
||||||
|
def load_detection_model(model_path: str = "test/model/initial.pt", device:str ="cpu"):
|
||||||
|
"""
|
||||||
|
지정된 경로에서 YOLO 모델을 로드합니다.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path (str): 모델 파일 경로.
|
||||||
|
device (str): 모델을 로드할 장치. 기본값은 'cpu'.
|
||||||
|
'cpu' 또는 'cuda'와 같은 장치를 지정할 수 있습니다.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
YOLO: 로드된 YOLO 모델 인스턴스
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not os.path.exists(model_path):
|
||||||
|
raise FileNotFoundError(f"Model file not found at path: {model_path}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
model = YOLO(model_path)
|
||||||
|
model.to(device)
|
||||||
|
return model
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Failed to load the model from {model_path}. Error: {str(e)}")
|
||||||
|
|
@ -1,7 +1,7 @@
|
|||||||
fastapi==0.104.1
|
fastapi==0.104.1
|
||||||
uvicorn==0.30.6
|
uvicorn==0.30.6
|
||||||
torch==2.4.0
|
torch==2.4.0 -f https://download.pytorch.org/whl/cpu
|
||||||
torchaudio==2.4.0
|
torchaudio==2.4.0 -f https://download.pytorch.org/whl/cpu
|
||||||
torchvision==0.19.0
|
torchvision==0.19.0 -f https://download.pytorch.org/whl/cpu
|
||||||
ultralytics==8.2.82
|
ultralytics==8.2.82
|
||||||
ultralytics-thop==2.0.5
|
ultralytics-thop==2.0.5
|
||||||
|
Binary file not shown.
After Width: | Height: | Size: 43 KiB |
Binary file not shown.
After Width: | Height: | Size: 31 KiB |
Binary file not shown.
After Width: | Height: | Size: 40 KiB |
BIN
ai/test/model/initial.pt
Normal file
BIN
ai/test/model/initial.pt
Normal file
Binary file not shown.
Loading…
Reference in New Issue
Block a user