Refactor: 모델 학습 리팩토링 중, 테스트 필요

This commit is contained in:
정현조 2024-09-30 08:47:37 +09:00
parent f5e00d5b42
commit 9ba7e677bc
3 changed files with 48 additions and 16 deletions

View File

@ -1,11 +1,10 @@
import { useState, useEffect } from 'react';
import { useState } from 'react';
import { Button } from '@/components/ui/button';
import SelectWithLabel from './SelectWithLabel';
import InputWithLabel from './InputWithLabel';
import useProjectModelsQuery from '@/queries/models/useProjectModelsQuery';
import { ModelTrainRequest, ModelResponse } from '@/types';
import { cn } from '@/lib/utils';
import { useQueryClient } from '@tanstack/react-query';
interface TrainingSettingsProps {
projectId: number | null;
@ -13,6 +12,7 @@ interface TrainingSettingsProps {
setSelectedModel: (model: ModelResponse | null) => void;
handleTrainingStart: (trainData: ModelTrainRequest) => void;
handleTrainingStop: () => void;
isPolling: boolean;
className?: string;
}
@ -22,6 +22,7 @@ export default function TrainingSettings({
setSelectedModel,
handleTrainingStart,
handleTrainingStop,
isPolling,
className,
}: TrainingSettingsProps) {
const { data: models } = useProjectModelsQuery(projectId ?? 0);
@ -31,13 +32,6 @@ export default function TrainingSettings({
const [optimizer, setOptimizer] = useState<'SGD' | 'AUTO' | 'ADAM' | 'ADAMW' | 'NADAM' | 'RADAM' | 'RMSPROP'>('AUTO');
const [lr0, setLr0] = useState<number>(0.01);
const [lrf, setLrf] = useState<number>(0.001);
const queryClient = useQueryClient();
useEffect(() => {
if (selectedModel?.isTrain) {
queryClient.invalidateQueries({ queryKey: ['projectModels', projectId] });
}
}, [selectedModel?.isTrain, queryClient, projectId]);
const handleSubmit = () => {
if (selectedModel) {
@ -136,9 +130,9 @@ export default function TrainingSettings({
variant="outlinePrimary"
size="lg"
onClick={handleSubmit}
disabled={!selectedModel}
disabled={!selectedModel || isPolling}
>
{isPolling ? '대기 중...' : '학습 시작'}
</Button>
</>
)}

View File

@ -1,8 +1,9 @@
import { useState } from 'react';
import { useState, useEffect } from 'react';
import TrainingSettings from './TrainingSettings';
import TrainingGraph from './TrainingGraph';
import useTrainModelQuery from '@/queries/models/useTrainModelQuery';
import { ModelTrainRequest, ModelResponse } from '@/types';
import { useQueryClient } from '@tanstack/react-query';
interface TrainingTabProps {
projectId: number | null;
@ -11,15 +12,43 @@ interface TrainingTabProps {
export default function TrainingTab({ projectId }: TrainingTabProps) {
const numericProjectId = projectId !== null ? Number(projectId) : null;
const [selectedModel, setSelectedModel] = useState<ModelResponse | null>(null);
const [isPolling, setIsPolling] = useState(false);
const queryClient = useQueryClient();
const { mutate: startTraining } = useTrainModelQuery(numericProjectId as number);
const { mutate: startTraining } = useTrainModelQuery(numericProjectId as number, {
onSuccess: () => {
setIsPolling(true);
},
onError: () => {
alert('학습 요청 실패');
setIsPolling(false);
},
});
const handleTrainingStart = (trainData: ModelTrainRequest) => {
startTraining(trainData);
};
useEffect(() => {
if (!selectedModel || !numericProjectId || !isPolling) return;
const intervalId = setInterval(() => {
queryClient.invalidateQueries({ queryKey: ['projectModels', numericProjectId] });
}, 2000);
const timeoutId = setTimeout(() => {
clearInterval(intervalId);
setIsPolling(false);
}, 30000);
return () => {
clearInterval(intervalId);
clearTimeout(timeoutId);
};
}, [selectedModel, numericProjectId, queryClient, isPolling]);
const handleTrainingStop = () => {
// Todo: 학습 중단 로직
setIsPolling(false);
};
return (
@ -30,9 +59,9 @@ export default function TrainingTab({ projectId }: TrainingTabProps) {
setSelectedModel={setSelectedModel}
handleTrainingStart={handleTrainingStart}
handleTrainingStop={handleTrainingStop}
isPolling={isPolling}
className="h-full"
/>
<TrainingGraph
projectId={numericProjectId}
selectedModel={selectedModel}

View File

@ -5,11 +5,20 @@ import { QueryClient } from '@tanstack/react-query';
const queryClient = new QueryClient();
export default function useTrainModelQuery(projectId: number) {
interface UseTrainModelOptions {
onSuccess?: () => void;
onError?: (error: unknown) => void;
}
export default function useTrainModelQuery(projectId: number, options?: UseTrainModelOptions) {
return useMutation({
mutationFn: (trainData: ModelTrainRequest) => trainModel(projectId, trainData),
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: ['projectModels', projectId] });
options?.onSuccess?.();
},
onError: (error) => {
options?.onError?.(error);
},
});
}