Refactor: 모델 학습 리팩토링 중, 테스트 필요
This commit is contained in:
parent
f5e00d5b42
commit
9ba7e677bc
@ -1,11 +1,10 @@
|
|||||||
import { useState, useEffect } from 'react';
|
import { useState } from 'react';
|
||||||
import { Button } from '@/components/ui/button';
|
import { Button } from '@/components/ui/button';
|
||||||
import SelectWithLabel from './SelectWithLabel';
|
import SelectWithLabel from './SelectWithLabel';
|
||||||
import InputWithLabel from './InputWithLabel';
|
import InputWithLabel from './InputWithLabel';
|
||||||
import useProjectModelsQuery from '@/queries/models/useProjectModelsQuery';
|
import useProjectModelsQuery from '@/queries/models/useProjectModelsQuery';
|
||||||
import { ModelTrainRequest, ModelResponse } from '@/types';
|
import { ModelTrainRequest, ModelResponse } from '@/types';
|
||||||
import { cn } from '@/lib/utils';
|
import { cn } from '@/lib/utils';
|
||||||
import { useQueryClient } from '@tanstack/react-query';
|
|
||||||
|
|
||||||
interface TrainingSettingsProps {
|
interface TrainingSettingsProps {
|
||||||
projectId: number | null;
|
projectId: number | null;
|
||||||
@ -13,6 +12,7 @@ interface TrainingSettingsProps {
|
|||||||
setSelectedModel: (model: ModelResponse | null) => void;
|
setSelectedModel: (model: ModelResponse | null) => void;
|
||||||
handleTrainingStart: (trainData: ModelTrainRequest) => void;
|
handleTrainingStart: (trainData: ModelTrainRequest) => void;
|
||||||
handleTrainingStop: () => void;
|
handleTrainingStop: () => void;
|
||||||
|
isPolling: boolean;
|
||||||
className?: string;
|
className?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -22,6 +22,7 @@ export default function TrainingSettings({
|
|||||||
setSelectedModel,
|
setSelectedModel,
|
||||||
handleTrainingStart,
|
handleTrainingStart,
|
||||||
handleTrainingStop,
|
handleTrainingStop,
|
||||||
|
isPolling,
|
||||||
className,
|
className,
|
||||||
}: TrainingSettingsProps) {
|
}: TrainingSettingsProps) {
|
||||||
const { data: models } = useProjectModelsQuery(projectId ?? 0);
|
const { data: models } = useProjectModelsQuery(projectId ?? 0);
|
||||||
@ -31,13 +32,6 @@ export default function TrainingSettings({
|
|||||||
const [optimizer, setOptimizer] = useState<'SGD' | 'AUTO' | 'ADAM' | 'ADAMW' | 'NADAM' | 'RADAM' | 'RMSPROP'>('AUTO');
|
const [optimizer, setOptimizer] = useState<'SGD' | 'AUTO' | 'ADAM' | 'ADAMW' | 'NADAM' | 'RADAM' | 'RMSPROP'>('AUTO');
|
||||||
const [lr0, setLr0] = useState<number>(0.01);
|
const [lr0, setLr0] = useState<number>(0.01);
|
||||||
const [lrf, setLrf] = useState<number>(0.001);
|
const [lrf, setLrf] = useState<number>(0.001);
|
||||||
const queryClient = useQueryClient();
|
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
if (selectedModel?.isTrain) {
|
|
||||||
queryClient.invalidateQueries({ queryKey: ['projectModels', projectId] });
|
|
||||||
}
|
|
||||||
}, [selectedModel?.isTrain, queryClient, projectId]);
|
|
||||||
|
|
||||||
const handleSubmit = () => {
|
const handleSubmit = () => {
|
||||||
if (selectedModel) {
|
if (selectedModel) {
|
||||||
@ -136,9 +130,9 @@ export default function TrainingSettings({
|
|||||||
variant="outlinePrimary"
|
variant="outlinePrimary"
|
||||||
size="lg"
|
size="lg"
|
||||||
onClick={handleSubmit}
|
onClick={handleSubmit}
|
||||||
disabled={!selectedModel}
|
disabled={!selectedModel || isPolling}
|
||||||
>
|
>
|
||||||
학습 시작
|
{isPolling ? '대기 중...' : '학습 시작'}
|
||||||
</Button>
|
</Button>
|
||||||
</>
|
</>
|
||||||
)}
|
)}
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
import { useState } from 'react';
|
import { useState, useEffect } from 'react';
|
||||||
import TrainingSettings from './TrainingSettings';
|
import TrainingSettings from './TrainingSettings';
|
||||||
import TrainingGraph from './TrainingGraph';
|
import TrainingGraph from './TrainingGraph';
|
||||||
import useTrainModelQuery from '@/queries/models/useTrainModelQuery';
|
import useTrainModelQuery from '@/queries/models/useTrainModelQuery';
|
||||||
import { ModelTrainRequest, ModelResponse } from '@/types';
|
import { ModelTrainRequest, ModelResponse } from '@/types';
|
||||||
|
import { useQueryClient } from '@tanstack/react-query';
|
||||||
|
|
||||||
interface TrainingTabProps {
|
interface TrainingTabProps {
|
||||||
projectId: number | null;
|
projectId: number | null;
|
||||||
@ -11,15 +12,43 @@ interface TrainingTabProps {
|
|||||||
export default function TrainingTab({ projectId }: TrainingTabProps) {
|
export default function TrainingTab({ projectId }: TrainingTabProps) {
|
||||||
const numericProjectId = projectId !== null ? Number(projectId) : null;
|
const numericProjectId = projectId !== null ? Number(projectId) : null;
|
||||||
const [selectedModel, setSelectedModel] = useState<ModelResponse | null>(null);
|
const [selectedModel, setSelectedModel] = useState<ModelResponse | null>(null);
|
||||||
|
const [isPolling, setIsPolling] = useState(false);
|
||||||
|
const queryClient = useQueryClient();
|
||||||
|
|
||||||
const { mutate: startTraining } = useTrainModelQuery(numericProjectId as number);
|
const { mutate: startTraining } = useTrainModelQuery(numericProjectId as number, {
|
||||||
|
onSuccess: () => {
|
||||||
|
setIsPolling(true);
|
||||||
|
},
|
||||||
|
onError: () => {
|
||||||
|
alert('학습 요청 실패');
|
||||||
|
setIsPolling(false);
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
const handleTrainingStart = (trainData: ModelTrainRequest) => {
|
const handleTrainingStart = (trainData: ModelTrainRequest) => {
|
||||||
startTraining(trainData);
|
startTraining(trainData);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (!selectedModel || !numericProjectId || !isPolling) return;
|
||||||
|
|
||||||
|
const intervalId = setInterval(() => {
|
||||||
|
queryClient.invalidateQueries({ queryKey: ['projectModels', numericProjectId] });
|
||||||
|
}, 2000);
|
||||||
|
|
||||||
|
const timeoutId = setTimeout(() => {
|
||||||
|
clearInterval(intervalId);
|
||||||
|
setIsPolling(false);
|
||||||
|
}, 30000);
|
||||||
|
|
||||||
|
return () => {
|
||||||
|
clearInterval(intervalId);
|
||||||
|
clearTimeout(timeoutId);
|
||||||
|
};
|
||||||
|
}, [selectedModel, numericProjectId, queryClient, isPolling]);
|
||||||
|
|
||||||
const handleTrainingStop = () => {
|
const handleTrainingStop = () => {
|
||||||
// Todo: 학습 중단 로직
|
setIsPolling(false);
|
||||||
};
|
};
|
||||||
|
|
||||||
return (
|
return (
|
||||||
@ -30,9 +59,9 @@ export default function TrainingTab({ projectId }: TrainingTabProps) {
|
|||||||
setSelectedModel={setSelectedModel}
|
setSelectedModel={setSelectedModel}
|
||||||
handleTrainingStart={handleTrainingStart}
|
handleTrainingStart={handleTrainingStart}
|
||||||
handleTrainingStop={handleTrainingStop}
|
handleTrainingStop={handleTrainingStop}
|
||||||
|
isPolling={isPolling}
|
||||||
className="h-full"
|
className="h-full"
|
||||||
/>
|
/>
|
||||||
|
|
||||||
<TrainingGraph
|
<TrainingGraph
|
||||||
projectId={numericProjectId}
|
projectId={numericProjectId}
|
||||||
selectedModel={selectedModel}
|
selectedModel={selectedModel}
|
||||||
|
@ -5,11 +5,20 @@ import { QueryClient } from '@tanstack/react-query';
|
|||||||
|
|
||||||
const queryClient = new QueryClient();
|
const queryClient = new QueryClient();
|
||||||
|
|
||||||
export default function useTrainModelQuery(projectId: number) {
|
interface UseTrainModelOptions {
|
||||||
|
onSuccess?: () => void;
|
||||||
|
onError?: (error: unknown) => void;
|
||||||
|
}
|
||||||
|
|
||||||
|
export default function useTrainModelQuery(projectId: number, options?: UseTrainModelOptions) {
|
||||||
return useMutation({
|
return useMutation({
|
||||||
mutationFn: (trainData: ModelTrainRequest) => trainModel(projectId, trainData),
|
mutationFn: (trainData: ModelTrainRequest) => trainModel(projectId, trainData),
|
||||||
onSuccess: () => {
|
onSuccess: () => {
|
||||||
queryClient.invalidateQueries({ queryKey: ['projectModels', projectId] });
|
queryClient.invalidateQueries({ queryKey: ['projectModels', projectId] });
|
||||||
|
options?.onSuccess?.();
|
||||||
|
},
|
||||||
|
onError: (error) => {
|
||||||
|
options?.onError?.(error);
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user