Feat: 오토 레이블링 AI 연동

This commit is contained in:
김용수 2024-09-04 14:20:54 +09:00
parent 971c3fb6fd
commit bc8988296c
8 changed files with 142 additions and 60 deletions

View File

@ -5,11 +5,14 @@ import org.springframework.data.jpa.repository.JpaRepository;
import org.springframework.stereotype.Repository;
import java.util.List;
import java.util.Optional;
@Repository
public interface FolderRepository extends JpaRepository<Folder, Integer> {
List<Folder> findAllByProjectIdAndParentIsNull(Integer projectId);
Optional<Folder> findAllByProjectIdAndId(Integer projectId, Integer folderId);
boolean existsByIdAndProjectId(Integer folderId, Integer projectId);
}

View File

@ -33,7 +33,7 @@ public class FolderService {
Folder parent = null;
if (folderRequest.getParentId() != 0) {
parent = getFolder(folderRequest.getParentId());
parent = getFolder(folderRequest.getParentId(),projectId);
}
Folder folder = Folder.of(folderRequest.getTitle(), parent, project);
@ -53,7 +53,7 @@ public class FolderService {
if (folderId == 0) {
return FolderResponse.from(folderRepository.findAllByProjectIdAndParentIsNull(projectId));
} else {
return FolderResponse.from(getFolder(folderId));
return FolderResponse.from(getFolder(folderId,projectId));
}
}
@ -62,7 +62,7 @@ public class FolderService {
*/
public FolderResponse updateFolder(final Integer memberId, final Integer projectId, final Integer folderId, final FolderRequest updatedFolderRequest) {
checkUnauthorized(memberId, projectId);
Folder folder = getFolder(folderId);
Folder folder = getFolder(folderId,projectId);
Folder parentFolder = folderRepository.findById(updatedFolderRequest.getParentId())
.orElse(null);
@ -77,7 +77,7 @@ public class FolderService {
*/
public void deleteFolder(final Integer memberId, final Integer projectId, final Integer folderId) {
checkUnauthorized(memberId, projectId);
Folder folder = getFolder(folderId);
Folder folder = getFolder(folderId,projectId);
folderRepository.delete(folder);
}
@ -86,8 +86,8 @@ public class FolderService {
.orElseThrow(() -> new CustomException(ErrorCode.PROJECT_NOT_FOUND));
}
private Folder getFolder(final Integer folderId) {
return folderRepository.findById(folderId)
private Folder getFolder(final Integer folderId, final Integer projectId) {
return folderRepository.findAllByProjectIdAndId(projectId,folderId)
.orElseThrow(() -> new CustomException(ErrorCode.FOLDER_NOT_FOUND));
}

View File

@ -0,0 +1,26 @@
package com.worlabel.domain.label.entity.dto;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.*;
import java.util.List;
@Getter
@NoArgsConstructor(access = AccessLevel.PRIVATE)
@AllArgsConstructor(access = AccessLevel.PRIVATE)
public class AutoLabelingRequest {
@JsonProperty("project_id")
private Integer projectId;
@JsonProperty("image_list")
private List<ImageRequest> imageList;
// private Double confThreshold
// private Double iouThreshold;
// List<?> classes
public static AutoLabelingRequest of(final Integer projectId, final List<ImageRequest> imageList) {
return new AutoLabelingRequest(projectId, imageList);
}
}

View File

@ -1,10 +1,17 @@
package com.worlabel.domain.label.entity.dto;
import lombok.Data;
import lombok.*;
@Data
@Getter
@NoArgsConstructor(access = AccessLevel.PRIVATE)
@AllArgsConstructor(access = AccessLevel.PRIVATE)
public class AutoLabelingResponse {
private String image_id;
private String title;
private Long imageId;
private String imageUrl;
private String data; // JSON 형식의 데이터를 그대로 저장
public static AutoLabelingResponse of(Long imageId, String title, String data) {
return new AutoLabelingResponse(imageId, title, data);
}
}

View File

@ -1,10 +1,9 @@
package com.worlabel.domain.label.entity.dto;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.worlabel.domain.image.entity.Image;
import com.worlabel.domain.project.entity.ProjectType;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotEmpty;
import jakarta.validation.constraints.NotNull;
import lombok.AccessLevel;
import lombok.AllArgsConstructor;
import lombok.Getter;
@ -18,17 +17,15 @@ public class ImageRequest {
@Schema(description = "이미지 PK", example = "2")
@NotEmpty(message = "이미지 PK를 입력하세요")
private Long id;
@JsonProperty("image_id")
private Long imageId;
@Schema(description = "이미지 url", example = "image.png")
@NotEmpty(message = "이미지 url을 입력하세요")
@JsonProperty("image_url")
private String imageUrl;
@Schema(description = "프로젝트 유형", example = "classification")
@NotNull(message = "카테고리를 입력하세요.")
private ProjectType projectType;
public static ImageRequest of(Image image, ProjectType projectType){
return new ImageRequest(image.getId(), image.getImageUrl(), projectType);
public static ImageRequest of(Image image){
return new ImageRequest(image.getId(), image.getImageUrl());
}
}

View File

@ -1,6 +1,8 @@
package com.worlabel.domain.label.service;
import com.google.gson.*;
import com.worlabel.domain.image.repository.ImageRepository;
import com.worlabel.domain.label.entity.dto.AutoLabelingRequest;
import com.worlabel.domain.label.entity.dto.AutoLabelingResponse;
import com.worlabel.domain.label.entity.dto.ImageRequest;
import com.worlabel.domain.participant.repository.ParticipantRepository;
@ -13,7 +15,6 @@ import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.web.client.RestTemplateBuilder;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
@ -21,7 +22,9 @@ import org.springframework.http.ResponseEntity;
import org.springframework.stereotype.Service;
import org.springframework.web.client.RestTemplate;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
@Slf4j
@Service
@ -29,10 +32,13 @@ import java.util.List;
public class LabelService {
private final ParticipantRepository participantRepository;
private final RestTemplateBuilder restTemplateBuilder;
// private final RestTemplateBuilder restTemplateBuilder;
private final RestTemplate restTemplate;
private final ProjectRepository projectRepository;
private final S3UploadService s3UploadService;
private final ImageRepository imageRepository;
private final Gson gson;
/**
* AI SERVER 주소
@ -44,76 +50,109 @@ public class LabelService {
checkEditorExistParticipant(memberId, projectId);
ProjectType projectType = getType(projectId);
List<ImageRequest> imageRequestList = getImageRequestList(projectId, projectType);
log.debug("{}번 프로젝트 이미지 {} 진행 ", projectId, projectType);
List<AutoLabelingResponse> autoLabelingResponseList = sendRequestToApi(imageRequestList, projectType.getValue(), projectId);
List<ImageRequest> imageRequestList = getImageRequestList(projectId);
AutoLabelingRequest autoLabelingRequest = AutoLabelingRequest.of(projectId, imageRequestList);
List<AutoLabelingResponse> autoLabelingResponseList = sendRequestToApi(autoLabelingRequest, projectType.getValue(), projectId);
}
private List<AutoLabelingResponse> sendRequestToApi(List<ImageRequest> imageRequestList, String apiEndpoint, int projectId) {
String url = aiServer + "/" + apiEndpoint;
// 요청 헤더 설정
HttpHeaders headers = new HttpHeaders();
headers.set("Content-Type", "application/json");
// 요청 본문 설정
HttpEntity<List<ImageRequest>> request = new HttpEntity<>(imageRequestList, headers);
private List<AutoLabelingResponse> sendRequestToApi(AutoLabelingRequest autoLabelingRequest, String apiEndpoint, int projectId) {
String url = createApiUrl("api/yolo/detection/predict");
// RestTemplate을 동적으로 생성하여 사용
RestTemplate restTemplate = restTemplateBuilder.build();
HttpHeaders headers = createJsonHeaders();
// 요청 본문 설정
HttpEntity<AutoLabelingRequest> request = new HttpEntity<>(autoLabelingRequest, headers);
try {
log.debug("요청 서버 : {}", url);
// AI 서버로 POST 요청
// TODO: 응답 추후 교체
ResponseEntity<List<AutoLabelingResponse>> response = restTemplate.exchange(
ResponseEntity<String> response = restTemplate.exchange(
url, // 요청을 보낼 URL
HttpMethod.POST, // HTTP 메서드 (POST)
request, // HTTP 요청 본문과 헤더가 포함된 객체
new ParameterizedTypeReference<List<AutoLabelingResponse>>() {
} // 응답 타입을 지정
String.class // 응답을 String 타입으로
);
String responseBody = Optional.ofNullable(response.getBody())
.orElseThrow(() -> new CustomException(ErrorCode.AI_SERVER_ERROR));
log.info("AI 서버 응답 -> {}", response.getBody());
// JSON 응답을 S3에 업로드
if(response.getBody() == null) {
throw new CustomException(ErrorCode.AI_SERVER_ERROR);
}
// if (response.getBody() != null) {
// for (AutoLabelingResponse autoLabelingResponse : response.getBody()) {
// String imageId = autoLabelingResponse.getImage_id();
// String jsonData = autoLabelingResponse.getData();
// String title = autoLabelingResponse.getTitle();
// if (imageId != null && jsonData != null) {
// // TODO: 받아온다면 s3에 업로드
// log.debug("구현 무리없이 잘 된 경우 :{}", autoLabelingResponse);
//// String jsonUrl = s3UploadService.uploadJson(jsonData, title, projectId);
//// DB에 저장해야한다. -> 레이블이 있다면 저장 없다면 생성 해야한다.
// }
// }
// // 곳에서 리턴 다른 곳에서 넣는게 코드가 깔끔해질 같다.
return response.getBody();
// }
return parseAutoLabelingResponseList(responseBody);
} catch (Exception e) {
log.error("AI 서버 요청 중 오류 발생: ", e);
throw new CustomException(ErrorCode.AI_SERVER_ERROR);
}
}
// TODO: N + 1문제 발생 추후 리팩토링해야합니다.
private List<ImageRequest> getImageRequestList(Integer projectId, ProjectType projectType) {
return imageRepository.findImagesByProjectId(projectId)
.stream().map(o -> ImageRequest.of(o, projectType)).toList();
private List<AutoLabelingResponse> parseAutoLabelingResponseList(String responseBody) {
JsonElement jsonElement = JsonParser.parseString(responseBody);
List<AutoLabelingResponse> autoLabelingResponseList = new ArrayList<>();
for (JsonElement element : jsonElement.getAsJsonArray()) {
AutoLabelingResponse response = parseAutoLabelingResponse(element);
autoLabelingResponseList.add(response);
}
return autoLabelingResponseList;
}
/**
* jsonElement -> AutoLabelingResponse
*/
private AutoLabelingResponse parseAutoLabelingResponse(JsonElement element) {
JsonObject jsonObject = element.getAsJsonObject();
Long imageId = jsonObject.get("image_id").getAsLong();
String imageUrl = jsonObject.get("image_url").getAsString();
JsonObject data = jsonObject.get("data").getAsJsonObject();
return AutoLabelingResponse.of(imageId,imageUrl, gson.toJson(data));
}
/**
* API URL 구성
*/
private String createApiUrl(String endPoint) {
return aiServer + "/" + endPoint;
}
/**
* 요청 헤더 설정
*/
private static HttpHeaders createJsonHeaders() {
HttpHeaders headers = new HttpHeaders();
headers.set("Content-Type", "application/json");
return headers;
}
// TODO: N + 1문제 발생 추후 리팩토링해야합니다.
private List<ImageRequest> getImageRequestList(Integer projectId) {
return imageRepository.findImagesByProjectId(projectId)
.stream().map(ImageRequest::of).toList();
}
/**
* 프로젝트 타입 조회
*/
private ProjectType getType(final Integer projectId) {
return projectRepository.findProjectTypeById(projectId)
.orElseThrow(() -> new CustomException(ErrorCode.PROJECT_NOT_FOUND));
}
/**
* 참여자(EDITOR, ADMIN) 검증 메서드
*/
private void checkEditorExistParticipant(final Integer memberId, final Integer projectId) {
if (participantRepository.doesParticipantUnauthorizedExistByMemberIdAndProjectId(memberId, projectId)) {
throw new CustomException(ErrorCode.PARTICIPANT_UNAUTHORIZED);
}
}
// TODO : 구현
public void save(final Integer imageId) {
}
}

View File

@ -64,8 +64,11 @@ public class SecurityConfig {
// 경로별 인가 작업
http
.authorizeHttpRequests(auth->auth
.requestMatchers("/swagger", "/swagger-ui.html", "/swagger-ui/**", "/api-docs", "/api-docs/**", "/v3/api-docs/**").permitAll()
.requestMatchers("/api/auth/reissue").permitAll()
.anyRequest().authenticated());
.anyRequest().authenticated()
// .anyRequest().permitAll()
);
// OAuth2
http

View File

@ -2,7 +2,9 @@ package com.worlabel.global.config;
import com.worlabel.global.resolver.CurrentUserArgumentResolver;
import lombok.RequiredArgsConstructor;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.method.support.HandlerMethodArgumentResolver;
import org.springframework.web.servlet.config.annotation.WebMvcConfigurer;
@ -18,4 +20,9 @@ public class WebConfig implements WebMvcConfigurer {
public void addArgumentResolvers(List<HandlerMethodArgumentResolver> resolvers) {
resolvers.add(currentUserArgumentResolver);
}
@Bean
public RestTemplate restTemplate() {
return new RestTemplate();
}
}