Refactor: 학습 부분 리팩토링

This commit is contained in:
정현조 2024-09-26 17:18:17 +09:00
parent 7788e49217
commit 24dd9b1d1b
3 changed files with 91 additions and 45 deletions

View File

@ -10,50 +10,75 @@ interface TrainingGraphProps {
}
export default function TrainingGraph({ projectId, selectedModel, className }: TrainingGraphProps) {
const { isTrainingByProject, setIsTraining, saveTrainingData, resetTrainingData, trainingDataByProject } =
useModelStore((state) => ({
isTrainingByProject: state.isTrainingByProject,
setIsTraining: state.setIsTraining,
saveTrainingData: state.saveTrainingData,
resetTrainingData: state.resetTrainingData,
trainingDataByProject: state.trainingDataByProject,
}));
const projectKey = projectId?.toString() || '';
const isTraining = isTrainingByProject[projectId?.toString() || ''] || false;
const {
isTrainingByProject,
isTrainingCompleteByProject,
setIsTraining,
setIsTrainingComplete,
saveTrainingData,
resetTrainingData,
trainingDataByProject,
selectModel,
} = useModelStore((state) => ({
isTrainingByProject: state.isTrainingByProject,
isTrainingCompleteByProject: state.isTrainingCompleteByProject,
setIsTraining: state.setIsTraining,
setIsTrainingComplete: state.setIsTrainingComplete,
saveTrainingData: state.saveTrainingData,
resetTrainingData: state.resetTrainingData,
trainingDataByProject: state.trainingDataByProject,
selectModel: state.selectModel,
}));
const isTraining = isTrainingByProject[projectKey] || false;
const isTrainingComplete = isTrainingCompleteByProject[projectKey] || false;
useEffect(() => {
if (projectId !== null) {
selectModel(projectKey, selectedModel);
}
}, [selectedModel, projectId, projectKey, selectModel]);
const { data: fetchedTrainingDataList } = usePollingModelReportsQuery(
projectId as number,
selectedModel ?? 0,
selectedModel as number,
isTraining && !!projectId && !!selectedModel
);
const trainingDataList = useMemo(() => {
return trainingDataByProject[projectId?.toString() || ''] || fetchedTrainingDataList || [];
}, [projectId, trainingDataByProject, fetchedTrainingDataList]);
if (!isTraining) {
return [];
}
return trainingDataByProject[projectKey] || fetchedTrainingDataList || [];
}, [isTraining, projectKey, trainingDataByProject, fetchedTrainingDataList]);
useEffect(() => {
if (fetchedTrainingDataList) {
saveTrainingData(projectId?.toString() || '', fetchedTrainingDataList);
saveTrainingData(projectKey, fetchedTrainingDataList);
}
}, [fetchedTrainingDataList, projectId, saveTrainingData]);
const latestData = useMemo(() => {
return (
trainingDataList?.[trainingDataList.length - 1] || {
epoch: 0,
totalEpochs: 0,
leftSecond: 0,
}
);
}, [trainingDataList]);
}, [fetchedTrainingDataList, projectKey, saveTrainingData]);
useEffect(() => {
if (latestData.epoch === latestData.totalEpochs && latestData.totalEpochs > 0) {
alert('학습이 완료되었습니다!');
setIsTraining(projectId?.toString() || '', false);
resetTrainingData(projectId?.toString() || '');
if (isTraining && trainingDataList.length > 0) {
const latestData = trainingDataList[trainingDataList.length - 1];
if (latestData.epoch === latestData.totalEpochs && latestData.totalEpochs > 0) {
setIsTrainingComplete(projectKey, true);
} else {
setIsTrainingComplete(projectKey, false);
}
}
}, [latestData.epoch, latestData.totalEpochs, setIsTraining, resetTrainingData, projectId]);
}, [trainingDataList, setIsTrainingComplete, projectKey, isTraining]);
useEffect(() => {
if (isTrainingComplete) {
alert('학습이 완료되었습니다!');
setIsTraining(projectKey, false);
resetTrainingData(projectKey);
setIsTrainingComplete(projectKey, false);
}
}, [isTrainingComplete, setIsTraining, resetTrainingData, setIsTrainingComplete, projectKey]);
return (
<ModelLineChart

View File

@ -15,26 +15,27 @@ export default function TrainingTab({ projectId }: TrainingTabProps) {
isTrainingByProject: state.isTrainingByProject,
setIsTraining: state.setIsTraining,
selectedModelByProject: state.selectedModelByProject,
setSelectedModel: state.setSelectedModel,
setSelectedModel: state.selectModel,
resetTrainingData: state.resetTrainingData,
}));
const isTraining = isTrainingByProject[numericProjectId?.toString() || ''] || false;
const selectedModel = selectedModelByProject[numericProjectId?.toString() || ''];
const projectKey = numericProjectId?.toString() || '';
const isTraining = isTrainingByProject[projectKey] || false;
const selectedModel = selectedModelByProject[projectKey];
const { mutate: startTraining } = useTrainModelQuery(numericProjectId as number);
const handleTrainingStart = (trainData: ModelTrainRequest) => {
if (!isTraining && selectedModel !== null) {
setIsTraining(numericProjectId?.toString() || '', true);
setIsTraining(projectKey, true);
startTraining(trainData);
}
};
const handleTrainingStop = () => {
if (isTraining) {
setIsTraining(numericProjectId?.toString() || '', false);
resetTrainingData(numericProjectId?.toString() || '');
setIsTraining(projectKey, false);
resetTrainingData(projectKey);
}
};
@ -43,7 +44,7 @@ export default function TrainingTab({ projectId }: TrainingTabProps) {
<TrainingSettings
projectId={numericProjectId}
selectedModel={selectedModel}
setSelectedModel={(modelId) => setSelectedModel(numericProjectId?.toString() || '', modelId)}
setSelectedModel={(modelId) => setSelectedModel(projectKey, modelId)}
handleTrainingStart={handleTrainingStart}
handleTrainingStop={handleTrainingStop}
className="h-full"

View File

@ -4,17 +4,22 @@ import { ReportResponse } from '@/types';
interface ModelStoreState {
trainingDataByProject: Record<string, ReportResponse[]>;
isTrainingByProject: Record<string, boolean>;
isTrainingCompleteByProject: Record<string, boolean>;
selectedModelByProject: Record<string, number | null>;
setIsTraining: (projectId: string, status: boolean) => void;
setIsTrainingComplete: (projectId: string, status: boolean) => void;
saveTrainingData: (projectId: string, data: ReportResponse[]) => void;
setSelectedModel: (projectId: string, modelId: number | null) => void;
resetTrainingData: (projectId: string) => void;
selectModel: (projectId: string, modelId: number | null) => void;
}
const useModelStore = create<ModelStoreState>((set) => ({
trainingDataByProject: {},
isTrainingByProject: {},
isTrainingCompleteByProject: {},
selectedModelByProject: {},
setIsTraining: (projectId, status) =>
set((state) => ({
isTrainingByProject: {
@ -22,6 +27,15 @@ const useModelStore = create<ModelStoreState>((set) => ({
[projectId]: status,
},
})),
setIsTrainingComplete: (projectId, status) =>
set((state) => ({
isTrainingCompleteByProject: {
...state.isTrainingCompleteByProject,
[projectId]: status,
},
})),
saveTrainingData: (projectId, data) =>
set((state) => ({
trainingDataByProject: {
@ -29,27 +43,33 @@ const useModelStore = create<ModelStoreState>((set) => ({
[projectId]: data,
},
})),
setSelectedModel: (projectId, modelId) =>
set((state) => ({
selectedModelByProject: {
...state.selectedModelByProject,
[projectId]: modelId,
},
})),
resetTrainingData: (projectId) =>
set((state) => ({
trainingDataByProject: {
...state.trainingDataByProject,
[projectId]: [],
},
})),
selectModel: (projectId, modelId) =>
set((state) => ({
selectedModelByProject: {
...state.selectedModelByProject,
[projectId]: null,
[projectId]: modelId,
},
trainingDataByProject: {
...state.trainingDataByProject,
[projectId]: [],
},
isTrainingByProject: {
...state.isTrainingByProject,
[projectId]: false,
},
isTrainingCompleteByProject: {
...state.isTrainingCompleteByProject,
[projectId]: false,
},
})),
}));