Refactor: 모델 리팩토링

This commit is contained in:
정현조 2024-09-29 23:50:42 +09:00
parent 5859e45ab4
commit f5e00d5b42
5 changed files with 45 additions and 47 deletions

View File

@ -1,6 +1,7 @@
import { useEffect, useState } from 'react';
import { useEffect } from 'react';
import ModelLineChart from './ModelLineChart';
import usePollingModelReportsQuery from '@/queries/reports/usePollingModelReportsQuery';
import usePollingTrainingModelReport from '@/queries/reports/usePollingModelReportsQuery';
import { useQueryClient } from '@tanstack/react-query';
import { ModelResponse } from '@/types';
interface TrainingGraphProps {
@ -10,20 +11,22 @@ interface TrainingGraphProps {
}
export default function TrainingGraph({ projectId, selectedModel, className }: TrainingGraphProps) {
const [isPolling, setIsPolling] = useState(false);
const { data: trainingDataList } = usePollingModelReportsQuery(
const queryClient = useQueryClient();
const { data: trainingDataList } = usePollingTrainingModelReport(
projectId as number,
selectedModel?.id as number,
isPolling
selectedModel?.isTrain || false
);
useEffect(() => {
if (selectedModel) {
setIsPolling(true);
} else {
setIsPolling(false);
if (!selectedModel || !selectedModel.isTrain) {
queryClient.resetQueries({
queryKey: [{ type: 'modelReports', projectId, modelId: selectedModel?.id }],
exact: true,
});
}
}, [selectedModel]);
}, [selectedModel, queryClient, projectId]);
return (
<ModelLineChart

View File

@ -1,4 +1,4 @@
import { useState, useEffect, useRef } from 'react';
import { useState, useEffect } from 'react';
import { Button } from '@/components/ui/button';
import SelectWithLabel from './SelectWithLabel';
import InputWithLabel from './InputWithLabel';
@ -31,14 +31,16 @@ export default function TrainingSettings({
const [optimizer, setOptimizer] = useState<'SGD' | 'AUTO' | 'ADAM' | 'ADAMW' | 'NADAM' | 'RADAM' | 'RMSPROP'>('AUTO');
const [lr0, setLr0] = useState<number>(0.01);
const [lrf, setLrf] = useState<number>(0.001);
const [isSubmitting, setIsSubmitting] = useState(false);
const queryClient = useQueryClient();
const intervalRef = useRef<NodeJS.Timeout | null>(null);
useEffect(() => {
if (selectedModel?.isTrain) {
queryClient.invalidateQueries({ queryKey: ['projectModels', projectId] });
}
}, [selectedModel?.isTrain, queryClient, projectId]);
const handleSubmit = () => {
if (selectedModel?.isTrain) {
handleTrainingStop();
} else if (selectedModel) {
if (selectedModel) {
const trainData: ModelTrainRequest = {
modelId: selectedModel.id,
ratio,
@ -48,34 +50,10 @@ export default function TrainingSettings({
lr0,
lrf,
};
setIsSubmitting(true);
handleTrainingStart(trainData);
}
};
useEffect(() => {
if (isSubmitting) {
intervalRef.current = setInterval(() => {
queryClient.invalidateQueries({ queryKey: ['projectModels', projectId] });
}, 1000);
} else if (intervalRef.current) {
clearInterval(intervalRef.current);
intervalRef.current = null;
}
return () => {
if (intervalRef.current) {
clearInterval(intervalRef.current);
}
};
}, [isSubmitting, queryClient, projectId]);
useEffect(() => {
if (selectedModel?.isTrain) {
setIsSubmitting(false);
}
}, [selectedModel]);
return (
<fieldset className={cn('grid gap-6 rounded-lg border p-4', className)}>
<legend className="-ml-1 px-1 text-sm font-medium"> </legend>
@ -158,12 +136,21 @@ export default function TrainingSettings({
variant="outlinePrimary"
size="lg"
onClick={handleSubmit}
disabled={!selectedModel || isSubmitting}
disabled={!selectedModel}
>
{isSubmitting ? '기다리는 중...' : '학습 시작'}
</Button>
</>
)}
{selectedModel?.isTrain && (
<Button
variant="secondary"
size="lg"
onClick={handleTrainingStop}
>
</Button>
)}
</fieldset>
);
}

View File

@ -1,15 +1,15 @@
import useTrainModelQuery from '@/queries/models/useTrainModelQuery';
import { useState } from 'react';
import TrainingSettings from './TrainingSettings';
import TrainingGraph from './TrainingGraph';
import useTrainModelQuery from '@/queries/models/useTrainModelQuery';
import { ModelTrainRequest, ModelResponse } from '@/types';
import { useState } from 'react';
interface TrainingTabProps {
projectId: number | null;
}
export default function TrainingTab({ projectId }: TrainingTabProps) {
const numericProjectId = projectId ? parseInt(projectId.toString(), 10) : null;
const numericProjectId = projectId !== null ? Number(projectId) : null;
const [selectedModel, setSelectedModel] = useState<ModelResponse | null>(null);
const { mutate: startTraining } = useTrainModelQuery(numericProjectId as number);
@ -18,7 +18,9 @@ export default function TrainingTab({ projectId }: TrainingTabProps) {
startTraining(trainData);
};
const handleTrainingStop = () => {};
const handleTrainingStop = () => {
// Todo: 학습 중단 로직
};
return (
<div className="grid grid-rows-[auto_1fr] gap-8 md:grid-cols-2">

View File

@ -1,9 +1,15 @@
import { useMutation } from '@tanstack/react-query';
import { trainModel } from '@/api/modelApi';
import { ModelTrainRequest } from '@/types';
import { QueryClient } from '@tanstack/react-query';
const queryClient = new QueryClient();
export default function useTrainModelQuery(projectId: number) {
return useMutation({
mutationFn: (trainData: ModelTrainRequest) => trainModel(projectId, trainData),
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: ['projectModels', projectId] });
},
});
}

View File

@ -6,7 +6,7 @@ export default function usePollingTrainingModelReport(projectId: number, modelId
return useQuery<ReportResponse[]>({
queryKey: ['modelReports', projectId, modelId],
queryFn: () => getTrainingModelReport(projectId, modelId),
refetchInterval: 5000,
refetchInterval: enabled ? 5000 : false,
enabled,
});
}