Refactor: 학습 관련 리팩토링

This commit is contained in:
정현조 2024-09-27 12:28:15 +09:00
parent ba639d0c75
commit 8f39909a28
10 changed files with 127 additions and 193 deletions

View File

@ -1,8 +1,8 @@
import { Label } from '@/components/ui/label';
import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from '@/components/ui/select';
import useProjectModelsQuery from '@/queries/models/useProjectModelsQuery';
import useModelReportsQuery from '@/queries/models/useModelReportsQuery';
import useModelResultsQuery from '@/queries/models/useModelResultsQuery';
import useCompletedModelReport from '@/queries/reports/useCompletedModelReport';
import useModelResultsQuery from '@/queries/results/useModelResultQuery';
import ModelBarChart from './ModelBarChart';
import ModelLineChart from './ModelLineChart';
import { useState } from 'react';
@ -68,7 +68,7 @@ interface ModelEvaluationProps {
}
function ModelEvaluation({ projectId, selectedModel }: ModelEvaluationProps) {
const { data: reportData } = useModelReportsQuery(projectId, selectedModel);
const { data: reportData } = useCompletedModelReport(projectId, selectedModel);
const { data: resultData } = useModelResultsQuery(selectedModel);
if (!reportData || !resultData) return null;

View File

@ -1,84 +1,29 @@
import { useEffect, useMemo } from 'react';
import { useMemo } from 'react';
import ModelLineChart from './ModelLineChart';
import usePollingModelReportsQuery from '@/queries/models/usePollingModelReportsQuery';
import useModelStore from '@/stores/useModelStore';
import usePollingTrainingModelReport from '@/queries/reports/usePollingModelReportsQuery';
import { ModelResponse } from '@/types';
interface TrainingGraphProps {
projectId: number | null;
selectedModel: number | null;
selectedModel: ModelResponse | null;
className?: string;
}
export default function TrainingGraph({ projectId, selectedModel, className }: TrainingGraphProps) {
const projectKey = projectId?.toString() || '';
const isTraining = selectedModel?.isTrain || false;
const {
isTrainingByProject,
isTrainingCompleteByProject,
setIsTraining,
setIsTrainingComplete,
saveTrainingData,
resetTrainingData,
trainingDataByProject,
selectModel,
} = useModelStore((state) => ({
isTrainingByProject: state.isTrainingByProject,
isTrainingCompleteByProject: state.isTrainingCompleteByProject,
setIsTraining: state.setIsTraining,
setIsTrainingComplete: state.setIsTrainingComplete,
saveTrainingData: state.saveTrainingData,
resetTrainingData: state.resetTrainingData,
trainingDataByProject: state.trainingDataByProject,
selectModel: state.selectModel,
}));
const isTraining = isTrainingByProject[projectKey] || false;
const isTrainingComplete = isTrainingCompleteByProject[projectKey] || false;
useEffect(() => {
if (projectId !== null) {
selectModel(projectKey, selectedModel);
}
}, [selectedModel, projectId, projectKey, selectModel]);
const { data: fetchedTrainingDataList } = usePollingModelReportsQuery(
const { data: fetchedTrainingDataList } = usePollingTrainingModelReport(
projectId as number,
selectedModel as number,
isTraining && !!projectId && !!selectedModel
selectedModel?.id as number,
isTraining
);
const trainingDataList = useMemo(() => {
if (!isTraining) {
return [];
}
return trainingDataByProject[projectKey] || fetchedTrainingDataList || [];
}, [isTraining, projectKey, trainingDataByProject, fetchedTrainingDataList]);
useEffect(() => {
if (fetchedTrainingDataList) {
saveTrainingData(projectKey, fetchedTrainingDataList);
}
}, [fetchedTrainingDataList, projectKey, saveTrainingData]);
useEffect(() => {
if (isTraining && trainingDataList.length > 0) {
const latestData = trainingDataList[trainingDataList.length - 1];
if (latestData.epoch === latestData.totalEpochs && latestData.totalEpochs > 0) {
setIsTrainingComplete(projectKey, true);
} else {
setIsTrainingComplete(projectKey, false);
}
}
}, [trainingDataList, setIsTrainingComplete, projectKey, isTraining]);
useEffect(() => {
if (isTrainingComplete) {
alert('학습이 완료되었습니다!');
setIsTraining(projectKey, false);
resetTrainingData(projectKey);
setIsTrainingComplete(projectKey, false);
}
}, [isTrainingComplete, setIsTraining, resetTrainingData, setIsTrainingComplete, projectKey]);
return fetchedTrainingDataList || [];
}, [isTraining, fetchedTrainingDataList]);
return (
<ModelLineChart

View File

@ -2,15 +2,14 @@ import SelectWithLabel from './SelectWithLabel';
import InputWithLabel from './InputWithLabel';
import { Button } from '@/components/ui/button';
import useProjectModelsQuery from '@/queries/models/useProjectModelsQuery';
import useModelStore from '@/stores/useModelStore';
import { ModelTrainRequest } from '@/types';
import { ModelTrainRequest, ModelResponse } from '@/types';
import { useState } from 'react';
import { cn } from '@/lib/utils';
interface TrainingSettingsProps {
projectId: number | null;
selectedModel: number | null;
setSelectedModel: (model: number | null) => void;
selectedModel: ModelResponse | null;
setSelectedModel: (model: ModelResponse | null) => void;
handleTrainingStart: (trainData: ModelTrainRequest) => void;
handleTrainingStop: () => void;
className?: string;
@ -25,9 +24,6 @@ export default function TrainingSettings({
className,
}: TrainingSettingsProps) {
const { data: models } = useProjectModelsQuery(projectId ?? 0);
const isTraining = useModelStore((state) => state.isTrainingByProject[projectId?.toString() || ''] || false);
const [ratio, setRatio] = useState<number>(0.8);
const [epochs, setEpochs] = useState<number>(50);
const [batchSize, setBatchSize] = useState<number>(32);
@ -36,11 +32,11 @@ export default function TrainingSettings({
const [lrf, setLrf] = useState<number>(0.001);
const handleSubmit = () => {
if (isTraining) {
if (selectedModel?.isTrain) {
handleTrainingStop();
} else if (selectedModel !== null) {
} else if (selectedModel) {
const trainData: ModelTrainRequest = {
modelId: selectedModel,
modelId: selectedModel.id,
ratio,
epochs,
batch: batchSize,
@ -54,7 +50,6 @@ export default function TrainingSettings({
return (
<fieldset className={cn('grid gap-6 rounded-lg border p-4', className)}>
{' '}
<legend className="-ml-1 px-1 text-sm font-medium"> </legend>
<div className="grid gap-3">
<SelectWithLabel
@ -62,83 +57,85 @@ export default function TrainingSettings({
id="model"
options={
models?.map((model) => ({
label: model.name,
label: `${model.name}${model.isTrain ? ' (학습 중)' : ''}${model.isDefault ? ' (기본)' : ''}`,
value: model.id.toString(),
})) || []
}
placeholder="모델을 선택하세요"
value={selectedModel ? selectedModel.toString() : ''}
onChange={(value) => setSelectedModel(parseInt(value, 10))}
disabled={isTraining}
value={selectedModel ? selectedModel.id.toString() : ''}
onChange={(value) => {
const selected = models?.find((model) => model.id === parseInt(value, 10));
setSelectedModel(selected || null);
}}
/>
</div>
<div className="grid grid-cols-2 gap-4">
<InputWithLabel
label="훈련/검증 비율"
placeholder="예: 0.8 (80% 훈련, 20% 검증)"
id="ratio"
value={ratio}
onChange={(e) => setRatio(parseFloat(e.target.value))}
disabled={isTraining}
/>
<InputWithLabel
label="에포크 수"
placeholder="예: 50 (총 반복 횟수)"
id="epochs"
value={epochs}
onChange={(e) => setEpochs(parseInt(e.target.value, 10))}
disabled={isTraining}
/>
<InputWithLabel
label="Batch 크기"
placeholder="예: 32 (한번에 처리할 샘플 수)"
id="batch"
value={batchSize}
onChange={(e) => setBatchSize(parseInt(e.target.value, 10))}
disabled={isTraining}
/>
<SelectWithLabel
label="옵티마이저"
id="optimizer"
options={[
{ label: 'AUTO', value: 'AUTO' },
{ label: 'SGD', value: 'SGD' },
{ label: 'ADAM', value: 'ADAM' },
{ label: 'ADAMW', value: 'ADAMW' },
{ label: 'NADAM', value: 'NADAM' },
{ label: 'RADAM', value: 'RADAM' },
{ label: 'RMSPROP', value: 'RMSPROP' },
]}
placeholder="옵티마이저 선택"
value={optimizer}
onChange={(value) => setOptimizer(value as 'AUTO' | 'SGD' | 'ADAM' | 'ADAMW' | 'NADAM' | 'RADAM' | 'RMSPROP')}
disabled={isTraining} // 학습 중일 때 옵티마이저 선택 비활성화
/>
<InputWithLabel
label="학습률(LR0)"
placeholder="예: 0.01 (초기 학습률)"
id="lr0"
value={lr0}
onChange={(e) => setLr0(parseFloat(e.target.value))}
disabled={isTraining}
/>
<InputWithLabel
label="최종 학습률(LRF)"
placeholder="예: 0.001 (최종 학습률)"
id="lrf"
value={lrf}
onChange={(e) => setLrf(parseFloat(e.target.value))}
disabled={isTraining}
/>
</div>
<Button
variant="outlinePrimary"
size="lg"
onClick={handleSubmit}
disabled={!selectedModel}
>
{isTraining ? '학습 중단' : '학습 시작'}
</Button>
{!selectedModel?.isTrain && (
<>
<div className="grid grid-cols-2 gap-4">
<InputWithLabel
label="훈련/검증 비율"
id="ratio"
value={ratio}
onChange={(e) => setRatio(parseFloat(e.target.value))}
placeholder="훈련/검증 비율"
/>
<InputWithLabel
label="에포크 수"
id="epochs"
value={epochs}
onChange={(e) => setEpochs(parseInt(e.target.value, 10))}
placeholder="에포크 수"
/>
<InputWithLabel
label="Batch 크기"
id="batch"
value={batchSize}
onChange={(e) => setBatchSize(parseInt(e.target.value, 10))}
placeholder="Batch 크기"
/>
<SelectWithLabel
label="옵티마이저"
id="optimizer"
options={[
{ label: 'AUTO', value: 'AUTO' },
{ label: 'SGD', value: 'SGD' },
{ label: 'ADAM', value: 'ADAM' },
{ label: 'ADAMW', value: 'ADAMW' },
{ label: 'NADAM', value: 'NADAM' },
{ label: 'RADAM', value: 'RADAM' },
{ label: 'RMSPROP', value: 'RMSPROP' },
]}
value={optimizer}
onChange={(value) =>
setOptimizer(value as 'AUTO' | 'SGD' | 'ADAM' | 'ADAMW' | 'NADAM' | 'RADAM' | 'RMSPROP')
}
placeholder="옵티마이저"
/>
<InputWithLabel
label="학습률(LR0)"
id="lr0"
value={lr0}
onChange={(e) => setLr0(parseFloat(e.target.value))}
placeholder="초기 학습률"
/>
<InputWithLabel
label="최종 학습률(LRF)"
id="lrf"
value={lrf}
onChange={(e) => setLrf(parseFloat(e.target.value))}
placeholder="최종 학습률"
/>
</div>
<Button
variant="outlinePrimary"
size="lg"
onClick={handleSubmit}
disabled={!selectedModel}
>
{selectedModel?.isTrain ? '학습 중단' : '학습 시작'}
</Button>
</>
)}
</fieldset>
);
}

View File

@ -1,50 +1,31 @@
import useTrainModelQuery from '@/queries/models/useTrainModelQuery';
import useModelStore from '@/stores/useModelStore';
import TrainingSettings from './TrainingSettings';
import TrainingGraph from './TrainingGraph';
import { ModelTrainRequest } from '@/types';
import { ModelTrainRequest, ModelResponse } from '@/types';
import { useState } from 'react';
interface TrainingTabProps {
projectId: number | null;
}
//Todo : 로직 수정, isTrain을 서버 단에서 받고, 셀렉트 됐을 때 개별 조회로 isTrain을 판단해서, 학습 중이면 리패치하는 방식으로 관리한다.
export default function TrainingTab({ projectId }: TrainingTabProps) {
const numericProjectId = projectId ? parseInt(projectId.toString(), 10) : null;
const { isTrainingByProject, setIsTraining, selectedModelByProject, setSelectedModel, resetTrainingData } =
useModelStore((state) => ({
isTrainingByProject: state.isTrainingByProject,
setIsTraining: state.setIsTraining,
selectedModelByProject: state.selectedModelByProject,
setSelectedModel: state.selectModel,
resetTrainingData: state.resetTrainingData,
}));
const projectKey = numericProjectId?.toString() || '';
const isTraining = isTrainingByProject[projectKey] || false;
const selectedModel = selectedModelByProject[projectKey];
const [selectedModel, setSelectedModel] = useState<ModelResponse | null>(null);
const { mutate: startTraining } = useTrainModelQuery(numericProjectId as number);
const handleTrainingStart = (trainData: ModelTrainRequest) => {
if (!isTraining && selectedModel !== null) {
setIsTraining(projectKey, true);
startTraining(trainData);
}
startTraining(trainData);
};
const handleTrainingStop = () => {
if (isTraining) {
setIsTraining(projectKey, false);
resetTrainingData(projectKey);
}
};
const handleTrainingStop = () => {};
return (
<div className="grid grid-rows-[auto_1fr] gap-8 md:grid-cols-2">
<TrainingSettings
projectId={numericProjectId}
selectedModel={selectedModel}
setSelectedModel={(modelId) => setSelectedModel(projectKey, modelId)}
setSelectedModel={setSelectedModel}
handleTrainingStart={handleTrainingStart}
handleTrainingStop={handleTrainingStop}
className="h-full"

View File

@ -1,12 +0,0 @@
import { useQuery } from '@tanstack/react-query';
import { getModelReports } from '@/api/modelApi';
import { ReportResponse } from '@/types';
export default function usePollingModelReportsQuery(projectId: number, modelId: number, enabled: boolean) {
return useQuery<ReportResponse[]>({
queryKey: ['pollingModelReports', projectId, modelId],
queryFn: () => getModelReports(projectId, modelId),
refetchInterval: 5000,
enabled,
});
}

View File

@ -0,0 +1,10 @@
import { useSuspenseQuery } from '@tanstack/react-query';
import { getCompletedModelReport } from '@/api/reportApi';
import { ReportResponse } from '@/types';
export default function useCompletedModelReport(projectId: number, modelId: number) {
return useSuspenseQuery<ReportResponse[]>({
queryKey: ['modelReport', projectId, modelId],
queryFn: () => getCompletedModelReport(projectId, modelId),
});
}

View File

@ -0,0 +1,12 @@
import { useQuery } from '@tanstack/react-query';
import { getTrainingModelReport } from '@/api/reportApi';
import { ReportResponse } from '@/types';
export default function usePollingTrainingModelReport(projectId: number, modelId: number, enabled: boolean) {
return useQuery<ReportResponse[]>({
queryKey: ['modelReports', projectId, modelId],
queryFn: () => getTrainingModelReport(projectId, modelId),
refetchInterval: 5000,
enabled,
});
}

View File

@ -1,10 +1,10 @@
import { useSuspenseQuery } from '@tanstack/react-query';
import { getModelReports } from '@/api/modelApi';
import { getTrainingModelReport } from '@/api/reportApi';
import { ReportResponse } from '@/types';
export default function useModelReportsQuery(projectId: number, modelId: number) {
export default function useTrainingModelReport(projectId: number, modelId: number) {
return useSuspenseQuery<ReportResponse[]>({
queryKey: ['modelReports', projectId, modelId],
queryFn: () => getModelReports(projectId, modelId),
queryFn: () => getTrainingModelReport(projectId, modelId),
});
}

View File

@ -1,10 +1,10 @@
import { useSuspenseQuery } from '@tanstack/react-query';
import { getModelResults } from '@/api/modelApi';
import { getModelResult } from '@/api/resultApi';
import { ResultResponse } from '@/types';
export default function useModelResultsQuery(modelId: number) {
return useSuspenseQuery<ResultResponse[]>({
queryKey: ['modelResults', modelId],
queryFn: () => getModelResults(modelId),
queryFn: () => getModelResult(modelId),
});
}

View File

@ -309,6 +309,7 @@ export interface ModelResponse {
id: number;
name: string;
isDefault: boolean;
isTrain: boolean;
}
// 프로젝트 모델 리스트 응답 DTO