Refactor: 모델 학습 작동하게 함

This commit is contained in:
정현조 2024-09-30 16:19:15 +09:00
parent 7381b67cb6
commit 9d0b3b0c7b
3 changed files with 41 additions and 18 deletions

View File

@ -48,9 +48,13 @@ export default function TrainingSettings({
} }
}; };
const isTraining = selectedModel?.isTrain;
const isWaiting = isPolling && !isTraining;
return ( return (
<fieldset className={cn('grid gap-6 rounded-lg border p-4', className)}> <fieldset className={cn('grid gap-6 rounded-lg border p-4', className)}>
<legend className="-ml-1 px-1 text-sm font-medium"> </legend> <legend className="-ml-1 px-1 text-sm font-medium"> </legend>
<div className="grid gap-3"> <div className="grid gap-3">
<SelectWithLabel <SelectWithLabel
label="모델 선택" label="모델 선택"
@ -69,7 +73,7 @@ export default function TrainingSettings({
}} }}
/> />
</div> </div>
{!selectedModel?.isTrain && ( {!isPolling && !isTraining && (
<> <>
<div className="grid grid-cols-2 gap-4"> <div className="grid grid-cols-2 gap-4">
<InputWithLabel <InputWithLabel
@ -130,19 +134,30 @@ export default function TrainingSettings({
variant="outlinePrimary" variant="outlinePrimary"
size="lg" size="lg"
onClick={handleSubmit} onClick={handleSubmit}
disabled={!selectedModel || isPolling} disabled={!selectedModel}
> >
{isPolling ? '대기 중...' : '학습 시작'}
</Button> </Button>
</> </>
)} )}
{selectedModel?.isTrain && (
{isWaiting && (
<Button <Button
variant="secondary" variant="secondary"
size="lg" size="lg"
onClick={handleTrainingStop} onClick={handleTrainingStop}
> >
</Button>
)}
{isTraining && (
<Button
variant="secondary"
size="lg"
onClick={handleTrainingStop}
>
</Button> </Button>
)} )}
</fieldset> </fieldset>

View File

@ -16,9 +16,6 @@ export default function TrainingTab({ projectId }: TrainingTabProps) {
const queryClient = useQueryClient(); const queryClient = useQueryClient();
const { mutate: startTraining } = useTrainModelQuery(numericProjectId as number, { const { mutate: startTraining } = useTrainModelQuery(numericProjectId as number, {
onSuccess: () => {
setIsPolling(true);
},
onError: () => { onError: () => {
alert('학습 요청 실패'); alert('학습 요청 실패');
setIsPolling(false); setIsPolling(false);
@ -28,24 +25,34 @@ export default function TrainingTab({ projectId }: TrainingTabProps) {
const handleTrainingStart = (trainData: ModelTrainRequest) => { const handleTrainingStart = (trainData: ModelTrainRequest) => {
if (numericProjectId !== null) { if (numericProjectId !== null) {
startTraining(trainData); startTraining(trainData);
setIsPolling(true);
} }
}; };
useEffect(() => { useEffect(() => {
if (!selectedModel || !numericProjectId || !isPolling) return; if (!selectedModel || !numericProjectId || !isPolling) return;
const intervalId = setInterval(() => { const intervalId = setInterval(async () => {
queryClient.invalidateQueries({ queryKey: ['projectModels', numericProjectId] }); await queryClient.invalidateQueries({ queryKey: ['projectModels', numericProjectId] });
}, 2000);
const timeoutId = setTimeout(() => { const models = await queryClient.getQueryData<ModelResponse[]>(['projectModels', numericProjectId]);
clearInterval(intervalId);
setIsPolling(false); const updatedModel = models?.find((model) => model.id === selectedModel.id);
}, 30000);
if (updatedModel) {
setSelectedModel(updatedModel);
if (updatedModel.isTrain) {
setIsPolling(true);
} else if (!updatedModel.isTrain && selectedModel.isTrain) {
setIsPolling(false);
setSelectedModel(null);
}
}
}, 2000);
return () => { return () => {
clearInterval(intervalId); clearInterval(intervalId);
clearTimeout(timeoutId);
}; };
}, [selectedModel, numericProjectId, queryClient, isPolling]); }, [selectedModel, numericProjectId, queryClient, isPolling]);
@ -56,7 +63,7 @@ export default function TrainingTab({ projectId }: TrainingTabProps) {
return ( return (
<div className="grid grid-rows-[auto_1fr] gap-8 md:grid-cols-2"> <div className="grid grid-rows-[auto_1fr] gap-8 md:grid-cols-2">
<TrainingSettings <TrainingSettings
key={selectedModel?.isTrain ? 'training' : 'settings'} key={`${selectedModel?.isTrain ? 'training' : 'settings'}-${isPolling}`}
projectId={numericProjectId} projectId={numericProjectId}
selectedModel={selectedModel} selectedModel={selectedModel}
setSelectedModel={setSelectedModel} setSelectedModel={setSelectedModel}
@ -66,7 +73,7 @@ export default function TrainingTab({ projectId }: TrainingTabProps) {
className="h-full" className="h-full"
/> />
<TrainingGraph <TrainingGraph
key={selectedModel?.isTrain ? 'training' : 'graph'} key={`${selectedModel?.isTrain ? 'training' : 'graph'}-${isPolling}`}
projectId={numericProjectId} projectId={numericProjectId}
selectedModel={selectedModel} selectedModel={selectedModel}
className="h-full" className="h-full"

View File

@ -5,5 +5,6 @@ export default function useProjectModelsQuery(projectId: number) {
return useSuspenseQuery({ return useSuspenseQuery({
queryKey: ['projectModels', projectId], queryKey: ['projectModels', projectId],
queryFn: () => getProjectModels(projectId), queryFn: () => getProjectModels(projectId),
refetchOnWindowFocus: false,
}); });
} }