Feat: 이미지 오토레이블링 AI 연동
This commit is contained in:
parent
e8525b97b4
commit
336f02834e
@ -16,7 +16,7 @@ import com.worlabel.domain.model.entity.dto.DefaultAiModelResponse;
|
||||
import com.worlabel.domain.model.entity.dto.DefaultResponse;
|
||||
import com.worlabel.domain.model.repository.AiModelRepository;
|
||||
import com.worlabel.domain.participant.entity.PrivilegeType;
|
||||
import com.worlabel.domain.project.dto.AiRequestDto;
|
||||
import com.worlabel.domain.project.dto.AiDto;
|
||||
import com.worlabel.domain.project.entity.Project;
|
||||
import com.worlabel.domain.project.repository.ProjectRepository;
|
||||
import com.worlabel.global.annotation.CheckPrivilege;
|
||||
@ -139,13 +139,13 @@ public class AiModelService {
|
||||
|
||||
List<Image> images = imageRepository.findImagesByProjectId(projectId);
|
||||
|
||||
List<AiRequestDto.TrainDataInfo> data = images.stream().filter(image -> image.getStatus() == LabelStatus.COMPLETED)
|
||||
.map(image -> new AiRequestDto.TrainDataInfo(image.getImagePath(), image.getDataPath()))
|
||||
List<AiDto.TrainDataInfo> data = images.stream().filter(image -> image.getStatus() == LabelStatus.COMPLETED)
|
||||
.map(image -> new AiDto.TrainDataInfo(image.getImagePath(), image.getDataPath()))
|
||||
.toList();
|
||||
|
||||
String endPoint = project.getProjectType().getValue() + "/train";
|
||||
|
||||
AiRequestDto.TrainRequest trainRequest = new AiRequestDto.TrainRequest();
|
||||
AiDto.TrainRequest trainRequest = new AiDto.TrainRequest();
|
||||
trainRequest.setProjectId(projectId);
|
||||
trainRequest.setCategoryId(categories);
|
||||
trainRequest.setData(data);
|
||||
|
@ -1,12 +1,14 @@
|
||||
package com.worlabel.domain.project.dto;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import com.google.gson.annotations.SerializedName;
|
||||
import com.worlabel.domain.image.entity.Image;
|
||||
import lombok.*;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
|
||||
public class AiRequestDto {
|
||||
public class AiDto {
|
||||
|
||||
@Data
|
||||
public static class TrainDataInfo {
|
||||
@ -51,6 +53,9 @@ public class AiRequestDto {
|
||||
@JsonProperty("model_key")
|
||||
private String modelKey;
|
||||
|
||||
@JsonProperty("label_map")
|
||||
private HashMap<Integer, Integer> labelMap;
|
||||
|
||||
@JsonProperty("image_list")
|
||||
private List<AutoLabelingImageRequest> imageList;
|
||||
|
||||
@ -60,8 +65,8 @@ public class AiRequestDto {
|
||||
@JsonProperty("iou_threshold")
|
||||
private Double iouThreshold;
|
||||
|
||||
public static AutoLabelingRequest of(final Integer projectId, final List<AutoLabelingImageRequest> imageList) {
|
||||
return new AutoLabelingRequest(projectId, null, imageList, 0.25, 0.45);
|
||||
public static AutoLabelingRequest of(final Integer projectId, final String modelKey, final HashMap<Integer, Integer> labelMap, final List<AutoLabelingImageRequest> imageList) {
|
||||
return new AutoLabelingRequest(projectId, modelKey, labelMap, imageList, 0.25, 0.45);
|
||||
}
|
||||
}
|
||||
|
||||
@ -76,8 +81,21 @@ public class AiRequestDto {
|
||||
@JsonProperty("image_url")
|
||||
private String imageUrl;
|
||||
|
||||
public static AutoLabelingImageRequest of(Image image){
|
||||
public static AutoLabelingImageRequest of(Image image) {
|
||||
return new AutoLabelingImageRequest(image.getId(), image.getImagePath());
|
||||
}
|
||||
}
|
||||
|
||||
@NoArgsConstructor(access = AccessLevel.PRIVATE)
|
||||
@AllArgsConstructor(access = AccessLevel.PRIVATE)
|
||||
@Getter
|
||||
@ToString
|
||||
public static class AutoLabelingResult{
|
||||
|
||||
@SerializedName("image_id")
|
||||
private Long imageId;
|
||||
|
||||
@SerializedName("data")
|
||||
private String data;
|
||||
}
|
||||
}
|
@ -52,7 +52,7 @@ public class Project extends BaseEntity {
|
||||
* 프로젝트에 속한 카테고리
|
||||
*/
|
||||
@OneToMany(mappedBy = "project", fetch = FetchType.LAZY, cascade = CascadeType.ALL, orphanRemoval = true)
|
||||
private List<ProjectCategory> category = new ArrayList<>();
|
||||
private List<ProjectCategory> categoryList = new ArrayList<>();
|
||||
|
||||
/**
|
||||
* 프로젝트에 속한 모델
|
||||
|
@ -27,8 +27,4 @@ public interface ProjectRepository extends JpaRepository<Project, Integer> {
|
||||
@Param("memberId") Integer memberId,
|
||||
@Param("lastProjectId") Integer lastProjectId,
|
||||
@Param("pageSize") Integer pageSize);
|
||||
|
||||
// ProjectType을 가져오는 메서드 추가
|
||||
@Query("SELECT p.projectType FROM Project p WHERE p.id = :projectId")
|
||||
Optional<ProjectType> findProjectTypeById(@Param("projectId") Integer projectId);
|
||||
}
|
||||
|
@ -1,18 +1,25 @@
|
||||
package com.worlabel.domain.project.service;
|
||||
|
||||
import com.google.gson.Gson;
|
||||
import com.google.gson.JsonSyntaxException;
|
||||
import com.google.gson.reflect.TypeToken;
|
||||
import com.worlabel.domain.image.entity.Image;
|
||||
import com.worlabel.domain.image.repository.ImageRepository;
|
||||
import com.worlabel.domain.labelcategory.entity.ProjectCategory;
|
||||
import com.worlabel.domain.member.entity.Member;
|
||||
import com.worlabel.domain.member.repository.MemberRepository;
|
||||
import com.worlabel.domain.model.entity.AiModel;
|
||||
import com.worlabel.domain.model.repository.AiModelRepository;
|
||||
import com.worlabel.domain.participant.entity.Participant;
|
||||
import com.worlabel.domain.participant.entity.PrivilegeType;
|
||||
import com.worlabel.domain.participant.entity.WorkspaceParticipant;
|
||||
import com.worlabel.domain.participant.entity.dto.ParticipantRequest;
|
||||
import com.worlabel.domain.participant.repository.ParticipantRepository;
|
||||
import com.worlabel.domain.participant.repository.WorkspaceParticipantRepository;
|
||||
import com.worlabel.domain.project.dto.AiRequestDto;
|
||||
import com.worlabel.domain.project.dto.AiRequestDto.AutoLabelingImageRequest;
|
||||
import com.worlabel.domain.project.dto.AiRequestDto.AutoLabelingRequest;
|
||||
import com.worlabel.domain.project.dto.AiDto;
|
||||
import com.worlabel.domain.project.dto.AiDto.AutoLabelingImageRequest;
|
||||
import com.worlabel.domain.project.dto.AiDto.AutoLabelingRequest;
|
||||
import com.worlabel.domain.project.dto.AiDto.AutoLabelingResult;
|
||||
import com.worlabel.domain.project.dto.AutoModelRequest;
|
||||
import com.worlabel.domain.project.entity.Project;
|
||||
import com.worlabel.domain.project.entity.dto.ProjectMemberResponse;
|
||||
@ -30,6 +37,8 @@ import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.transaction.annotation.Transactional;
|
||||
|
||||
import java.lang.reflect.Type;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
@ -45,8 +54,9 @@ public class ProjectService {
|
||||
private final MemberRepository memberRepository;
|
||||
private final WorkspaceParticipantRepository workspaceParticipantRepository;
|
||||
private final ImageRepository imageRepository;
|
||||
|
||||
private final AiModelRepository aiModelRepository;
|
||||
private final AiRequestService aiService;
|
||||
private final Gson gson;
|
||||
|
||||
public ProjectResponse createProject(final Integer memberId, final Integer workspaceId, final ProjectRequest projectRequest) {
|
||||
Workspace workspace = getWorkspace(memberId, workspaceId);
|
||||
@ -135,43 +145,64 @@ public class ProjectService {
|
||||
participantRepository.delete(participant);
|
||||
}
|
||||
|
||||
// @CheckPrivilege(PrivilegeType.EDITOR)
|
||||
// public void train(final Integer memberId, final Integer projectId) {
|
||||
// // TODO: 레디스 train 테이블에 존재하는지 확인 -> 이미 있으면 있다고 예외를 던져준다. -> 용수 추후 구현 예정
|
||||
// /*
|
||||
// 없으면 redis 상태 테이블을 만든다. progressTable
|
||||
// */
|
||||
//
|
||||
// // FastAPI 서버로 학습 요청을 전송
|
||||
// Project project = getProject(projectId);
|
||||
// String endPoint = project.getProjectType().getValue() + "/train";
|
||||
//
|
||||
// TrainRequest trainRequest = new TrainRequest();
|
||||
// trainRequest.setProjectId(projectId);
|
||||
// trainRequest.setData(List.of());
|
||||
//
|
||||
// // FastAPI 서버로 POST 요청 전송
|
||||
// String modelKey = aiService.postRequest(endPoint, trainRequest, String.class, response -> response);
|
||||
//
|
||||
// // TODO: 모델 생성 후 Default 이름과 Key 값 설정
|
||||
// }
|
||||
|
||||
/**
|
||||
* 프로젝트별 오토 레이블링
|
||||
*/
|
||||
@CheckPrivilege(PrivilegeType.EDITOR)
|
||||
public void autoLabeling(final Integer memberId, final Integer projectId, final AutoModelRequest request) {
|
||||
log.debug("project {}", projectId);
|
||||
Project project = getProject(projectId);
|
||||
String endPoint = project.getProjectType().getValue() + "/predict";
|
||||
|
||||
log.debug("이미지 조회");
|
||||
List<Image> imageList = imageRepository.findImagesByProjectIdAndPendingOrInProgress(projectId);
|
||||
List<AutoLabelingImageRequest> imageRequestList = imageList.stream()
|
||||
.map(AutoLabelingImageRequest::of)
|
||||
.toList();
|
||||
AutoLabelingRequest autoLabelingRequest = AutoLabelingRequest.of(projectId, imageRequestList);
|
||||
|
||||
log.debug("카테고리 조회 ");
|
||||
HashMap<Integer, Integer> labelMap = getLabelMap(project);
|
||||
|
||||
log.debug("모델 조회");
|
||||
AiModel aiModel = getAiModel(request);
|
||||
AutoLabelingRequest autoLabelingRequest = AutoLabelingRequest.of(projectId, aiModel.getModelKey(), labelMap, imageRequestList);
|
||||
|
||||
// 응답없음
|
||||
aiService.postRequest(endPoint, autoLabelingRequest, Void.class, response -> null);
|
||||
log.debug("요청");
|
||||
List<AutoLabelingResult> list = aiService.postRequest(endPoint, autoLabelingRequest, List.class, this::converter);
|
||||
log.debug("list: {}", list);
|
||||
}
|
||||
|
||||
public List<AutoLabelingResult> converter(String data) {
|
||||
try {
|
||||
log.debug("data :{}", data);
|
||||
// Gson에서 리스트 형태로 변환할 타입을 지정
|
||||
Type listType = new TypeToken<List<AutoLabelingResult>>() {
|
||||
}.getType();
|
||||
|
||||
// JSON 배열을 List<AutoLabelingResult>로 변환
|
||||
return gson.fromJson(data, listType);
|
||||
} catch (JsonSyntaxException e) {
|
||||
// JSON 파싱 중 오류가 발생한 경우 처리
|
||||
throw new CustomException(ErrorCode.AI_SERVER_ERROR);
|
||||
}
|
||||
}
|
||||
|
||||
private AiModel getAiModel(AutoModelRequest request) {
|
||||
return aiModelRepository.findById(request.getModelId())
|
||||
.orElseThrow(() -> new CustomException(ErrorCode.DATA_NOT_FOUND));
|
||||
}
|
||||
|
||||
private HashMap<Integer, Integer> getLabelMap(Project project) {
|
||||
HashMap<Integer, Integer> labelMap = new HashMap<>();
|
||||
List<ProjectCategory> category = project.getCategoryList();
|
||||
for (ProjectCategory projectCategory : category) {
|
||||
int aiId = projectCategory.getLabelCategory().getAiCategoryId();
|
||||
if (labelMap.containsKey(aiId)) continue;
|
||||
|
||||
labelMap.put(aiId, projectCategory.getId());
|
||||
}
|
||||
return labelMap;
|
||||
}
|
||||
|
||||
private Project getProject(final Integer projectId) {
|
||||
|
Loading…
Reference in New Issue
Block a user