Feat: 이미지 오토레이블링 AI 연동
This commit is contained in:
parent
e8525b97b4
commit
336f02834e
@ -16,7 +16,7 @@ import com.worlabel.domain.model.entity.dto.DefaultAiModelResponse;
|
|||||||
import com.worlabel.domain.model.entity.dto.DefaultResponse;
|
import com.worlabel.domain.model.entity.dto.DefaultResponse;
|
||||||
import com.worlabel.domain.model.repository.AiModelRepository;
|
import com.worlabel.domain.model.repository.AiModelRepository;
|
||||||
import com.worlabel.domain.participant.entity.PrivilegeType;
|
import com.worlabel.domain.participant.entity.PrivilegeType;
|
||||||
import com.worlabel.domain.project.dto.AiRequestDto;
|
import com.worlabel.domain.project.dto.AiDto;
|
||||||
import com.worlabel.domain.project.entity.Project;
|
import com.worlabel.domain.project.entity.Project;
|
||||||
import com.worlabel.domain.project.repository.ProjectRepository;
|
import com.worlabel.domain.project.repository.ProjectRepository;
|
||||||
import com.worlabel.global.annotation.CheckPrivilege;
|
import com.worlabel.global.annotation.CheckPrivilege;
|
||||||
@ -139,13 +139,13 @@ public class AiModelService {
|
|||||||
|
|
||||||
List<Image> images = imageRepository.findImagesByProjectId(projectId);
|
List<Image> images = imageRepository.findImagesByProjectId(projectId);
|
||||||
|
|
||||||
List<AiRequestDto.TrainDataInfo> data = images.stream().filter(image -> image.getStatus() == LabelStatus.COMPLETED)
|
List<AiDto.TrainDataInfo> data = images.stream().filter(image -> image.getStatus() == LabelStatus.COMPLETED)
|
||||||
.map(image -> new AiRequestDto.TrainDataInfo(image.getImagePath(), image.getDataPath()))
|
.map(image -> new AiDto.TrainDataInfo(image.getImagePath(), image.getDataPath()))
|
||||||
.toList();
|
.toList();
|
||||||
|
|
||||||
String endPoint = project.getProjectType().getValue() + "/train";
|
String endPoint = project.getProjectType().getValue() + "/train";
|
||||||
|
|
||||||
AiRequestDto.TrainRequest trainRequest = new AiRequestDto.TrainRequest();
|
AiDto.TrainRequest trainRequest = new AiDto.TrainRequest();
|
||||||
trainRequest.setProjectId(projectId);
|
trainRequest.setProjectId(projectId);
|
||||||
trainRequest.setCategoryId(categories);
|
trainRequest.setCategoryId(categories);
|
||||||
trainRequest.setData(data);
|
trainRequest.setData(data);
|
||||||
|
@ -1,12 +1,14 @@
|
|||||||
package com.worlabel.domain.project.dto;
|
package com.worlabel.domain.project.dto;
|
||||||
|
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
|
import com.google.gson.annotations.SerializedName;
|
||||||
import com.worlabel.domain.image.entity.Image;
|
import com.worlabel.domain.image.entity.Image;
|
||||||
import lombok.*;
|
import lombok.*;
|
||||||
|
|
||||||
|
import java.util.HashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
public class AiRequestDto {
|
public class AiDto {
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public static class TrainDataInfo {
|
public static class TrainDataInfo {
|
||||||
@ -51,6 +53,9 @@ public class AiRequestDto {
|
|||||||
@JsonProperty("model_key")
|
@JsonProperty("model_key")
|
||||||
private String modelKey;
|
private String modelKey;
|
||||||
|
|
||||||
|
@JsonProperty("label_map")
|
||||||
|
private HashMap<Integer, Integer> labelMap;
|
||||||
|
|
||||||
@JsonProperty("image_list")
|
@JsonProperty("image_list")
|
||||||
private List<AutoLabelingImageRequest> imageList;
|
private List<AutoLabelingImageRequest> imageList;
|
||||||
|
|
||||||
@ -60,8 +65,8 @@ public class AiRequestDto {
|
|||||||
@JsonProperty("iou_threshold")
|
@JsonProperty("iou_threshold")
|
||||||
private Double iouThreshold;
|
private Double iouThreshold;
|
||||||
|
|
||||||
public static AutoLabelingRequest of(final Integer projectId, final List<AutoLabelingImageRequest> imageList) {
|
public static AutoLabelingRequest of(final Integer projectId, final String modelKey, final HashMap<Integer, Integer> labelMap, final List<AutoLabelingImageRequest> imageList) {
|
||||||
return new AutoLabelingRequest(projectId, null, imageList, 0.25, 0.45);
|
return new AutoLabelingRequest(projectId, modelKey, labelMap, imageList, 0.25, 0.45);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -80,4 +85,17 @@ public class AiRequestDto {
|
|||||||
return new AutoLabelingImageRequest(image.getId(), image.getImagePath());
|
return new AutoLabelingImageRequest(image.getId(), image.getImagePath());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@NoArgsConstructor(access = AccessLevel.PRIVATE)
|
||||||
|
@AllArgsConstructor(access = AccessLevel.PRIVATE)
|
||||||
|
@Getter
|
||||||
|
@ToString
|
||||||
|
public static class AutoLabelingResult{
|
||||||
|
|
||||||
|
@SerializedName("image_id")
|
||||||
|
private Long imageId;
|
||||||
|
|
||||||
|
@SerializedName("data")
|
||||||
|
private String data;
|
||||||
|
}
|
||||||
}
|
}
|
@ -52,7 +52,7 @@ public class Project extends BaseEntity {
|
|||||||
* 프로젝트에 속한 카테고리
|
* 프로젝트에 속한 카테고리
|
||||||
*/
|
*/
|
||||||
@OneToMany(mappedBy = "project", fetch = FetchType.LAZY, cascade = CascadeType.ALL, orphanRemoval = true)
|
@OneToMany(mappedBy = "project", fetch = FetchType.LAZY, cascade = CascadeType.ALL, orphanRemoval = true)
|
||||||
private List<ProjectCategory> category = new ArrayList<>();
|
private List<ProjectCategory> categoryList = new ArrayList<>();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 프로젝트에 속한 모델
|
* 프로젝트에 속한 모델
|
||||||
|
@ -27,8 +27,4 @@ public interface ProjectRepository extends JpaRepository<Project, Integer> {
|
|||||||
@Param("memberId") Integer memberId,
|
@Param("memberId") Integer memberId,
|
||||||
@Param("lastProjectId") Integer lastProjectId,
|
@Param("lastProjectId") Integer lastProjectId,
|
||||||
@Param("pageSize") Integer pageSize);
|
@Param("pageSize") Integer pageSize);
|
||||||
|
|
||||||
// ProjectType을 가져오는 메서드 추가
|
|
||||||
@Query("SELECT p.projectType FROM Project p WHERE p.id = :projectId")
|
|
||||||
Optional<ProjectType> findProjectTypeById(@Param("projectId") Integer projectId);
|
|
||||||
}
|
}
|
||||||
|
@ -1,18 +1,25 @@
|
|||||||
package com.worlabel.domain.project.service;
|
package com.worlabel.domain.project.service;
|
||||||
|
|
||||||
|
import com.google.gson.Gson;
|
||||||
|
import com.google.gson.JsonSyntaxException;
|
||||||
|
import com.google.gson.reflect.TypeToken;
|
||||||
import com.worlabel.domain.image.entity.Image;
|
import com.worlabel.domain.image.entity.Image;
|
||||||
import com.worlabel.domain.image.repository.ImageRepository;
|
import com.worlabel.domain.image.repository.ImageRepository;
|
||||||
|
import com.worlabel.domain.labelcategory.entity.ProjectCategory;
|
||||||
import com.worlabel.domain.member.entity.Member;
|
import com.worlabel.domain.member.entity.Member;
|
||||||
import com.worlabel.domain.member.repository.MemberRepository;
|
import com.worlabel.domain.member.repository.MemberRepository;
|
||||||
|
import com.worlabel.domain.model.entity.AiModel;
|
||||||
|
import com.worlabel.domain.model.repository.AiModelRepository;
|
||||||
import com.worlabel.domain.participant.entity.Participant;
|
import com.worlabel.domain.participant.entity.Participant;
|
||||||
import com.worlabel.domain.participant.entity.PrivilegeType;
|
import com.worlabel.domain.participant.entity.PrivilegeType;
|
||||||
import com.worlabel.domain.participant.entity.WorkspaceParticipant;
|
import com.worlabel.domain.participant.entity.WorkspaceParticipant;
|
||||||
import com.worlabel.domain.participant.entity.dto.ParticipantRequest;
|
import com.worlabel.domain.participant.entity.dto.ParticipantRequest;
|
||||||
import com.worlabel.domain.participant.repository.ParticipantRepository;
|
import com.worlabel.domain.participant.repository.ParticipantRepository;
|
||||||
import com.worlabel.domain.participant.repository.WorkspaceParticipantRepository;
|
import com.worlabel.domain.participant.repository.WorkspaceParticipantRepository;
|
||||||
import com.worlabel.domain.project.dto.AiRequestDto;
|
import com.worlabel.domain.project.dto.AiDto;
|
||||||
import com.worlabel.domain.project.dto.AiRequestDto.AutoLabelingImageRequest;
|
import com.worlabel.domain.project.dto.AiDto.AutoLabelingImageRequest;
|
||||||
import com.worlabel.domain.project.dto.AiRequestDto.AutoLabelingRequest;
|
import com.worlabel.domain.project.dto.AiDto.AutoLabelingRequest;
|
||||||
|
import com.worlabel.domain.project.dto.AiDto.AutoLabelingResult;
|
||||||
import com.worlabel.domain.project.dto.AutoModelRequest;
|
import com.worlabel.domain.project.dto.AutoModelRequest;
|
||||||
import com.worlabel.domain.project.entity.Project;
|
import com.worlabel.domain.project.entity.Project;
|
||||||
import com.worlabel.domain.project.entity.dto.ProjectMemberResponse;
|
import com.worlabel.domain.project.entity.dto.ProjectMemberResponse;
|
||||||
@ -30,6 +37,8 @@ import lombok.extern.slf4j.Slf4j;
|
|||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
import org.springframework.transaction.annotation.Transactional;
|
import org.springframework.transaction.annotation.Transactional;
|
||||||
|
|
||||||
|
import java.lang.reflect.Type;
|
||||||
|
import java.util.HashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
|
|
||||||
@ -45,8 +54,9 @@ public class ProjectService {
|
|||||||
private final MemberRepository memberRepository;
|
private final MemberRepository memberRepository;
|
||||||
private final WorkspaceParticipantRepository workspaceParticipantRepository;
|
private final WorkspaceParticipantRepository workspaceParticipantRepository;
|
||||||
private final ImageRepository imageRepository;
|
private final ImageRepository imageRepository;
|
||||||
|
private final AiModelRepository aiModelRepository;
|
||||||
private final AiRequestService aiService;
|
private final AiRequestService aiService;
|
||||||
|
private final Gson gson;
|
||||||
|
|
||||||
public ProjectResponse createProject(final Integer memberId, final Integer workspaceId, final ProjectRequest projectRequest) {
|
public ProjectResponse createProject(final Integer memberId, final Integer workspaceId, final ProjectRequest projectRequest) {
|
||||||
Workspace workspace = getWorkspace(memberId, workspaceId);
|
Workspace workspace = getWorkspace(memberId, workspaceId);
|
||||||
@ -135,43 +145,64 @@ public class ProjectService {
|
|||||||
participantRepository.delete(participant);
|
participantRepository.delete(participant);
|
||||||
}
|
}
|
||||||
|
|
||||||
// @CheckPrivilege(PrivilegeType.EDITOR)
|
|
||||||
// public void train(final Integer memberId, final Integer projectId) {
|
|
||||||
// // TODO: 레디스 train 테이블에 존재하는지 확인 -> 이미 있으면 있다고 예외를 던져준다. -> 용수 추후 구현 예정
|
|
||||||
// /*
|
|
||||||
// 없으면 redis 상태 테이블을 만든다. progressTable
|
|
||||||
// */
|
|
||||||
//
|
|
||||||
// // FastAPI 서버로 학습 요청을 전송
|
|
||||||
// Project project = getProject(projectId);
|
|
||||||
// String endPoint = project.getProjectType().getValue() + "/train";
|
|
||||||
//
|
|
||||||
// TrainRequest trainRequest = new TrainRequest();
|
|
||||||
// trainRequest.setProjectId(projectId);
|
|
||||||
// trainRequest.setData(List.of());
|
|
||||||
//
|
|
||||||
// // FastAPI 서버로 POST 요청 전송
|
|
||||||
// String modelKey = aiService.postRequest(endPoint, trainRequest, String.class, response -> response);
|
|
||||||
//
|
|
||||||
// // TODO: 모델 생성 후 Default 이름과 Key 값 설정
|
|
||||||
// }
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 프로젝트별 오토 레이블링
|
* 프로젝트별 오토 레이블링
|
||||||
*/
|
*/
|
||||||
@CheckPrivilege(PrivilegeType.EDITOR)
|
@CheckPrivilege(PrivilegeType.EDITOR)
|
||||||
public void autoLabeling(final Integer memberId, final Integer projectId, final AutoModelRequest request) {
|
public void autoLabeling(final Integer memberId, final Integer projectId, final AutoModelRequest request) {
|
||||||
|
log.debug("project {}", projectId);
|
||||||
Project project = getProject(projectId);
|
Project project = getProject(projectId);
|
||||||
String endPoint = project.getProjectType().getValue() + "/predict";
|
String endPoint = project.getProjectType().getValue() + "/predict";
|
||||||
|
|
||||||
|
log.debug("이미지 조회");
|
||||||
List<Image> imageList = imageRepository.findImagesByProjectIdAndPendingOrInProgress(projectId);
|
List<Image> imageList = imageRepository.findImagesByProjectIdAndPendingOrInProgress(projectId);
|
||||||
List<AutoLabelingImageRequest> imageRequestList = imageList.stream()
|
List<AutoLabelingImageRequest> imageRequestList = imageList.stream()
|
||||||
.map(AutoLabelingImageRequest::of)
|
.map(AutoLabelingImageRequest::of)
|
||||||
.toList();
|
.toList();
|
||||||
AutoLabelingRequest autoLabelingRequest = AutoLabelingRequest.of(projectId, imageRequestList);
|
|
||||||
|
log.debug("카테고리 조회 ");
|
||||||
|
HashMap<Integer, Integer> labelMap = getLabelMap(project);
|
||||||
|
|
||||||
|
log.debug("모델 조회");
|
||||||
|
AiModel aiModel = getAiModel(request);
|
||||||
|
AutoLabelingRequest autoLabelingRequest = AutoLabelingRequest.of(projectId, aiModel.getModelKey(), labelMap, imageRequestList);
|
||||||
|
|
||||||
// 응답없음
|
// 응답없음
|
||||||
aiService.postRequest(endPoint, autoLabelingRequest, Void.class, response -> null);
|
log.debug("요청");
|
||||||
|
List<AutoLabelingResult> list = aiService.postRequest(endPoint, autoLabelingRequest, List.class, this::converter);
|
||||||
|
log.debug("list: {}", list);
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<AutoLabelingResult> converter(String data) {
|
||||||
|
try {
|
||||||
|
log.debug("data :{}", data);
|
||||||
|
// Gson에서 리스트 형태로 변환할 타입을 지정
|
||||||
|
Type listType = new TypeToken<List<AutoLabelingResult>>() {
|
||||||
|
}.getType();
|
||||||
|
|
||||||
|
// JSON 배열을 List<AutoLabelingResult>로 변환
|
||||||
|
return gson.fromJson(data, listType);
|
||||||
|
} catch (JsonSyntaxException e) {
|
||||||
|
// JSON 파싱 중 오류가 발생한 경우 처리
|
||||||
|
throw new CustomException(ErrorCode.AI_SERVER_ERROR);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private AiModel getAiModel(AutoModelRequest request) {
|
||||||
|
return aiModelRepository.findById(request.getModelId())
|
||||||
|
.orElseThrow(() -> new CustomException(ErrorCode.DATA_NOT_FOUND));
|
||||||
|
}
|
||||||
|
|
||||||
|
private HashMap<Integer, Integer> getLabelMap(Project project) {
|
||||||
|
HashMap<Integer, Integer> labelMap = new HashMap<>();
|
||||||
|
List<ProjectCategory> category = project.getCategoryList();
|
||||||
|
for (ProjectCategory projectCategory : category) {
|
||||||
|
int aiId = projectCategory.getLabelCategory().getAiCategoryId();
|
||||||
|
if (labelMap.containsKey(aiId)) continue;
|
||||||
|
|
||||||
|
labelMap.put(aiId, projectCategory.getId());
|
||||||
|
}
|
||||||
|
return labelMap;
|
||||||
}
|
}
|
||||||
|
|
||||||
private Project getProject(final Integer projectId) {
|
private Project getProject(final Integer projectId) {
|
||||||
|
Loading…
Reference in New Issue
Block a user