Refactor: 모델 학습 작동하게 함
This commit is contained in:
parent
7381b67cb6
commit
9d0b3b0c7b
@ -48,9 +48,13 @@ export default function TrainingSettings({
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const isTraining = selectedModel?.isTrain;
|
||||||
|
const isWaiting = isPolling && !isTraining;
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<fieldset className={cn('grid gap-6 rounded-lg border p-4', className)}>
|
<fieldset className={cn('grid gap-6 rounded-lg border p-4', className)}>
|
||||||
<legend className="-ml-1 px-1 text-sm font-medium">모델 설정</legend>
|
<legend className="-ml-1 px-1 text-sm font-medium">모델 설정</legend>
|
||||||
|
|
||||||
<div className="grid gap-3">
|
<div className="grid gap-3">
|
||||||
<SelectWithLabel
|
<SelectWithLabel
|
||||||
label="모델 선택"
|
label="모델 선택"
|
||||||
@ -69,7 +73,7 @@ export default function TrainingSettings({
|
|||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
{!selectedModel?.isTrain && (
|
{!isPolling && !isTraining && (
|
||||||
<>
|
<>
|
||||||
<div className="grid grid-cols-2 gap-4">
|
<div className="grid grid-cols-2 gap-4">
|
||||||
<InputWithLabel
|
<InputWithLabel
|
||||||
@ -130,19 +134,30 @@ export default function TrainingSettings({
|
|||||||
variant="outlinePrimary"
|
variant="outlinePrimary"
|
||||||
size="lg"
|
size="lg"
|
||||||
onClick={handleSubmit}
|
onClick={handleSubmit}
|
||||||
disabled={!selectedModel || isPolling}
|
disabled={!selectedModel}
|
||||||
>
|
>
|
||||||
{isPolling ? '대기 중...' : '학습 시작'}
|
학습 시작
|
||||||
</Button>
|
</Button>
|
||||||
</>
|
</>
|
||||||
)}
|
)}
|
||||||
{selectedModel?.isTrain && (
|
|
||||||
|
{isWaiting && (
|
||||||
<Button
|
<Button
|
||||||
variant="secondary"
|
variant="secondary"
|
||||||
size="lg"
|
size="lg"
|
||||||
onClick={handleTrainingStop}
|
onClick={handleTrainingStop}
|
||||||
>
|
>
|
||||||
학습 중단
|
대기 중
|
||||||
|
</Button>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{isTraining && (
|
||||||
|
<Button
|
||||||
|
variant="secondary"
|
||||||
|
size="lg"
|
||||||
|
onClick={handleTrainingStop}
|
||||||
|
>
|
||||||
|
학습 중
|
||||||
</Button>
|
</Button>
|
||||||
)}
|
)}
|
||||||
</fieldset>
|
</fieldset>
|
||||||
|
@ -16,9 +16,6 @@ export default function TrainingTab({ projectId }: TrainingTabProps) {
|
|||||||
const queryClient = useQueryClient();
|
const queryClient = useQueryClient();
|
||||||
|
|
||||||
const { mutate: startTraining } = useTrainModelQuery(numericProjectId as number, {
|
const { mutate: startTraining } = useTrainModelQuery(numericProjectId as number, {
|
||||||
onSuccess: () => {
|
|
||||||
setIsPolling(true);
|
|
||||||
},
|
|
||||||
onError: () => {
|
onError: () => {
|
||||||
alert('학습 요청 실패');
|
alert('학습 요청 실패');
|
||||||
setIsPolling(false);
|
setIsPolling(false);
|
||||||
@ -28,24 +25,34 @@ export default function TrainingTab({ projectId }: TrainingTabProps) {
|
|||||||
const handleTrainingStart = (trainData: ModelTrainRequest) => {
|
const handleTrainingStart = (trainData: ModelTrainRequest) => {
|
||||||
if (numericProjectId !== null) {
|
if (numericProjectId !== null) {
|
||||||
startTraining(trainData);
|
startTraining(trainData);
|
||||||
|
setIsPolling(true);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (!selectedModel || !numericProjectId || !isPolling) return;
|
if (!selectedModel || !numericProjectId || !isPolling) return;
|
||||||
|
|
||||||
const intervalId = setInterval(() => {
|
const intervalId = setInterval(async () => {
|
||||||
queryClient.invalidateQueries({ queryKey: ['projectModels', numericProjectId] });
|
await queryClient.invalidateQueries({ queryKey: ['projectModels', numericProjectId] });
|
||||||
}, 2000);
|
|
||||||
|
|
||||||
const timeoutId = setTimeout(() => {
|
const models = await queryClient.getQueryData<ModelResponse[]>(['projectModels', numericProjectId]);
|
||||||
clearInterval(intervalId);
|
|
||||||
setIsPolling(false);
|
const updatedModel = models?.find((model) => model.id === selectedModel.id);
|
||||||
}, 30000);
|
|
||||||
|
if (updatedModel) {
|
||||||
|
setSelectedModel(updatedModel);
|
||||||
|
|
||||||
|
if (updatedModel.isTrain) {
|
||||||
|
setIsPolling(true);
|
||||||
|
} else if (!updatedModel.isTrain && selectedModel.isTrain) {
|
||||||
|
setIsPolling(false);
|
||||||
|
setSelectedModel(null);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}, 2000);
|
||||||
|
|
||||||
return () => {
|
return () => {
|
||||||
clearInterval(intervalId);
|
clearInterval(intervalId);
|
||||||
clearTimeout(timeoutId);
|
|
||||||
};
|
};
|
||||||
}, [selectedModel, numericProjectId, queryClient, isPolling]);
|
}, [selectedModel, numericProjectId, queryClient, isPolling]);
|
||||||
|
|
||||||
@ -56,7 +63,7 @@ export default function TrainingTab({ projectId }: TrainingTabProps) {
|
|||||||
return (
|
return (
|
||||||
<div className="grid grid-rows-[auto_1fr] gap-8 md:grid-cols-2">
|
<div className="grid grid-rows-[auto_1fr] gap-8 md:grid-cols-2">
|
||||||
<TrainingSettings
|
<TrainingSettings
|
||||||
key={selectedModel?.isTrain ? 'training' : 'settings'}
|
key={`${selectedModel?.isTrain ? 'training' : 'settings'}-${isPolling}`}
|
||||||
projectId={numericProjectId}
|
projectId={numericProjectId}
|
||||||
selectedModel={selectedModel}
|
selectedModel={selectedModel}
|
||||||
setSelectedModel={setSelectedModel}
|
setSelectedModel={setSelectedModel}
|
||||||
@ -66,7 +73,7 @@ export default function TrainingTab({ projectId }: TrainingTabProps) {
|
|||||||
className="h-full"
|
className="h-full"
|
||||||
/>
|
/>
|
||||||
<TrainingGraph
|
<TrainingGraph
|
||||||
key={selectedModel?.isTrain ? 'training' : 'graph'}
|
key={`${selectedModel?.isTrain ? 'training' : 'graph'}-${isPolling}`}
|
||||||
projectId={numericProjectId}
|
projectId={numericProjectId}
|
||||||
selectedModel={selectedModel}
|
selectedModel={selectedModel}
|
||||||
className="h-full"
|
className="h-full"
|
||||||
|
@ -5,5 +5,6 @@ export default function useProjectModelsQuery(projectId: number) {
|
|||||||
return useSuspenseQuery({
|
return useSuspenseQuery({
|
||||||
queryKey: ['projectModels', projectId],
|
queryKey: ['projectModels', projectId],
|
||||||
queryFn: () => getProjectModels(projectId),
|
queryFn: () => getProjectModels(projectId),
|
||||||
|
refetchOnWindowFocus: false,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user