From 7d310367d2c03a60870c83157a06a6ca3018b307 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EC=A0=95=ED=98=84=EC=A1=B0?= Date: Wed, 25 Sep 2024 03:44:31 +0900 Subject: [PATCH 1/3] =?UTF-8?q?Refactor:=20=EB=B3=80=EA=B2=BD=EB=90=9C=20a?= =?UTF-8?q?pi=EC=97=90=20=EB=A7=9E=EA=B2=8C=20=EB=B3=80=EA=B2=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- frontend/src/api/modelApi.ts | 22 ++++++- .../{index.tsx => ModelBarChart.tsx} | 0 .../{index.tsx => ModelLineChart.tsx} | 0 .../queries/models/useModelReportsQuery.ts | 10 +++ .../queries/models/useModelResultsQuery.ts | 10 +++ .../src/queries/models/useTrainModelQuery.ts | 3 +- frontend/src/types/index.ts | 66 +++++++++++++++---- 7 files changed, 93 insertions(+), 18 deletions(-) rename frontend/src/components/ModelBarChart/{index.tsx => ModelBarChart.tsx} (100%) rename frontend/src/components/ModelLineChart/{index.tsx => ModelLineChart.tsx} (100%) create mode 100644 frontend/src/queries/models/useModelReportsQuery.ts create mode 100644 frontend/src/queries/models/useModelResultsQuery.ts diff --git a/frontend/src/api/modelApi.ts b/frontend/src/api/modelApi.ts index 106b735..0f7e99e 100644 --- a/frontend/src/api/modelApi.ts +++ b/frontend/src/api/modelApi.ts @@ -1,12 +1,20 @@ import api from '@/api/axiosConfig'; -import { ModelRequest, ModelResponse, ProjectModelsResponse, ModelCategoryResponse } from '@/types'; +import { + ModelRequest, + ModelResponse, + ProjectModelsResponse, + ModelCategoryResponse, + ModelTrainRequest, + ResultResponse, + ReportResponse, +} from '@/types'; export async function updateModelName(projectId: number, modelId: number, modelData: ModelRequest) { return api.put(`/projects/${projectId}/models/${modelId}`, modelData).then(({ data }) => data); } -export async function trainModel(projectId: number) { - return api.post(`/projects/${projectId}/train`).then(({ data }) => data); +export async function trainModel(projectId: number, trainData: ModelTrainRequest) { + return api.post(`/projects/${projectId}/train`, trainData).then(({ data }) => data); } export async function getProjectModels(projectId: number) { @@ -20,3 +28,11 @@ export async function addProjectModel(projectId: number, modelData: ModelRequest export async function getModelCategories(modelId: number) { return api.get(`/models/${modelId}/categories`).then(({ data }) => data); } + +export async function getModelResults(modelId: number) { + return api.get(`/results/model/${modelId}`).then(({ data }) => data); +} + +export async function getModelReports(projectId: number, modelId: number) { + return api.get(`/projects/${projectId}/reports/model/${modelId}`).then(({ data }) => data); +} diff --git a/frontend/src/components/ModelBarChart/index.tsx b/frontend/src/components/ModelBarChart/ModelBarChart.tsx similarity index 100% rename from frontend/src/components/ModelBarChart/index.tsx rename to frontend/src/components/ModelBarChart/ModelBarChart.tsx diff --git a/frontend/src/components/ModelLineChart/index.tsx b/frontend/src/components/ModelLineChart/ModelLineChart.tsx similarity index 100% rename from frontend/src/components/ModelLineChart/index.tsx rename to frontend/src/components/ModelLineChart/ModelLineChart.tsx diff --git a/frontend/src/queries/models/useModelReportsQuery.ts b/frontend/src/queries/models/useModelReportsQuery.ts new file mode 100644 index 0000000..81f3d31 --- /dev/null +++ b/frontend/src/queries/models/useModelReportsQuery.ts @@ -0,0 +1,10 @@ +import { useSuspenseQuery } from '@tanstack/react-query'; +import { getModelReports } from '@/api/modelApi'; +import { ReportResponse } from '@/types'; + +export default function useModelReportsQuery(projectId: number, modelId: number) { + return useSuspenseQuery({ + queryKey: ['modelReports', projectId, modelId], + queryFn: () => getModelReports(projectId, modelId), + }); +} diff --git a/frontend/src/queries/models/useModelResultsQuery.ts b/frontend/src/queries/models/useModelResultsQuery.ts new file mode 100644 index 0000000..124520e --- /dev/null +++ b/frontend/src/queries/models/useModelResultsQuery.ts @@ -0,0 +1,10 @@ +import { useSuspenseQuery } from '@tanstack/react-query'; +import { getModelResults } from '@/api/modelApi'; +import { ResultResponse } from '@/types'; + +export default function useModelResultsQuery(modelId: number) { + return useSuspenseQuery({ + queryKey: ['modelResults', modelId], + queryFn: () => getModelResults(modelId), + }); +} diff --git a/frontend/src/queries/models/useTrainModelQuery.ts b/frontend/src/queries/models/useTrainModelQuery.ts index 822c7e3..4a44d22 100644 --- a/frontend/src/queries/models/useTrainModelQuery.ts +++ b/frontend/src/queries/models/useTrainModelQuery.ts @@ -1,8 +1,9 @@ import { useMutation } from '@tanstack/react-query'; import { trainModel } from '@/api/modelApi'; +import { ModelTrainRequest } from '@/types'; export default function useTrainModelQuery(projectId: number) { return useMutation({ - mutationFn: () => trainModel(projectId), + mutationFn: (trainData: ModelTrainRequest) => trainModel(projectId, trainData), }); } diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index c0bc387..2ddbd9e 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -280,6 +280,25 @@ export interface ImageFolderRequest { parentId: number; files: File[]; } +export interface LabelCategoryResponse { + id: number; + name: string; +} +// 카테고리 요청 DTO +export interface LabelCategoryRequest { + labelCategoryList: number[]; +} + +// 카테고리 응답 DTO +export interface LabelCategoryResponse { + id: number; + name: string; +} +// 모델 카테고리 응답 DTO +export interface ModelCategoryResponse { + id: number; + name: string; +} // 모델 요청 DTO (API로 전달할 데이터 타입) export interface ModelRequest { @@ -292,22 +311,41 @@ export interface ModelResponse { name: string; } -// 모델 카테고리 응답 DTO -export interface ModelCategoryResponse { - id: number; - name: string; -} - // 프로젝트 모델 리스트 응답 DTO export interface ProjectModelsResponse extends Array {} - -// 카테고리 요청 DTO -export interface LabelCategoryRequest { - labelCategoryList: number[]; +// 모델 훈련 요청 DTO +export interface ModelTrainRequest { + modelId: number; + ratio: number; + epochs: number; + batch: number; + lr0: number; + lrf: number; + optimizer: 'AUTO' | 'SGD' | 'ADAM' | 'ADAMW' | 'NADAM' | 'RADAM' | 'RMSPROP'; } - -// 카테고리 응답 DTO -export interface LabelCategoryResponse { +export interface ResultResponse { id: number; - name: string; + precision: number; + recall: number; + fitness: number; + ratio: number; + epochs: number; + batch: number; + lr0: number; + lrf: number; + optimizer: 'AUTO' | 'SGD' | 'ADAM' | 'ADAMW' | 'NADAM' | 'RADAM' | 'RMSPROP'; + map50: number; + map5095: number; +} + +export interface ReportResponse { + modelId: number; + totalEpochs: number; + epoch: number; + boxLoss: number; + clsLoss: number; + dflLoss: number; + fitness: number; + epochTime: number; + leftSecond: number; } From c40cd0741d698ebcb1c8f3d6c2b854358effd370 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EC=A0=95=ED=98=84=EC=A1=B0?= Date: Wed, 25 Sep 2024 06:43:01 +0900 Subject: [PATCH 2/3] =?UTF-8?q?Refactor:=20=EB=AA=A8=EB=8D=B8=20api=20?= =?UTF-8?q?=EC=97=B0=EA=B2=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../components/ModelManage/EvaluationTab.tsx | 133 ++++++++---- .../components/ModelManage/InputWithLabel.tsx | 24 +++ .../ModelBarChart.tsx | 0 .../ModelLineChart.tsx | 0 .../ModelManage/SelectWithLabel.tsx | 42 ++++ .../components/ModelManage/SettingsForm.tsx | 189 ------------------ .../components/ModelManage/TrainingGraph.tsx | 25 +++ .../ModelManage/TrainingSettings.tsx | 134 +++++++++++++ .../components/ModelManage/TrainingTab.tsx | 85 ++++---- frontend/src/components/ModelManage/index.tsx | 33 +-- frontend/src/hooks/useTrainPolling.ts | 49 ----- .../models/usePollingModelReportsQuery.ts | 12 ++ frontend/src/stores/useModelStore.ts | 40 ++++ frontend/src/stores/useTrainStore.ts | 40 ---- 14 files changed, 427 insertions(+), 379 deletions(-) create mode 100644 frontend/src/components/ModelManage/InputWithLabel.tsx rename frontend/src/components/{ModelBarChart => ModelManage}/ModelBarChart.tsx (100%) rename frontend/src/components/{ModelLineChart => ModelManage}/ModelLineChart.tsx (100%) create mode 100644 frontend/src/components/ModelManage/SelectWithLabel.tsx delete mode 100644 frontend/src/components/ModelManage/SettingsForm.tsx create mode 100644 frontend/src/components/ModelManage/TrainingGraph.tsx create mode 100644 frontend/src/components/ModelManage/TrainingSettings.tsx delete mode 100644 frontend/src/hooks/useTrainPolling.ts create mode 100644 frontend/src/queries/models/usePollingModelReportsQuery.ts create mode 100644 frontend/src/stores/useModelStore.ts delete mode 100644 frontend/src/stores/useTrainStore.ts diff --git a/frontend/src/components/ModelManage/EvaluationTab.tsx b/frontend/src/components/ModelManage/EvaluationTab.tsx index 5824787..830c9a9 100644 --- a/frontend/src/components/ModelManage/EvaluationTab.tsx +++ b/frontend/src/components/ModelManage/EvaluationTab.tsx @@ -1,55 +1,116 @@ import { Label } from '@/components/ui/label'; import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from '@/components/ui/select'; -import ModelBarChart from '@/components/ModelBarChart'; +import { useState } from 'react'; +import useProjectModelsQuery from '@/queries/models/useProjectModelsQuery'; +import useModelReportsQuery from '@/queries/models/useModelReportsQuery'; +import useModelResultsQuery from '@/queries/models/useModelResultsQuery'; +import ModelBarChart from './ModelBarChart'; +import ModelLineChart from './ModelLineChart'; interface EvaluationTabProps { - selectedModel: string | null; - setSelectedModel: (model: string | null) => void; + projectId: number | null; } -export default function EvaluationTab({ selectedModel, setSelectedModel }: EvaluationTabProps) { +export default function EvaluationTab({ projectId }: EvaluationTabProps) { + const [selectedModel, setSelectedModel] = useState(null); + + const { data: models } = useProjectModelsQuery(projectId ?? 0); + return (
-
- - -
+ {selectedModel && ( -
-
- -
-
- -
-
+ )}
); } -function LabelingPreview() { +interface ModelSelectionProps { + models: Array<{ id: number; name: string }> | undefined; + setSelectedModel: (modelId: number) => void; +} + +function ModelSelection({ models, setSelectedModel }: ModelSelectionProps) { return ( -
-

레이블링 프리뷰

+
+ +
); } + +interface ModelEvaluationProps { + projectId: number; + selectedModel: number; +} + +function ModelEvaluation({ projectId, selectedModel }: ModelEvaluationProps) { + const { data: reportData } = useModelReportsQuery(projectId, selectedModel); + const { data: resultData } = useModelResultsQuery(selectedModel); + + if (!reportData || !resultData) { + return null; + } + + return ( +
+
+ +
+ +
+ ({ + epoch: report.epoch.toString(), + loss1: report.boxLoss, + loss2: report.clsLoss, + loss3: report.dflLoss, + fitness: report.fitness, + }))} + /> +
+ + {/*
+ +
*/} +
+ ); +} + +// function LabelingPreview() { +// return ( +//
+//

