Refactor: 모델 리팩토링
This commit is contained in:
parent
5859e45ab4
commit
f5e00d5b42
@ -1,6 +1,7 @@
|
|||||||
import { useEffect, useState } from 'react';
|
import { useEffect } from 'react';
|
||||||
import ModelLineChart from './ModelLineChart';
|
import ModelLineChart from './ModelLineChart';
|
||||||
import usePollingModelReportsQuery from '@/queries/reports/usePollingModelReportsQuery';
|
import usePollingTrainingModelReport from '@/queries/reports/usePollingModelReportsQuery';
|
||||||
|
import { useQueryClient } from '@tanstack/react-query';
|
||||||
import { ModelResponse } from '@/types';
|
import { ModelResponse } from '@/types';
|
||||||
|
|
||||||
interface TrainingGraphProps {
|
interface TrainingGraphProps {
|
||||||
@ -10,20 +11,22 @@ interface TrainingGraphProps {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export default function TrainingGraph({ projectId, selectedModel, className }: TrainingGraphProps) {
|
export default function TrainingGraph({ projectId, selectedModel, className }: TrainingGraphProps) {
|
||||||
const [isPolling, setIsPolling] = useState(false);
|
const queryClient = useQueryClient();
|
||||||
const { data: trainingDataList } = usePollingModelReportsQuery(
|
|
||||||
|
const { data: trainingDataList } = usePollingTrainingModelReport(
|
||||||
projectId as number,
|
projectId as number,
|
||||||
selectedModel?.id as number,
|
selectedModel?.id as number,
|
||||||
isPolling
|
selectedModel?.isTrain || false
|
||||||
);
|
);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (selectedModel) {
|
if (!selectedModel || !selectedModel.isTrain) {
|
||||||
setIsPolling(true);
|
queryClient.resetQueries({
|
||||||
} else {
|
queryKey: [{ type: 'modelReports', projectId, modelId: selectedModel?.id }],
|
||||||
setIsPolling(false);
|
exact: true,
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}, [selectedModel]);
|
}, [selectedModel, queryClient, projectId]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<ModelLineChart
|
<ModelLineChart
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
import { useState, useEffect, useRef } from 'react';
|
import { useState, useEffect } from 'react';
|
||||||
import { Button } from '@/components/ui/button';
|
import { Button } from '@/components/ui/button';
|
||||||
import SelectWithLabel from './SelectWithLabel';
|
import SelectWithLabel from './SelectWithLabel';
|
||||||
import InputWithLabel from './InputWithLabel';
|
import InputWithLabel from './InputWithLabel';
|
||||||
@ -31,14 +31,16 @@ export default function TrainingSettings({
|
|||||||
const [optimizer, setOptimizer] = useState<'SGD' | 'AUTO' | 'ADAM' | 'ADAMW' | 'NADAM' | 'RADAM' | 'RMSPROP'>('AUTO');
|
const [optimizer, setOptimizer] = useState<'SGD' | 'AUTO' | 'ADAM' | 'ADAMW' | 'NADAM' | 'RADAM' | 'RMSPROP'>('AUTO');
|
||||||
const [lr0, setLr0] = useState<number>(0.01);
|
const [lr0, setLr0] = useState<number>(0.01);
|
||||||
const [lrf, setLrf] = useState<number>(0.001);
|
const [lrf, setLrf] = useState<number>(0.001);
|
||||||
const [isSubmitting, setIsSubmitting] = useState(false);
|
|
||||||
const queryClient = useQueryClient();
|
const queryClient = useQueryClient();
|
||||||
const intervalRef = useRef<NodeJS.Timeout | null>(null);
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (selectedModel?.isTrain) {
|
||||||
|
queryClient.invalidateQueries({ queryKey: ['projectModels', projectId] });
|
||||||
|
}
|
||||||
|
}, [selectedModel?.isTrain, queryClient, projectId]);
|
||||||
|
|
||||||
const handleSubmit = () => {
|
const handleSubmit = () => {
|
||||||
if (selectedModel?.isTrain) {
|
if (selectedModel) {
|
||||||
handleTrainingStop();
|
|
||||||
} else if (selectedModel) {
|
|
||||||
const trainData: ModelTrainRequest = {
|
const trainData: ModelTrainRequest = {
|
||||||
modelId: selectedModel.id,
|
modelId: selectedModel.id,
|
||||||
ratio,
|
ratio,
|
||||||
@ -48,34 +50,10 @@ export default function TrainingSettings({
|
|||||||
lr0,
|
lr0,
|
||||||
lrf,
|
lrf,
|
||||||
};
|
};
|
||||||
setIsSubmitting(true);
|
|
||||||
handleTrainingStart(trainData);
|
handleTrainingStart(trainData);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
if (isSubmitting) {
|
|
||||||
intervalRef.current = setInterval(() => {
|
|
||||||
queryClient.invalidateQueries({ queryKey: ['projectModels', projectId] });
|
|
||||||
}, 1000);
|
|
||||||
} else if (intervalRef.current) {
|
|
||||||
clearInterval(intervalRef.current);
|
|
||||||
intervalRef.current = null;
|
|
||||||
}
|
|
||||||
|
|
||||||
return () => {
|
|
||||||
if (intervalRef.current) {
|
|
||||||
clearInterval(intervalRef.current);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
}, [isSubmitting, queryClient, projectId]);
|
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
if (selectedModel?.isTrain) {
|
|
||||||
setIsSubmitting(false);
|
|
||||||
}
|
|
||||||
}, [selectedModel]);
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<fieldset className={cn('grid gap-6 rounded-lg border p-4', className)}>
|
<fieldset className={cn('grid gap-6 rounded-lg border p-4', className)}>
|
||||||
<legend className="-ml-1 px-1 text-sm font-medium">모델 설정</legend>
|
<legend className="-ml-1 px-1 text-sm font-medium">모델 설정</legend>
|
||||||
@ -158,12 +136,21 @@ export default function TrainingSettings({
|
|||||||
variant="outlinePrimary"
|
variant="outlinePrimary"
|
||||||
size="lg"
|
size="lg"
|
||||||
onClick={handleSubmit}
|
onClick={handleSubmit}
|
||||||
disabled={!selectedModel || isSubmitting}
|
disabled={!selectedModel}
|
||||||
>
|
>
|
||||||
{isSubmitting ? '기다리는 중...' : '학습 시작'}
|
학습 시작
|
||||||
</Button>
|
</Button>
|
||||||
</>
|
</>
|
||||||
)}
|
)}
|
||||||
|
{selectedModel?.isTrain && (
|
||||||
|
<Button
|
||||||
|
variant="secondary"
|
||||||
|
size="lg"
|
||||||
|
onClick={handleTrainingStop}
|
||||||
|
>
|
||||||
|
학습 중단
|
||||||
|
</Button>
|
||||||
|
)}
|
||||||
</fieldset>
|
</fieldset>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
@ -1,15 +1,15 @@
|
|||||||
import useTrainModelQuery from '@/queries/models/useTrainModelQuery';
|
import { useState } from 'react';
|
||||||
import TrainingSettings from './TrainingSettings';
|
import TrainingSettings from './TrainingSettings';
|
||||||
import TrainingGraph from './TrainingGraph';
|
import TrainingGraph from './TrainingGraph';
|
||||||
|
import useTrainModelQuery from '@/queries/models/useTrainModelQuery';
|
||||||
import { ModelTrainRequest, ModelResponse } from '@/types';
|
import { ModelTrainRequest, ModelResponse } from '@/types';
|
||||||
import { useState } from 'react';
|
|
||||||
|
|
||||||
interface TrainingTabProps {
|
interface TrainingTabProps {
|
||||||
projectId: number | null;
|
projectId: number | null;
|
||||||
}
|
}
|
||||||
|
|
||||||
export default function TrainingTab({ projectId }: TrainingTabProps) {
|
export default function TrainingTab({ projectId }: TrainingTabProps) {
|
||||||
const numericProjectId = projectId ? parseInt(projectId.toString(), 10) : null;
|
const numericProjectId = projectId !== null ? Number(projectId) : null;
|
||||||
const [selectedModel, setSelectedModel] = useState<ModelResponse | null>(null);
|
const [selectedModel, setSelectedModel] = useState<ModelResponse | null>(null);
|
||||||
|
|
||||||
const { mutate: startTraining } = useTrainModelQuery(numericProjectId as number);
|
const { mutate: startTraining } = useTrainModelQuery(numericProjectId as number);
|
||||||
@ -18,7 +18,9 @@ export default function TrainingTab({ projectId }: TrainingTabProps) {
|
|||||||
startTraining(trainData);
|
startTraining(trainData);
|
||||||
};
|
};
|
||||||
|
|
||||||
const handleTrainingStop = () => {};
|
const handleTrainingStop = () => {
|
||||||
|
// Todo: 학습 중단 로직
|
||||||
|
};
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="grid grid-rows-[auto_1fr] gap-8 md:grid-cols-2">
|
<div className="grid grid-rows-[auto_1fr] gap-8 md:grid-cols-2">
|
||||||
|
@ -1,9 +1,15 @@
|
|||||||
import { useMutation } from '@tanstack/react-query';
|
import { useMutation } from '@tanstack/react-query';
|
||||||
import { trainModel } from '@/api/modelApi';
|
import { trainModel } from '@/api/modelApi';
|
||||||
import { ModelTrainRequest } from '@/types';
|
import { ModelTrainRequest } from '@/types';
|
||||||
|
import { QueryClient } from '@tanstack/react-query';
|
||||||
|
|
||||||
|
const queryClient = new QueryClient();
|
||||||
|
|
||||||
export default function useTrainModelQuery(projectId: number) {
|
export default function useTrainModelQuery(projectId: number) {
|
||||||
return useMutation({
|
return useMutation({
|
||||||
mutationFn: (trainData: ModelTrainRequest) => trainModel(projectId, trainData),
|
mutationFn: (trainData: ModelTrainRequest) => trainModel(projectId, trainData),
|
||||||
|
onSuccess: () => {
|
||||||
|
queryClient.invalidateQueries({ queryKey: ['projectModels', projectId] });
|
||||||
|
},
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
@ -6,7 +6,7 @@ export default function usePollingTrainingModelReport(projectId: number, modelId
|
|||||||
return useQuery<ReportResponse[]>({
|
return useQuery<ReportResponse[]>({
|
||||||
queryKey: ['modelReports', projectId, modelId],
|
queryKey: ['modelReports', projectId, modelId],
|
||||||
queryFn: () => getTrainingModelReport(projectId, modelId),
|
queryFn: () => getTrainingModelReport(projectId, modelId),
|
||||||
refetchInterval: 5000,
|
refetchInterval: enabled ? 5000 : false,
|
||||||
enabled,
|
enabled,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user