From c40cd0741d698ebcb1c8f3d6c2b854358effd370 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EC=A0=95=ED=98=84=EC=A1=B0?= Date: Wed, 25 Sep 2024 06:43:01 +0900 Subject: [PATCH] =?UTF-8?q?Refactor:=20=EB=AA=A8=EB=8D=B8=20api=20?= =?UTF-8?q?=EC=97=B0=EA=B2=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../components/ModelManage/EvaluationTab.tsx | 133 ++++++++---- .../components/ModelManage/InputWithLabel.tsx | 24 +++ .../ModelBarChart.tsx | 0 .../ModelLineChart.tsx | 0 .../ModelManage/SelectWithLabel.tsx | 42 ++++ .../components/ModelManage/SettingsForm.tsx | 189 ------------------ .../components/ModelManage/TrainingGraph.tsx | 25 +++ .../ModelManage/TrainingSettings.tsx | 134 +++++++++++++ .../components/ModelManage/TrainingTab.tsx | 85 ++++---- frontend/src/components/ModelManage/index.tsx | 33 +-- frontend/src/hooks/useTrainPolling.ts | 49 ----- .../models/usePollingModelReportsQuery.ts | 12 ++ frontend/src/stores/useModelStore.ts | 40 ++++ frontend/src/stores/useTrainStore.ts | 40 ---- 14 files changed, 427 insertions(+), 379 deletions(-) create mode 100644 frontend/src/components/ModelManage/InputWithLabel.tsx rename frontend/src/components/{ModelBarChart => ModelManage}/ModelBarChart.tsx (100%) rename frontend/src/components/{ModelLineChart => ModelManage}/ModelLineChart.tsx (100%) create mode 100644 frontend/src/components/ModelManage/SelectWithLabel.tsx delete mode 100644 frontend/src/components/ModelManage/SettingsForm.tsx create mode 100644 frontend/src/components/ModelManage/TrainingGraph.tsx create mode 100644 frontend/src/components/ModelManage/TrainingSettings.tsx delete mode 100644 frontend/src/hooks/useTrainPolling.ts create mode 100644 frontend/src/queries/models/usePollingModelReportsQuery.ts create mode 100644 frontend/src/stores/useModelStore.ts delete mode 100644 frontend/src/stores/useTrainStore.ts diff --git a/frontend/src/components/ModelManage/EvaluationTab.tsx b/frontend/src/components/ModelManage/EvaluationTab.tsx index 5824787..830c9a9 100644 --- a/frontend/src/components/ModelManage/EvaluationTab.tsx +++ b/frontend/src/components/ModelManage/EvaluationTab.tsx @@ -1,55 +1,116 @@ import { Label } from '@/components/ui/label'; import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from '@/components/ui/select'; -import ModelBarChart from '@/components/ModelBarChart'; +import { useState } from 'react'; +import useProjectModelsQuery from '@/queries/models/useProjectModelsQuery'; +import useModelReportsQuery from '@/queries/models/useModelReportsQuery'; +import useModelResultsQuery from '@/queries/models/useModelResultsQuery'; +import ModelBarChart from './ModelBarChart'; +import ModelLineChart from './ModelLineChart'; interface EvaluationTabProps { - selectedModel: string | null; - setSelectedModel: (model: string | null) => void; + projectId: number | null; } -export default function EvaluationTab({ selectedModel, setSelectedModel }: EvaluationTabProps) { +export default function EvaluationTab({ projectId }: EvaluationTabProps) { + const [selectedModel, setSelectedModel] = useState(null); + + const { data: models } = useProjectModelsQuery(projectId ?? 0); + return (
-
- - -
+ {selectedModel && ( -
-
- -
-
- -
-
+ )}
); } -function LabelingPreview() { +interface ModelSelectionProps { + models: Array<{ id: number; name: string }> | undefined; + setSelectedModel: (modelId: number) => void; +} + +function ModelSelection({ models, setSelectedModel }: ModelSelectionProps) { return ( -
-

레이블링 프리뷰

+
+ +
); } + +interface ModelEvaluationProps { + projectId: number; + selectedModel: number; +} + +function ModelEvaluation({ projectId, selectedModel }: ModelEvaluationProps) { + const { data: reportData } = useModelReportsQuery(projectId, selectedModel); + const { data: resultData } = useModelResultsQuery(selectedModel); + + if (!reportData || !resultData) { + return null; + } + + return ( +
+
+ +
+ +
+ ({ + epoch: report.epoch.toString(), + loss1: report.boxLoss, + loss2: report.clsLoss, + loss3: report.dflLoss, + fitness: report.fitness, + }))} + /> +
+ + {/*
+ +
*/} +
+ ); +} + +// function LabelingPreview() { +// return ( +//
+//

