Fix: rename 과정에서 생긴 import app.~에 의한 버그 수정
- 추가적으로 api endpoint 이름과 swagger 관련 설정
This commit is contained in:
parent
e2008e1d4d
commit
aae8faf11e
@ -4,7 +4,7 @@ from fastapi import APIRouter, HTTPException
|
||||
from schemas.predict_request import PredictRequest
|
||||
from schemas.train_request import TrainRequest
|
||||
from schemas.predict_response import PredictResponse, LabelData
|
||||
from app.services.load_model import load_detection_model
|
||||
from services.load_model import load_detection_model
|
||||
from utils.dataset_utils import split_data
|
||||
from utils.file_utils import get_dataset_root_path, process_directories, process_image_and_label, join_path
|
||||
from typing import List
|
||||
@ -14,9 +14,8 @@ import asyncio
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/detection", response_model=List[PredictResponse])
|
||||
async def predict(request: PredictRequest):
|
||||
@router.post("/predict", response_model=List[PredictResponse])
|
||||
async def detection_predict(request: PredictRequest):
|
||||
version = "0.1.0"
|
||||
print("여기")
|
||||
|
||||
@ -141,8 +140,8 @@ async def predict(request: PredictRequest):
|
||||
await ws_client.close()
|
||||
|
||||
|
||||
@router.post("/detection/train")
|
||||
async def train(request: TrainRequest):
|
||||
@router.post("/train")
|
||||
async def detection_train(request: TrainRequest):
|
||||
# 데이터셋 루트 경로 얻기
|
||||
dataset_root_path = get_dataset_root_path(request.project_id)
|
||||
|
||||
@ -204,7 +203,3 @@ async def train(request: TrainRequest):
|
||||
finally:
|
||||
if ws_client.is_connected():
|
||||
await ws_client.close()
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -1,12 +1,12 @@
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from schemas.predict_request import PredictRequest
|
||||
from schemas.predict_response import PredictResponse, LabelData
|
||||
from app.services.load_model import load_segmentation_model
|
||||
from services.load_model import load_segmentation_model
|
||||
from typing import List
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("/segmentation", response_model=List[PredictResponse])
|
||||
@router.post("/predict", response_model=List[PredictResponse])
|
||||
def predict(request: PredictRequest):
|
||||
version = "0.1.0"
|
||||
|
||||
|
@ -5,8 +5,8 @@ from api.yolo.segmentation import router as yolo_segmentation_router
|
||||
app = FastAPI()
|
||||
|
||||
# 각 기능별 라우터를 애플리케이션에 등록
|
||||
app.include_router(yolo_detection_router, prefix="/api")
|
||||
app.include_router(yolo_segmentation_router, prefix="/api")
|
||||
app.include_router(yolo_detection_router, prefix="/api/detection", tags=["Detection"])
|
||||
app.include_router(yolo_segmentation_router, prefix="/api/segmentation", tags=["Segmentation"])
|
||||
|
||||
# 애플리케이션 실행
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
Reference in New Issue
Block a user