Feat: 학습 중단 등 로직 추가
This commit is contained in:
parent
c40cd0741d
commit
fbf4c7a6a5
@ -1,54 +1,74 @@
|
||||
'use client';
|
||||
|
||||
import { TrendingUp } from 'lucide-react';
|
||||
import { CartesianGrid, Line, LineChart, XAxis } from 'recharts';
|
||||
|
||||
import { Card, CardContent, CardDescription, CardFooter, CardHeader, CardTitle } from '@/components/ui/card';
|
||||
import { ChartConfig, ChartContainer, ChartTooltip, ChartTooltipContent } from '@/components/ui/chart';
|
||||
import { CartesianGrid, Line, LineChart, XAxis, YAxis, Tooltip, Legend } from 'recharts';
|
||||
import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card';
|
||||
import { ChartConfig, ChartContainer } from '@/components/ui/chart';
|
||||
|
||||
interface MetricData {
|
||||
epoch: string;
|
||||
loss1: number;
|
||||
loss2: number;
|
||||
loss3: number;
|
||||
fitness: number;
|
||||
loss1?: number;
|
||||
loss2?: number;
|
||||
loss3?: number;
|
||||
fitness?: number;
|
||||
}
|
||||
|
||||
interface ModelLineChartProps {
|
||||
data: MetricData[];
|
||||
currentEpoch?: number;
|
||||
totalEpochs?: number;
|
||||
remainingTime?: number;
|
||||
}
|
||||
|
||||
const chartConfig = {
|
||||
loss1: {
|
||||
label: 'Loss 1',
|
||||
color: '#FF6347', // 토마토색
|
||||
color: '#FF6347',
|
||||
},
|
||||
loss2: {
|
||||
label: 'Loss 2',
|
||||
color: '#1E90FF', // 다저블루색
|
||||
color: '#1E90FF',
|
||||
},
|
||||
loss3: {
|
||||
label: 'Loss 3',
|
||||
color: '#32CD32', // 라임색
|
||||
color: '#32CD32',
|
||||
},
|
||||
fitness: {
|
||||
label: 'Fitness',
|
||||
color: '#FFD700', // 골드색
|
||||
color: '#FFD700',
|
||||
},
|
||||
} satisfies ChartConfig;
|
||||
|
||||
export default function ModelLineChart({ data }: ModelLineChartProps) {
|
||||
export default function ModelLineChart({ data, currentEpoch, totalEpochs, remainingTime }: ModelLineChartProps) {
|
||||
const emptyData = Array.from({ length: totalEpochs || 0 }, (_, i) => ({
|
||||
epoch: (i + 1).toString(),
|
||||
loss1: null,
|
||||
loss2: null,
|
||||
loss3: null,
|
||||
fitness: null,
|
||||
}));
|
||||
|
||||
const filledData = emptyData.map((item, index) => ({
|
||||
...item,
|
||||
...(data[index] || {}),
|
||||
}));
|
||||
|
||||
return (
|
||||
<Card>
|
||||
<CardHeader>
|
||||
<CardTitle>Model Training Metrics</CardTitle>
|
||||
<CardDescription>Loss and Fitness over Epochs</CardDescription>
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
{currentEpoch !== undefined && totalEpochs !== undefined && remainingTime !== undefined && (
|
||||
<div className="mb-4 flex justify-between">
|
||||
<p>현재 에포크: {currentEpoch}</p>
|
||||
<p>총 에포크: {totalEpochs}</p>
|
||||
<p>예상 남은시간: {remainingTime}</p>
|
||||
</div>
|
||||
)}
|
||||
<ChartContainer config={chartConfig}>
|
||||
<LineChart
|
||||
accessibilityLayer
|
||||
data={data}
|
||||
data={filledData}
|
||||
margin={{
|
||||
left: 12,
|
||||
right: 12,
|
||||
@ -62,10 +82,9 @@ export default function ModelLineChart({ data }: ModelLineChartProps) {
|
||||
tickMargin={8}
|
||||
tickFormatter={(value) => `Epoch ${value}`}
|
||||
/>
|
||||
<ChartTooltip
|
||||
cursor={false}
|
||||
content={<ChartTooltipContent />}
|
||||
/>
|
||||
<YAxis />
|
||||
<Tooltip />
|
||||
<Legend />
|
||||
<Line
|
||||
dataKey="loss1"
|
||||
type="monotone"
|
||||
@ -97,18 +116,6 @@ export default function ModelLineChart({ data }: ModelLineChartProps) {
|
||||
</LineChart>
|
||||
</ChartContainer>
|
||||
</CardContent>
|
||||
<CardFooter>
|
||||
<div className="flex w-full items-start gap-2 text-sm">
|
||||
<div className="grid gap-2">
|
||||
<div className="flex items-center gap-2 font-medium leading-none">
|
||||
Trending up by 5.2% this epoch <TrendingUp className="h-4 w-4" />
|
||||
</div>
|
||||
<div className="text-muted-foreground flex items-center gap-2 leading-none">
|
||||
Showing training loss and fitness for the current model
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</CardFooter>
|
||||
</Card>
|
||||
);
|
||||
}
|
||||
|
@ -1,5 +1,7 @@
|
||||
import { useEffect, useMemo } from 'react';
|
||||
import ModelLineChart from './ModelLineChart';
|
||||
import usePollingModelReportsQuery from '@/queries/models/usePollingModelReportsQuery';
|
||||
import useModelStore from '@/stores/useModelStore';
|
||||
|
||||
interface TrainingGraphProps {
|
||||
projectId: number | null;
|
||||
@ -7,7 +9,36 @@ interface TrainingGraphProps {
|
||||
}
|
||||
|
||||
export default function TrainingGraph({ projectId, selectedModel }: TrainingGraphProps) {
|
||||
const { data: trainingDataList } = usePollingModelReportsQuery(projectId as number, selectedModel ?? 0);
|
||||
const { isTrainingByProject, setIsTraining, resetTrainingData } = useModelStore((state) => ({
|
||||
isTrainingByProject: state.isTrainingByProject,
|
||||
setIsTraining: state.setIsTraining,
|
||||
resetTrainingData: state.resetTrainingData,
|
||||
}));
|
||||
|
||||
const isTraining = isTrainingByProject[projectId?.toString() || ''] || false;
|
||||
|
||||
const { data: trainingDataList } = usePollingModelReportsQuery(
|
||||
projectId as number,
|
||||
selectedModel ?? 0,
|
||||
isTraining && !!projectId && !!selectedModel
|
||||
);
|
||||
|
||||
const latestData = useMemo(() => {
|
||||
return (
|
||||
trainingDataList?.[trainingDataList.length - 1] || {
|
||||
epoch: 0,
|
||||
totalEpochs: 0,
|
||||
leftSecond: 0,
|
||||
}
|
||||
);
|
||||
}, [trainingDataList]);
|
||||
|
||||
useEffect(() => {
|
||||
if (latestData.epoch === latestData.totalEpochs && latestData.totalEpochs > 0) {
|
||||
setIsTraining(projectId?.toString() || '', false);
|
||||
resetTrainingData(projectId?.toString() || '');
|
||||
}
|
||||
}, [latestData.epoch, latestData.totalEpochs, setIsTraining, resetTrainingData, projectId]);
|
||||
|
||||
return (
|
||||
<ModelLineChart
|
||||
@ -20,6 +51,9 @@ export default function TrainingGraph({ projectId, selectedModel }: TrainingGrap
|
||||
fitness: data.fitness,
|
||||
})) || []
|
||||
}
|
||||
currentEpoch={latestData.epoch}
|
||||
totalEpochs={latestData.totalEpochs}
|
||||
remainingTime={latestData.leftSecond}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
@ -10,6 +10,7 @@ interface TrainingSettingsProps {
|
||||
selectedModel: number | null;
|
||||
setSelectedModel: (model: number | null) => void;
|
||||
handleTrainingStart: (trainData: ModelTrainRequest) => void;
|
||||
handleTrainingStop: () => void;
|
||||
isTraining: boolean;
|
||||
}
|
||||
|
||||
@ -18,6 +19,7 @@ export default function TrainingSettings({
|
||||
selectedModel,
|
||||
setSelectedModel,
|
||||
handleTrainingStart,
|
||||
handleTrainingStop,
|
||||
isTraining,
|
||||
}: TrainingSettingsProps) {
|
||||
const { data: models } = useProjectModelsQuery(projectId ?? 0);
|
||||
@ -30,7 +32,9 @@ export default function TrainingSettings({
|
||||
const [lrf, setLrf] = useState<number>(0.001);
|
||||
|
||||
const handleSubmit = () => {
|
||||
if (selectedModel !== null) {
|
||||
if (isTraining) {
|
||||
handleTrainingStop();
|
||||
} else if (selectedModel !== null) {
|
||||
const trainData: ModelTrainRequest = {
|
||||
modelId: selectedModel,
|
||||
ratio,
|
||||
@ -127,7 +131,7 @@ export default function TrainingSettings({
|
||||
onClick={handleSubmit}
|
||||
disabled={!selectedModel || isTraining}
|
||||
>
|
||||
학습 시작
|
||||
{isTraining ? '학습 중단' : '학습 시작'}
|
||||
</Button>
|
||||
</fieldset>
|
||||
);
|
||||
|
@ -10,13 +10,13 @@ interface TrainingTabProps {
|
||||
|
||||
export default function TrainingTab({ projectId }: TrainingTabProps) {
|
||||
const numericProjectId = projectId ? parseInt(projectId.toString(), 10) : null;
|
||||
const { isTrainingByProject, setIsTraining, selectedModelByProject, setSelectedModel, trainingDataByProject } =
|
||||
const { isTrainingByProject, setIsTraining, selectedModelByProject, setSelectedModel, resetTrainingData } =
|
||||
useModelStore((state) => ({
|
||||
isTrainingByProject: state.isTrainingByProject,
|
||||
setIsTraining: state.setIsTraining,
|
||||
selectedModelByProject: state.selectedModelByProject,
|
||||
setSelectedModel: state.setSelectedModel,
|
||||
trainingDataByProject: state.trainingDataByProject,
|
||||
resetTrainingData: state.resetTrainingData,
|
||||
}));
|
||||
|
||||
const isTraining = isTrainingByProject[numericProjectId?.toString() || ''] || false;
|
||||
@ -31,7 +31,12 @@ export default function TrainingTab({ projectId }: TrainingTabProps) {
|
||||
}
|
||||
};
|
||||
|
||||
const trainingData = trainingDataByProject[numericProjectId?.toString() || ''];
|
||||
const handleTrainingStop = () => {
|
||||
if (isTraining) {
|
||||
setIsTraining(numericProjectId?.toString() || '', false);
|
||||
resetTrainingData(numericProjectId?.toString() || '');
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="grid gap-8 md:grid-cols-2">
|
||||
@ -40,6 +45,7 @@ export default function TrainingTab({ projectId }: TrainingTabProps) {
|
||||
selectedModel={selectedModel}
|
||||
setSelectedModel={(modelId) => setSelectedModel(numericProjectId?.toString() || '', modelId)}
|
||||
handleTrainingStart={handleTrainingStart}
|
||||
handleTrainingStop={handleTrainingStop}
|
||||
isTraining={isTraining}
|
||||
/>
|
||||
|
||||
@ -47,14 +53,6 @@ export default function TrainingTab({ projectId }: TrainingTabProps) {
|
||||
projectId={numericProjectId}
|
||||
selectedModel={selectedModel}
|
||||
/>
|
||||
|
||||
{trainingData && (
|
||||
<div className="mt-4">
|
||||
<p>현재 에포크: {trainingData[trainingData.length - 1]?.epoch}</p>
|
||||
<p>총 에포크: {trainingData[trainingData.length - 1]?.totalEpochs}</p>
|
||||
<p>예상 남은시간: {trainingData[trainingData.length - 1]?.leftSecond}</p>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
@ -2,11 +2,11 @@ import { useQuery } from '@tanstack/react-query';
|
||||
import { getModelReports } from '@/api/modelApi';
|
||||
import { ReportResponse } from '@/types';
|
||||
|
||||
export default function usePollingModelReportsQuery(projectId: number, modelId: number) {
|
||||
export default function usePollingModelReportsQuery(projectId: number, modelId: number, enabled: boolean) {
|
||||
return useQuery<ReportResponse[]>({
|
||||
queryKey: ['pollingModelReports', projectId, modelId],
|
||||
queryFn: () => getModelReports(projectId, modelId),
|
||||
refetchInterval: 5000,
|
||||
enabled: !!projectId && !!modelId,
|
||||
enabled,
|
||||
});
|
||||
}
|
||||
|
@ -8,6 +8,7 @@ interface ModelStoreState {
|
||||
setIsTraining: (projectId: string, status: boolean) => void;
|
||||
saveTrainingData: (projectId: string, data: ReportResponse[]) => void;
|
||||
setSelectedModel: (projectId: string, modelId: number | null) => void;
|
||||
resetTrainingData: (projectId: string) => void;
|
||||
}
|
||||
|
||||
const useModelStore = create<ModelStoreState>((set) => ({
|
||||
@ -35,6 +36,21 @@ const useModelStore = create<ModelStoreState>((set) => ({
|
||||
[projectId]: modelId,
|
||||
},
|
||||
})),
|
||||
resetTrainingData: (projectId) =>
|
||||
set((state) => ({
|
||||
trainingDataByProject: {
|
||||
...state.trainingDataByProject,
|
||||
[projectId]: [],
|
||||
},
|
||||
selectedModelByProject: {
|
||||
...state.selectedModelByProject,
|
||||
[projectId]: null,
|
||||
},
|
||||
isTrainingByProject: {
|
||||
...state.isTrainingByProject,
|
||||
[projectId]: false,
|
||||
},
|
||||
})),
|
||||
}));
|
||||
|
||||
export default useModelStore;
|
||||
|
Loading…
Reference in New Issue
Block a user