Feat: 이미지 오토레이블링 AI 연동

This commit is contained in:
김용수 2024-09-23 16:18:13 +09:00
parent e8525b97b4
commit 336f02834e
5 changed files with 85 additions and 40 deletions

View File

@ -16,7 +16,7 @@ import com.worlabel.domain.model.entity.dto.DefaultAiModelResponse;
import com.worlabel.domain.model.entity.dto.DefaultResponse;
import com.worlabel.domain.model.repository.AiModelRepository;
import com.worlabel.domain.participant.entity.PrivilegeType;
import com.worlabel.domain.project.dto.AiRequestDto;
import com.worlabel.domain.project.dto.AiDto;
import com.worlabel.domain.project.entity.Project;
import com.worlabel.domain.project.repository.ProjectRepository;
import com.worlabel.global.annotation.CheckPrivilege;
@ -139,13 +139,13 @@ public class AiModelService {
List<Image> images = imageRepository.findImagesByProjectId(projectId);
List<AiRequestDto.TrainDataInfo> data = images.stream().filter(image -> image.getStatus() == LabelStatus.COMPLETED)
.map(image -> new AiRequestDto.TrainDataInfo(image.getImagePath(), image.getDataPath()))
List<AiDto.TrainDataInfo> data = images.stream().filter(image -> image.getStatus() == LabelStatus.COMPLETED)
.map(image -> new AiDto.TrainDataInfo(image.getImagePath(), image.getDataPath()))
.toList();
String endPoint = project.getProjectType().getValue() + "/train";
AiRequestDto.TrainRequest trainRequest = new AiRequestDto.TrainRequest();
AiDto.TrainRequest trainRequest = new AiDto.TrainRequest();
trainRequest.setProjectId(projectId);
trainRequest.setCategoryId(categories);
trainRequest.setData(data);

View File

@ -1,12 +1,14 @@
package com.worlabel.domain.project.dto;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.gson.annotations.SerializedName;
import com.worlabel.domain.image.entity.Image;
import lombok.*;
import java.util.HashMap;
import java.util.List;
public class AiRequestDto {
public class AiDto {
@Data
public static class TrainDataInfo {
@ -51,6 +53,9 @@ public class AiRequestDto {
@JsonProperty("model_key")
private String modelKey;
@JsonProperty("label_map")
private HashMap<Integer, Integer> labelMap;
@JsonProperty("image_list")
private List<AutoLabelingImageRequest> imageList;
@ -60,8 +65,8 @@ public class AiRequestDto {
@JsonProperty("iou_threshold")
private Double iouThreshold;
public static AutoLabelingRequest of(final Integer projectId, final List<AutoLabelingImageRequest> imageList) {
return new AutoLabelingRequest(projectId, null, imageList, 0.25, 0.45);
public static AutoLabelingRequest of(final Integer projectId, final String modelKey, final HashMap<Integer, Integer> labelMap, final List<AutoLabelingImageRequest> imageList) {
return new AutoLabelingRequest(projectId, modelKey, labelMap, imageList, 0.25, 0.45);
}
}
@ -80,4 +85,17 @@ public class AiRequestDto {
return new AutoLabelingImageRequest(image.getId(), image.getImagePath());
}
}
@NoArgsConstructor(access = AccessLevel.PRIVATE)
@AllArgsConstructor(access = AccessLevel.PRIVATE)
@Getter
@ToString
public static class AutoLabelingResult{
@SerializedName("image_id")
private Long imageId;
@SerializedName("data")
private String data;
}
}

View File

@ -52,7 +52,7 @@ public class Project extends BaseEntity {
* 프로젝트에 속한 카테고리
*/
@OneToMany(mappedBy = "project", fetch = FetchType.LAZY, cascade = CascadeType.ALL, orphanRemoval = true)
private List<ProjectCategory> category = new ArrayList<>();
private List<ProjectCategory> categoryList = new ArrayList<>();
/**
* 프로젝트에 속한 모델

View File

@ -27,8 +27,4 @@ public interface ProjectRepository extends JpaRepository<Project, Integer> {
@Param("memberId") Integer memberId,
@Param("lastProjectId") Integer lastProjectId,
@Param("pageSize") Integer pageSize);
// ProjectType을 가져오는 메서드 추가
@Query("SELECT p.projectType FROM Project p WHERE p.id = :projectId")
Optional<ProjectType> findProjectTypeById(@Param("projectId") Integer projectId);
}

View File

@ -1,18 +1,25 @@
package com.worlabel.domain.project.service;
import com.google.gson.Gson;
import com.google.gson.JsonSyntaxException;
import com.google.gson.reflect.TypeToken;
import com.worlabel.domain.image.entity.Image;
import com.worlabel.domain.image.repository.ImageRepository;
import com.worlabel.domain.labelcategory.entity.ProjectCategory;
import com.worlabel.domain.member.entity.Member;
import com.worlabel.domain.member.repository.MemberRepository;
import com.worlabel.domain.model.entity.AiModel;
import com.worlabel.domain.model.repository.AiModelRepository;
import com.worlabel.domain.participant.entity.Participant;
import com.worlabel.domain.participant.entity.PrivilegeType;
import com.worlabel.domain.participant.entity.WorkspaceParticipant;
import com.worlabel.domain.participant.entity.dto.ParticipantRequest;
import com.worlabel.domain.participant.repository.ParticipantRepository;
import com.worlabel.domain.participant.repository.WorkspaceParticipantRepository;
import com.worlabel.domain.project.dto.AiRequestDto;
import com.worlabel.domain.project.dto.AiRequestDto.AutoLabelingImageRequest;
import com.worlabel.domain.project.dto.AiRequestDto.AutoLabelingRequest;
import com.worlabel.domain.project.dto.AiDto;
import com.worlabel.domain.project.dto.AiDto.AutoLabelingImageRequest;
import com.worlabel.domain.project.dto.AiDto.AutoLabelingRequest;
import com.worlabel.domain.project.dto.AiDto.AutoLabelingResult;
import com.worlabel.domain.project.dto.AutoModelRequest;
import com.worlabel.domain.project.entity.Project;
import com.worlabel.domain.project.entity.dto.ProjectMemberResponse;
@ -30,6 +37,8 @@ import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import java.lang.reflect.Type;
import java.util.HashMap;
import java.util.List;
import java.util.Objects;
@ -45,8 +54,9 @@ public class ProjectService {
private final MemberRepository memberRepository;
private final WorkspaceParticipantRepository workspaceParticipantRepository;
private final ImageRepository imageRepository;
private final AiModelRepository aiModelRepository;
private final AiRequestService aiService;
private final Gson gson;
public ProjectResponse createProject(final Integer memberId, final Integer workspaceId, final ProjectRequest projectRequest) {
Workspace workspace = getWorkspace(memberId, workspaceId);
@ -135,43 +145,64 @@ public class ProjectService {
participantRepository.delete(participant);
}
// @CheckPrivilege(PrivilegeType.EDITOR)
// public void train(final Integer memberId, final Integer projectId) {
// // TODO: 레디스 train 테이블에 존재하는지 확인 -> 이미 있으면 있다고 예외를 던져준다. -> 용수 추후 구현 예정
// /*
// 없으면 redis 상태 테이블을 만든다. progressTable
// */
//
// // FastAPI 서버로 학습 요청을 전송
// Project project = getProject(projectId);
// String endPoint = project.getProjectType().getValue() + "/train";
//
// TrainRequest trainRequest = new TrainRequest();
// trainRequest.setProjectId(projectId);
// trainRequest.setData(List.of());
//
// // FastAPI 서버로 POST 요청 전송
// String modelKey = aiService.postRequest(endPoint, trainRequest, String.class, response -> response);
//
// // TODO: 모델 생성 Default 이름과 Key 설정
// }
/**
* 프로젝트별 오토 레이블링
*/
@CheckPrivilege(PrivilegeType.EDITOR)
public void autoLabeling(final Integer memberId, final Integer projectId, final AutoModelRequest request) {
log.debug("project {}", projectId);
Project project = getProject(projectId);
String endPoint = project.getProjectType().getValue() + "/predict";
log.debug("이미지 조회");
List<Image> imageList = imageRepository.findImagesByProjectIdAndPendingOrInProgress(projectId);
List<AutoLabelingImageRequest> imageRequestList = imageList.stream()
.map(AutoLabelingImageRequest::of)
.toList();
AutoLabelingRequest autoLabelingRequest = AutoLabelingRequest.of(projectId, imageRequestList);
log.debug("카테고리 조회 ");
HashMap<Integer, Integer> labelMap = getLabelMap(project);
log.debug("모델 조회");
AiModel aiModel = getAiModel(request);
AutoLabelingRequest autoLabelingRequest = AutoLabelingRequest.of(projectId, aiModel.getModelKey(), labelMap, imageRequestList);
// 응답없음
aiService.postRequest(endPoint, autoLabelingRequest, Void.class, response -> null);
log.debug("요청");
List<AutoLabelingResult> list = aiService.postRequest(endPoint, autoLabelingRequest, List.class, this::converter);
log.debug("list: {}", list);
}
public List<AutoLabelingResult> converter(String data) {
try {
log.debug("data :{}", data);
// Gson에서 리스트 형태로 변환할 타입을 지정
Type listType = new TypeToken<List<AutoLabelingResult>>() {
}.getType();
// JSON 배열을 List<AutoLabelingResult> 변환
return gson.fromJson(data, listType);
} catch (JsonSyntaxException e) {
// JSON 파싱 오류가 발생한 경우 처리
throw new CustomException(ErrorCode.AI_SERVER_ERROR);
}
}
private AiModel getAiModel(AutoModelRequest request) {
return aiModelRepository.findById(request.getModelId())
.orElseThrow(() -> new CustomException(ErrorCode.DATA_NOT_FOUND));
}
private HashMap<Integer, Integer> getLabelMap(Project project) {
HashMap<Integer, Integer> labelMap = new HashMap<>();
List<ProjectCategory> category = project.getCategoryList();
for (ProjectCategory projectCategory : category) {
int aiId = projectCategory.getLabelCategory().getAiCategoryId();
if (labelMap.containsKey(aiId)) continue;
labelMap.put(aiId, projectCategory.getId());
}
return labelMap;
}
private Project getProject(final Integer projectId) {