Refactor: 학습 모델 리팩토링

This commit is contained in:
정현조 2024-10-02 16:21:57 +09:00
parent 5081f342d1
commit eb24f919b1
6 changed files with 199 additions and 118 deletions

View File

@ -1,44 +1,29 @@
import { useEffect } from 'react';
import ModelLineChart from './ModelLineChart'; import ModelLineChart from './ModelLineChart';
import usePollingTrainingModelReport from '@/hooks/usePollingTrainingModelReport'; import usePollingTrainingModelReport from '@/hooks/usePollingTrainingModelReport';
import { useQueryClient } from '@tanstack/react-query';
import { ModelResponse } from '@/types'; import { ModelResponse } from '@/types';
interface TrainingGraphProps { interface TrainingGraphProps {
projectId: number | null; projectId: number | null;
selectedModel: ModelResponse | null; selectedModel: ModelResponse | null;
isTraining: boolean;
onTrainingEnd: () => void;
className?: string; className?: string;
} }
export default function TrainingGraph({ projectId, selectedModel, className }: TrainingGraphProps) { export default function TrainingGraph({
const queryClient = useQueryClient(); projectId,
selectedModel,
const isTrainingEnabled = Boolean(selectedModel?.isTrain); isTraining,
onTrainingEnd,
const handleTrainingEnd = () => { className,
queryClient.resetQueries({ }: TrainingGraphProps) {
queryKey: ['modelReports', projectId, selectedModel?.id], const { reportData: trainingDataList } = usePollingTrainingModelReport({
exact: true,
});
alert('학습이 완료되었습니다.');
};
const { data: trainingDataList } = usePollingTrainingModelReport({
projectId: projectId as number, projectId: projectId as number,
modelId: selectedModel?.id as number, modelId: selectedModel?.id as number,
enabled: isTrainingEnabled, enabled: isTraining,
onTrainingEnd: handleTrainingEnd, onTrainingEnd,
}); });
useEffect(() => {
if (!selectedModel || !selectedModel.isTrain) {
queryClient.resetQueries({
queryKey: ['modelReports', projectId, selectedModel?.id],
exact: true,
});
}
}, [selectedModel, queryClient, projectId]);
return ( return (
<ModelLineChart <ModelLineChart
data={trainingDataList || []} data={trainingDataList || []}

View File

@ -12,7 +12,8 @@ interface TrainingSettingsProps {
setSelectedModel: (model: ModelResponse | null) => void; setSelectedModel: (model: ModelResponse | null) => void;
handleTrainingStart: (trainData: ModelTrainRequest) => void; handleTrainingStart: (trainData: ModelTrainRequest) => void;
handleTrainingStop: () => void; handleTrainingStop: () => void;
isPolling: boolean; isWaiting: boolean;
isTraining: boolean;
className?: string; className?: string;
} }
@ -22,7 +23,8 @@ export default function TrainingSettings({
setSelectedModel, setSelectedModel,
handleTrainingStart, handleTrainingStart,
handleTrainingStop, handleTrainingStop,
isPolling, isWaiting,
isTraining,
className, className,
}: TrainingSettingsProps) { }: TrainingSettingsProps) {
const { data: models } = useProjectModelsQuery(projectId ?? 0); const { data: models } = useProjectModelsQuery(projectId ?? 0);
@ -48,9 +50,6 @@ export default function TrainingSettings({
} }
}; };
const isTraining = selectedModel?.isTrain;
const isWaiting = isPolling && !isTraining;
return ( return (
<fieldset className={cn('grid gap-6 rounded-lg border p-4', className)}> <fieldset className={cn('grid gap-6 rounded-lg border p-4', className)}>
<legend className="-ml-1 px-1 text-sm font-medium"> </legend> <legend className="-ml-1 px-1 text-sm font-medium"> </legend>
@ -73,7 +72,8 @@ export default function TrainingSettings({
}} }}
/> />
</div> </div>
{!isPolling && !isTraining && (
{!isWaiting && !isTraining && (
<> <>
<div className="grid grid-cols-2 gap-4"> <div className="grid grid-cols-2 gap-4">
<InputWithLabel <InputWithLabel

View File

@ -1,10 +1,11 @@
import { useState, useEffect } from 'react'; import { useState, useEffect, useRef } from 'react';
import TrainingSettings from './TrainingSettings'; import TrainingSettings from './TrainingSettings';
import TrainingGraph from './TrainingGraph'; import TrainingGraph from './TrainingGraph';
import useTrainModelQuery from '@/queries/models/useTrainModelQuery'; import useTrainModelQuery from '@/queries/models/useTrainModelQuery';
import usePollingTrainingModelReport from '@/hooks/usePollingTrainingModelReport'; import useProjectModelsQuery from '@/queries/models/useProjectModelsQuery';
import { ModelTrainRequest, ModelResponse } from '@/types'; import { ModelTrainRequest, ModelResponse } from '@/types';
import { useQueryClient } from '@tanstack/react-query'; import { useQueryClient } from '@tanstack/react-query';
import Swal from 'sweetalert2';
interface TrainingTabProps { interface TrainingTabProps {
projectId: number | null; projectId: number | null;
@ -13,63 +14,103 @@ interface TrainingTabProps {
export default function TrainingTab({ projectId }: TrainingTabProps) { export default function TrainingTab({ projectId }: TrainingTabProps) {
const numericProjectId = projectId !== null ? Number(projectId) : null; const numericProjectId = projectId !== null ? Number(projectId) : null;
const [selectedModel, setSelectedModel] = useState<ModelResponse | null>(null); const [selectedModel, setSelectedModel] = useState<ModelResponse | null>(null);
const [isPolling, setIsPolling] = useState(false); const [isWaiting, setIsWaiting] = useState<{ [modelId: number]: boolean }>({});
const [isTraining, setIsTraining] = useState<{ [modelId: number]: boolean }>({});
const queryClient = useQueryClient(); const queryClient = useQueryClient();
const prevModelRef = useRef<ModelResponse | null>(null);
const { mutate: startTraining } = useTrainModelQuery(numericProjectId as number); const { mutate: startTraining } = useTrainModelQuery(numericProjectId as number);
const { data: models } = useProjectModelsQuery(numericProjectId ?? 0);
const handleTrainingStart = (trainData: ModelTrainRequest) => {
if (numericProjectId !== null) {
startTraining(trainData);
setIsPolling(true);
}
};
const handleTrainingEnd = () => {
setIsPolling(false);
setSelectedModel((prevModel) => (prevModel ? { ...prevModel, isTrain: false } : null));
};
useEffect(() => { useEffect(() => {
if (!selectedModel || !numericProjectId || !isPolling) return; if (models) {
const trainingModels = models.filter((model) => model.isTrain);
const newIsTraining = trainingModels.reduce(
(acc, model) => {
acc[model.id] = true;
return acc;
},
{} as { [modelId: number]: boolean }
);
setIsTraining(newIsTraining);
const intervalId = setInterval(async () => { if (selectedModel && trainingModels.some((model) => model.id === selectedModel.id)) {
await queryClient.invalidateQueries({ queryKey: ['projectModels', numericProjectId] }); setSelectedModel(selectedModel);
} else {
const models = await queryClient.getQueryData<ModelResponse[]>(['projectModels', numericProjectId]); setSelectedModel(null);
}
const updatedModel = models?.find((model) => model.id === selectedModel.id); }
}, [models, selectedModel]);
useEffect(() => {
if (models && selectedModel) {
const updatedModel = models.find((model) => model.id === selectedModel.id);
if (updatedModel) { if (updatedModel) {
setSelectedModel(updatedModel); setSelectedModel(updatedModel);
if (updatedModel.isTrain) { if (isWaiting[selectedModel.id] && updatedModel.isTrain) {
setIsPolling(false); setIsWaiting((prev) => ({ ...prev, [selectedModel.id]: false }));
} else { setIsTraining((prev) => ({ ...prev, [selectedModel.id]: true }));
setIsPolling(false);
setSelectedModel({ ...updatedModel, isTrain: false });
} }
} }
}, 2000); }
}, [models, selectedModel, isWaiting]);
useEffect(() => {
let intervalId: NodeJS.Timeout | null = null;
if (selectedModel && isWaiting[selectedModel.id]) {
intervalId = setInterval(async () => {
await queryClient.invalidateQueries({ queryKey: ['projectModels', numericProjectId] });
}, 2000);
}
return () => { return () => {
clearInterval(intervalId); if (intervalId) {
clearInterval(intervalId);
}
}; };
}, [selectedModel, numericProjectId, queryClient, isPolling]); }, [isWaiting, selectedModel, queryClient, numericProjectId]);
usePollingTrainingModelReport({ const handleTrainingStart = (trainData: ModelTrainRequest) => {
projectId: numericProjectId as number, if (numericProjectId !== null && selectedModel) {
modelId: selectedModel?.id as number, startTraining(trainData);
enabled: selectedModel?.isTrain || false, setIsWaiting((prev) => ({ ...prev, [selectedModel.id]: true }));
onTrainingEnd: handleTrainingEnd, }
}); };
const handleTrainingEnd = (modelId: number) => {
if (prevModelRef.current && prevModelRef.current.id === modelId) {
Swal.fire({
title: '학습 완료',
text: `모델 "${prevModelRef.current.name}"의 학습이 완료되었습니다.`,
icon: 'success',
confirmButtonText: '확인',
});
}
setIsTraining((prev) => ({ ...prev, [modelId]: false }));
if (selectedModel && selectedModel.id === modelId) {
setSelectedModel(null);
}
};
const handleTrainingStop = () => { const handleTrainingStop = () => {
setIsPolling(false); if (selectedModel) {
setSelectedModel((prevModel) => (prevModel ? { ...prevModel, isTrain: false } : null)); setIsWaiting((prev) => ({ ...prev, [selectedModel.id]: false }));
//todo: 중단 함수 연결 setIsTraining((prev) => ({ ...prev, [selectedModel.id]: false }));
setSelectedModel(null);
// TODO: 학습 중단 기능 구현
}
}; };
useEffect(() => {
if (selectedModel) {
prevModelRef.current = selectedModel;
}
}, [selectedModel]);
return ( return (
<div className="grid grid-rows-[auto_1fr] gap-8 md:grid-cols-2"> <div className="grid grid-rows-[auto_1fr] gap-8 md:grid-cols-2">
<TrainingSettings <TrainingSettings
@ -78,12 +119,15 @@ export default function TrainingTab({ projectId }: TrainingTabProps) {
setSelectedModel={setSelectedModel} setSelectedModel={setSelectedModel}
handleTrainingStart={handleTrainingStart} handleTrainingStart={handleTrainingStart}
handleTrainingStop={handleTrainingStop} handleTrainingStop={handleTrainingStop}
isPolling={isPolling} isWaiting={selectedModel ? isWaiting[selectedModel.id] || false : false}
isTraining={selectedModel ? isTraining[selectedModel.id] || false : false}
className="h-full" className="h-full"
/> />
<TrainingGraph <TrainingGraph
projectId={numericProjectId} projectId={numericProjectId}
selectedModel={selectedModel} selectedModel={selectedModel}
isTraining={selectedModel ? isTraining[selectedModel.id] || false : false}
onTrainingEnd={() => selectedModel && handleTrainingEnd(selectedModel.id)}
className="h-full" className="h-full"
/> />
</div> </div>

View File

@ -5,37 +5,71 @@ import { Project } from '@/types';
import { Select, SelectTrigger, SelectContent, SelectItem, SelectValue } from '../ui/select'; import { Select, SelectTrigger, SelectContent, SelectItem, SelectValue } from '../ui/select';
import useCanvasStore from '@/stores/useCanvasStore'; import useCanvasStore from '@/stores/useCanvasStore';
import { webPath } from '@/router'; import { webPath } from '@/router';
import { Suspense, useEffect } from 'react'; import { useState } from 'react';
import useUploadImageFileQuery from '@/queries/projects/useUploadImageFileQuery';
import useAuthStore from '@/stores/useAuthStore';
export default function WorkspaceSidebar({ workspaceName, projects }: { workspaceName: string; projects: Project[] }) { export default function WorkspaceSidebar({ workspaceName, projects }: { workspaceName: string; projects: Project[] }) {
const { setImage } = useCanvasStore(); const { projectId: selectedProjectId } = useParams<{ projectId: string }>();
const { projectId: selectedProjectId, workspaceId } = useParams<{ projectId: string; workspaceId: string }>(); const selectedProject = projects.find((project) => project.id.toString() === selectedProjectId);
const selectedProject = projects.find((project) => project.id.toString() === selectedProjectId) || null;
const setSidebarSize = useCanvasStore((state) => state.setSidebarSize); const setSidebarSize = useCanvasStore((state) => state.setSidebarSize);
const navigate = useNavigate(); const navigate = useNavigate();
const { workspaceId } = useParams<{ workspaceId: string }>();
const [isDragging, setIsDragging] = useState(false);
const uploadImageFileMutation = useUploadImageFileQuery();
const { profile } = useAuthStore();
const handleSelectProject = (projectId: string) => { const handleSelectProject = (projectId: string) => {
setImage(null); navigate(`${webPath.workspace()}/${workspaceId}/project/${projectId}`);
navigate(`${webPath.workspace()}/${workspaceId}/${projectId}`);
}; };
useEffect(() => { const handleDragOver = (e: React.DragEvent) => {
if (!selectedProject) { e.preventDefault();
setImage(null); setIsDragging(true);
} };
}, [selectedProject, setImage]);
const handleDragLeave = () => {
setIsDragging(false);
};
const handleDrop = (e: React.DragEvent) => {
e.preventDefault();
setIsDragging(false);
if (!selectedProjectId || !profile) return;
const files = Array.from(e.dataTransfer.files);
const memberId = profile.id;
const projectId = parseInt(selectedProjectId);
const folderId = 0;
uploadImageFileMutation.mutate({
memberId,
projectId,
folderId,
files,
progressCallback: (progress) => {
console.log(`업로드 진행률: ${progress}%`);
},
});
};
return ( return (
<> <>
<ResizablePanel <ResizablePanel
minSize={10} minSize={10}
maxSize={35} maxSize={35}
defaultSize={15} defaultSize={20}
className={`flex h-full flex-col bg-gray-50 ${isDragging ? 'bg-blue-100' : ''}`}
onResize={(size) => setSidebarSize(size)} onResize={(size) => setSidebarSize(size)}
onDragOver={(e) => handleDragOver(e as unknown as React.DragEvent<Element>)}
onDragLeave={handleDragLeave}
onDrop={(e) => handleDrop(e as unknown as React.DragEvent<Element>)}
> >
<div className="box-border flex h-full flex-col gap-2 bg-gray-50 p-3"> <header className="body flex w-full items-center gap-2 p-2">
<header className="body flex w-full items-center gap-2"> <h1 className="w-full overflow-hidden text-ellipsis whitespace-nowrap">{workspaceName}</h1>
<h1 className="subheading w-full overflow-hidden text-ellipsis whitespace-nowrap">{workspaceName}</h1> </header>
</header> <div className="p-2">
<Select <Select
onValueChange={handleSelectProject} onValueChange={handleSelectProject}
defaultValue={selectedProjectId} defaultValue={selectedProjectId}
@ -54,10 +88,8 @@ export default function WorkspaceSidebar({ workspaceName, projects }: { workspac
))} ))}
</SelectContent> </SelectContent>
</Select> </Select>
<Suspense fallback={<div></div>}>
{selectedProject && <ProjectStructure project={selectedProject} />}
</Suspense>
</div> </div>
{selectedProject && <ProjectStructure project={selectedProject} />}
</ResizablePanel> </ResizablePanel>
<ResizableHandle className="bg-gray-300" /> <ResizableHandle className="bg-gray-300" />
</> </>

View File

@ -1,33 +1,36 @@
import { GripVertical } from "lucide-react" import { GripVertical } from 'lucide-react';
import * as ResizablePrimitive from "react-resizable-panels" import * as ResizablePrimitive from 'react-resizable-panels';
import { cn } from "@/lib/utils" import { cn } from '@/lib/utils';
const ResizablePanelGroup = ({ type PanelGroupProps = React.ComponentProps<typeof ResizablePrimitive.PanelGroup>;
className,
...props const ResizablePanelGroup = ({ className, ...props }: PanelGroupProps) => (
}: React.ComponentProps<typeof ResizablePrimitive.PanelGroup>) => (
<ResizablePrimitive.PanelGroup <ResizablePrimitive.PanelGroup
className={cn( className={cn('flex h-full w-full data-[panel-group-direction=vertical]:flex-col', className)}
"flex h-full w-full data-[panel-group-direction=vertical]:flex-col",
className
)}
{...props} {...props}
/> />
) );
const ResizablePanel = ResizablePrimitive.Panel type PanelProps = React.ComponentProps<typeof ResizablePrimitive.Panel>;
const ResizablePanel = ({ className, ...props }: PanelProps) => (
<ResizablePrimitive.Panel
className={cn('resizable-panel', className)}
{...props}
/>
);
const ResizableHandle = ({ const ResizableHandle = ({
withHandle, withHandle,
className, className,
...props ...props
}: React.ComponentProps<typeof ResizablePrimitive.PanelResizeHandle> & { }: React.ComponentProps<typeof ResizablePrimitive.PanelResizeHandle> & {
withHandle?: boolean withHandle?: boolean;
}) => ( }) => (
<ResizablePrimitive.PanelResizeHandle <ResizablePrimitive.PanelResizeHandle
className={cn( className={cn(
"relative flex w-px items-center justify-center bg-gray-200 after:absolute after:inset-y-0 after:left-1/2 after:w-1 after:-translate-x-1/2 focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-gray-950 focus-visible:ring-offset-1 data-[panel-group-direction=vertical]:h-px data-[panel-group-direction=vertical]:w-full data-[panel-group-direction=vertical]:after:left-0 data-[panel-group-direction=vertical]:after:h-1 data-[panel-group-direction=vertical]:after:w-full data-[panel-group-direction=vertical]:after:-translate-y-1/2 data-[panel-group-direction=vertical]:after:translate-x-0 [&[data-panel-group-direction=vertical]>div]:rotate-90 dark:bg-gray-800 dark:focus-visible:ring-gray-300", 'relative flex w-px items-center justify-center bg-gray-200 after:absolute after:inset-y-0 after:left-1/2 after:w-1 after:-translate-x-1/2 focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-gray-950 focus-visible:ring-offset-1 data-[panel-group-direction=vertical]:h-px data-[panel-group-direction=vertical]:w-full data-[panel-group-direction=vertical]:after:left-0 data-[panel-group-direction=vertical]:after:h-1 data-[panel-group-direction=vertical]:after:w-full data-[panel-group-direction=vertical]:after:-translate-y-1/2 data-[panel-group-direction=vertical]:after:translate-x-0 dark:bg-gray-800 dark:focus-visible:ring-gray-300 [&[data-panel-group-direction=vertical]>div]:rotate-90',
className className
)} )}
{...props} {...props}
@ -38,6 +41,6 @@ const ResizableHandle = ({
</div> </div>
)} )}
</ResizablePrimitive.PanelResizeHandle> </ResizablePrimitive.PanelResizeHandle>
) );
export { ResizablePanelGroup, ResizablePanel, ResizableHandle } export { ResizablePanelGroup, ResizablePanel, ResizableHandle };

View File

@ -1,7 +1,8 @@
import { useEffect } from 'react'; import { useEffect, useRef } from 'react';
import { useQuery } from '@tanstack/react-query'; import { useQuery } from '@tanstack/react-query';
import { getTrainingModelReport } from '@/api/reportApi'; import { getTrainingModelReport } from '@/api/reportApi';
import { ReportResponse } from '@/types'; import { getProjectModels } from '@/api/modelApi';
import { ReportResponse, ProjectModelsResponse } from '@/types';
interface UsePollingTrainingModelReportProps { interface UsePollingTrainingModelReportProps {
projectId: number; projectId: number;
@ -16,7 +17,7 @@ export default function usePollingTrainingModelReport({
enabled, enabled,
onTrainingEnd, onTrainingEnd,
}: UsePollingTrainingModelReportProps) { }: UsePollingTrainingModelReportProps) {
const query = useQuery<ReportResponse[]>({ const reportQuery = useQuery<ReportResponse[]>({
queryKey: ['modelReports', projectId, modelId], queryKey: ['modelReports', projectId, modelId],
queryFn: () => getTrainingModelReport(projectId, modelId), queryFn: () => getTrainingModelReport(projectId, modelId),
enabled, enabled,
@ -30,14 +31,30 @@ export default function usePollingTrainingModelReport({
}, },
}); });
const modelQuery = useQuery<ProjectModelsResponse>({
queryKey: ['projectModels', projectId],
queryFn: () => getProjectModels(projectId),
enabled,
refetchInterval: 2000,
});
const prevIsTrainRef = useRef<boolean | null>(null);
useEffect(() => { useEffect(() => {
if (query.data && query.data.length > 0) { if (modelQuery.data) {
const lastReport = query.data[query.data.length - 1]; const model = modelQuery.data.find((m) => m.id === modelId);
if (lastReport.epoch >= lastReport.totalEpochs) { if (model) {
onTrainingEnd(); const currentIsTrain = model.isTrain;
const prevIsTrain = prevIsTrainRef.current;
if (prevIsTrain === true && currentIsTrain === false) {
onTrainingEnd();
}
prevIsTrainRef.current = currentIsTrain;
} }
} }
}, [query.data, onTrainingEnd]); }, [modelQuery.data, modelId, onTrainingEnd]);
return query; return { reportData: reportQuery.data, modelData: modelQuery.data };
} }