diff --git a/backend/src/main/java/com/worlabel/domain/model/service/AiModelService.java b/backend/src/main/java/com/worlabel/domain/model/service/AiModelService.java index b2d64fd..aaba16c 100644 --- a/backend/src/main/java/com/worlabel/domain/model/service/AiModelService.java +++ b/backend/src/main/java/com/worlabel/domain/model/service/AiModelService.java @@ -16,7 +16,7 @@ import com.worlabel.domain.model.entity.dto.DefaultAiModelResponse; import com.worlabel.domain.model.entity.dto.DefaultResponse; import com.worlabel.domain.model.repository.AiModelRepository; import com.worlabel.domain.participant.entity.PrivilegeType; -import com.worlabel.domain.project.dto.AiRequestDto; +import com.worlabel.domain.project.dto.AiDto; import com.worlabel.domain.project.entity.Project; import com.worlabel.domain.project.repository.ProjectRepository; import com.worlabel.global.annotation.CheckPrivilege; @@ -139,13 +139,13 @@ public class AiModelService { List images = imageRepository.findImagesByProjectId(projectId); - List data = images.stream().filter(image -> image.getStatus() == LabelStatus.COMPLETED) - .map(image -> new AiRequestDto.TrainDataInfo(image.getImagePath(), image.getDataPath())) + List data = images.stream().filter(image -> image.getStatus() == LabelStatus.COMPLETED) + .map(image -> new AiDto.TrainDataInfo(image.getImagePath(), image.getDataPath())) .toList(); String endPoint = project.getProjectType().getValue() + "/train"; - AiRequestDto.TrainRequest trainRequest = new AiRequestDto.TrainRequest(); + AiDto.TrainRequest trainRequest = new AiDto.TrainRequest(); trainRequest.setProjectId(projectId); trainRequest.setCategoryId(categories); trainRequest.setData(data); diff --git a/backend/src/main/java/com/worlabel/domain/project/dto/AiRequestDto.java b/backend/src/main/java/com/worlabel/domain/project/dto/AiDto.java similarity index 72% rename from backend/src/main/java/com/worlabel/domain/project/dto/AiRequestDto.java rename to backend/src/main/java/com/worlabel/domain/project/dto/AiDto.java index a5ae80a..a7dd979 100644 --- a/backend/src/main/java/com/worlabel/domain/project/dto/AiRequestDto.java +++ b/backend/src/main/java/com/worlabel/domain/project/dto/AiDto.java @@ -1,12 +1,14 @@ package com.worlabel.domain.project.dto; import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.gson.annotations.SerializedName; import com.worlabel.domain.image.entity.Image; import lombok.*; +import java.util.HashMap; import java.util.List; -public class AiRequestDto { +public class AiDto { @Data public static class TrainDataInfo { @@ -51,6 +53,9 @@ public class AiRequestDto { @JsonProperty("model_key") private String modelKey; + @JsonProperty("label_map") + private HashMap labelMap; + @JsonProperty("image_list") private List imageList; @@ -60,8 +65,8 @@ public class AiRequestDto { @JsonProperty("iou_threshold") private Double iouThreshold; - public static AutoLabelingRequest of(final Integer projectId, final List imageList) { - return new AutoLabelingRequest(projectId, null, imageList, 0.25, 0.45); + public static AutoLabelingRequest of(final Integer projectId, final String modelKey, final HashMap labelMap, final List imageList) { + return new AutoLabelingRequest(projectId, modelKey, labelMap, imageList, 0.25, 0.45); } } @@ -76,8 +81,21 @@ public class AiRequestDto { @JsonProperty("image_url") private String imageUrl; - public static AutoLabelingImageRequest of(Image image){ + public static AutoLabelingImageRequest of(Image image) { return new AutoLabelingImageRequest(image.getId(), image.getImagePath()); } } + + @NoArgsConstructor(access = AccessLevel.PRIVATE) + @AllArgsConstructor(access = AccessLevel.PRIVATE) + @Getter + @ToString + public static class AutoLabelingResult{ + + @SerializedName("image_id") + private Long imageId; + + @SerializedName("data") + private String data; + } } diff --git a/backend/src/main/java/com/worlabel/domain/project/entity/Project.java b/backend/src/main/java/com/worlabel/domain/project/entity/Project.java index 42caf27..1d8ab5c 100644 --- a/backend/src/main/java/com/worlabel/domain/project/entity/Project.java +++ b/backend/src/main/java/com/worlabel/domain/project/entity/Project.java @@ -52,7 +52,7 @@ public class Project extends BaseEntity { * 프로젝트에 속한 카테고리 */ @OneToMany(mappedBy = "project", fetch = FetchType.LAZY, cascade = CascadeType.ALL, orphanRemoval = true) - private List category = new ArrayList<>(); + private List categoryList = new ArrayList<>(); /** * 프로젝트에 속한 모델 diff --git a/backend/src/main/java/com/worlabel/domain/project/repository/ProjectRepository.java b/backend/src/main/java/com/worlabel/domain/project/repository/ProjectRepository.java index a83db99..9a93a84 100644 --- a/backend/src/main/java/com/worlabel/domain/project/repository/ProjectRepository.java +++ b/backend/src/main/java/com/worlabel/domain/project/repository/ProjectRepository.java @@ -27,8 +27,4 @@ public interface ProjectRepository extends JpaRepository { @Param("memberId") Integer memberId, @Param("lastProjectId") Integer lastProjectId, @Param("pageSize") Integer pageSize); - - // ProjectType을 가져오는 메서드 추가 - @Query("SELECT p.projectType FROM Project p WHERE p.id = :projectId") - Optional findProjectTypeById(@Param("projectId") Integer projectId); } diff --git a/backend/src/main/java/com/worlabel/domain/project/service/ProjectService.java b/backend/src/main/java/com/worlabel/domain/project/service/ProjectService.java index 1fb6243..36d2a2b 100644 --- a/backend/src/main/java/com/worlabel/domain/project/service/ProjectService.java +++ b/backend/src/main/java/com/worlabel/domain/project/service/ProjectService.java @@ -1,18 +1,25 @@ package com.worlabel.domain.project.service; +import com.google.gson.Gson; +import com.google.gson.JsonSyntaxException; +import com.google.gson.reflect.TypeToken; import com.worlabel.domain.image.entity.Image; import com.worlabel.domain.image.repository.ImageRepository; +import com.worlabel.domain.labelcategory.entity.ProjectCategory; import com.worlabel.domain.member.entity.Member; import com.worlabel.domain.member.repository.MemberRepository; +import com.worlabel.domain.model.entity.AiModel; +import com.worlabel.domain.model.repository.AiModelRepository; import com.worlabel.domain.participant.entity.Participant; import com.worlabel.domain.participant.entity.PrivilegeType; import com.worlabel.domain.participant.entity.WorkspaceParticipant; import com.worlabel.domain.participant.entity.dto.ParticipantRequest; import com.worlabel.domain.participant.repository.ParticipantRepository; import com.worlabel.domain.participant.repository.WorkspaceParticipantRepository; -import com.worlabel.domain.project.dto.AiRequestDto; -import com.worlabel.domain.project.dto.AiRequestDto.AutoLabelingImageRequest; -import com.worlabel.domain.project.dto.AiRequestDto.AutoLabelingRequest; +import com.worlabel.domain.project.dto.AiDto; +import com.worlabel.domain.project.dto.AiDto.AutoLabelingImageRequest; +import com.worlabel.domain.project.dto.AiDto.AutoLabelingRequest; +import com.worlabel.domain.project.dto.AiDto.AutoLabelingResult; import com.worlabel.domain.project.dto.AutoModelRequest; import com.worlabel.domain.project.entity.Project; import com.worlabel.domain.project.entity.dto.ProjectMemberResponse; @@ -30,6 +37,8 @@ import lombok.extern.slf4j.Slf4j; import org.springframework.stereotype.Service; import org.springframework.transaction.annotation.Transactional; +import java.lang.reflect.Type; +import java.util.HashMap; import java.util.List; import java.util.Objects; @@ -45,8 +54,9 @@ public class ProjectService { private final MemberRepository memberRepository; private final WorkspaceParticipantRepository workspaceParticipantRepository; private final ImageRepository imageRepository; - + private final AiModelRepository aiModelRepository; private final AiRequestService aiService; + private final Gson gson; public ProjectResponse createProject(final Integer memberId, final Integer workspaceId, final ProjectRequest projectRequest) { Workspace workspace = getWorkspace(memberId, workspaceId); @@ -135,43 +145,64 @@ public class ProjectService { participantRepository.delete(participant); } -// @CheckPrivilege(PrivilegeType.EDITOR) -// public void train(final Integer memberId, final Integer projectId) { -// // TODO: 레디스 train 테이블에 존재하는지 확인 -> 이미 있으면 있다고 예외를 던져준다. -> 용수 추후 구현 예정 -// /* -// 없으면 redis 상태 테이블을 만든다. progressTable -// */ -// -// // FastAPI 서버로 학습 요청을 전송 -// Project project = getProject(projectId); -// String endPoint = project.getProjectType().getValue() + "/train"; -// -// TrainRequest trainRequest = new TrainRequest(); -// trainRequest.setProjectId(projectId); -// trainRequest.setData(List.of()); -// -// // FastAPI 서버로 POST 요청 전송 -// String modelKey = aiService.postRequest(endPoint, trainRequest, String.class, response -> response); -// -// // TODO: 모델 생성 후 Default 이름과 Key 값 설정 -// } - /** * 프로젝트별 오토 레이블링 */ @CheckPrivilege(PrivilegeType.EDITOR) public void autoLabeling(final Integer memberId, final Integer projectId, final AutoModelRequest request) { + log.debug("project {}", projectId); Project project = getProject(projectId); String endPoint = project.getProjectType().getValue() + "/predict"; + log.debug("이미지 조회"); List imageList = imageRepository.findImagesByProjectIdAndPendingOrInProgress(projectId); List imageRequestList = imageList.stream() .map(AutoLabelingImageRequest::of) .toList(); - AutoLabelingRequest autoLabelingRequest = AutoLabelingRequest.of(projectId, imageRequestList); + + log.debug("카테고리 조회 "); + HashMap labelMap = getLabelMap(project); + + log.debug("모델 조회"); + AiModel aiModel = getAiModel(request); + AutoLabelingRequest autoLabelingRequest = AutoLabelingRequest.of(projectId, aiModel.getModelKey(), labelMap, imageRequestList); // 응답없음 - aiService.postRequest(endPoint, autoLabelingRequest, Void.class, response -> null); + log.debug("요청"); + List list = aiService.postRequest(endPoint, autoLabelingRequest, List.class, this::converter); + log.debug("list: {}", list); + } + + public List converter(String data) { + try { + log.debug("data :{}", data); + // Gson에서 리스트 형태로 변환할 타입을 지정 + Type listType = new TypeToken>() { + }.getType(); + + // JSON 배열을 List로 변환 + return gson.fromJson(data, listType); + } catch (JsonSyntaxException e) { + // JSON 파싱 중 오류가 발생한 경우 처리 + throw new CustomException(ErrorCode.AI_SERVER_ERROR); + } + } + + private AiModel getAiModel(AutoModelRequest request) { + return aiModelRepository.findById(request.getModelId()) + .orElseThrow(() -> new CustomException(ErrorCode.DATA_NOT_FOUND)); + } + + private HashMap getLabelMap(Project project) { + HashMap labelMap = new HashMap<>(); + List category = project.getCategoryList(); + for (ProjectCategory projectCategory : category) { + int aiId = projectCategory.getLabelCategory().getAiCategoryId(); + if (labelMap.containsKey(aiId)) continue; + + labelMap.put(aiId, projectCategory.getId()); + } + return labelMap; } private Project getProject(final Integer projectId) {