Refactor: 학습 부분 리팩토링

This commit is contained in:
정현조 2024-09-26 17:18:17 +09:00
parent 7788e49217
commit 24dd9b1d1b
3 changed files with 91 additions and 45 deletions

View File

@ -10,50 +10,75 @@ interface TrainingGraphProps {
} }
export default function TrainingGraph({ projectId, selectedModel, className }: TrainingGraphProps) { export default function TrainingGraph({ projectId, selectedModel, className }: TrainingGraphProps) {
const { isTrainingByProject, setIsTraining, saveTrainingData, resetTrainingData, trainingDataByProject } = const projectKey = projectId?.toString() || '';
useModelStore((state) => ({
isTrainingByProject: state.isTrainingByProject,
setIsTraining: state.setIsTraining,
saveTrainingData: state.saveTrainingData,
resetTrainingData: state.resetTrainingData,
trainingDataByProject: state.trainingDataByProject,
}));
const isTraining = isTrainingByProject[projectId?.toString() || ''] || false; const {
isTrainingByProject,
isTrainingCompleteByProject,
setIsTraining,
setIsTrainingComplete,
saveTrainingData,
resetTrainingData,
trainingDataByProject,
selectModel,
} = useModelStore((state) => ({
isTrainingByProject: state.isTrainingByProject,
isTrainingCompleteByProject: state.isTrainingCompleteByProject,
setIsTraining: state.setIsTraining,
setIsTrainingComplete: state.setIsTrainingComplete,
saveTrainingData: state.saveTrainingData,
resetTrainingData: state.resetTrainingData,
trainingDataByProject: state.trainingDataByProject,
selectModel: state.selectModel,
}));
const isTraining = isTrainingByProject[projectKey] || false;
const isTrainingComplete = isTrainingCompleteByProject[projectKey] || false;
useEffect(() => {
if (projectId !== null) {
selectModel(projectKey, selectedModel);
}
}, [selectedModel, projectId, projectKey, selectModel]);
const { data: fetchedTrainingDataList } = usePollingModelReportsQuery( const { data: fetchedTrainingDataList } = usePollingModelReportsQuery(
projectId as number, projectId as number,
selectedModel ?? 0, selectedModel as number,
isTraining && !!projectId && !!selectedModel isTraining && !!projectId && !!selectedModel
); );
const trainingDataList = useMemo(() => { const trainingDataList = useMemo(() => {
return trainingDataByProject[projectId?.toString() || ''] || fetchedTrainingDataList || []; if (!isTraining) {
}, [projectId, trainingDataByProject, fetchedTrainingDataList]); return [];
}
return trainingDataByProject[projectKey] || fetchedTrainingDataList || [];
}, [isTraining, projectKey, trainingDataByProject, fetchedTrainingDataList]);
useEffect(() => { useEffect(() => {
if (fetchedTrainingDataList) { if (fetchedTrainingDataList) {
saveTrainingData(projectId?.toString() || '', fetchedTrainingDataList); saveTrainingData(projectKey, fetchedTrainingDataList);
} }
}, [fetchedTrainingDataList, projectId, saveTrainingData]); }, [fetchedTrainingDataList, projectKey, saveTrainingData]);
const latestData = useMemo(() => {
return (
trainingDataList?.[trainingDataList.length - 1] || {
epoch: 0,
totalEpochs: 0,
leftSecond: 0,
}
);
}, [trainingDataList]);
useEffect(() => { useEffect(() => {
if (latestData.epoch === latestData.totalEpochs && latestData.totalEpochs > 0) { if (isTraining && trainingDataList.length > 0) {
alert('학습이 완료되었습니다!'); const latestData = trainingDataList[trainingDataList.length - 1];
setIsTraining(projectId?.toString() || '', false); if (latestData.epoch === latestData.totalEpochs && latestData.totalEpochs > 0) {
resetTrainingData(projectId?.toString() || ''); setIsTrainingComplete(projectKey, true);
} else {
setIsTrainingComplete(projectKey, false);
}
} }
}, [latestData.epoch, latestData.totalEpochs, setIsTraining, resetTrainingData, projectId]); }, [trainingDataList, setIsTrainingComplete, projectKey, isTraining]);
useEffect(() => {
if (isTrainingComplete) {
alert('학습이 완료되었습니다!');
setIsTraining(projectKey, false);
resetTrainingData(projectKey);
setIsTrainingComplete(projectKey, false);
}
}, [isTrainingComplete, setIsTraining, resetTrainingData, setIsTrainingComplete, projectKey]);
return ( return (
<ModelLineChart <ModelLineChart

View File

@ -15,26 +15,27 @@ export default function TrainingTab({ projectId }: TrainingTabProps) {
isTrainingByProject: state.isTrainingByProject, isTrainingByProject: state.isTrainingByProject,
setIsTraining: state.setIsTraining, setIsTraining: state.setIsTraining,
selectedModelByProject: state.selectedModelByProject, selectedModelByProject: state.selectedModelByProject,
setSelectedModel: state.setSelectedModel, setSelectedModel: state.selectModel,
resetTrainingData: state.resetTrainingData, resetTrainingData: state.resetTrainingData,
})); }));
const isTraining = isTrainingByProject[numericProjectId?.toString() || ''] || false; const projectKey = numericProjectId?.toString() || '';
const selectedModel = selectedModelByProject[numericProjectId?.toString() || '']; const isTraining = isTrainingByProject[projectKey] || false;
const selectedModel = selectedModelByProject[projectKey];
const { mutate: startTraining } = useTrainModelQuery(numericProjectId as number); const { mutate: startTraining } = useTrainModelQuery(numericProjectId as number);
const handleTrainingStart = (trainData: ModelTrainRequest) => { const handleTrainingStart = (trainData: ModelTrainRequest) => {
if (!isTraining && selectedModel !== null) { if (!isTraining && selectedModel !== null) {
setIsTraining(numericProjectId?.toString() || '', true); setIsTraining(projectKey, true);
startTraining(trainData); startTraining(trainData);
} }
}; };
const handleTrainingStop = () => { const handleTrainingStop = () => {
if (isTraining) { if (isTraining) {
setIsTraining(numericProjectId?.toString() || '', false); setIsTraining(projectKey, false);
resetTrainingData(numericProjectId?.toString() || ''); resetTrainingData(projectKey);
} }
}; };
@ -43,7 +44,7 @@ export default function TrainingTab({ projectId }: TrainingTabProps) {
<TrainingSettings <TrainingSettings
projectId={numericProjectId} projectId={numericProjectId}
selectedModel={selectedModel} selectedModel={selectedModel}
setSelectedModel={(modelId) => setSelectedModel(numericProjectId?.toString() || '', modelId)} setSelectedModel={(modelId) => setSelectedModel(projectKey, modelId)}
handleTrainingStart={handleTrainingStart} handleTrainingStart={handleTrainingStart}
handleTrainingStop={handleTrainingStop} handleTrainingStop={handleTrainingStop}
className="h-full" className="h-full"

View File

@ -4,17 +4,22 @@ import { ReportResponse } from '@/types';
interface ModelStoreState { interface ModelStoreState {
trainingDataByProject: Record<string, ReportResponse[]>; trainingDataByProject: Record<string, ReportResponse[]>;
isTrainingByProject: Record<string, boolean>; isTrainingByProject: Record<string, boolean>;
isTrainingCompleteByProject: Record<string, boolean>;
selectedModelByProject: Record<string, number | null>; selectedModelByProject: Record<string, number | null>;
setIsTraining: (projectId: string, status: boolean) => void; setIsTraining: (projectId: string, status: boolean) => void;
setIsTrainingComplete: (projectId: string, status: boolean) => void;
saveTrainingData: (projectId: string, data: ReportResponse[]) => void; saveTrainingData: (projectId: string, data: ReportResponse[]) => void;
setSelectedModel: (projectId: string, modelId: number | null) => void;
resetTrainingData: (projectId: string) => void; resetTrainingData: (projectId: string) => void;
selectModel: (projectId: string, modelId: number | null) => void;
} }
const useModelStore = create<ModelStoreState>((set) => ({ const useModelStore = create<ModelStoreState>((set) => ({
trainingDataByProject: {}, trainingDataByProject: {},
isTrainingByProject: {}, isTrainingByProject: {},
isTrainingCompleteByProject: {},
selectedModelByProject: {}, selectedModelByProject: {},
setIsTraining: (projectId, status) => setIsTraining: (projectId, status) =>
set((state) => ({ set((state) => ({
isTrainingByProject: { isTrainingByProject: {
@ -22,6 +27,15 @@ const useModelStore = create<ModelStoreState>((set) => ({
[projectId]: status, [projectId]: status,
}, },
})), })),
setIsTrainingComplete: (projectId, status) =>
set((state) => ({
isTrainingCompleteByProject: {
...state.isTrainingCompleteByProject,
[projectId]: status,
},
})),
saveTrainingData: (projectId, data) => saveTrainingData: (projectId, data) =>
set((state) => ({ set((state) => ({
trainingDataByProject: { trainingDataByProject: {
@ -29,27 +43,33 @@ const useModelStore = create<ModelStoreState>((set) => ({
[projectId]: data, [projectId]: data,
}, },
})), })),
setSelectedModel: (projectId, modelId) =>
set((state) => ({
selectedModelByProject: {
...state.selectedModelByProject,
[projectId]: modelId,
},
})),
resetTrainingData: (projectId) => resetTrainingData: (projectId) =>
set((state) => ({ set((state) => ({
trainingDataByProject: { trainingDataByProject: {
...state.trainingDataByProject, ...state.trainingDataByProject,
[projectId]: [], [projectId]: [],
}, },
})),
selectModel: (projectId, modelId) =>
set((state) => ({
selectedModelByProject: { selectedModelByProject: {
...state.selectedModelByProject, ...state.selectedModelByProject,
[projectId]: null, [projectId]: modelId,
},
trainingDataByProject: {
...state.trainingDataByProject,
[projectId]: [],
}, },
isTrainingByProject: { isTrainingByProject: {
...state.isTrainingByProject, ...state.isTrainingByProject,
[projectId]: false, [projectId]: false,
}, },
isTrainingCompleteByProject: {
...state.isTrainingCompleteByProject,
[projectId]: false,
},
})), })),
})); }));