From 9ba7e677bcd01138be9d26cdf88efed328a854de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EC=A0=95=ED=98=84=EC=A1=B0?= Date: Mon, 30 Sep 2024 08:47:37 +0900 Subject: [PATCH] =?UTF-8?q?Refactor:=20=EB=AA=A8=EB=8D=B8=20=ED=95=99?= =?UTF-8?q?=EC=8A=B5=20=EB=A6=AC=ED=8C=A9=ED=86=A0=EB=A7=81=20=EC=A4=91,?= =?UTF-8?q?=20=ED=85=8C=EC=8A=A4=ED=8A=B8=20=ED=95=84=EC=9A=94?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../ModelManage/TrainingSettings.tsx | 16 +++----- .../components/ModelManage/TrainingTab.tsx | 37 +++++++++++++++++-- .../src/queries/models/useTrainModelQuery.ts | 11 +++++- 3 files changed, 48 insertions(+), 16 deletions(-) diff --git a/frontend/src/components/ModelManage/TrainingSettings.tsx b/frontend/src/components/ModelManage/TrainingSettings.tsx index 1333638..016c4f6 100644 --- a/frontend/src/components/ModelManage/TrainingSettings.tsx +++ b/frontend/src/components/ModelManage/TrainingSettings.tsx @@ -1,11 +1,10 @@ -import { useState, useEffect } from 'react'; +import { useState } from 'react'; import { Button } from '@/components/ui/button'; import SelectWithLabel from './SelectWithLabel'; import InputWithLabel from './InputWithLabel'; import useProjectModelsQuery from '@/queries/models/useProjectModelsQuery'; import { ModelTrainRequest, ModelResponse } from '@/types'; import { cn } from '@/lib/utils'; -import { useQueryClient } from '@tanstack/react-query'; interface TrainingSettingsProps { projectId: number | null; @@ -13,6 +12,7 @@ interface TrainingSettingsProps { setSelectedModel: (model: ModelResponse | null) => void; handleTrainingStart: (trainData: ModelTrainRequest) => void; handleTrainingStop: () => void; + isPolling: boolean; className?: string; } @@ -22,6 +22,7 @@ export default function TrainingSettings({ setSelectedModel, handleTrainingStart, handleTrainingStop, + isPolling, className, }: TrainingSettingsProps) { const { data: models } = useProjectModelsQuery(projectId ?? 0); @@ -31,13 +32,6 @@ export default function TrainingSettings({ const [optimizer, setOptimizer] = useState<'SGD' | 'AUTO' | 'ADAM' | 'ADAMW' | 'NADAM' | 'RADAM' | 'RMSPROP'>('AUTO'); const [lr0, setLr0] = useState(0.01); const [lrf, setLrf] = useState(0.001); - const queryClient = useQueryClient(); - - useEffect(() => { - if (selectedModel?.isTrain) { - queryClient.invalidateQueries({ queryKey: ['projectModels', projectId] }); - } - }, [selectedModel?.isTrain, queryClient, projectId]); const handleSubmit = () => { if (selectedModel) { @@ -136,9 +130,9 @@ export default function TrainingSettings({ variant="outlinePrimary" size="lg" onClick={handleSubmit} - disabled={!selectedModel} + disabled={!selectedModel || isPolling} > - 학습 시작 + {isPolling ? '대기 중...' : '학습 시작'} )} diff --git a/frontend/src/components/ModelManage/TrainingTab.tsx b/frontend/src/components/ModelManage/TrainingTab.tsx index c1813aa..97a0928 100644 --- a/frontend/src/components/ModelManage/TrainingTab.tsx +++ b/frontend/src/components/ModelManage/TrainingTab.tsx @@ -1,8 +1,9 @@ -import { useState } from 'react'; +import { useState, useEffect } from 'react'; import TrainingSettings from './TrainingSettings'; import TrainingGraph from './TrainingGraph'; import useTrainModelQuery from '@/queries/models/useTrainModelQuery'; import { ModelTrainRequest, ModelResponse } from '@/types'; +import { useQueryClient } from '@tanstack/react-query'; interface TrainingTabProps { projectId: number | null; @@ -11,15 +12,43 @@ interface TrainingTabProps { export default function TrainingTab({ projectId }: TrainingTabProps) { const numericProjectId = projectId !== null ? Number(projectId) : null; const [selectedModel, setSelectedModel] = useState(null); + const [isPolling, setIsPolling] = useState(false); + const queryClient = useQueryClient(); - const { mutate: startTraining } = useTrainModelQuery(numericProjectId as number); + const { mutate: startTraining } = useTrainModelQuery(numericProjectId as number, { + onSuccess: () => { + setIsPolling(true); + }, + onError: () => { + alert('학습 요청 실패'); + setIsPolling(false); + }, + }); const handleTrainingStart = (trainData: ModelTrainRequest) => { startTraining(trainData); }; + useEffect(() => { + if (!selectedModel || !numericProjectId || !isPolling) return; + + const intervalId = setInterval(() => { + queryClient.invalidateQueries({ queryKey: ['projectModels', numericProjectId] }); + }, 2000); + + const timeoutId = setTimeout(() => { + clearInterval(intervalId); + setIsPolling(false); + }, 30000); + + return () => { + clearInterval(intervalId); + clearTimeout(timeoutId); + }; + }, [selectedModel, numericProjectId, queryClient, isPolling]); + const handleTrainingStop = () => { - // Todo: 학습 중단 로직 + setIsPolling(false); }; return ( @@ -30,9 +59,9 @@ export default function TrainingTab({ projectId }: TrainingTabProps) { setSelectedModel={setSelectedModel} handleTrainingStart={handleTrainingStart} handleTrainingStop={handleTrainingStop} + isPolling={isPolling} className="h-full" /> - void; + onError?: (error: unknown) => void; +} + +export default function useTrainModelQuery(projectId: number, options?: UseTrainModelOptions) { return useMutation({ mutationFn: (trainData: ModelTrainRequest) => trainModel(projectId, trainData), onSuccess: () => { queryClient.invalidateQueries({ queryKey: ['projectModels', projectId] }); + options?.onSuccess?.(); + }, + onError: (error) => { + options?.onError?.(error); }, }); }