Fix: 학습 진행 상황 보고
This commit is contained in:
parent
41002cedfb
commit
8dad7d4f1f
@ -148,7 +148,11 @@ public class AiModelService {
|
||||
|
||||
// 레이블 맵 만들기
|
||||
private Map<String, Integer> getLabelMap(final Project project) {
|
||||
return project.getCategoryList().stream()
|
||||
List<ProjectCategory> categoryList = project.getCategoryList();
|
||||
if(categoryList.isEmpty()){
|
||||
throw new CustomException(ErrorCode.BAD_REQUEST, "카테고리가 존재하지 않습니다. 학습이 불가합니다.");
|
||||
}
|
||||
return categoryList.stream()
|
||||
.collect(Collectors.toMap(
|
||||
ProjectCategory::getLabelName,
|
||||
ProjectCategory::getId
|
||||
@ -157,7 +161,11 @@ public class AiModelService {
|
||||
|
||||
@Transactional(readOnly = true)
|
||||
public List<TrainDataInfo> getTrainDataInfoList(final Integer projectId) {
|
||||
return imageRepository.findImagesByProjectIdAndCompleted(projectId)
|
||||
List<Image> completedImageList = imageRepository.findImagesByProjectIdAndCompleted(projectId);
|
||||
if(completedImageList.size() < 2){
|
||||
throw new CustomException(ErrorCode.BAD_REQUEST, "Completed Image가 2개 이상부터 학습 가능합니다.");
|
||||
}
|
||||
return completedImageList
|
||||
.stream()
|
||||
.map(TrainDataInfo::of)
|
||||
.toList();
|
||||
|
@ -76,10 +76,13 @@ public class ProgressCacheRepository {
|
||||
* 현재 학습 진행 여부 확인 메서드 (단일 키 사용)
|
||||
*/
|
||||
public boolean trainModelProgressCheck(final int projectId, final int modelId) {
|
||||
String key = CacheKey.trainKey(projectId, modelId);
|
||||
Long listSize = redisTemplate.opsForList().size(key);
|
||||
|
||||
return listSize != null && listSize > 0;
|
||||
String key = CacheKey.trainProgressKey();
|
||||
Object result = redisTemplate.opsForHash().get(key, String.valueOf(projectId));
|
||||
if(result == null){
|
||||
return false;
|
||||
}
|
||||
int progressModelId = Integer.parseInt((String) result);
|
||||
return modelId == progressModelId;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -27,7 +27,7 @@ public class ReportController {
|
||||
@SwaggerApiSuccess(description = "완성된 모델 리포트를 조회합니다.")
|
||||
@SwaggerApiError({ErrorCode.EMPTY_REQUEST_PARAMETER, ErrorCode.SERVER_ERROR})
|
||||
@GetMapping("/models/{model_id}")
|
||||
public ReportResponse getReportsByModelId(@PathVariable("model_id") final Integer modelId) {
|
||||
public List<ReportResponse> getReportsByModelId(@PathVariable("model_id") final Integer modelId) {
|
||||
return reportService.getReportsByModelId(modelId);
|
||||
}
|
||||
|
||||
|
@ -2,11 +2,14 @@ package com.worlabel.domain.report.repository;
|
||||
|
||||
import com.worlabel.domain.report.entity.Report;
|
||||
import org.springframework.data.jpa.repository.JpaRepository;
|
||||
import org.springframework.data.jpa.repository.Query;
|
||||
import org.springframework.data.repository.query.Param;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
|
||||
public interface ReportRepository extends JpaRepository<Report, Integer> {
|
||||
|
||||
Optional<Report> findByAiModelId(Integer modelId);
|
||||
@Query("SELECT r FROM Report r " +
|
||||
"WHERE r.aiModel.id =:modelId ")
|
||||
List<Report> findByAiModelId(@Param("modelId") Integer modelId);
|
||||
}
|
||||
|
@ -27,10 +27,10 @@ public class ReportService {
|
||||
private final ReportRepository reportRepository;
|
||||
private final ProgressService progressService;
|
||||
|
||||
public ReportResponse getReportsByModelId(final Integer modelId) {
|
||||
Report report = reportRepository.findByAiModelId(modelId)
|
||||
.orElseThrow(() -> new CustomException(ErrorCode.DATA_NOT_FOUND));
|
||||
return ReportResponse.from(report);
|
||||
public List<ReportResponse> getReportsByModelId(final Integer modelId) {
|
||||
return reportRepository.findByAiModelId(modelId).stream()
|
||||
.map(ReportResponse::from)
|
||||
.toList();
|
||||
}
|
||||
|
||||
public void addReportByModelId(final Integer projectId, final Integer modelId, final ReportRequest reportRequest) {
|
||||
|
Loading…
Reference in New Issue
Block a user