From f5e00d5b4217132b2ae885294113347ff0732ace Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EC=A0=95=ED=98=84=EC=A1=B0?= Date: Sun, 29 Sep 2024 23:50:42 +0900 Subject: [PATCH] =?UTF-8?q?Refactor:=20=EB=AA=A8=EB=8D=B8=20=EB=A6=AC?= =?UTF-8?q?=ED=8C=A9=ED=86=A0=EB=A7=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../components/ModelManage/TrainingGraph.tsx | 23 +++++---- .../ModelManage/TrainingSettings.tsx | 51 +++++++------------ .../components/ModelManage/TrainingTab.tsx | 10 ++-- .../src/queries/models/useTrainModelQuery.ts | 6 +++ .../reports/usePollingModelReportsQuery.ts | 2 +- 5 files changed, 45 insertions(+), 47 deletions(-) diff --git a/frontend/src/components/ModelManage/TrainingGraph.tsx b/frontend/src/components/ModelManage/TrainingGraph.tsx index cb9c473..8f8ca15 100644 --- a/frontend/src/components/ModelManage/TrainingGraph.tsx +++ b/frontend/src/components/ModelManage/TrainingGraph.tsx @@ -1,6 +1,7 @@ -import { useEffect, useState } from 'react'; +import { useEffect } from 'react'; import ModelLineChart from './ModelLineChart'; -import usePollingModelReportsQuery from '@/queries/reports/usePollingModelReportsQuery'; +import usePollingTrainingModelReport from '@/queries/reports/usePollingModelReportsQuery'; +import { useQueryClient } from '@tanstack/react-query'; import { ModelResponse } from '@/types'; interface TrainingGraphProps { @@ -10,20 +11,22 @@ interface TrainingGraphProps { } export default function TrainingGraph({ projectId, selectedModel, className }: TrainingGraphProps) { - const [isPolling, setIsPolling] = useState(false); - const { data: trainingDataList } = usePollingModelReportsQuery( + const queryClient = useQueryClient(); + + const { data: trainingDataList } = usePollingTrainingModelReport( projectId as number, selectedModel?.id as number, - isPolling + selectedModel?.isTrain || false ); useEffect(() => { - if (selectedModel) { - setIsPolling(true); - } else { - setIsPolling(false); + if (!selectedModel || !selectedModel.isTrain) { + queryClient.resetQueries({ + queryKey: [{ type: 'modelReports', projectId, modelId: selectedModel?.id }], + exact: true, + }); } - }, [selectedModel]); + }, [selectedModel, queryClient, projectId]); return ( ('AUTO'); const [lr0, setLr0] = useState(0.01); const [lrf, setLrf] = useState(0.001); - const [isSubmitting, setIsSubmitting] = useState(false); const queryClient = useQueryClient(); - const intervalRef = useRef(null); + + useEffect(() => { + if (selectedModel?.isTrain) { + queryClient.invalidateQueries({ queryKey: ['projectModels', projectId] }); + } + }, [selectedModel?.isTrain, queryClient, projectId]); const handleSubmit = () => { - if (selectedModel?.isTrain) { - handleTrainingStop(); - } else if (selectedModel) { + if (selectedModel) { const trainData: ModelTrainRequest = { modelId: selectedModel.id, ratio, @@ -48,34 +50,10 @@ export default function TrainingSettings({ lr0, lrf, }; - setIsSubmitting(true); handleTrainingStart(trainData); } }; - useEffect(() => { - if (isSubmitting) { - intervalRef.current = setInterval(() => { - queryClient.invalidateQueries({ queryKey: ['projectModels', projectId] }); - }, 1000); - } else if (intervalRef.current) { - clearInterval(intervalRef.current); - intervalRef.current = null; - } - - return () => { - if (intervalRef.current) { - clearInterval(intervalRef.current); - } - }; - }, [isSubmitting, queryClient, projectId]); - - useEffect(() => { - if (selectedModel?.isTrain) { - setIsSubmitting(false); - } - }, [selectedModel]); - return (
모델 설정 @@ -158,12 +136,21 @@ export default function TrainingSettings({ variant="outlinePrimary" size="lg" onClick={handleSubmit} - disabled={!selectedModel || isSubmitting} + disabled={!selectedModel} > - {isSubmitting ? '기다리는 중...' : '학습 시작'} + 학습 시작 )} + {selectedModel?.isTrain && ( + + )}
); } diff --git a/frontend/src/components/ModelManage/TrainingTab.tsx b/frontend/src/components/ModelManage/TrainingTab.tsx index c3feb96..c1813aa 100644 --- a/frontend/src/components/ModelManage/TrainingTab.tsx +++ b/frontend/src/components/ModelManage/TrainingTab.tsx @@ -1,15 +1,15 @@ -import useTrainModelQuery from '@/queries/models/useTrainModelQuery'; +import { useState } from 'react'; import TrainingSettings from './TrainingSettings'; import TrainingGraph from './TrainingGraph'; +import useTrainModelQuery from '@/queries/models/useTrainModelQuery'; import { ModelTrainRequest, ModelResponse } from '@/types'; -import { useState } from 'react'; interface TrainingTabProps { projectId: number | null; } export default function TrainingTab({ projectId }: TrainingTabProps) { - const numericProjectId = projectId ? parseInt(projectId.toString(), 10) : null; + const numericProjectId = projectId !== null ? Number(projectId) : null; const [selectedModel, setSelectedModel] = useState(null); const { mutate: startTraining } = useTrainModelQuery(numericProjectId as number); @@ -18,7 +18,9 @@ export default function TrainingTab({ projectId }: TrainingTabProps) { startTraining(trainData); }; - const handleTrainingStop = () => {}; + const handleTrainingStop = () => { + // Todo: 학습 중단 로직 + }; return (
diff --git a/frontend/src/queries/models/useTrainModelQuery.ts b/frontend/src/queries/models/useTrainModelQuery.ts index 4a44d22..036bf64 100644 --- a/frontend/src/queries/models/useTrainModelQuery.ts +++ b/frontend/src/queries/models/useTrainModelQuery.ts @@ -1,9 +1,15 @@ import { useMutation } from '@tanstack/react-query'; import { trainModel } from '@/api/modelApi'; import { ModelTrainRequest } from '@/types'; +import { QueryClient } from '@tanstack/react-query'; + +const queryClient = new QueryClient(); export default function useTrainModelQuery(projectId: number) { return useMutation({ mutationFn: (trainData: ModelTrainRequest) => trainModel(projectId, trainData), + onSuccess: () => { + queryClient.invalidateQueries({ queryKey: ['projectModels', projectId] }); + }, }); } diff --git a/frontend/src/queries/reports/usePollingModelReportsQuery.ts b/frontend/src/queries/reports/usePollingModelReportsQuery.ts index 16499e0..962968e 100644 --- a/frontend/src/queries/reports/usePollingModelReportsQuery.ts +++ b/frontend/src/queries/reports/usePollingModelReportsQuery.ts @@ -6,7 +6,7 @@ export default function usePollingTrainingModelReport(projectId: number, modelId return useQuery({ queryKey: ['modelReports', projectId, modelId], queryFn: () => getTrainingModelReport(projectId, modelId), - refetchInterval: 5000, + refetchInterval: enabled ? 5000 : false, enabled, }); }