Merge branch 'fe/refactor/model-learn' into 'fe/develop'
Refator: 모델 학습 부분, 개선 필요 See merge request s11-s-project/S11P21S002!238
This commit is contained in:
commit
ab38cc0746
@ -1,6 +1,7 @@
|
||||
import { useEffect, useState } from 'react';
|
||||
import { useEffect } from 'react';
|
||||
import ModelLineChart from './ModelLineChart';
|
||||
import usePollingModelReportsQuery from '@/queries/reports/usePollingModelReportsQuery';
|
||||
import usePollingTrainingModelReport from '@/queries/reports/usePollingModelReportsQuery';
|
||||
import { useQueryClient } from '@tanstack/react-query';
|
||||
import { ModelResponse } from '@/types';
|
||||
|
||||
interface TrainingGraphProps {
|
||||
@ -10,20 +11,22 @@ interface TrainingGraphProps {
|
||||
}
|
||||
|
||||
export default function TrainingGraph({ projectId, selectedModel, className }: TrainingGraphProps) {
|
||||
const [isPolling, setIsPolling] = useState(false);
|
||||
const { data: trainingDataList } = usePollingModelReportsQuery(
|
||||
const queryClient = useQueryClient();
|
||||
|
||||
const { data: trainingDataList } = usePollingTrainingModelReport(
|
||||
projectId as number,
|
||||
selectedModel?.id as number,
|
||||
isPolling
|
||||
selectedModel?.isTrain || false
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
if (selectedModel) {
|
||||
setIsPolling(true);
|
||||
} else {
|
||||
setIsPolling(false);
|
||||
if (!selectedModel || !selectedModel.isTrain) {
|
||||
queryClient.resetQueries({
|
||||
queryKey: [{ type: 'modelReports', projectId, modelId: selectedModel?.id }],
|
||||
exact: true,
|
||||
});
|
||||
}
|
||||
}, [selectedModel]);
|
||||
}, [selectedModel, queryClient, projectId]);
|
||||
|
||||
return (
|
||||
<ModelLineChart
|
||||
|
@ -1,11 +1,10 @@
|
||||
import { useState, useEffect, useRef } from 'react';
|
||||
import { useState } from 'react';
|
||||
import { Button } from '@/components/ui/button';
|
||||
import SelectWithLabel from './SelectWithLabel';
|
||||
import InputWithLabel from './InputWithLabel';
|
||||
import useProjectModelsQuery from '@/queries/models/useProjectModelsQuery';
|
||||
import { ModelTrainRequest, ModelResponse } from '@/types';
|
||||
import { cn } from '@/lib/utils';
|
||||
import { useQueryClient } from '@tanstack/react-query';
|
||||
|
||||
interface TrainingSettingsProps {
|
||||
projectId: number | null;
|
||||
@ -13,6 +12,7 @@ interface TrainingSettingsProps {
|
||||
setSelectedModel: (model: ModelResponse | null) => void;
|
||||
handleTrainingStart: (trainData: ModelTrainRequest) => void;
|
||||
handleTrainingStop: () => void;
|
||||
isPolling: boolean;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
@ -22,6 +22,7 @@ export default function TrainingSettings({
|
||||
setSelectedModel,
|
||||
handleTrainingStart,
|
||||
handleTrainingStop,
|
||||
isPolling,
|
||||
className,
|
||||
}: TrainingSettingsProps) {
|
||||
const { data: models } = useProjectModelsQuery(projectId ?? 0);
|
||||
@ -31,14 +32,9 @@ export default function TrainingSettings({
|
||||
const [optimizer, setOptimizer] = useState<'SGD' | 'AUTO' | 'ADAM' | 'ADAMW' | 'NADAM' | 'RADAM' | 'RMSPROP'>('AUTO');
|
||||
const [lr0, setLr0] = useState<number>(0.01);
|
||||
const [lrf, setLrf] = useState<number>(0.001);
|
||||
const [isSubmitting, setIsSubmitting] = useState(false);
|
||||
const queryClient = useQueryClient();
|
||||
const intervalRef = useRef<NodeJS.Timeout | null>(null);
|
||||
|
||||
const handleSubmit = () => {
|
||||
if (selectedModel?.isTrain) {
|
||||
handleTrainingStop();
|
||||
} else if (selectedModel) {
|
||||
if (selectedModel) {
|
||||
const trainData: ModelTrainRequest = {
|
||||
modelId: selectedModel.id,
|
||||
ratio,
|
||||
@ -48,34 +44,10 @@ export default function TrainingSettings({
|
||||
lr0,
|
||||
lrf,
|
||||
};
|
||||
setIsSubmitting(true);
|
||||
handleTrainingStart(trainData);
|
||||
}
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
if (isSubmitting) {
|
||||
intervalRef.current = setInterval(() => {
|
||||
queryClient.invalidateQueries({ queryKey: ['projectModels', projectId] });
|
||||
}, 1000);
|
||||
} else if (intervalRef.current) {
|
||||
clearInterval(intervalRef.current);
|
||||
intervalRef.current = null;
|
||||
}
|
||||
|
||||
return () => {
|
||||
if (intervalRef.current) {
|
||||
clearInterval(intervalRef.current);
|
||||
}
|
||||
};
|
||||
}, [isSubmitting, queryClient, projectId]);
|
||||
|
||||
useEffect(() => {
|
||||
if (selectedModel?.isTrain) {
|
||||
setIsSubmitting(false);
|
||||
}
|
||||
}, [selectedModel]);
|
||||
|
||||
return (
|
||||
<fieldset className={cn('grid gap-6 rounded-lg border p-4', className)}>
|
||||
<legend className="-ml-1 px-1 text-sm font-medium">모델 설정</legend>
|
||||
@ -158,12 +130,21 @@ export default function TrainingSettings({
|
||||
variant="outlinePrimary"
|
||||
size="lg"
|
||||
onClick={handleSubmit}
|
||||
disabled={!selectedModel || isSubmitting}
|
||||
disabled={!selectedModel || isPolling}
|
||||
>
|
||||
{isSubmitting ? '기다리는 중...' : '학습 시작'}
|
||||
{isPolling ? '대기 중...' : '학습 시작'}
|
||||
</Button>
|
||||
</>
|
||||
)}
|
||||
{selectedModel?.isTrain && (
|
||||
<Button
|
||||
variant="secondary"
|
||||
size="lg"
|
||||
onClick={handleTrainingStop}
|
||||
>
|
||||
학습 중단
|
||||
</Button>
|
||||
)}
|
||||
</fieldset>
|
||||
);
|
||||
}
|
||||
|
@ -1,24 +1,57 @@
|
||||
import useTrainModelQuery from '@/queries/models/useTrainModelQuery';
|
||||
import { useState, useEffect } from 'react';
|
||||
import TrainingSettings from './TrainingSettings';
|
||||
import TrainingGraph from './TrainingGraph';
|
||||
import useTrainModelQuery from '@/queries/models/useTrainModelQuery';
|
||||
import { ModelTrainRequest, ModelResponse } from '@/types';
|
||||
import { useState } from 'react';
|
||||
import { useQueryClient } from '@tanstack/react-query';
|
||||
|
||||
interface TrainingTabProps {
|
||||
projectId: number | null;
|
||||
}
|
||||
|
||||
export default function TrainingTab({ projectId }: TrainingTabProps) {
|
||||
const numericProjectId = projectId ? parseInt(projectId.toString(), 10) : null;
|
||||
const numericProjectId = projectId !== null ? Number(projectId) : null;
|
||||
const [selectedModel, setSelectedModel] = useState<ModelResponse | null>(null);
|
||||
const [isPolling, setIsPolling] = useState(false);
|
||||
const queryClient = useQueryClient();
|
||||
|
||||
const { mutate: startTraining } = useTrainModelQuery(numericProjectId as number);
|
||||
const { mutate: startTraining } = useTrainModelQuery(numericProjectId as number, {
|
||||
onSuccess: () => {
|
||||
setIsPolling(true);
|
||||
},
|
||||
onError: () => {
|
||||
alert('학습 요청 실패');
|
||||
setIsPolling(false);
|
||||
},
|
||||
});
|
||||
|
||||
const handleTrainingStart = (trainData: ModelTrainRequest) => {
|
||||
startTraining(trainData);
|
||||
if (numericProjectId !== null) {
|
||||
startTraining(trainData);
|
||||
}
|
||||
};
|
||||
|
||||
const handleTrainingStop = () => {};
|
||||
useEffect(() => {
|
||||
if (!selectedModel || !numericProjectId || !isPolling) return;
|
||||
|
||||
const intervalId = setInterval(() => {
|
||||
queryClient.invalidateQueries({ queryKey: ['projectModels', numericProjectId] });
|
||||
}, 2000);
|
||||
|
||||
const timeoutId = setTimeout(() => {
|
||||
clearInterval(intervalId);
|
||||
setIsPolling(false);
|
||||
}, 30000);
|
||||
|
||||
return () => {
|
||||
clearInterval(intervalId);
|
||||
clearTimeout(timeoutId);
|
||||
};
|
||||
}, [selectedModel, numericProjectId, queryClient, isPolling]);
|
||||
|
||||
const handleTrainingStop = () => {
|
||||
setIsPolling(false);
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="grid grid-rows-[auto_1fr] gap-8 md:grid-cols-2">
|
||||
@ -28,9 +61,9 @@ export default function TrainingTab({ projectId }: TrainingTabProps) {
|
||||
setSelectedModel={setSelectedModel}
|
||||
handleTrainingStart={handleTrainingStart}
|
||||
handleTrainingStop={handleTrainingStop}
|
||||
isPolling={isPolling}
|
||||
className="h-full"
|
||||
/>
|
||||
|
||||
<TrainingGraph
|
||||
projectId={numericProjectId}
|
||||
selectedModel={selectedModel}
|
||||
|
@ -1,9 +1,24 @@
|
||||
import { useMutation } from '@tanstack/react-query';
|
||||
import { trainModel } from '@/api/modelApi';
|
||||
import { ModelTrainRequest } from '@/types';
|
||||
import { QueryClient } from '@tanstack/react-query';
|
||||
|
||||
export default function useTrainModelQuery(projectId: number) {
|
||||
const queryClient = new QueryClient();
|
||||
|
||||
interface UseTrainModelOptions {
|
||||
onSuccess?: () => void;
|
||||
onError?: (error: unknown) => void;
|
||||
}
|
||||
|
||||
export default function useTrainModelQuery(projectId: number, options?: UseTrainModelOptions) {
|
||||
return useMutation({
|
||||
mutationFn: (trainData: ModelTrainRequest) => trainModel(projectId, trainData),
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({ queryKey: ['projectModels', projectId] });
|
||||
options?.onSuccess?.();
|
||||
},
|
||||
onError: (error) => {
|
||||
options?.onError?.(error);
|
||||
},
|
||||
});
|
||||
}
|
||||
|
@ -6,7 +6,7 @@ export default function usePollingTrainingModelReport(projectId: number, modelId
|
||||
return useQuery<ReportResponse[]>({
|
||||
queryKey: ['modelReports', projectId, modelId],
|
||||
queryFn: () => getTrainingModelReport(projectId, modelId),
|
||||
refetchInterval: 5000,
|
||||
refetchInterval: enabled ? 5000 : false,
|
||||
enabled,
|
||||
});
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user