From 24dd9b1d1b7eca7c5144a1b9c261e5f158e5bdda Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EC=A0=95=ED=98=84=EC=A1=B0?= Date: Thu, 26 Sep 2024 17:18:17 +0900 Subject: [PATCH] =?UTF-8?q?Refactor:=20=ED=95=99=EC=8A=B5=20=EB=B6=80?= =?UTF-8?q?=EB=B6=84=20=EB=A6=AC=ED=8C=A9=ED=86=A0=EB=A7=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../components/ModelManage/TrainingGraph.tsx | 83 ++++++++++++------- .../components/ModelManage/TrainingTab.tsx | 15 ++-- frontend/src/stores/useModelStore.ts | 38 +++++++-- 3 files changed, 91 insertions(+), 45 deletions(-) diff --git a/frontend/src/components/ModelManage/TrainingGraph.tsx b/frontend/src/components/ModelManage/TrainingGraph.tsx index 2072916..3c3e6e8 100644 --- a/frontend/src/components/ModelManage/TrainingGraph.tsx +++ b/frontend/src/components/ModelManage/TrainingGraph.tsx @@ -10,50 +10,75 @@ interface TrainingGraphProps { } export default function TrainingGraph({ projectId, selectedModel, className }: TrainingGraphProps) { - const { isTrainingByProject, setIsTraining, saveTrainingData, resetTrainingData, trainingDataByProject } = - useModelStore((state) => ({ - isTrainingByProject: state.isTrainingByProject, - setIsTraining: state.setIsTraining, - saveTrainingData: state.saveTrainingData, - resetTrainingData: state.resetTrainingData, - trainingDataByProject: state.trainingDataByProject, - })); + const projectKey = projectId?.toString() || ''; - const isTraining = isTrainingByProject[projectId?.toString() || ''] || false; + const { + isTrainingByProject, + isTrainingCompleteByProject, + setIsTraining, + setIsTrainingComplete, + saveTrainingData, + resetTrainingData, + trainingDataByProject, + selectModel, + } = useModelStore((state) => ({ + isTrainingByProject: state.isTrainingByProject, + isTrainingCompleteByProject: state.isTrainingCompleteByProject, + setIsTraining: state.setIsTraining, + setIsTrainingComplete: state.setIsTrainingComplete, + saveTrainingData: state.saveTrainingData, + resetTrainingData: state.resetTrainingData, + trainingDataByProject: state.trainingDataByProject, + selectModel: state.selectModel, + })); + + const isTraining = isTrainingByProject[projectKey] || false; + const isTrainingComplete = isTrainingCompleteByProject[projectKey] || false; + + useEffect(() => { + if (projectId !== null) { + selectModel(projectKey, selectedModel); + } + }, [selectedModel, projectId, projectKey, selectModel]); const { data: fetchedTrainingDataList } = usePollingModelReportsQuery( projectId as number, - selectedModel ?? 0, + selectedModel as number, isTraining && !!projectId && !!selectedModel ); const trainingDataList = useMemo(() => { - return trainingDataByProject[projectId?.toString() || ''] || fetchedTrainingDataList || []; - }, [projectId, trainingDataByProject, fetchedTrainingDataList]); + if (!isTraining) { + return []; + } + return trainingDataByProject[projectKey] || fetchedTrainingDataList || []; + }, [isTraining, projectKey, trainingDataByProject, fetchedTrainingDataList]); useEffect(() => { if (fetchedTrainingDataList) { - saveTrainingData(projectId?.toString() || '', fetchedTrainingDataList); + saveTrainingData(projectKey, fetchedTrainingDataList); } - }, [fetchedTrainingDataList, projectId, saveTrainingData]); - - const latestData = useMemo(() => { - return ( - trainingDataList?.[trainingDataList.length - 1] || { - epoch: 0, - totalEpochs: 0, - leftSecond: 0, - } - ); - }, [trainingDataList]); + }, [fetchedTrainingDataList, projectKey, saveTrainingData]); useEffect(() => { - if (latestData.epoch === latestData.totalEpochs && latestData.totalEpochs > 0) { - alert('학습이 완료되었습니다!'); - setIsTraining(projectId?.toString() || '', false); - resetTrainingData(projectId?.toString() || ''); + if (isTraining && trainingDataList.length > 0) { + const latestData = trainingDataList[trainingDataList.length - 1]; + if (latestData.epoch === latestData.totalEpochs && latestData.totalEpochs > 0) { + setIsTrainingComplete(projectKey, true); + } else { + setIsTrainingComplete(projectKey, false); + } } - }, [latestData.epoch, latestData.totalEpochs, setIsTraining, resetTrainingData, projectId]); + }, [trainingDataList, setIsTrainingComplete, projectKey, isTraining]); + + useEffect(() => { + if (isTrainingComplete) { + alert('학습이 완료되었습니다!'); + setIsTraining(projectKey, false); + resetTrainingData(projectKey); + setIsTrainingComplete(projectKey, false); + } + }, [isTrainingComplete, setIsTraining, resetTrainingData, setIsTrainingComplete, projectKey]); return ( { if (!isTraining && selectedModel !== null) { - setIsTraining(numericProjectId?.toString() || '', true); + setIsTraining(projectKey, true); startTraining(trainData); } }; const handleTrainingStop = () => { if (isTraining) { - setIsTraining(numericProjectId?.toString() || '', false); - resetTrainingData(numericProjectId?.toString() || ''); + setIsTraining(projectKey, false); + resetTrainingData(projectKey); } }; @@ -43,7 +44,7 @@ export default function TrainingTab({ projectId }: TrainingTabProps) { setSelectedModel(numericProjectId?.toString() || '', modelId)} + setSelectedModel={(modelId) => setSelectedModel(projectKey, modelId)} handleTrainingStart={handleTrainingStart} handleTrainingStop={handleTrainingStop} className="h-full" diff --git a/frontend/src/stores/useModelStore.ts b/frontend/src/stores/useModelStore.ts index f6ba8a5..a5b24c4 100644 --- a/frontend/src/stores/useModelStore.ts +++ b/frontend/src/stores/useModelStore.ts @@ -4,17 +4,22 @@ import { ReportResponse } from '@/types'; interface ModelStoreState { trainingDataByProject: Record; isTrainingByProject: Record; + isTrainingCompleteByProject: Record; selectedModelByProject: Record; + setIsTraining: (projectId: string, status: boolean) => void; + setIsTrainingComplete: (projectId: string, status: boolean) => void; saveTrainingData: (projectId: string, data: ReportResponse[]) => void; - setSelectedModel: (projectId: string, modelId: number | null) => void; resetTrainingData: (projectId: string) => void; + selectModel: (projectId: string, modelId: number | null) => void; } const useModelStore = create((set) => ({ trainingDataByProject: {}, isTrainingByProject: {}, + isTrainingCompleteByProject: {}, selectedModelByProject: {}, + setIsTraining: (projectId, status) => set((state) => ({ isTrainingByProject: { @@ -22,6 +27,15 @@ const useModelStore = create((set) => ({ [projectId]: status, }, })), + + setIsTrainingComplete: (projectId, status) => + set((state) => ({ + isTrainingCompleteByProject: { + ...state.isTrainingCompleteByProject, + [projectId]: status, + }, + })), + saveTrainingData: (projectId, data) => set((state) => ({ trainingDataByProject: { @@ -29,27 +43,33 @@ const useModelStore = create((set) => ({ [projectId]: data, }, })), - setSelectedModel: (projectId, modelId) => - set((state) => ({ - selectedModelByProject: { - ...state.selectedModelByProject, - [projectId]: modelId, - }, - })), + resetTrainingData: (projectId) => set((state) => ({ trainingDataByProject: { ...state.trainingDataByProject, [projectId]: [], }, + })), + + selectModel: (projectId, modelId) => + set((state) => ({ selectedModelByProject: { ...state.selectedModelByProject, - [projectId]: null, + [projectId]: modelId, + }, + trainingDataByProject: { + ...state.trainingDataByProject, + [projectId]: [], }, isTrainingByProject: { ...state.isTrainingByProject, [projectId]: false, }, + isTrainingCompleteByProject: { + ...state.isTrainingCompleteByProject, + [projectId]: false, + }, })), }));