Feat: 학습 중단 등 로직 추가

This commit is contained in:
정현조 2024-09-25 07:31:33 +09:00
parent c40cd0741d
commit fbf4c7a6a5
6 changed files with 107 additions and 48 deletions

View File

@ -1,54 +1,74 @@
'use client';
import { TrendingUp } from 'lucide-react';
import { CartesianGrid, Line, LineChart, XAxis } from 'recharts';
import { Card, CardContent, CardDescription, CardFooter, CardHeader, CardTitle } from '@/components/ui/card';
import { ChartConfig, ChartContainer, ChartTooltip, ChartTooltipContent } from '@/components/ui/chart';
import { CartesianGrid, Line, LineChart, XAxis, YAxis, Tooltip, Legend } from 'recharts';
import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card';
import { ChartConfig, ChartContainer } from '@/components/ui/chart';
interface MetricData {
epoch: string;
loss1: number;
loss2: number;
loss3: number;
fitness: number;
loss1?: number;
loss2?: number;
loss3?: number;
fitness?: number;
}
interface ModelLineChartProps {
data: MetricData[];
currentEpoch?: number;
totalEpochs?: number;
remainingTime?: number;
}
const chartConfig = {
loss1: {
label: 'Loss 1',
color: '#FF6347', // 토마토색
color: '#FF6347',
},
loss2: {
label: 'Loss 2',
color: '#1E90FF', // 다저블루색
color: '#1E90FF',
},
loss3: {
label: 'Loss 3',
color: '#32CD32', // 라임색
color: '#32CD32',
},
fitness: {
label: 'Fitness',
color: '#FFD700', // 골드색
color: '#FFD700',
},
} satisfies ChartConfig;
export default function ModelLineChart({ data }: ModelLineChartProps) {
export default function ModelLineChart({ data, currentEpoch, totalEpochs, remainingTime }: ModelLineChartProps) {
const emptyData = Array.from({ length: totalEpochs || 0 }, (_, i) => ({
epoch: (i + 1).toString(),
loss1: null,
loss2: null,
loss3: null,
fitness: null,
}));
const filledData = emptyData.map((item, index) => ({
...item,
...(data[index] || {}),
}));
return (
<Card>
<CardHeader>
<CardTitle>Model Training Metrics</CardTitle>
<CardDescription>Loss and Fitness over Epochs</CardDescription>
</CardHeader>
<CardContent>
{currentEpoch !== undefined && totalEpochs !== undefined && remainingTime !== undefined && (
<div className="mb-4 flex justify-between">
<p> : {currentEpoch}</p>
<p> : {totalEpochs}</p>
<p> : {remainingTime}</p>
</div>
)}
<ChartContainer config={chartConfig}>
<LineChart
accessibilityLayer
data={data}
data={filledData}
margin={{
left: 12,
right: 12,
@ -62,10 +82,9 @@ export default function ModelLineChart({ data }: ModelLineChartProps) {
tickMargin={8}
tickFormatter={(value) => `Epoch ${value}`}
/>
<ChartTooltip
cursor={false}
content={<ChartTooltipContent />}
/>
<YAxis />
<Tooltip />
<Legend />
<Line
dataKey="loss1"
type="monotone"
@ -97,18 +116,6 @@ export default function ModelLineChart({ data }: ModelLineChartProps) {
</LineChart>
</ChartContainer>
</CardContent>
<CardFooter>
<div className="flex w-full items-start gap-2 text-sm">
<div className="grid gap-2">
<div className="flex items-center gap-2 font-medium leading-none">
Trending up by 5.2% this epoch <TrendingUp className="h-4 w-4" />
</div>
<div className="text-muted-foreground flex items-center gap-2 leading-none">
Showing training loss and fitness for the current model
</div>
</div>
</div>
</CardFooter>
</Card>
);
}

View File

@ -1,5 +1,7 @@
import { useEffect, useMemo } from 'react';
import ModelLineChart from './ModelLineChart';
import usePollingModelReportsQuery from '@/queries/models/usePollingModelReportsQuery';
import useModelStore from '@/stores/useModelStore';
interface TrainingGraphProps {
projectId: number | null;
@ -7,7 +9,36 @@ interface TrainingGraphProps {
}
export default function TrainingGraph({ projectId, selectedModel }: TrainingGraphProps) {
const { data: trainingDataList } = usePollingModelReportsQuery(projectId as number, selectedModel ?? 0);
const { isTrainingByProject, setIsTraining, resetTrainingData } = useModelStore((state) => ({
isTrainingByProject: state.isTrainingByProject,
setIsTraining: state.setIsTraining,
resetTrainingData: state.resetTrainingData,
}));
const isTraining = isTrainingByProject[projectId?.toString() || ''] || false;
const { data: trainingDataList } = usePollingModelReportsQuery(
projectId as number,
selectedModel ?? 0,
isTraining && !!projectId && !!selectedModel
);
const latestData = useMemo(() => {
return (
trainingDataList?.[trainingDataList.length - 1] || {
epoch: 0,
totalEpochs: 0,
leftSecond: 0,
}
);
}, [trainingDataList]);
useEffect(() => {
if (latestData.epoch === latestData.totalEpochs && latestData.totalEpochs > 0) {
setIsTraining(projectId?.toString() || '', false);
resetTrainingData(projectId?.toString() || '');
}
}, [latestData.epoch, latestData.totalEpochs, setIsTraining, resetTrainingData, projectId]);
return (
<ModelLineChart
@ -20,6 +51,9 @@ export default function TrainingGraph({ projectId, selectedModel }: TrainingGrap
fitness: data.fitness,
})) || []
}
currentEpoch={latestData.epoch}
totalEpochs={latestData.totalEpochs}
remainingTime={latestData.leftSecond}
/>
);
}

View File

@ -10,6 +10,7 @@ interface TrainingSettingsProps {
selectedModel: number | null;
setSelectedModel: (model: number | null) => void;
handleTrainingStart: (trainData: ModelTrainRequest) => void;
handleTrainingStop: () => void;
isTraining: boolean;
}
@ -18,6 +19,7 @@ export default function TrainingSettings({
selectedModel,
setSelectedModel,
handleTrainingStart,
handleTrainingStop,
isTraining,
}: TrainingSettingsProps) {
const { data: models } = useProjectModelsQuery(projectId ?? 0);
@ -30,7 +32,9 @@ export default function TrainingSettings({
const [lrf, setLrf] = useState<number>(0.001);
const handleSubmit = () => {
if (selectedModel !== null) {
if (isTraining) {
handleTrainingStop();
} else if (selectedModel !== null) {
const trainData: ModelTrainRequest = {
modelId: selectedModel,
ratio,
@ -127,7 +131,7 @@ export default function TrainingSettings({
onClick={handleSubmit}
disabled={!selectedModel || isTraining}
>
{isTraining ? '학습 중단' : '학습 시작'}
</Button>
</fieldset>
);

View File

@ -10,13 +10,13 @@ interface TrainingTabProps {
export default function TrainingTab({ projectId }: TrainingTabProps) {
const numericProjectId = projectId ? parseInt(projectId.toString(), 10) : null;
const { isTrainingByProject, setIsTraining, selectedModelByProject, setSelectedModel, trainingDataByProject } =
const { isTrainingByProject, setIsTraining, selectedModelByProject, setSelectedModel, resetTrainingData } =
useModelStore((state) => ({
isTrainingByProject: state.isTrainingByProject,
setIsTraining: state.setIsTraining,
selectedModelByProject: state.selectedModelByProject,
setSelectedModel: state.setSelectedModel,
trainingDataByProject: state.trainingDataByProject,
resetTrainingData: state.resetTrainingData,
}));
const isTraining = isTrainingByProject[numericProjectId?.toString() || ''] || false;
@ -31,7 +31,12 @@ export default function TrainingTab({ projectId }: TrainingTabProps) {
}
};
const trainingData = trainingDataByProject[numericProjectId?.toString() || ''];
const handleTrainingStop = () => {
if (isTraining) {
setIsTraining(numericProjectId?.toString() || '', false);
resetTrainingData(numericProjectId?.toString() || '');
}
};
return (
<div className="grid gap-8 md:grid-cols-2">
@ -40,6 +45,7 @@ export default function TrainingTab({ projectId }: TrainingTabProps) {
selectedModel={selectedModel}
setSelectedModel={(modelId) => setSelectedModel(numericProjectId?.toString() || '', modelId)}
handleTrainingStart={handleTrainingStart}
handleTrainingStop={handleTrainingStop}
isTraining={isTraining}
/>
@ -47,14 +53,6 @@ export default function TrainingTab({ projectId }: TrainingTabProps) {
projectId={numericProjectId}
selectedModel={selectedModel}
/>
{trainingData && (
<div className="mt-4">
<p> : {trainingData[trainingData.length - 1]?.epoch}</p>
<p> : {trainingData[trainingData.length - 1]?.totalEpochs}</p>
<p> : {trainingData[trainingData.length - 1]?.leftSecond}</p>
</div>
)}
</div>
);
}

View File

@ -2,11 +2,11 @@ import { useQuery } from '@tanstack/react-query';
import { getModelReports } from '@/api/modelApi';
import { ReportResponse } from '@/types';
export default function usePollingModelReportsQuery(projectId: number, modelId: number) {
export default function usePollingModelReportsQuery(projectId: number, modelId: number, enabled: boolean) {
return useQuery<ReportResponse[]>({
queryKey: ['pollingModelReports', projectId, modelId],
queryFn: () => getModelReports(projectId, modelId),
refetchInterval: 5000,
enabled: !!projectId && !!modelId,
enabled,
});
}

View File

@ -8,6 +8,7 @@ interface ModelStoreState {
setIsTraining: (projectId: string, status: boolean) => void;
saveTrainingData: (projectId: string, data: ReportResponse[]) => void;
setSelectedModel: (projectId: string, modelId: number | null) => void;
resetTrainingData: (projectId: string) => void;
}
const useModelStore = create<ModelStoreState>((set) => ({
@ -35,6 +36,21 @@ const useModelStore = create<ModelStoreState>((set) => ({
[projectId]: modelId,
},
})),
resetTrainingData: (projectId) =>
set((state) => ({
trainingDataByProject: {
...state.trainingDataByProject,
[projectId]: [],
},
selectedModelByProject: {
...state.selectedModelByProject,
[projectId]: null,
},
isTrainingByProject: {
...state.isTrainingByProject,
[projectId]: false,
},
})),
}));
export default useModelStore;