Final: ai_final

This commit is contained in:
김진현 2024-10-11 09:11:52 +09:00
parent b1c6282f52
commit b89e6ec827
2 changed files with 14 additions and 10 deletions

View File

@ -73,13 +73,16 @@ def get_classes(label_map:dict[str: int], model_names: dict[int, str]):
def run_predictions(model, image, request, classes):
try:
with torch.no_grad():
result = model.predict(
source=image,
iou=request.iou_threshold,
conf=request.conf_threshold,
classes=classes
)
return result
results = []
for img in image:
result = model.predict(
source=[img],
iou=request.iou_threshold,
conf=request.conf_threshold,
classes=classes
)
results += result
return results
except Exception as e:
raise HTTPException(status_code=500, detail="exception in run_predictions: " + str(e))
@ -122,7 +125,7 @@ def get_random_color():
random_number = random.randint(0, 0xFFFFFF)
return f"#{random_number:06X}"
@router.post("/train")
@router.post("/train", response_model=TrainResponse)
async def detection_train(request: TrainRequest):
send_slack_message(f"Detection train 요청 projectId : {request.project_id}, 이미지 개수:{len(request.data)}", status="success")

View File

@ -24,15 +24,16 @@ async def resource_cleaner_middleware(request: Request, call_next):
start_time = time.time()
try:
response = await call_next(request)
return response
except Exception as exc:
raise exc
finally:
process_time = time.time() - start_time
if request.method != "GET":
send_slack_message(f"처리 시간: {process_time}")
gc.collect()
# gc.collect()
torch.cuda.empty_cache()
return response
# 예외 처리기
@app.exception_handler(HTTPException)