Merge branch 'fe/refactor/model-learn' into 'fe/develop'

Refator: 모델 학습 부분, 개선 필요

See merge request s11-s-project/S11P21S002!238
This commit is contained in:
홍창기 2024-09-30 10:36:10 +09:00
commit ab38cc0746
5 changed files with 85 additions and 53 deletions

View File

@ -1,6 +1,7 @@
import { useEffect, useState } from 'react'; import { useEffect } from 'react';
import ModelLineChart from './ModelLineChart'; import ModelLineChart from './ModelLineChart';
import usePollingModelReportsQuery from '@/queries/reports/usePollingModelReportsQuery'; import usePollingTrainingModelReport from '@/queries/reports/usePollingModelReportsQuery';
import { useQueryClient } from '@tanstack/react-query';
import { ModelResponse } from '@/types'; import { ModelResponse } from '@/types';
interface TrainingGraphProps { interface TrainingGraphProps {
@ -10,20 +11,22 @@ interface TrainingGraphProps {
} }
export default function TrainingGraph({ projectId, selectedModel, className }: TrainingGraphProps) { export default function TrainingGraph({ projectId, selectedModel, className }: TrainingGraphProps) {
const [isPolling, setIsPolling] = useState(false); const queryClient = useQueryClient();
const { data: trainingDataList } = usePollingModelReportsQuery(
const { data: trainingDataList } = usePollingTrainingModelReport(
projectId as number, projectId as number,
selectedModel?.id as number, selectedModel?.id as number,
isPolling selectedModel?.isTrain || false
); );
useEffect(() => { useEffect(() => {
if (selectedModel) { if (!selectedModel || !selectedModel.isTrain) {
setIsPolling(true); queryClient.resetQueries({
} else { queryKey: [{ type: 'modelReports', projectId, modelId: selectedModel?.id }],
setIsPolling(false); exact: true,
});
} }
}, [selectedModel]); }, [selectedModel, queryClient, projectId]);
return ( return (
<ModelLineChart <ModelLineChart

View File

@ -1,11 +1,10 @@
import { useState, useEffect, useRef } from 'react'; import { useState } from 'react';
import { Button } from '@/components/ui/button'; import { Button } from '@/components/ui/button';
import SelectWithLabel from './SelectWithLabel'; import SelectWithLabel from './SelectWithLabel';
import InputWithLabel from './InputWithLabel'; import InputWithLabel from './InputWithLabel';
import useProjectModelsQuery from '@/queries/models/useProjectModelsQuery'; import useProjectModelsQuery from '@/queries/models/useProjectModelsQuery';
import { ModelTrainRequest, ModelResponse } from '@/types'; import { ModelTrainRequest, ModelResponse } from '@/types';
import { cn } from '@/lib/utils'; import { cn } from '@/lib/utils';
import { useQueryClient } from '@tanstack/react-query';
interface TrainingSettingsProps { interface TrainingSettingsProps {
projectId: number | null; projectId: number | null;
@ -13,6 +12,7 @@ interface TrainingSettingsProps {
setSelectedModel: (model: ModelResponse | null) => void; setSelectedModel: (model: ModelResponse | null) => void;
handleTrainingStart: (trainData: ModelTrainRequest) => void; handleTrainingStart: (trainData: ModelTrainRequest) => void;
handleTrainingStop: () => void; handleTrainingStop: () => void;
isPolling: boolean;
className?: string; className?: string;
} }
@ -22,6 +22,7 @@ export default function TrainingSettings({
setSelectedModel, setSelectedModel,
handleTrainingStart, handleTrainingStart,
handleTrainingStop, handleTrainingStop,
isPolling,
className, className,
}: TrainingSettingsProps) { }: TrainingSettingsProps) {
const { data: models } = useProjectModelsQuery(projectId ?? 0); const { data: models } = useProjectModelsQuery(projectId ?? 0);
@ -31,14 +32,9 @@ export default function TrainingSettings({
const [optimizer, setOptimizer] = useState<'SGD' | 'AUTO' | 'ADAM' | 'ADAMW' | 'NADAM' | 'RADAM' | 'RMSPROP'>('AUTO'); const [optimizer, setOptimizer] = useState<'SGD' | 'AUTO' | 'ADAM' | 'ADAMW' | 'NADAM' | 'RADAM' | 'RMSPROP'>('AUTO');
const [lr0, setLr0] = useState<number>(0.01); const [lr0, setLr0] = useState<number>(0.01);
const [lrf, setLrf] = useState<number>(0.001); const [lrf, setLrf] = useState<number>(0.001);
const [isSubmitting, setIsSubmitting] = useState(false);
const queryClient = useQueryClient();
const intervalRef = useRef<NodeJS.Timeout | null>(null);
const handleSubmit = () => { const handleSubmit = () => {
if (selectedModel?.isTrain) { if (selectedModel) {
handleTrainingStop();
} else if (selectedModel) {
const trainData: ModelTrainRequest = { const trainData: ModelTrainRequest = {
modelId: selectedModel.id, modelId: selectedModel.id,
ratio, ratio,
@ -48,34 +44,10 @@ export default function TrainingSettings({
lr0, lr0,
lrf, lrf,
}; };
setIsSubmitting(true);
handleTrainingStart(trainData); handleTrainingStart(trainData);
} }
}; };
useEffect(() => {
if (isSubmitting) {
intervalRef.current = setInterval(() => {
queryClient.invalidateQueries({ queryKey: ['projectModels', projectId] });
}, 1000);
} else if (intervalRef.current) {
clearInterval(intervalRef.current);
intervalRef.current = null;
}
return () => {
if (intervalRef.current) {
clearInterval(intervalRef.current);
}
};
}, [isSubmitting, queryClient, projectId]);
useEffect(() => {
if (selectedModel?.isTrain) {
setIsSubmitting(false);
}
}, [selectedModel]);
return ( return (
<fieldset className={cn('grid gap-6 rounded-lg border p-4', className)}> <fieldset className={cn('grid gap-6 rounded-lg border p-4', className)}>
<legend className="-ml-1 px-1 text-sm font-medium"> </legend> <legend className="-ml-1 px-1 text-sm font-medium"> </legend>
@ -158,12 +130,21 @@ export default function TrainingSettings({
variant="outlinePrimary" variant="outlinePrimary"
size="lg" size="lg"
onClick={handleSubmit} onClick={handleSubmit}
disabled={!selectedModel || isSubmitting} disabled={!selectedModel || isPolling}
> >
{isSubmitting ? '기다리는 중...' : '학습 시작'} {isPolling ? '대기 중...' : '학습 시작'}
</Button> </Button>
</> </>
)} )}
{selectedModel?.isTrain && (
<Button
variant="secondary"
size="lg"
onClick={handleTrainingStop}
>
</Button>
)}
</fieldset> </fieldset>
); );
} }

View File

@ -1,24 +1,57 @@
import useTrainModelQuery from '@/queries/models/useTrainModelQuery'; import { useState, useEffect } from 'react';
import TrainingSettings from './TrainingSettings'; import TrainingSettings from './TrainingSettings';
import TrainingGraph from './TrainingGraph'; import TrainingGraph from './TrainingGraph';
import useTrainModelQuery from '@/queries/models/useTrainModelQuery';
import { ModelTrainRequest, ModelResponse } from '@/types'; import { ModelTrainRequest, ModelResponse } from '@/types';
import { useState } from 'react'; import { useQueryClient } from '@tanstack/react-query';
interface TrainingTabProps { interface TrainingTabProps {
projectId: number | null; projectId: number | null;
} }
export default function TrainingTab({ projectId }: TrainingTabProps) { export default function TrainingTab({ projectId }: TrainingTabProps) {
const numericProjectId = projectId ? parseInt(projectId.toString(), 10) : null; const numericProjectId = projectId !== null ? Number(projectId) : null;
const [selectedModel, setSelectedModel] = useState<ModelResponse | null>(null); const [selectedModel, setSelectedModel] = useState<ModelResponse | null>(null);
const [isPolling, setIsPolling] = useState(false);
const queryClient = useQueryClient();
const { mutate: startTraining } = useTrainModelQuery(numericProjectId as number); const { mutate: startTraining } = useTrainModelQuery(numericProjectId as number, {
onSuccess: () => {
setIsPolling(true);
},
onError: () => {
alert('학습 요청 실패');
setIsPolling(false);
},
});
const handleTrainingStart = (trainData: ModelTrainRequest) => { const handleTrainingStart = (trainData: ModelTrainRequest) => {
if (numericProjectId !== null) {
startTraining(trainData); startTraining(trainData);
}
}; };
const handleTrainingStop = () => {}; useEffect(() => {
if (!selectedModel || !numericProjectId || !isPolling) return;
const intervalId = setInterval(() => {
queryClient.invalidateQueries({ queryKey: ['projectModels', numericProjectId] });
}, 2000);
const timeoutId = setTimeout(() => {
clearInterval(intervalId);
setIsPolling(false);
}, 30000);
return () => {
clearInterval(intervalId);
clearTimeout(timeoutId);
};
}, [selectedModel, numericProjectId, queryClient, isPolling]);
const handleTrainingStop = () => {
setIsPolling(false);
};
return ( return (
<div className="grid grid-rows-[auto_1fr] gap-8 md:grid-cols-2"> <div className="grid grid-rows-[auto_1fr] gap-8 md:grid-cols-2">
@ -28,9 +61,9 @@ export default function TrainingTab({ projectId }: TrainingTabProps) {
setSelectedModel={setSelectedModel} setSelectedModel={setSelectedModel}
handleTrainingStart={handleTrainingStart} handleTrainingStart={handleTrainingStart}
handleTrainingStop={handleTrainingStop} handleTrainingStop={handleTrainingStop}
isPolling={isPolling}
className="h-full" className="h-full"
/> />
<TrainingGraph <TrainingGraph
projectId={numericProjectId} projectId={numericProjectId}
selectedModel={selectedModel} selectedModel={selectedModel}

View File

@ -1,9 +1,24 @@
import { useMutation } from '@tanstack/react-query'; import { useMutation } from '@tanstack/react-query';
import { trainModel } from '@/api/modelApi'; import { trainModel } from '@/api/modelApi';
import { ModelTrainRequest } from '@/types'; import { ModelTrainRequest } from '@/types';
import { QueryClient } from '@tanstack/react-query';
export default function useTrainModelQuery(projectId: number) { const queryClient = new QueryClient();
interface UseTrainModelOptions {
onSuccess?: () => void;
onError?: (error: unknown) => void;
}
export default function useTrainModelQuery(projectId: number, options?: UseTrainModelOptions) {
return useMutation({ return useMutation({
mutationFn: (trainData: ModelTrainRequest) => trainModel(projectId, trainData), mutationFn: (trainData: ModelTrainRequest) => trainModel(projectId, trainData),
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: ['projectModels', projectId] });
options?.onSuccess?.();
},
onError: (error) => {
options?.onError?.(error);
},
}); });
} }

View File

@ -6,7 +6,7 @@ export default function usePollingTrainingModelReport(projectId: number, modelId
return useQuery<ReportResponse[]>({ return useQuery<ReportResponse[]>({
queryKey: ['modelReports', projectId, modelId], queryKey: ['modelReports', projectId, modelId],
queryFn: () => getTrainingModelReport(projectId, modelId), queryFn: () => getTrainingModelReport(projectId, modelId),
refetchInterval: 5000, refetchInterval: enabled ? 5000 : false,
enabled, enabled,
}); });
} }