Fix: 학습 진행 상황 보고

This commit is contained in:
김용수 2024-09-30 15:31:34 +09:00
parent 41002cedfb
commit 8dad7d4f1f
5 changed files with 27 additions and 13 deletions

View File

@ -148,7 +148,11 @@ public class AiModelService {
// 레이블 만들기
private Map<String, Integer> getLabelMap(final Project project) {
return project.getCategoryList().stream()
List<ProjectCategory> categoryList = project.getCategoryList();
if(categoryList.isEmpty()){
throw new CustomException(ErrorCode.BAD_REQUEST, "카테고리가 존재하지 않습니다. 학습이 불가합니다.");
}
return categoryList.stream()
.collect(Collectors.toMap(
ProjectCategory::getLabelName,
ProjectCategory::getId
@ -157,7 +161,11 @@ public class AiModelService {
@Transactional(readOnly = true)
public List<TrainDataInfo> getTrainDataInfoList(final Integer projectId) {
return imageRepository.findImagesByProjectIdAndCompleted(projectId)
List<Image> completedImageList = imageRepository.findImagesByProjectIdAndCompleted(projectId);
if(completedImageList.size() < 2){
throw new CustomException(ErrorCode.BAD_REQUEST, "Completed Image가 2개 이상부터 학습 가능합니다.");
}
return completedImageList
.stream()
.map(TrainDataInfo::of)
.toList();

View File

@ -76,10 +76,13 @@ public class ProgressCacheRepository {
* 현재 학습 진행 여부 확인 메서드 (단일 사용)
*/
public boolean trainModelProgressCheck(final int projectId, final int modelId) {
String key = CacheKey.trainKey(projectId, modelId);
Long listSize = redisTemplate.opsForList().size(key);
return listSize != null && listSize > 0;
String key = CacheKey.trainProgressKey();
Object result = redisTemplate.opsForHash().get(key, String.valueOf(projectId));
if(result == null){
return false;
}
int progressModelId = Integer.parseInt((String) result);
return modelId == progressModelId;
}
/**

View File

@ -27,7 +27,7 @@ public class ReportController {
@SwaggerApiSuccess(description = "완성된 모델 리포트를 조회합니다.")
@SwaggerApiError({ErrorCode.EMPTY_REQUEST_PARAMETER, ErrorCode.SERVER_ERROR})
@GetMapping("/models/{model_id}")
public ReportResponse getReportsByModelId(@PathVariable("model_id") final Integer modelId) {
public List<ReportResponse> getReportsByModelId(@PathVariable("model_id") final Integer modelId) {
return reportService.getReportsByModelId(modelId);
}

View File

@ -2,11 +2,14 @@ package com.worlabel.domain.report.repository;
import com.worlabel.domain.report.entity.Report;
import org.springframework.data.jpa.repository.JpaRepository;
import org.springframework.data.jpa.repository.Query;
import org.springframework.data.repository.query.Param;
import java.util.List;
import java.util.Optional;
public interface ReportRepository extends JpaRepository<Report, Integer> {
Optional<Report> findByAiModelId(Integer modelId);
@Query("SELECT r FROM Report r " +
"WHERE r.aiModel.id =:modelId ")
List<Report> findByAiModelId(@Param("modelId") Integer modelId);
}

View File

@ -27,10 +27,10 @@ public class ReportService {
private final ReportRepository reportRepository;
private final ProgressService progressService;
public ReportResponse getReportsByModelId(final Integer modelId) {
Report report = reportRepository.findByAiModelId(modelId)
.orElseThrow(() -> new CustomException(ErrorCode.DATA_NOT_FOUND));
return ReportResponse.from(report);
public List<ReportResponse> getReportsByModelId(final Integer modelId) {
return reportRepository.findByAiModelId(modelId).stream()
.map(ReportResponse::from)
.toList();
}
public void addReportByModelId(final Integer projectId, final Integer modelId, final ReportRequest reportRequest) {