레이블링 프리뷰

+//
+// ); +// } diff --git a/frontend/src/components/ModelManage/InputWithLabel.tsx b/frontend/src/components/ModelManage/InputWithLabel.tsx new file mode 100644 index 0000000..48bfe9b --- /dev/null +++ b/frontend/src/components/ModelManage/InputWithLabel.tsx @@ -0,0 +1,24 @@ +import { Label } from '@/components/ui/label'; +import { Input } from '../ui/input'; +interface InputWithLabelProps { + label: string; + id: string; + placeholder: string; + value: number; + onChange: (e: React.ChangeEvent) => void; +} + +export default function InputWithLabel({ label, id, placeholder, value, onChange }: InputWithLabelProps) { + return ( +
+ + +
+ ); +} diff --git a/frontend/src/components/ModelBarChart/ModelBarChart.tsx b/frontend/src/components/ModelManage/ModelBarChart.tsx similarity index 100% rename from frontend/src/components/ModelBarChart/ModelBarChart.tsx rename to frontend/src/components/ModelManage/ModelBarChart.tsx diff --git a/frontend/src/components/ModelLineChart/ModelLineChart.tsx b/frontend/src/components/ModelManage/ModelLineChart.tsx similarity index 100% rename from frontend/src/components/ModelLineChart/ModelLineChart.tsx rename to frontend/src/components/ModelManage/ModelLineChart.tsx diff --git a/frontend/src/components/ModelManage/SelectWithLabel.tsx b/frontend/src/components/ModelManage/SelectWithLabel.tsx new file mode 100644 index 0000000..da860f6 --- /dev/null +++ b/frontend/src/components/ModelManage/SelectWithLabel.tsx @@ -0,0 +1,42 @@ +import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from '@/components/ui/select'; +import { Label } from '@/components/ui/label'; + +interface SelectWithLabelOption { + label: string; + value: string; +} + +interface SelectWithLabelProps { + label: string; + id: string; + options: SelectWithLabelOption[]; + placeholder: string; + value: string; + onChange: (value: string) => void; +} + +export default function SelectWithLabel({ label, id, options, placeholder, value, onChange }: SelectWithLabelProps) { + return ( +
+ + +
+ ); +} diff --git a/frontend/src/components/ModelManage/SettingsForm.tsx b/frontend/src/components/ModelManage/SettingsForm.tsx deleted file mode 100644 index 5fdf721..0000000 --- a/frontend/src/components/ModelManage/SettingsForm.tsx +++ /dev/null @@ -1,189 +0,0 @@ -import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from '@/components/ui/select'; -import { Input } from '@/components/ui/input'; -import { Label } from '@/components/ui/label'; -import useProjectModelsQuery from '@/queries/models/useProjectModelsQuery'; -import { useState } from 'react'; - -interface SettingsFormProps { - projectId: string | null; // projectId를 프랍으로 받음 - onSubmit?: (data: SettingsFormData) => void; -} - -export interface SettingsFormData { - projectId: number | null; - selectedModel: string | null; - ratio: number; - epochs: number; - batchSize: number; - optimizer: string; - lr0: number; - lrf: number; -} - -export default function SettingsForm({ projectId, onSubmit }: SettingsFormProps) { - const numericProjectId = projectId ? parseInt(projectId, 10) : null; - - const { data: models } = useProjectModelsQuery(numericProjectId ?? 0); - const [selectedModel, setSelectedModel] = useState(null); - const [ratio, setRatio] = useState(0.8); - const [epochs, setEpochs] = useState(50); - const [batchSize, setBatchSize] = useState(32); - const [optimizer, setOptimizer] = useState('SGD'); - const [lr0, setLr0] = useState(0.01); - const [lrf, setLrf] = useState(0.001); - - const handleSubmit = () => { - if (onSubmit) { - onSubmit({ - projectId: numericProjectId, - selectedModel, - ratio, - epochs, - batchSize, - optimizer, - lr0, - lrf, - }); - } - }; - - return ( -
-
- 모델 설정 - - {/* 모델 선택 */} -
- - -
- - {/* 훈련/검증 비율 및 학습 파라미터 */} -
- setRatio(parseFloat(e.target.value))} - /> - setEpochs(parseInt(e.target.value, 10))} - /> - setBatchSize(parseInt(e.target.value, 10))} - /> - - setLr0(parseFloat(e.target.value))} - /> - setLrf(parseFloat(e.target.value))} - /> -
- - -
-
- ); -} - -interface InputWithLabelProps { - label: string; - id: string; - placeholder: string; - value: number; - onChange: (e: React.ChangeEvent) => void; -} - -function InputWithLabel({ label, id, placeholder, value, onChange }: InputWithLabelProps) { - return ( -
- - -
- ); -} - -interface SelectWithLabelProps { - label: string; - id: string; - options: string[]; - placeholder: string; - value: string; - onChange: (value: string) => void; -} - -function SelectWithLabel({ label, id, options, placeholder, onChange }: SelectWithLabelProps) { - return ( -
- - -
- ); -} diff --git a/frontend/src/components/ModelManage/TrainingGraph.tsx b/frontend/src/components/ModelManage/TrainingGraph.tsx new file mode 100644 index 0000000..96169e5 --- /dev/null +++ b/frontend/src/components/ModelManage/TrainingGraph.tsx @@ -0,0 +1,25 @@ +import ModelLineChart from './ModelLineChart'; +import usePollingModelReportsQuery from '@/queries/models/usePollingModelReportsQuery'; + +interface TrainingGraphProps { + projectId: number | null; + selectedModel: number | null; +} + +export default function TrainingGraph({ projectId, selectedModel }: TrainingGraphProps) { + const { data: trainingDataList } = usePollingModelReportsQuery(projectId as number, selectedModel ?? 0); + + return ( + ({ + epoch: data.epoch.toString(), + loss1: data.boxLoss, + loss2: data.clsLoss, + loss3: data.dflLoss, + fitness: data.fitness, + })) || [] + } + /> + ); +} diff --git a/frontend/src/components/ModelManage/TrainingSettings.tsx b/frontend/src/components/ModelManage/TrainingSettings.tsx new file mode 100644 index 0000000..785b1dc --- /dev/null +++ b/frontend/src/components/ModelManage/TrainingSettings.tsx @@ -0,0 +1,134 @@ +import SelectWithLabel from './SelectWithLabel'; +import InputWithLabel from './InputWithLabel'; +import { Button } from '@/components/ui/button'; +import useProjectModelsQuery from '@/queries/models/useProjectModelsQuery'; +import { ModelTrainRequest } from '@/types'; +import { useState } from 'react'; + +interface TrainingSettingsProps { + projectId: number | null; + selectedModel: number | null; + setSelectedModel: (model: number | null) => void; + handleTrainingStart: (trainData: ModelTrainRequest) => void; + isTraining: boolean; +} + +export default function TrainingSettings({ + projectId, + selectedModel, + setSelectedModel, + handleTrainingStart, + isTraining, +}: TrainingSettingsProps) { + const { data: models } = useProjectModelsQuery(projectId ?? 0); + + const [ratio, setRatio] = useState(0.8); + const [epochs, setEpochs] = useState(50); + const [batchSize, setBatchSize] = useState(32); + const [optimizer, setOptimizer] = useState<'SGD' | 'AUTO' | 'ADAM' | 'ADAMW' | 'NADAM' | 'RADAM' | 'RMSPROP'>('AUTO'); + const [lr0, setLr0] = useState(0.01); + const [lrf, setLrf] = useState(0.001); + + const handleSubmit = () => { + if (selectedModel !== null) { + const trainData: ModelTrainRequest = { + modelId: selectedModel, + ratio, + epochs, + batch: batchSize, + optimizer, + lr0, + lrf, + }; + handleTrainingStart(trainData); + } + }; + + return ( +
+ 모델 설정 + +
+ ({ + label: model.name, + value: model.id.toString(), + })) || [] + } + placeholder="모델을 선택하세요" + value={selectedModel ? selectedModel.toString() : ''} + onChange={(value) => setSelectedModel(parseInt(value, 10))} + /> +
+ +
+ setRatio(parseFloat(e.target.value))} + /> + setEpochs(parseInt(e.target.value, 10))} + /> + setBatchSize(parseInt(e.target.value, 10))} + /> + setOptimizer(value as 'AUTO' | 'SGD' | 'ADAM' | 'ADAMW' | 'NADAM' | 'RADAM' | 'RMSPROP')} + /> + setLr0(parseFloat(e.target.value))} + /> + setLrf(parseFloat(e.target.value))} + /> +
+ + +
+ ); +} diff --git a/frontend/src/components/ModelManage/TrainingTab.tsx b/frontend/src/components/ModelManage/TrainingTab.tsx index 15866df..6bba7e0 100644 --- a/frontend/src/components/ModelManage/TrainingTab.tsx +++ b/frontend/src/components/ModelManage/TrainingTab.tsx @@ -1,45 +1,60 @@ -import { Button } from '@/components/ui/button'; -import ModelLineChart from '@/components/ModelLineChart'; -import SettingsForm from './SettingsForm'; +import useTrainModelQuery from '@/queries/models/useTrainModelQuery'; +import useModelStore from '@/stores/useModelStore'; +import TrainingSettings from './TrainingSettings'; +import TrainingGraph from './TrainingGraph'; +import { ModelTrainRequest } from '@/types'; interface TrainingTabProps { - training: boolean; - handleTrainingToggle: () => void; - trainingDataList: { - epoch: number; - box_loss: number; - cls_loss: number; - dfl_loss: number; - fitness: number; - }[]; - projectId: string | null; // projectId를 프랍으로 받음 + projectId: number | null; } -export default function TrainingTab({ training, handleTrainingToggle, trainingDataList, projectId }: TrainingTabProps) { +export default function TrainingTab({ projectId }: TrainingTabProps) { + const numericProjectId = projectId ? parseInt(projectId.toString(), 10) : null; + const { isTrainingByProject, setIsTraining, selectedModelByProject, setSelectedModel, trainingDataByProject } = + useModelStore((state) => ({ + isTrainingByProject: state.isTrainingByProject, + setIsTraining: state.setIsTraining, + selectedModelByProject: state.selectedModelByProject, + setSelectedModel: state.setSelectedModel, + trainingDataByProject: state.trainingDataByProject, + })); + + const isTraining = isTrainingByProject[numericProjectId?.toString() || ''] || false; + const selectedModel = selectedModelByProject[numericProjectId?.toString() || '']; + + const { mutate: startTraining } = useTrainModelQuery(numericProjectId as number); + + const handleTrainingStart = (trainData: ModelTrainRequest) => { + if (!isTraining && selectedModel !== null) { + setIsTraining(numericProjectId?.toString() || '', true); + startTraining(trainData); + } + }; + + const trainingData = trainingDataByProject[numericProjectId?.toString() || '']; + return (
-
- - -
+ setSelectedModel(numericProjectId?.toString() || '', modelId)} + handleTrainingStart={handleTrainingStart} + isTraining={isTraining} + /> -
- ({ - epoch: data.epoch.toString(), - loss1: data.box_loss, - loss2: data.cls_loss, - loss3: data.dfl_loss, - fitness: data.fitness, - }))} - /> -
+ + + {trainingData && ( +
+

현재 에포크: {trainingData[trainingData.length - 1]?.epoch}

+

총 에포크: {trainingData[trainingData.length - 1]?.totalEpochs}

+

예상 남은시간: {trainingData[trainingData.length - 1]?.leftSecond}

+
+ )}
); } diff --git a/frontend/src/components/ModelManage/index.tsx b/frontend/src/components/ModelManage/index.tsx index e1e5938..144645a 100644 --- a/frontend/src/components/ModelManage/index.tsx +++ b/frontend/src/components/ModelManage/index.tsx @@ -1,28 +1,11 @@ import { Tabs, TabsContent, TabsList, TabsTrigger } from '@/components/ui/tabs'; -import { useState } from 'react'; import { useParams } from 'react-router-dom'; - -import useTrainWebSocket from '@/hooks/useTrainPolling'; -import useTrainStore from '@/stores/useTrainStore'; import TrainingTab from './TrainingTab'; import EvaluationTab from './EvaluationTab'; export default function ModelManage() { const { projectId } = useParams<{ projectId?: string }>(); - const [training, setTraining] = useState(false); - const [selectedModel, setSelectedModel] = useState(null); - - const numericProjectId = projectId ?? null; - - useTrainWebSocket(training, numericProjectId); - - const { trainingDataList } = useTrainStore((state) => ({ - trainingDataList: numericProjectId ? state.trainingDataByProject[numericProjectId] || [] : [], - })); - - const handleTrainingToggle = () => { - setTraining((prev) => !prev); - }; + const numericProjectId = projectId ? parseInt(projectId, 10) : null; return (
@@ -41,22 +24,12 @@ export default function ModelManage() { 모델 평가 - {/* 학습 탭 */} - + - {/* 평가 탭 */} - + diff --git a/frontend/src/hooks/useTrainPolling.ts b/frontend/src/hooks/useTrainPolling.ts deleted file mode 100644 index 4fd4a36..0000000 --- a/frontend/src/hooks/useTrainPolling.ts +++ /dev/null @@ -1,49 +0,0 @@ -// 임시 가짜 훅 -import { useEffect, useRef, useCallback } from 'react'; -import axios from 'axios'; -import useTrainStore from '@/stores/useTrainStore'; - -export default function useTrainPolling(start: boolean, projectId?: string | null) { - const { addTrainingData, resetTrainingData } = useTrainStore((state) => ({ - addTrainingData: state.addTrainingData, - resetTrainingData: state.resetTrainingData, - })); - - const intervalIdRef = useRef(null); - // 함수 api 후 교체 예정 - const fetchTrainingData = useCallback(async () => { - if (projectId) { - try { - const response = await axios.get(`/api/바보=${projectId}`); - const data = response.data; - - addTrainingData(projectId, { - epoch: data.epoch, - total_epochs: data.total_epochs, - box_loss: data.box_loss, - cls_loss: data.cls_loss, - dfl_loss: data.dfl_loss, - fitness: data.fitness, - epoch_time: data.epoch_time, - left_second: data.left_second, - }); - } catch (error) { - console.error('Fetching error:', error); - } - } - }, [projectId, addTrainingData]); - - useEffect(() => { - if (start && projectId) { - resetTrainingData(projectId); - intervalIdRef.current = window.setInterval(fetchTrainingData, 5000); - } - - return () => { - if (intervalIdRef.current) { - clearInterval(intervalIdRef.current); - intervalIdRef.current = null; - } - }; - }, [start, projectId, fetchTrainingData, resetTrainingData]); -} diff --git a/frontend/src/queries/models/usePollingModelReportsQuery.ts b/frontend/src/queries/models/usePollingModelReportsQuery.ts new file mode 100644 index 0000000..e1d20bb --- /dev/null +++ b/frontend/src/queries/models/usePollingModelReportsQuery.ts @@ -0,0 +1,12 @@ +import { useQuery } from '@tanstack/react-query'; +import { getModelReports } from '@/api/modelApi'; +import { ReportResponse } from '@/types'; + +export default function usePollingModelReportsQuery(projectId: number, modelId: number) { + return useQuery({ + queryKey: ['pollingModelReports', projectId, modelId], + queryFn: () => getModelReports(projectId, modelId), + refetchInterval: 5000, + enabled: !!projectId && !!modelId, + }); +} diff --git a/frontend/src/stores/useModelStore.ts b/frontend/src/stores/useModelStore.ts new file mode 100644 index 0000000..353c6c0 --- /dev/null +++ b/frontend/src/stores/useModelStore.ts @@ -0,0 +1,40 @@ +import { create } from 'zustand'; +import { ReportResponse } from '@/types'; + +interface ModelStoreState { + trainingDataByProject: Record; + isTrainingByProject: Record; + selectedModelByProject: Record; + setIsTraining: (projectId: string, status: boolean) => void; + saveTrainingData: (projectId: string, data: ReportResponse[]) => void; + setSelectedModel: (projectId: string, modelId: number | null) => void; +} + +const useModelStore = create((set) => ({ + trainingDataByProject: {}, + isTrainingByProject: {}, + selectedModelByProject: {}, + setIsTraining: (projectId, status) => + set((state) => ({ + isTrainingByProject: { + ...state.isTrainingByProject, + [projectId]: status, + }, + })), + saveTrainingData: (projectId, data) => + set((state) => ({ + trainingDataByProject: { + ...state.trainingDataByProject, + [projectId]: data, + }, + })), + setSelectedModel: (projectId, modelId) => + set((state) => ({ + selectedModelByProject: { + ...state.selectedModelByProject, + [projectId]: modelId, + }, + })), +})); + +export default useModelStore; diff --git a/frontend/src/stores/useTrainStore.ts b/frontend/src/stores/useTrainStore.ts deleted file mode 100644 index 6aeffd9..0000000 --- a/frontend/src/stores/useTrainStore.ts +++ /dev/null @@ -1,40 +0,0 @@ -import { create } from 'zustand'; - -interface TrainingData { - epoch: number; - total_epochs: number; - box_loss: number; - cls_loss: number; - dfl_loss: number; - fitness: number; - epoch_time: number; - left_second: number; -} - -interface StoreState { - trainingDataByProject: { [projectId: string]: TrainingData[] }; - addTrainingData: (projectId: string, data: TrainingData) => void; - resetTrainingData: (projectId: string) => void; -} - -const useTrainStore = create((set) => ({ - trainingDataByProject: {}, - - addTrainingData: (projectId: string, data: TrainingData) => - set((state) => ({ - trainingDataByProject: { - ...state.trainingDataByProject, - [projectId]: [...(state.trainingDataByProject[projectId] || []), data], - }, - })), - - resetTrainingData: (projectId: string) => - set((state) => ({ - trainingDataByProject: { - ...state.trainingDataByProject, - [projectId]: [], - }, - })), -})); - -export default useTrainStore; From fbf4c7a6a51bf8c0bf4f29372eacff3dd9b3f427 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EC=A0=95=ED=98=84=EC=A1=B0?= Date: Wed, 25 Sep 2024 07:31:33 +0900 Subject: [PATCH 3/3] =?UTF-8?q?Feat:=20=ED=95=99=EC=8A=B5=20=EC=A4=91?= =?UTF-8?q?=EB=8B=A8=20=EB=93=B1=20=EB=A1=9C=EC=A7=81=20=EC=B6=94=EA=B0=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../components/ModelManage/ModelLineChart.tsx | 71 ++++++++++--------- .../components/ModelManage/TrainingGraph.tsx | 36 +++++++++- .../ModelManage/TrainingSettings.tsx | 8 ++- .../components/ModelManage/TrainingTab.tsx | 20 +++--- .../models/usePollingModelReportsQuery.ts | 4 +- frontend/src/stores/useModelStore.ts | 16 +++++ 6 files changed, 107 insertions(+), 48 deletions(-) diff --git a/frontend/src/components/ModelManage/ModelLineChart.tsx b/frontend/src/components/ModelManage/ModelLineChart.tsx index 2627611..d2dfa23 100644 --- a/frontend/src/components/ModelManage/ModelLineChart.tsx +++ b/frontend/src/components/ModelManage/ModelLineChart.tsx @@ -1,54 +1,74 @@ 'use client'; -import { TrendingUp } from 'lucide-react'; -import { CartesianGrid, Line, LineChart, XAxis } from 'recharts'; - -import { Card, CardContent, CardDescription, CardFooter, CardHeader, CardTitle } from '@/components/ui/card'; -import { ChartConfig, ChartContainer, ChartTooltip, ChartTooltipContent } from '@/components/ui/chart'; +import { CartesianGrid, Line, LineChart, XAxis, YAxis, Tooltip, Legend } from 'recharts'; +import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card'; +import { ChartConfig, ChartContainer } from '@/components/ui/chart'; interface MetricData { epoch: string; - loss1: number; - loss2: number; - loss3: number; - fitness: number; + loss1?: number; + loss2?: number; + loss3?: number; + fitness?: number; } interface ModelLineChartProps { data: MetricData[]; + currentEpoch?: number; + totalEpochs?: number; + remainingTime?: number; } const chartConfig = { loss1: { label: 'Loss 1', - color: '#FF6347', // 토마토색 + color: '#FF6347', }, loss2: { label: 'Loss 2', - color: '#1E90FF', // 다저블루색 + color: '#1E90FF', }, loss3: { label: 'Loss 3', - color: '#32CD32', // 라임색 + color: '#32CD32', }, fitness: { label: 'Fitness', - color: '#FFD700', // 골드색 + color: '#FFD700', }, } satisfies ChartConfig; -export default function ModelLineChart({ data }: ModelLineChartProps) { +export default function ModelLineChart({ data, currentEpoch, totalEpochs, remainingTime }: ModelLineChartProps) { + const emptyData = Array.from({ length: totalEpochs || 0 }, (_, i) => ({ + epoch: (i + 1).toString(), + loss1: null, + loss2: null, + loss3: null, + fitness: null, + })); + + const filledData = emptyData.map((item, index) => ({ + ...item, + ...(data[index] || {}), + })); + return ( Model Training Metrics - Loss and Fitness over Epochs + {currentEpoch !== undefined && totalEpochs !== undefined && remainingTime !== undefined && ( +
+

현재 에포크: {currentEpoch}

+

총 에포크: {totalEpochs}

+

예상 남은시간: {remainingTime}

+
+ )} `Epoch ${value}`} /> - } - /> + + +
- -
-
-
- Trending up by 5.2% this epoch -
-
- Showing training loss and fitness for the current model -
-
-
-
); } diff --git a/frontend/src/components/ModelManage/TrainingGraph.tsx b/frontend/src/components/ModelManage/TrainingGraph.tsx index 96169e5..e73cf9c 100644 --- a/frontend/src/components/ModelManage/TrainingGraph.tsx +++ b/frontend/src/components/ModelManage/TrainingGraph.tsx @@ -1,5 +1,7 @@ +import { useEffect, useMemo } from 'react'; import ModelLineChart from './ModelLineChart'; import usePollingModelReportsQuery from '@/queries/models/usePollingModelReportsQuery'; +import useModelStore from '@/stores/useModelStore'; interface TrainingGraphProps { projectId: number | null; @@ -7,7 +9,36 @@ interface TrainingGraphProps { } export default function TrainingGraph({ projectId, selectedModel }: TrainingGraphProps) { - const { data: trainingDataList } = usePollingModelReportsQuery(projectId as number, selectedModel ?? 0); + const { isTrainingByProject, setIsTraining, resetTrainingData } = useModelStore((state) => ({ + isTrainingByProject: state.isTrainingByProject, + setIsTraining: state.setIsTraining, + resetTrainingData: state.resetTrainingData, + })); + + const isTraining = isTrainingByProject[projectId?.toString() || ''] || false; + + const { data: trainingDataList } = usePollingModelReportsQuery( + projectId as number, + selectedModel ?? 0, + isTraining && !!projectId && !!selectedModel + ); + + const latestData = useMemo(() => { + return ( + trainingDataList?.[trainingDataList.length - 1] || { + epoch: 0, + totalEpochs: 0, + leftSecond: 0, + } + ); + }, [trainingDataList]); + + useEffect(() => { + if (latestData.epoch === latestData.totalEpochs && latestData.totalEpochs > 0) { + setIsTraining(projectId?.toString() || '', false); + resetTrainingData(projectId?.toString() || ''); + } + }, [latestData.epoch, latestData.totalEpochs, setIsTraining, resetTrainingData, projectId]); return ( ); } diff --git a/frontend/src/components/ModelManage/TrainingSettings.tsx b/frontend/src/components/ModelManage/TrainingSettings.tsx index 785b1dc..63660be 100644 --- a/frontend/src/components/ModelManage/TrainingSettings.tsx +++ b/frontend/src/components/ModelManage/TrainingSettings.tsx @@ -10,6 +10,7 @@ interface TrainingSettingsProps { selectedModel: number | null; setSelectedModel: (model: number | null) => void; handleTrainingStart: (trainData: ModelTrainRequest) => void; + handleTrainingStop: () => void; isTraining: boolean; } @@ -18,6 +19,7 @@ export default function TrainingSettings({ selectedModel, setSelectedModel, handleTrainingStart, + handleTrainingStop, isTraining, }: TrainingSettingsProps) { const { data: models } = useProjectModelsQuery(projectId ?? 0); @@ -30,7 +32,9 @@ export default function TrainingSettings({ const [lrf, setLrf] = useState(0.001); const handleSubmit = () => { - if (selectedModel !== null) { + if (isTraining) { + handleTrainingStop(); + } else if (selectedModel !== null) { const trainData: ModelTrainRequest = { modelId: selectedModel, ratio, @@ -127,7 +131,7 @@ export default function TrainingSettings({ onClick={handleSubmit} disabled={!selectedModel || isTraining} > - 학습 시작 + {isTraining ? '학습 중단' : '학습 시작'} ); diff --git a/frontend/src/components/ModelManage/TrainingTab.tsx b/frontend/src/components/ModelManage/TrainingTab.tsx index 6bba7e0..16ddfb6 100644 --- a/frontend/src/components/ModelManage/TrainingTab.tsx +++ b/frontend/src/components/ModelManage/TrainingTab.tsx @@ -10,13 +10,13 @@ interface TrainingTabProps { export default function TrainingTab({ projectId }: TrainingTabProps) { const numericProjectId = projectId ? parseInt(projectId.toString(), 10) : null; - const { isTrainingByProject, setIsTraining, selectedModelByProject, setSelectedModel, trainingDataByProject } = + const { isTrainingByProject, setIsTraining, selectedModelByProject, setSelectedModel, resetTrainingData } = useModelStore((state) => ({ isTrainingByProject: state.isTrainingByProject, setIsTraining: state.setIsTraining, selectedModelByProject: state.selectedModelByProject, setSelectedModel: state.setSelectedModel, - trainingDataByProject: state.trainingDataByProject, + resetTrainingData: state.resetTrainingData, })); const isTraining = isTrainingByProject[numericProjectId?.toString() || ''] || false; @@ -31,7 +31,12 @@ export default function TrainingTab({ projectId }: TrainingTabProps) { } }; - const trainingData = trainingDataByProject[numericProjectId?.toString() || '']; + const handleTrainingStop = () => { + if (isTraining) { + setIsTraining(numericProjectId?.toString() || '', false); + resetTrainingData(numericProjectId?.toString() || ''); + } + }; return (
@@ -40,6 +45,7 @@ export default function TrainingTab({ projectId }: TrainingTabProps) { selectedModel={selectedModel} setSelectedModel={(modelId) => setSelectedModel(numericProjectId?.toString() || '', modelId)} handleTrainingStart={handleTrainingStart} + handleTrainingStop={handleTrainingStop} isTraining={isTraining} /> @@ -47,14 +53,6 @@ export default function TrainingTab({ projectId }: TrainingTabProps) { projectId={numericProjectId} selectedModel={selectedModel} /> - - {trainingData && ( -
-

현재 에포크: {trainingData[trainingData.length - 1]?.epoch}

-

총 에포크: {trainingData[trainingData.length - 1]?.totalEpochs}

-

예상 남은시간: {trainingData[trainingData.length - 1]?.leftSecond}

-
- )}
); } diff --git a/frontend/src/queries/models/usePollingModelReportsQuery.ts b/frontend/src/queries/models/usePollingModelReportsQuery.ts index e1d20bb..85a64ee 100644 --- a/frontend/src/queries/models/usePollingModelReportsQuery.ts +++ b/frontend/src/queries/models/usePollingModelReportsQuery.ts @@ -2,11 +2,11 @@ import { useQuery } from '@tanstack/react-query'; import { getModelReports } from '@/api/modelApi'; import { ReportResponse } from '@/types'; -export default function usePollingModelReportsQuery(projectId: number, modelId: number) { +export default function usePollingModelReportsQuery(projectId: number, modelId: number, enabled: boolean) { return useQuery({ queryKey: ['pollingModelReports', projectId, modelId], queryFn: () => getModelReports(projectId, modelId), refetchInterval: 5000, - enabled: !!projectId && !!modelId, + enabled, }); } diff --git a/frontend/src/stores/useModelStore.ts b/frontend/src/stores/useModelStore.ts index 353c6c0..f6ba8a5 100644 --- a/frontend/src/stores/useModelStore.ts +++ b/frontend/src/stores/useModelStore.ts @@ -8,6 +8,7 @@ interface ModelStoreState { setIsTraining: (projectId: string, status: boolean) => void; saveTrainingData: (projectId: string, data: ReportResponse[]) => void; setSelectedModel: (projectId: string, modelId: number | null) => void; + resetTrainingData: (projectId: string) => void; } const useModelStore = create((set) => ({ @@ -35,6 +36,21 @@ const useModelStore = create((set) => ({ [projectId]: modelId, }, })), + resetTrainingData: (projectId) => + set((state) => ({ + trainingDataByProject: { + ...state.trainingDataByProject, + [projectId]: [], + }, + selectedModelByProject: { + ...state.selectedModelByProject, + [projectId]: null, + }, + isTrainingByProject: { + ...state.isTrainingByProject, + [projectId]: false, + }, + })), })); export default useModelStore;