Refactor: 학습 모델 리팩토링
This commit is contained in:
parent
5081f342d1
commit
eb24f919b1
@ -1,44 +1,29 @@
|
||||
import { useEffect } from 'react';
|
||||
import ModelLineChart from './ModelLineChart';
|
||||
import usePollingTrainingModelReport from '@/hooks/usePollingTrainingModelReport';
|
||||
import { useQueryClient } from '@tanstack/react-query';
|
||||
import { ModelResponse } from '@/types';
|
||||
|
||||
interface TrainingGraphProps {
|
||||
projectId: number | null;
|
||||
selectedModel: ModelResponse | null;
|
||||
isTraining: boolean;
|
||||
onTrainingEnd: () => void;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export default function TrainingGraph({ projectId, selectedModel, className }: TrainingGraphProps) {
|
||||
const queryClient = useQueryClient();
|
||||
|
||||
const isTrainingEnabled = Boolean(selectedModel?.isTrain);
|
||||
|
||||
const handleTrainingEnd = () => {
|
||||
queryClient.resetQueries({
|
||||
queryKey: ['modelReports', projectId, selectedModel?.id],
|
||||
exact: true,
|
||||
});
|
||||
alert('학습이 완료되었습니다.');
|
||||
};
|
||||
|
||||
const { data: trainingDataList } = usePollingTrainingModelReport({
|
||||
export default function TrainingGraph({
|
||||
projectId,
|
||||
selectedModel,
|
||||
isTraining,
|
||||
onTrainingEnd,
|
||||
className,
|
||||
}: TrainingGraphProps) {
|
||||
const { reportData: trainingDataList } = usePollingTrainingModelReport({
|
||||
projectId: projectId as number,
|
||||
modelId: selectedModel?.id as number,
|
||||
enabled: isTrainingEnabled,
|
||||
onTrainingEnd: handleTrainingEnd,
|
||||
enabled: isTraining,
|
||||
onTrainingEnd,
|
||||
});
|
||||
|
||||
useEffect(() => {
|
||||
if (!selectedModel || !selectedModel.isTrain) {
|
||||
queryClient.resetQueries({
|
||||
queryKey: ['modelReports', projectId, selectedModel?.id],
|
||||
exact: true,
|
||||
});
|
||||
}
|
||||
}, [selectedModel, queryClient, projectId]);
|
||||
|
||||
return (
|
||||
<ModelLineChart
|
||||
data={trainingDataList || []}
|
||||
|
@ -12,7 +12,8 @@ interface TrainingSettingsProps {
|
||||
setSelectedModel: (model: ModelResponse | null) => void;
|
||||
handleTrainingStart: (trainData: ModelTrainRequest) => void;
|
||||
handleTrainingStop: () => void;
|
||||
isPolling: boolean;
|
||||
isWaiting: boolean;
|
||||
isTraining: boolean;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
@ -22,7 +23,8 @@ export default function TrainingSettings({
|
||||
setSelectedModel,
|
||||
handleTrainingStart,
|
||||
handleTrainingStop,
|
||||
isPolling,
|
||||
isWaiting,
|
||||
isTraining,
|
||||
className,
|
||||
}: TrainingSettingsProps) {
|
||||
const { data: models } = useProjectModelsQuery(projectId ?? 0);
|
||||
@ -48,9 +50,6 @@ export default function TrainingSettings({
|
||||
}
|
||||
};
|
||||
|
||||
const isTraining = selectedModel?.isTrain;
|
||||
const isWaiting = isPolling && !isTraining;
|
||||
|
||||
return (
|
||||
<fieldset className={cn('grid gap-6 rounded-lg border p-4', className)}>
|
||||
<legend className="-ml-1 px-1 text-sm font-medium">모델 설정</legend>
|
||||
@ -73,7 +72,8 @@ export default function TrainingSettings({
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
{!isPolling && !isTraining && (
|
||||
|
||||
{!isWaiting && !isTraining && (
|
||||
<>
|
||||
<div className="grid grid-cols-2 gap-4">
|
||||
<InputWithLabel
|
||||
|
@ -1,10 +1,11 @@
|
||||
import { useState, useEffect } from 'react';
|
||||
import { useState, useEffect, useRef } from 'react';
|
||||
import TrainingSettings from './TrainingSettings';
|
||||
import TrainingGraph from './TrainingGraph';
|
||||
import useTrainModelQuery from '@/queries/models/useTrainModelQuery';
|
||||
import usePollingTrainingModelReport from '@/hooks/usePollingTrainingModelReport';
|
||||
import useProjectModelsQuery from '@/queries/models/useProjectModelsQuery';
|
||||
import { ModelTrainRequest, ModelResponse } from '@/types';
|
||||
import { useQueryClient } from '@tanstack/react-query';
|
||||
import Swal from 'sweetalert2';
|
||||
|
||||
interface TrainingTabProps {
|
||||
projectId: number | null;
|
||||
@ -13,63 +14,103 @@ interface TrainingTabProps {
|
||||
export default function TrainingTab({ projectId }: TrainingTabProps) {
|
||||
const numericProjectId = projectId !== null ? Number(projectId) : null;
|
||||
const [selectedModel, setSelectedModel] = useState<ModelResponse | null>(null);
|
||||
const [isPolling, setIsPolling] = useState(false);
|
||||
const [isWaiting, setIsWaiting] = useState<{ [modelId: number]: boolean }>({});
|
||||
const [isTraining, setIsTraining] = useState<{ [modelId: number]: boolean }>({});
|
||||
const queryClient = useQueryClient();
|
||||
const prevModelRef = useRef<ModelResponse | null>(null);
|
||||
|
||||
const { mutate: startTraining } = useTrainModelQuery(numericProjectId as number);
|
||||
|
||||
const handleTrainingStart = (trainData: ModelTrainRequest) => {
|
||||
if (numericProjectId !== null) {
|
||||
startTraining(trainData);
|
||||
setIsPolling(true);
|
||||
}
|
||||
};
|
||||
|
||||
const handleTrainingEnd = () => {
|
||||
setIsPolling(false);
|
||||
setSelectedModel((prevModel) => (prevModel ? { ...prevModel, isTrain: false } : null));
|
||||
};
|
||||
const { data: models } = useProjectModelsQuery(numericProjectId ?? 0);
|
||||
|
||||
useEffect(() => {
|
||||
if (!selectedModel || !numericProjectId || !isPolling) return;
|
||||
if (models) {
|
||||
const trainingModels = models.filter((model) => model.isTrain);
|
||||
const newIsTraining = trainingModels.reduce(
|
||||
(acc, model) => {
|
||||
acc[model.id] = true;
|
||||
return acc;
|
||||
},
|
||||
{} as { [modelId: number]: boolean }
|
||||
);
|
||||
setIsTraining(newIsTraining);
|
||||
|
||||
const intervalId = setInterval(async () => {
|
||||
await queryClient.invalidateQueries({ queryKey: ['projectModels', numericProjectId] });
|
||||
|
||||
const models = await queryClient.getQueryData<ModelResponse[]>(['projectModels', numericProjectId]);
|
||||
|
||||
const updatedModel = models?.find((model) => model.id === selectedModel.id);
|
||||
if (selectedModel && trainingModels.some((model) => model.id === selectedModel.id)) {
|
||||
setSelectedModel(selectedModel);
|
||||
} else {
|
||||
setSelectedModel(null);
|
||||
}
|
||||
}
|
||||
}, [models, selectedModel]);
|
||||
|
||||
useEffect(() => {
|
||||
if (models && selectedModel) {
|
||||
const updatedModel = models.find((model) => model.id === selectedModel.id);
|
||||
if (updatedModel) {
|
||||
setSelectedModel(updatedModel);
|
||||
|
||||
if (updatedModel.isTrain) {
|
||||
setIsPolling(false);
|
||||
} else {
|
||||
setIsPolling(false);
|
||||
setSelectedModel({ ...updatedModel, isTrain: false });
|
||||
if (isWaiting[selectedModel.id] && updatedModel.isTrain) {
|
||||
setIsWaiting((prev) => ({ ...prev, [selectedModel.id]: false }));
|
||||
setIsTraining((prev) => ({ ...prev, [selectedModel.id]: true }));
|
||||
}
|
||||
}
|
||||
}
|
||||
}, [models, selectedModel, isWaiting]);
|
||||
|
||||
useEffect(() => {
|
||||
let intervalId: NodeJS.Timeout | null = null;
|
||||
|
||||
if (selectedModel && isWaiting[selectedModel.id]) {
|
||||
intervalId = setInterval(async () => {
|
||||
await queryClient.invalidateQueries({ queryKey: ['projectModels', numericProjectId] });
|
||||
}, 2000);
|
||||
}
|
||||
|
||||
return () => {
|
||||
if (intervalId) {
|
||||
clearInterval(intervalId);
|
||||
}
|
||||
};
|
||||
}, [selectedModel, numericProjectId, queryClient, isPolling]);
|
||||
}, [isWaiting, selectedModel, queryClient, numericProjectId]);
|
||||
|
||||
usePollingTrainingModelReport({
|
||||
projectId: numericProjectId as number,
|
||||
modelId: selectedModel?.id as number,
|
||||
enabled: selectedModel?.isTrain || false,
|
||||
onTrainingEnd: handleTrainingEnd,
|
||||
const handleTrainingStart = (trainData: ModelTrainRequest) => {
|
||||
if (numericProjectId !== null && selectedModel) {
|
||||
startTraining(trainData);
|
||||
setIsWaiting((prev) => ({ ...prev, [selectedModel.id]: true }));
|
||||
}
|
||||
};
|
||||
|
||||
const handleTrainingEnd = (modelId: number) => {
|
||||
if (prevModelRef.current && prevModelRef.current.id === modelId) {
|
||||
Swal.fire({
|
||||
title: '학습 완료',
|
||||
text: `모델 "${prevModelRef.current.name}"의 학습이 완료되었습니다.`,
|
||||
icon: 'success',
|
||||
confirmButtonText: '확인',
|
||||
});
|
||||
}
|
||||
|
||||
setIsTraining((prev) => ({ ...prev, [modelId]: false }));
|
||||
|
||||
if (selectedModel && selectedModel.id === modelId) {
|
||||
setSelectedModel(null);
|
||||
}
|
||||
};
|
||||
|
||||
const handleTrainingStop = () => {
|
||||
setIsPolling(false);
|
||||
setSelectedModel((prevModel) => (prevModel ? { ...prevModel, isTrain: false } : null));
|
||||
//todo: 중단 함수 연결
|
||||
if (selectedModel) {
|
||||
setIsWaiting((prev) => ({ ...prev, [selectedModel.id]: false }));
|
||||
setIsTraining((prev) => ({ ...prev, [selectedModel.id]: false }));
|
||||
setSelectedModel(null);
|
||||
// TODO: 학습 중단 기능 구현
|
||||
}
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
if (selectedModel) {
|
||||
prevModelRef.current = selectedModel;
|
||||
}
|
||||
}, [selectedModel]);
|
||||
|
||||
return (
|
||||
<div className="grid grid-rows-[auto_1fr] gap-8 md:grid-cols-2">
|
||||
<TrainingSettings
|
||||
@ -78,12 +119,15 @@ export default function TrainingTab({ projectId }: TrainingTabProps) {
|
||||
setSelectedModel={setSelectedModel}
|
||||
handleTrainingStart={handleTrainingStart}
|
||||
handleTrainingStop={handleTrainingStop}
|
||||
isPolling={isPolling}
|
||||
isWaiting={selectedModel ? isWaiting[selectedModel.id] || false : false}
|
||||
isTraining={selectedModel ? isTraining[selectedModel.id] || false : false}
|
||||
className="h-full"
|
||||
/>
|
||||
<TrainingGraph
|
||||
projectId={numericProjectId}
|
||||
selectedModel={selectedModel}
|
||||
isTraining={selectedModel ? isTraining[selectedModel.id] || false : false}
|
||||
onTrainingEnd={() => selectedModel && handleTrainingEnd(selectedModel.id)}
|
||||
className="h-full"
|
||||
/>
|
||||
</div>
|
||||
|
@ -5,37 +5,71 @@ import { Project } from '@/types';
|
||||
import { Select, SelectTrigger, SelectContent, SelectItem, SelectValue } from '../ui/select';
|
||||
import useCanvasStore from '@/stores/useCanvasStore';
|
||||
import { webPath } from '@/router';
|
||||
import { Suspense, useEffect } from 'react';
|
||||
import { useState } from 'react';
|
||||
import useUploadImageFileQuery from '@/queries/projects/useUploadImageFileQuery';
|
||||
import useAuthStore from '@/stores/useAuthStore';
|
||||
|
||||
export default function WorkspaceSidebar({ workspaceName, projects }: { workspaceName: string; projects: Project[] }) {
|
||||
const { setImage } = useCanvasStore();
|
||||
const { projectId: selectedProjectId, workspaceId } = useParams<{ projectId: string; workspaceId: string }>();
|
||||
const selectedProject = projects.find((project) => project.id.toString() === selectedProjectId) || null;
|
||||
const { projectId: selectedProjectId } = useParams<{ projectId: string }>();
|
||||
const selectedProject = projects.find((project) => project.id.toString() === selectedProjectId);
|
||||
const setSidebarSize = useCanvasStore((state) => state.setSidebarSize);
|
||||
const navigate = useNavigate();
|
||||
const { workspaceId } = useParams<{ workspaceId: string }>();
|
||||
const [isDragging, setIsDragging] = useState(false);
|
||||
const uploadImageFileMutation = useUploadImageFileQuery();
|
||||
const { profile } = useAuthStore();
|
||||
|
||||
const handleSelectProject = (projectId: string) => {
|
||||
setImage(null);
|
||||
navigate(`${webPath.workspace()}/${workspaceId}/${projectId}`);
|
||||
navigate(`${webPath.workspace()}/${workspaceId}/project/${projectId}`);
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
if (!selectedProject) {
|
||||
setImage(null);
|
||||
}
|
||||
}, [selectedProject, setImage]);
|
||||
const handleDragOver = (e: React.DragEvent) => {
|
||||
e.preventDefault();
|
||||
setIsDragging(true);
|
||||
};
|
||||
|
||||
const handleDragLeave = () => {
|
||||
setIsDragging(false);
|
||||
};
|
||||
|
||||
const handleDrop = (e: React.DragEvent) => {
|
||||
e.preventDefault();
|
||||
setIsDragging(false);
|
||||
|
||||
if (!selectedProjectId || !profile) return;
|
||||
|
||||
const files = Array.from(e.dataTransfer.files);
|
||||
const memberId = profile.id;
|
||||
const projectId = parseInt(selectedProjectId);
|
||||
const folderId = 0;
|
||||
|
||||
uploadImageFileMutation.mutate({
|
||||
memberId,
|
||||
projectId,
|
||||
folderId,
|
||||
files,
|
||||
progressCallback: (progress) => {
|
||||
console.log(`업로드 진행률: ${progress}%`);
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
return (
|
||||
<>
|
||||
<ResizablePanel
|
||||
minSize={10}
|
||||
maxSize={35}
|
||||
defaultSize={15}
|
||||
defaultSize={20}
|
||||
className={`flex h-full flex-col bg-gray-50 ${isDragging ? 'bg-blue-100' : ''}`}
|
||||
onResize={(size) => setSidebarSize(size)}
|
||||
onDragOver={(e) => handleDragOver(e as unknown as React.DragEvent<Element>)}
|
||||
onDragLeave={handleDragLeave}
|
||||
onDrop={(e) => handleDrop(e as unknown as React.DragEvent<Element>)}
|
||||
>
|
||||
<div className="box-border flex h-full flex-col gap-2 bg-gray-50 p-3">
|
||||
<header className="body flex w-full items-center gap-2">
|
||||
<h1 className="subheading w-full overflow-hidden text-ellipsis whitespace-nowrap">{workspaceName}</h1>
|
||||
<header className="body flex w-full items-center gap-2 p-2">
|
||||
<h1 className="w-full overflow-hidden text-ellipsis whitespace-nowrap">{workspaceName}</h1>
|
||||
</header>
|
||||
<div className="p-2">
|
||||
<Select
|
||||
onValueChange={handleSelectProject}
|
||||
defaultValue={selectedProjectId}
|
||||
@ -54,10 +88,8 @@ export default function WorkspaceSidebar({ workspaceName, projects }: { workspac
|
||||
))}
|
||||
</SelectContent>
|
||||
</Select>
|
||||
<Suspense fallback={<div></div>}>
|
||||
{selectedProject && <ProjectStructure project={selectedProject} />}
|
||||
</Suspense>
|
||||
</div>
|
||||
{selectedProject && <ProjectStructure project={selectedProject} />}
|
||||
</ResizablePanel>
|
||||
<ResizableHandle className="bg-gray-300" />
|
||||
</>
|
||||
|
@ -1,33 +1,36 @@
|
||||
import { GripVertical } from "lucide-react"
|
||||
import * as ResizablePrimitive from "react-resizable-panels"
|
||||
import { GripVertical } from 'lucide-react';
|
||||
import * as ResizablePrimitive from 'react-resizable-panels';
|
||||
|
||||
import { cn } from "@/lib/utils"
|
||||
import { cn } from '@/lib/utils';
|
||||
|
||||
const ResizablePanelGroup = ({
|
||||
className,
|
||||
...props
|
||||
}: React.ComponentProps<typeof ResizablePrimitive.PanelGroup>) => (
|
||||
type PanelGroupProps = React.ComponentProps<typeof ResizablePrimitive.PanelGroup>;
|
||||
|
||||
const ResizablePanelGroup = ({ className, ...props }: PanelGroupProps) => (
|
||||
<ResizablePrimitive.PanelGroup
|
||||
className={cn(
|
||||
"flex h-full w-full data-[panel-group-direction=vertical]:flex-col",
|
||||
className
|
||||
)}
|
||||
className={cn('flex h-full w-full data-[panel-group-direction=vertical]:flex-col', className)}
|
||||
{...props}
|
||||
/>
|
||||
)
|
||||
);
|
||||
|
||||
const ResizablePanel = ResizablePrimitive.Panel
|
||||
type PanelProps = React.ComponentProps<typeof ResizablePrimitive.Panel>;
|
||||
|
||||
const ResizablePanel = ({ className, ...props }: PanelProps) => (
|
||||
<ResizablePrimitive.Panel
|
||||
className={cn('resizable-panel', className)}
|
||||
{...props}
|
||||
/>
|
||||
);
|
||||
|
||||
const ResizableHandle = ({
|
||||
withHandle,
|
||||
className,
|
||||
...props
|
||||
}: React.ComponentProps<typeof ResizablePrimitive.PanelResizeHandle> & {
|
||||
withHandle?: boolean
|
||||
withHandle?: boolean;
|
||||
}) => (
|
||||
<ResizablePrimitive.PanelResizeHandle
|
||||
className={cn(
|
||||
"relative flex w-px items-center justify-center bg-gray-200 after:absolute after:inset-y-0 after:left-1/2 after:w-1 after:-translate-x-1/2 focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-gray-950 focus-visible:ring-offset-1 data-[panel-group-direction=vertical]:h-px data-[panel-group-direction=vertical]:w-full data-[panel-group-direction=vertical]:after:left-0 data-[panel-group-direction=vertical]:after:h-1 data-[panel-group-direction=vertical]:after:w-full data-[panel-group-direction=vertical]:after:-translate-y-1/2 data-[panel-group-direction=vertical]:after:translate-x-0 [&[data-panel-group-direction=vertical]>div]:rotate-90 dark:bg-gray-800 dark:focus-visible:ring-gray-300",
|
||||
'relative flex w-px items-center justify-center bg-gray-200 after:absolute after:inset-y-0 after:left-1/2 after:w-1 after:-translate-x-1/2 focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-gray-950 focus-visible:ring-offset-1 data-[panel-group-direction=vertical]:h-px data-[panel-group-direction=vertical]:w-full data-[panel-group-direction=vertical]:after:left-0 data-[panel-group-direction=vertical]:after:h-1 data-[panel-group-direction=vertical]:after:w-full data-[panel-group-direction=vertical]:after:-translate-y-1/2 data-[panel-group-direction=vertical]:after:translate-x-0 dark:bg-gray-800 dark:focus-visible:ring-gray-300 [&[data-panel-group-direction=vertical]>div]:rotate-90',
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
@ -38,6 +41,6 @@ const ResizableHandle = ({
|
||||
</div>
|
||||
)}
|
||||
</ResizablePrimitive.PanelResizeHandle>
|
||||
)
|
||||
);
|
||||
|
||||
export { ResizablePanelGroup, ResizablePanel, ResizableHandle }
|
||||
export { ResizablePanelGroup, ResizablePanel, ResizableHandle };
|
||||
|
@ -1,7 +1,8 @@
|
||||
import { useEffect } from 'react';
|
||||
import { useEffect, useRef } from 'react';
|
||||
import { useQuery } from '@tanstack/react-query';
|
||||
import { getTrainingModelReport } from '@/api/reportApi';
|
||||
import { ReportResponse } from '@/types';
|
||||
import { getProjectModels } from '@/api/modelApi';
|
||||
import { ReportResponse, ProjectModelsResponse } from '@/types';
|
||||
|
||||
interface UsePollingTrainingModelReportProps {
|
||||
projectId: number;
|
||||
@ -16,7 +17,7 @@ export default function usePollingTrainingModelReport({
|
||||
enabled,
|
||||
onTrainingEnd,
|
||||
}: UsePollingTrainingModelReportProps) {
|
||||
const query = useQuery<ReportResponse[]>({
|
||||
const reportQuery = useQuery<ReportResponse[]>({
|
||||
queryKey: ['modelReports', projectId, modelId],
|
||||
queryFn: () => getTrainingModelReport(projectId, modelId),
|
||||
enabled,
|
||||
@ -30,14 +31,30 @@ export default function usePollingTrainingModelReport({
|
||||
},
|
||||
});
|
||||
|
||||
const modelQuery = useQuery<ProjectModelsResponse>({
|
||||
queryKey: ['projectModels', projectId],
|
||||
queryFn: () => getProjectModels(projectId),
|
||||
enabled,
|
||||
refetchInterval: 2000,
|
||||
});
|
||||
|
||||
const prevIsTrainRef = useRef<boolean | null>(null);
|
||||
|
||||
useEffect(() => {
|
||||
if (query.data && query.data.length > 0) {
|
||||
const lastReport = query.data[query.data.length - 1];
|
||||
if (lastReport.epoch >= lastReport.totalEpochs) {
|
||||
if (modelQuery.data) {
|
||||
const model = modelQuery.data.find((m) => m.id === modelId);
|
||||
if (model) {
|
||||
const currentIsTrain = model.isTrain;
|
||||
const prevIsTrain = prevIsTrainRef.current;
|
||||
|
||||
if (prevIsTrain === true && currentIsTrain === false) {
|
||||
onTrainingEnd();
|
||||
}
|
||||
}
|
||||
}, [query.data, onTrainingEnd]);
|
||||
|
||||
return query;
|
||||
prevIsTrainRef.current = currentIsTrain;
|
||||
}
|
||||
}
|
||||
}, [modelQuery.data, modelId, onTrainingEnd]);
|
||||
|
||||
return { reportData: reportQuery.data, modelData: modelQuery.data };
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user