레이블링 프리뷰

+//
+// ); +// } diff --git a/frontend/src/components/ModelManage/InputWithLabel.tsx b/frontend/src/components/ModelManage/InputWithLabel.tsx new file mode 100644 index 0000000..48bfe9b --- /dev/null +++ b/frontend/src/components/ModelManage/InputWithLabel.tsx @@ -0,0 +1,24 @@ +import { Label } from '@/components/ui/label'; +import { Input } from '../ui/input'; +interface InputWithLabelProps { + label: string; + id: string; + placeholder: string; + value: number; + onChange: (e: React.ChangeEvent) => void; +} + +export default function InputWithLabel({ label, id, placeholder, value, onChange }: InputWithLabelProps) { + return ( +
+ + +
+ ); +} diff --git a/frontend/src/components/ModelBarChart/ModelBarChart.tsx b/frontend/src/components/ModelManage/ModelBarChart.tsx similarity index 100% rename from frontend/src/components/ModelBarChart/ModelBarChart.tsx rename to frontend/src/components/ModelManage/ModelBarChart.tsx diff --git a/frontend/src/components/ModelLineChart/ModelLineChart.tsx b/frontend/src/components/ModelManage/ModelLineChart.tsx similarity index 100% rename from frontend/src/components/ModelLineChart/ModelLineChart.tsx rename to frontend/src/components/ModelManage/ModelLineChart.tsx diff --git a/frontend/src/components/ModelManage/SelectWithLabel.tsx b/frontend/src/components/ModelManage/SelectWithLabel.tsx new file mode 100644 index 0000000..da860f6 --- /dev/null +++ b/frontend/src/components/ModelManage/SelectWithLabel.tsx @@ -0,0 +1,42 @@ +import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from '@/components/ui/select'; +import { Label } from '@/components/ui/label'; + +interface SelectWithLabelOption { + label: string; + value: string; +} + +interface SelectWithLabelProps { + label: string; + id: string; + options: SelectWithLabelOption[]; + placeholder: string; + value: string; + onChange: (value: string) => void; +} + +export default function SelectWithLabel({ label, id, options, placeholder, value, onChange }: SelectWithLabelProps) { + return ( +
+ + +
+ ); +} diff --git a/frontend/src/components/ModelManage/SettingsForm.tsx b/frontend/src/components/ModelManage/SettingsForm.tsx deleted file mode 100644 index 5fdf721..0000000 --- a/frontend/src/components/ModelManage/SettingsForm.tsx +++ /dev/null @@ -1,189 +0,0 @@ -import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from '@/components/ui/select'; -import { Input } from '@/components/ui/input'; -import { Label } from '@/components/ui/label'; -import useProjectModelsQuery from '@/queries/models/useProjectModelsQuery'; -import { useState } from 'react'; - -interface SettingsFormProps { - projectId: string | null; // projectId를 프랍으로 받음 - onSubmit?: (data: SettingsFormData) => void; -} - -export interface SettingsFormData { - projectId: number | null; - selectedModel: string | null; - ratio: number; - epochs: number; - batchSize: number; - optimizer: string; - lr0: number; - lrf: number; -} - -export default function SettingsForm({ projectId, onSubmit }: SettingsFormProps) { - const numericProjectId = projectId ? parseInt(projectId, 10) : null; - - const { data: models } = useProjectModelsQuery(numericProjectId ?? 0); - const [selectedModel, setSelectedModel] = useState(null); - const [ratio, setRatio] = useState(0.8); - const [epochs, setEpochs] = useState(50); - const [batchSize, setBatchSize] = useState(32); - const [optimizer, setOptimizer] = useState('SGD'); - const [lr0, setLr0] = useState(0.01); - const [lrf, setLrf] = useState(0.001); - - const handleSubmit = () => { - if (onSubmit) { - onSubmit({ - projectId: numericProjectId, - selectedModel, - ratio, - epochs, - batchSize, - optimizer, - lr0, - lrf, - }); - } - }; - - return ( -
-
- 모델 설정 - - {/* 모델 선택 */} -
- - -
- - {/* 훈련/검증 비율 및 학습 파라미터 */} -
- setRatio(parseFloat(e.target.value))} - /> - setEpochs(parseInt(e.target.value, 10))} - /> - setBatchSize(parseInt(e.target.value, 10))} - /> - - setLr0(parseFloat(e.target.value))} - /> - setLrf(parseFloat(e.target.value))} - /> -
- - -
-
- ); -} - -interface InputWithLabelProps { - label: string; - id: string; - placeholder: string; - value: number; - onChange: (e: React.ChangeEvent) => void; -} - -function InputWithLabel({ label, id, placeholder, value, onChange }: InputWithLabelProps) { - return ( -
- - -
- ); -} - -interface SelectWithLabelProps { - label: string; - id: string; - options: string[]; - placeholder: string; - value: string; - onChange: (value: string) => void; -} - -function SelectWithLabel({ label, id, options, placeholder, onChange }: SelectWithLabelProps) { - return ( -
- - -
- ); -} diff --git a/frontend/src/components/ModelManage/TrainingGraph.tsx b/frontend/src/components/ModelManage/TrainingGraph.tsx new file mode 100644 index 0000000..96169e5 --- /dev/null +++ b/frontend/src/components/ModelManage/TrainingGraph.tsx @@ -0,0 +1,25 @@ +import ModelLineChart from './ModelLineChart'; +import usePollingModelReportsQuery from '@/queries/models/usePollingModelReportsQuery'; + +interface TrainingGraphProps { + projectId: number | null; + selectedModel: number | null; +} + +export default function TrainingGraph({ projectId, selectedModel }: TrainingGraphProps) { + const { data: trainingDataList } = usePollingModelReportsQuery(projectId as number, selectedModel ?? 0); + + return ( + ({ + epoch: data.epoch.toString(), + loss1: data.boxLoss, + loss2: data.clsLoss, + loss3: data.dflLoss, + fitness: data.fitness, + })) || [] + } + /> + ); +} diff --git a/frontend/src/components/ModelManage/TrainingSettings.tsx b/frontend/src/components/ModelManage/TrainingSettings.tsx new file mode 100644 index 0000000..785b1dc --- /dev/null +++ b/frontend/src/components/ModelManage/TrainingSettings.tsx @@ -0,0 +1,134 @@ +import SelectWithLabel from './SelectWithLabel'; +import InputWithLabel from './InputWithLabel'; +import { Button } from '@/components/ui/button'; +import useProjectModelsQuery from '@/queries/models/useProjectModelsQuery'; +import { ModelTrainRequest } from '@/types'; +import { useState } from 'react'; + +interface TrainingSettingsProps { + projectId: number | null; + selectedModel: number | null; + setSelectedModel: (model: number | null) => void; + handleTrainingStart: (trainData: ModelTrainRequest) => void; + isTraining: boolean; +} + +export default function TrainingSettings({ + projectId, + selectedModel, + setSelectedModel, + handleTrainingStart, + isTraining, +}: TrainingSettingsProps) { + const { data: models } = useProjectModelsQuery(projectId ?? 0); + + const [ratio, setRatio] = useState(0.8); + const [epochs, setEpochs] = useState(50); + const [batchSize, setBatchSize] = useState(32); + const [optimizer, setOptimizer] = useState<'SGD' | 'AUTO' | 'ADAM' | 'ADAMW' | 'NADAM' | 'RADAM' | 'RMSPROP'>('AUTO'); + const [lr0, setLr0] = useState(0.01); + const [lrf, setLrf] = useState(0.001); + + const handleSubmit = () => { + if (selectedModel !== null) { + const trainData: ModelTrainRequest = { + modelId: selectedModel, + ratio, + epochs, + batch: batchSize, + optimizer, + lr0, + lrf, + }; + handleTrainingStart(trainData); + } + }; + + return ( +
+ 모델 설정 + +
+ ({ + label: model.name, + value: model.id.toString(), + })) || [] + } + placeholder="모델을 선택하세요" + value={selectedModel ? selectedModel.toString() : ''} + onChange={(value) => setSelectedModel(parseInt(value, 10))} + /> +
+ +
+ setRatio(parseFloat(e.target.value))} + /> + setEpochs(parseInt(e.target.value, 10))} + /> + setBatchSize(parseInt(e.target.value, 10))} + /> + setOptimizer(value as 'AUTO' | 'SGD' | 'ADAM' | 'ADAMW' | 'NADAM' | 'RADAM' | 'RMSPROP')} + /> + setLr0(parseFloat(e.target.value))} + /> + setLrf(parseFloat(e.target.value))} + /> +
+ + +
+ ); +} diff --git a/frontend/src/components/ModelManage/TrainingTab.tsx b/frontend/src/components/ModelManage/TrainingTab.tsx index 15866df..6bba7e0 100644 --- a/frontend/src/components/ModelManage/TrainingTab.tsx +++ b/frontend/src/components/ModelManage/TrainingTab.tsx @@ -1,45 +1,60 @@ -import { Button } from '@/components/ui/button'; -import ModelLineChart from '@/components/ModelLineChart'; -import SettingsForm from './SettingsForm'; +import useTrainModelQuery from '@/queries/models/useTrainModelQuery'; +import useModelStore from '@/stores/useModelStore'; +import TrainingSettings from './TrainingSettings'; +import TrainingGraph from './TrainingGraph'; +import { ModelTrainRequest } from '@/types'; interface TrainingTabProps { - training: boolean; - handleTrainingToggle: () => void; - trainingDataList: { - epoch: number; - box_loss: number; - cls_loss: number; - dfl_loss: number; - fitness: number; - }[]; - projectId: string | null; // projectId를 프랍으로 받음 + projectId: number | null; } -export default function TrainingTab({ training, handleTrainingToggle, trainingDataList, projectId }: TrainingTabProps) { +export default function TrainingTab({ projectId }: TrainingTabProps) { + const numericProjectId = projectId ? parseInt(projectId.toString(), 10) : null; + const { isTrainingByProject, setIsTraining, selectedModelByProject, setSelectedModel, trainingDataByProject } = + useModelStore((state) => ({ + isTrainingByProject: state.isTrainingByProject, + setIsTraining: state.setIsTraining, + selectedModelByProject: state.selectedModelByProject, + setSelectedModel: state.setSelectedModel, + trainingDataByProject: state.trainingDataByProject, + })); + + const isTraining = isTrainingByProject[numericProjectId?.toString() || ''] || false; + const selectedModel = selectedModelByProject[numericProjectId?.toString() || '']; + + const { mutate: startTraining } = useTrainModelQuery(numericProjectId as number); + + const handleTrainingStart = (trainData: ModelTrainRequest) => { + if (!isTraining && selectedModel !== null) { + setIsTraining(numericProjectId?.toString() || '', true); + startTraining(trainData); + } + }; + + const trainingData = trainingDataByProject[numericProjectId?.toString() || '']; + return (
-
- - -
+ setSelectedModel(numericProjectId?.toString() || '', modelId)} + handleTrainingStart={handleTrainingStart} + isTraining={isTraining} + /> -
- ({ - epoch: data.epoch.toString(), - loss1: data.box_loss, - loss2: data.cls_loss, - loss3: data.dfl_loss, - fitness: data.fitness, - }))} - /> -
+ + + {trainingData && ( +
+

현재 에포크: {trainingData[trainingData.length - 1]?.epoch}

+

총 에포크: {trainingData[trainingData.length - 1]?.totalEpochs}

+

예상 남은시간: {trainingData[trainingData.length - 1]?.leftSecond}

+
+ )}
); } diff --git a/frontend/src/components/ModelManage/index.tsx b/frontend/src/components/ModelManage/index.tsx index e1e5938..144645a 100644 --- a/frontend/src/components/ModelManage/index.tsx +++ b/frontend/src/components/ModelManage/index.tsx @@ -1,28 +1,11 @@ import { Tabs, TabsContent, TabsList, TabsTrigger } from '@/components/ui/tabs'; -import { useState } from 'react'; import { useParams } from 'react-router-dom'; - -import useTrainWebSocket from '@/hooks/useTrainPolling'; -import useTrainStore from '@/stores/useTrainStore'; import TrainingTab from './TrainingTab'; import EvaluationTab from './EvaluationTab'; export default function ModelManage() { const { projectId } = useParams<{ projectId?: string }>(); - const [training, setTraining] = useState(false); - const [selectedModel, setSelectedModel] = useState(null); - - const numericProjectId = projectId ?? null; - - useTrainWebSocket(training, numericProjectId); - - const { trainingDataList } = useTrainStore((state) => ({ - trainingDataList: numericProjectId ? state.trainingDataByProject[numericProjectId] || [] : [], - })); - - const handleTrainingToggle = () => { - setTraining((prev) => !prev); - }; + const numericProjectId = projectId ? parseInt(projectId, 10) : null; return (
@@ -41,22 +24,12 @@ export default function ModelManage() { 모델 평가 - {/* 학습 탭 */} - + - {/* 평가 탭 */} - + diff --git a/frontend/src/hooks/useTrainPolling.ts b/frontend/src/hooks/useTrainPolling.ts deleted file mode 100644 index 4fd4a36..0000000 --- a/frontend/src/hooks/useTrainPolling.ts +++ /dev/null @@ -1,49 +0,0 @@ -// 임시 가짜 훅 -import { useEffect, useRef, useCallback } from 'react'; -import axios from 'axios'; -import useTrainStore from '@/stores/useTrainStore'; - -export default function useTrainPolling(start: boolean, projectId?: string | null) { - const { addTrainingData, resetTrainingData } = useTrainStore((state) => ({ - addTrainingData: state.addTrainingData, - resetTrainingData: state.resetTrainingData, - })); - - const intervalIdRef = useRef(null); - // 함수 api 후 교체 예정 - const fetchTrainingData = useCallback(async () => { - if (projectId) { - try { - const response = await axios.get(`/api/바보=${projectId}`); - const data = response.data; - - addTrainingData(projectId, { - epoch: data.epoch, - total_epochs: data.total_epochs, - box_loss: data.box_loss, - cls_loss: data.cls_loss, - dfl_loss: data.dfl_loss, - fitness: data.fitness, - epoch_time: data.epoch_time, - left_second: data.left_second, - }); - } catch (error) { - console.error('Fetching error:', error); - } - } - }, [projectId, addTrainingData]); - - useEffect(() => { - if (start && projectId) { - resetTrainingData(projectId); - intervalIdRef.current = window.setInterval(fetchTrainingData, 5000); - } - - return () => { - if (intervalIdRef.current) { - clearInterval(intervalIdRef.current); - intervalIdRef.current = null; - } - }; - }, [start, projectId, fetchTrainingData, resetTrainingData]); -} diff --git a/frontend/src/queries/models/usePollingModelReportsQuery.ts b/frontend/src/queries/models/usePollingModelReportsQuery.ts new file mode 100644 index 0000000..e1d20bb --- /dev/null +++ b/frontend/src/queries/models/usePollingModelReportsQuery.ts @@ -0,0 +1,12 @@ +import { useQuery } from '@tanstack/react-query'; +import { getModelReports } from '@/api/modelApi'; +import { ReportResponse } from '@/types'; + +export default function usePollingModelReportsQuery(projectId: number, modelId: number) { + return useQuery({ + queryKey: ['pollingModelReports', projectId, modelId], + queryFn: () => getModelReports(projectId, modelId), + refetchInterval: 5000, + enabled: !!projectId && !!modelId, + }); +} diff --git a/frontend/src/stores/useModelStore.ts b/frontend/src/stores/useModelStore.ts new file mode 100644 index 0000000..353c6c0 --- /dev/null +++ b/frontend/src/stores/useModelStore.ts @@ -0,0 +1,40 @@ +import { create } from 'zustand'; +import { ReportResponse } from '@/types'; + +interface ModelStoreState { + trainingDataByProject: Record; + isTrainingByProject: Record; + selectedModelByProject: Record; + setIsTraining: (projectId: string, status: boolean) => void; + saveTrainingData: (projectId: string, data: ReportResponse[]) => void; + setSelectedModel: (projectId: string, modelId: number | null) => void; +} + +const useModelStore = create((set) => ({ + trainingDataByProject: {}, + isTrainingByProject: {}, + selectedModelByProject: {}, + setIsTraining: (projectId, status) => + set((state) => ({ + isTrainingByProject: { + ...state.isTrainingByProject, + [projectId]: status, + }, + })), + saveTrainingData: (projectId, data) => + set((state) => ({ + trainingDataByProject: { + ...state.trainingDataByProject, + [projectId]: data, + }, + })), + setSelectedModel: (projectId, modelId) => + set((state) => ({ + selectedModelByProject: { + ...state.selectedModelByProject, + [projectId]: modelId, + }, + })), +})); + +export default useModelStore; diff --git a/frontend/src/stores/useTrainStore.ts b/frontend/src/stores/useTrainStore.ts deleted file mode 100644 index 6aeffd9..0000000 --- a/frontend/src/stores/useTrainStore.ts +++ /dev/null @@ -1,40 +0,0 @@ -import { create } from 'zustand'; - -interface TrainingData { - epoch: number; - total_epochs: number; - box_loss: number; - cls_loss: number; - dfl_loss: number; - fitness: number; - epoch_time: number; - left_second: number; -} - -interface StoreState { - trainingDataByProject: { [projectId: string]: TrainingData[] }; - addTrainingData: (projectId: string, data: TrainingData) => void; - resetTrainingData: (projectId: string) => void; -} - -const useTrainStore = create((set) => ({ - trainingDataByProject: {}, - - addTrainingData: (projectId: string, data: TrainingData) => - set((state) => ({ - trainingDataByProject: { - ...state.trainingDataByProject, - [projectId]: [...(state.trainingDataByProject[projectId] || []), data], - }, - })), - - resetTrainingData: (projectId: string) => - set((state) => ({ - trainingDataByProject: { - ...state.trainingDataByProject, - [projectId]: [], - }, - })), -})); - -export default useTrainStore;