Refactor: 학습 모델 리팩토링

This commit is contained in:
정현조 2024-10-02 16:21:57 +09:00
parent 5081f342d1
commit eb24f919b1
6 changed files with 199 additions and 118 deletions

View File

@ -1,44 +1,29 @@
import { useEffect } from 'react';
import ModelLineChart from './ModelLineChart';
import usePollingTrainingModelReport from '@/hooks/usePollingTrainingModelReport';
import { useQueryClient } from '@tanstack/react-query';
import { ModelResponse } from '@/types';
interface TrainingGraphProps {
projectId: number | null;
selectedModel: ModelResponse | null;
isTraining: boolean;
onTrainingEnd: () => void;
className?: string;
}
export default function TrainingGraph({ projectId, selectedModel, className }: TrainingGraphProps) {
const queryClient = useQueryClient();
const isTrainingEnabled = Boolean(selectedModel?.isTrain);
const handleTrainingEnd = () => {
queryClient.resetQueries({
queryKey: ['modelReports', projectId, selectedModel?.id],
exact: true,
});
alert('학습이 완료되었습니다.');
};
const { data: trainingDataList } = usePollingTrainingModelReport({
export default function TrainingGraph({
projectId,
selectedModel,
isTraining,
onTrainingEnd,
className,
}: TrainingGraphProps) {
const { reportData: trainingDataList } = usePollingTrainingModelReport({
projectId: projectId as number,
modelId: selectedModel?.id as number,
enabled: isTrainingEnabled,
onTrainingEnd: handleTrainingEnd,
enabled: isTraining,
onTrainingEnd,
});
useEffect(() => {
if (!selectedModel || !selectedModel.isTrain) {
queryClient.resetQueries({
queryKey: ['modelReports', projectId, selectedModel?.id],
exact: true,
});
}
}, [selectedModel, queryClient, projectId]);
return (
<ModelLineChart
data={trainingDataList || []}

View File

@ -12,7 +12,8 @@ interface TrainingSettingsProps {
setSelectedModel: (model: ModelResponse | null) => void;
handleTrainingStart: (trainData: ModelTrainRequest) => void;
handleTrainingStop: () => void;
isPolling: boolean;
isWaiting: boolean;
isTraining: boolean;
className?: string;
}
@ -22,7 +23,8 @@ export default function TrainingSettings({
setSelectedModel,
handleTrainingStart,
handleTrainingStop,
isPolling,
isWaiting,
isTraining,
className,
}: TrainingSettingsProps) {
const { data: models } = useProjectModelsQuery(projectId ?? 0);
@ -48,9 +50,6 @@ export default function TrainingSettings({
}
};
const isTraining = selectedModel?.isTrain;
const isWaiting = isPolling && !isTraining;
return (
<fieldset className={cn('grid gap-6 rounded-lg border p-4', className)}>
<legend className="-ml-1 px-1 text-sm font-medium"> </legend>
@ -73,7 +72,8 @@ export default function TrainingSettings({
}}
/>
</div>
{!isPolling && !isTraining && (
{!isWaiting && !isTraining && (
<>
<div className="grid grid-cols-2 gap-4">
<InputWithLabel

View File

@ -1,10 +1,11 @@
import { useState, useEffect } from 'react';
import { useState, useEffect, useRef } from 'react';
import TrainingSettings from './TrainingSettings';
import TrainingGraph from './TrainingGraph';
import useTrainModelQuery from '@/queries/models/useTrainModelQuery';
import usePollingTrainingModelReport from '@/hooks/usePollingTrainingModelReport';
import useProjectModelsQuery from '@/queries/models/useProjectModelsQuery';
import { ModelTrainRequest, ModelResponse } from '@/types';
import { useQueryClient } from '@tanstack/react-query';
import Swal from 'sweetalert2';
interface TrainingTabProps {
projectId: number | null;
@ -13,63 +14,103 @@ interface TrainingTabProps {
export default function TrainingTab({ projectId }: TrainingTabProps) {
const numericProjectId = projectId !== null ? Number(projectId) : null;
const [selectedModel, setSelectedModel] = useState<ModelResponse | null>(null);
const [isPolling, setIsPolling] = useState(false);
const [isWaiting, setIsWaiting] = useState<{ [modelId: number]: boolean }>({});
const [isTraining, setIsTraining] = useState<{ [modelId: number]: boolean }>({});
const queryClient = useQueryClient();
const prevModelRef = useRef<ModelResponse | null>(null);
const { mutate: startTraining } = useTrainModelQuery(numericProjectId as number);
const handleTrainingStart = (trainData: ModelTrainRequest) => {
if (numericProjectId !== null) {
startTraining(trainData);
setIsPolling(true);
}
};
const handleTrainingEnd = () => {
setIsPolling(false);
setSelectedModel((prevModel) => (prevModel ? { ...prevModel, isTrain: false } : null));
};
const { data: models } = useProjectModelsQuery(numericProjectId ?? 0);
useEffect(() => {
if (!selectedModel || !numericProjectId || !isPolling) return;
if (models) {
const trainingModels = models.filter((model) => model.isTrain);
const newIsTraining = trainingModels.reduce(
(acc, model) => {
acc[model.id] = true;
return acc;
},
{} as { [modelId: number]: boolean }
);
setIsTraining(newIsTraining);
const intervalId = setInterval(async () => {
await queryClient.invalidateQueries({ queryKey: ['projectModels', numericProjectId] });
const models = await queryClient.getQueryData<ModelResponse[]>(['projectModels', numericProjectId]);
const updatedModel = models?.find((model) => model.id === selectedModel.id);
if (selectedModel && trainingModels.some((model) => model.id === selectedModel.id)) {
setSelectedModel(selectedModel);
} else {
setSelectedModel(null);
}
}
}, [models, selectedModel]);
useEffect(() => {
if (models && selectedModel) {
const updatedModel = models.find((model) => model.id === selectedModel.id);
if (updatedModel) {
setSelectedModel(updatedModel);
if (updatedModel.isTrain) {
setIsPolling(false);
} else {
setIsPolling(false);
setSelectedModel({ ...updatedModel, isTrain: false });
if (isWaiting[selectedModel.id] && updatedModel.isTrain) {
setIsWaiting((prev) => ({ ...prev, [selectedModel.id]: false }));
setIsTraining((prev) => ({ ...prev, [selectedModel.id]: true }));
}
}
}
}, [models, selectedModel, isWaiting]);
useEffect(() => {
let intervalId: NodeJS.Timeout | null = null;
if (selectedModel && isWaiting[selectedModel.id]) {
intervalId = setInterval(async () => {
await queryClient.invalidateQueries({ queryKey: ['projectModels', numericProjectId] });
}, 2000);
}
return () => {
if (intervalId) {
clearInterval(intervalId);
}
};
}, [selectedModel, numericProjectId, queryClient, isPolling]);
}, [isWaiting, selectedModel, queryClient, numericProjectId]);
usePollingTrainingModelReport({
projectId: numericProjectId as number,
modelId: selectedModel?.id as number,
enabled: selectedModel?.isTrain || false,
onTrainingEnd: handleTrainingEnd,
const handleTrainingStart = (trainData: ModelTrainRequest) => {
if (numericProjectId !== null && selectedModel) {
startTraining(trainData);
setIsWaiting((prev) => ({ ...prev, [selectedModel.id]: true }));
}
};
const handleTrainingEnd = (modelId: number) => {
if (prevModelRef.current && prevModelRef.current.id === modelId) {
Swal.fire({
title: '학습 완료',
text: `모델 "${prevModelRef.current.name}"의 학습이 완료되었습니다.`,
icon: 'success',
confirmButtonText: '확인',
});
}
setIsTraining((prev) => ({ ...prev, [modelId]: false }));
if (selectedModel && selectedModel.id === modelId) {
setSelectedModel(null);
}
};
const handleTrainingStop = () => {
setIsPolling(false);
setSelectedModel((prevModel) => (prevModel ? { ...prevModel, isTrain: false } : null));
//todo: 중단 함수 연결
if (selectedModel) {
setIsWaiting((prev) => ({ ...prev, [selectedModel.id]: false }));
setIsTraining((prev) => ({ ...prev, [selectedModel.id]: false }));
setSelectedModel(null);
// TODO: 학습 중단 기능 구현
}
};
useEffect(() => {
if (selectedModel) {
prevModelRef.current = selectedModel;
}
}, [selectedModel]);
return (
<div className="grid grid-rows-[auto_1fr] gap-8 md:grid-cols-2">
<TrainingSettings
@ -78,12 +119,15 @@ export default function TrainingTab({ projectId }: TrainingTabProps) {
setSelectedModel={setSelectedModel}
handleTrainingStart={handleTrainingStart}
handleTrainingStop={handleTrainingStop}
isPolling={isPolling}
isWaiting={selectedModel ? isWaiting[selectedModel.id] || false : false}
isTraining={selectedModel ? isTraining[selectedModel.id] || false : false}
className="h-full"
/>
<TrainingGraph
projectId={numericProjectId}
selectedModel={selectedModel}
isTraining={selectedModel ? isTraining[selectedModel.id] || false : false}
onTrainingEnd={() => selectedModel && handleTrainingEnd(selectedModel.id)}
className="h-full"
/>
</div>

View File

@ -5,37 +5,71 @@ import { Project } from '@/types';
import { Select, SelectTrigger, SelectContent, SelectItem, SelectValue } from '../ui/select';
import useCanvasStore from '@/stores/useCanvasStore';
import { webPath } from '@/router';
import { Suspense, useEffect } from 'react';
import { useState } from 'react';
import useUploadImageFileQuery from '@/queries/projects/useUploadImageFileQuery';
import useAuthStore from '@/stores/useAuthStore';
export default function WorkspaceSidebar({ workspaceName, projects }: { workspaceName: string; projects: Project[] }) {
const { setImage } = useCanvasStore();
const { projectId: selectedProjectId, workspaceId } = useParams<{ projectId: string; workspaceId: string }>();
const selectedProject = projects.find((project) => project.id.toString() === selectedProjectId) || null;
const { projectId: selectedProjectId } = useParams<{ projectId: string }>();
const selectedProject = projects.find((project) => project.id.toString() === selectedProjectId);
const setSidebarSize = useCanvasStore((state) => state.setSidebarSize);
const navigate = useNavigate();
const { workspaceId } = useParams<{ workspaceId: string }>();
const [isDragging, setIsDragging] = useState(false);
const uploadImageFileMutation = useUploadImageFileQuery();
const { profile } = useAuthStore();
const handleSelectProject = (projectId: string) => {
setImage(null);
navigate(`${webPath.workspace()}/${workspaceId}/${projectId}`);
navigate(`${webPath.workspace()}/${workspaceId}/project/${projectId}`);
};
useEffect(() => {
if (!selectedProject) {
setImage(null);
}
}, [selectedProject, setImage]);
const handleDragOver = (e: React.DragEvent) => {
e.preventDefault();
setIsDragging(true);
};
const handleDragLeave = () => {
setIsDragging(false);
};
const handleDrop = (e: React.DragEvent) => {
e.preventDefault();
setIsDragging(false);
if (!selectedProjectId || !profile) return;
const files = Array.from(e.dataTransfer.files);
const memberId = profile.id;
const projectId = parseInt(selectedProjectId);
const folderId = 0;
uploadImageFileMutation.mutate({
memberId,
projectId,
folderId,
files,
progressCallback: (progress) => {
console.log(`업로드 진행률: ${progress}%`);
},
});
};
return (
<>
<ResizablePanel
minSize={10}
maxSize={35}
defaultSize={15}
defaultSize={20}
className={`flex h-full flex-col bg-gray-50 ${isDragging ? 'bg-blue-100' : ''}`}
onResize={(size) => setSidebarSize(size)}
onDragOver={(e) => handleDragOver(e as unknown as React.DragEvent<Element>)}
onDragLeave={handleDragLeave}
onDrop={(e) => handleDrop(e as unknown as React.DragEvent<Element>)}
>
<div className="box-border flex h-full flex-col gap-2 bg-gray-50 p-3">
<header className="body flex w-full items-center gap-2">
<h1 className="subheading w-full overflow-hidden text-ellipsis whitespace-nowrap">{workspaceName}</h1>
<header className="body flex w-full items-center gap-2 p-2">
<h1 className="w-full overflow-hidden text-ellipsis whitespace-nowrap">{workspaceName}</h1>
</header>
<div className="p-2">
<Select
onValueChange={handleSelectProject}
defaultValue={selectedProjectId}
@ -54,10 +88,8 @@ export default function WorkspaceSidebar({ workspaceName, projects }: { workspac
))}
</SelectContent>
</Select>
<Suspense fallback={<div></div>}>
{selectedProject && <ProjectStructure project={selectedProject} />}
</Suspense>
</div>
{selectedProject && <ProjectStructure project={selectedProject} />}
</ResizablePanel>
<ResizableHandle className="bg-gray-300" />
</>

View File

@ -1,33 +1,36 @@
import { GripVertical } from "lucide-react"
import * as ResizablePrimitive from "react-resizable-panels"
import { GripVertical } from 'lucide-react';
import * as ResizablePrimitive from 'react-resizable-panels';
import { cn } from "@/lib/utils"
import { cn } from '@/lib/utils';
const ResizablePanelGroup = ({
className,
...props
}: React.ComponentProps<typeof ResizablePrimitive.PanelGroup>) => (
type PanelGroupProps = React.ComponentProps<typeof ResizablePrimitive.PanelGroup>;
const ResizablePanelGroup = ({ className, ...props }: PanelGroupProps) => (
<ResizablePrimitive.PanelGroup
className={cn(
"flex h-full w-full data-[panel-group-direction=vertical]:flex-col",
className
)}
className={cn('flex h-full w-full data-[panel-group-direction=vertical]:flex-col', className)}
{...props}
/>
)
);
const ResizablePanel = ResizablePrimitive.Panel
type PanelProps = React.ComponentProps<typeof ResizablePrimitive.Panel>;
const ResizablePanel = ({ className, ...props }: PanelProps) => (
<ResizablePrimitive.Panel
className={cn('resizable-panel', className)}
{...props}
/>
);
const ResizableHandle = ({
withHandle,
className,
...props
}: React.ComponentProps<typeof ResizablePrimitive.PanelResizeHandle> & {
withHandle?: boolean
withHandle?: boolean;
}) => (
<ResizablePrimitive.PanelResizeHandle
className={cn(
"relative flex w-px items-center justify-center bg-gray-200 after:absolute after:inset-y-0 after:left-1/2 after:w-1 after:-translate-x-1/2 focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-gray-950 focus-visible:ring-offset-1 data-[panel-group-direction=vertical]:h-px data-[panel-group-direction=vertical]:w-full data-[panel-group-direction=vertical]:after:left-0 data-[panel-group-direction=vertical]:after:h-1 data-[panel-group-direction=vertical]:after:w-full data-[panel-group-direction=vertical]:after:-translate-y-1/2 data-[panel-group-direction=vertical]:after:translate-x-0 [&[data-panel-group-direction=vertical]>div]:rotate-90 dark:bg-gray-800 dark:focus-visible:ring-gray-300",
'relative flex w-px items-center justify-center bg-gray-200 after:absolute after:inset-y-0 after:left-1/2 after:w-1 after:-translate-x-1/2 focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-gray-950 focus-visible:ring-offset-1 data-[panel-group-direction=vertical]:h-px data-[panel-group-direction=vertical]:w-full data-[panel-group-direction=vertical]:after:left-0 data-[panel-group-direction=vertical]:after:h-1 data-[panel-group-direction=vertical]:after:w-full data-[panel-group-direction=vertical]:after:-translate-y-1/2 data-[panel-group-direction=vertical]:after:translate-x-0 dark:bg-gray-800 dark:focus-visible:ring-gray-300 [&[data-panel-group-direction=vertical]>div]:rotate-90',
className
)}
{...props}
@ -38,6 +41,6 @@ const ResizableHandle = ({
</div>
)}
</ResizablePrimitive.PanelResizeHandle>
)
);
export { ResizablePanelGroup, ResizablePanel, ResizableHandle }
export { ResizablePanelGroup, ResizablePanel, ResizableHandle };

View File

@ -1,7 +1,8 @@
import { useEffect } from 'react';
import { useEffect, useRef } from 'react';
import { useQuery } from '@tanstack/react-query';
import { getTrainingModelReport } from '@/api/reportApi';
import { ReportResponse } from '@/types';
import { getProjectModels } from '@/api/modelApi';
import { ReportResponse, ProjectModelsResponse } from '@/types';
interface UsePollingTrainingModelReportProps {
projectId: number;
@ -16,7 +17,7 @@ export default function usePollingTrainingModelReport({
enabled,
onTrainingEnd,
}: UsePollingTrainingModelReportProps) {
const query = useQuery<ReportResponse[]>({
const reportQuery = useQuery<ReportResponse[]>({
queryKey: ['modelReports', projectId, modelId],
queryFn: () => getTrainingModelReport(projectId, modelId),
enabled,
@ -30,14 +31,30 @@ export default function usePollingTrainingModelReport({
},
});
const modelQuery = useQuery<ProjectModelsResponse>({
queryKey: ['projectModels', projectId],
queryFn: () => getProjectModels(projectId),
enabled,
refetchInterval: 2000,
});
const prevIsTrainRef = useRef<boolean | null>(null);
useEffect(() => {
if (query.data && query.data.length > 0) {
const lastReport = query.data[query.data.length - 1];
if (lastReport.epoch >= lastReport.totalEpochs) {
if (modelQuery.data) {
const model = modelQuery.data.find((m) => m.id === modelId);
if (model) {
const currentIsTrain = model.isTrain;
const prevIsTrain = prevIsTrainRef.current;
if (prevIsTrain === true && currentIsTrain === false) {
onTrainingEnd();
}
}
}, [query.data, onTrainingEnd]);
return query;
prevIsTrainRef.current = currentIsTrain;
}
}
}, [modelQuery.data, modelId, onTrainingEnd]);
return { reportData: reportQuery.data, modelData: modelQuery.data };
}