Final: ai_final

This commit is contained in:
김진현 2024-10-11 09:11:52 +09:00
parent b1c6282f52
commit b89e6ec827
2 changed files with 14 additions and 10 deletions

View File

@ -73,13 +73,16 @@ def get_classes(label_map:dict[str: int], model_names: dict[int, str]):
def run_predictions(model, image, request, classes): def run_predictions(model, image, request, classes):
try: try:
with torch.no_grad(): with torch.no_grad():
results = []
for img in image:
result = model.predict( result = model.predict(
source=image, source=[img],
iou=request.iou_threshold, iou=request.iou_threshold,
conf=request.conf_threshold, conf=request.conf_threshold,
classes=classes classes=classes
) )
return result results += result
return results
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail="exception in run_predictions: " + str(e)) raise HTTPException(status_code=500, detail="exception in run_predictions: " + str(e))
@ -122,7 +125,7 @@ def get_random_color():
random_number = random.randint(0, 0xFFFFFF) random_number = random.randint(0, 0xFFFFFF)
return f"#{random_number:06X}" return f"#{random_number:06X}"
@router.post("/train") @router.post("/train", response_model=TrainResponse)
async def detection_train(request: TrainRequest): async def detection_train(request: TrainRequest):
send_slack_message(f"Detection train 요청 projectId : {request.project_id}, 이미지 개수:{len(request.data)}", status="success") send_slack_message(f"Detection train 요청 projectId : {request.project_id}, 이미지 개수:{len(request.data)}", status="success")

View File

@ -24,15 +24,16 @@ async def resource_cleaner_middleware(request: Request, call_next):
start_time = time.time() start_time = time.time()
try: try:
response = await call_next(request) response = await call_next(request)
return response
except Exception as exc: except Exception as exc:
raise exc raise exc
finally: finally:
process_time = time.time() - start_time process_time = time.time() - start_time
if request.method != "GET": if request.method != "GET":
send_slack_message(f"처리 시간: {process_time}") send_slack_message(f"처리 시간: {process_time}")
gc.collect() # gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
return response
# 예외 처리기 # 예외 처리기
@app.exception_handler(HTTPException) @app.exception_handler(HTTPException)