Refactor: mocker를 통해 연결한 것 기반 오류 수정

This commit is contained in:
정현조 2024-09-25 17:38:20 +09:00
parent 4e844dd367
commit dc78d000e7
8 changed files with 110 additions and 98 deletions

View File

@ -1,11 +1,11 @@
import { Label } from '@/components/ui/label';
import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from '@/components/ui/select';
import { useState } from 'react';
import useProjectModelsQuery from '@/queries/models/useProjectModelsQuery';
import useModelReportsQuery from '@/queries/models/useModelReportsQuery';
import useModelResultsQuery from '@/queries/models/useModelResultsQuery';
import ModelBarChart from './ModelBarChart';
import ModelLineChart from './ModelLineChart';
import { useState } from 'react';
interface EvaluationTabProps {
projectId: number | null;
@ -13,7 +13,6 @@ interface EvaluationTabProps {
export default function EvaluationTab({ projectId }: EvaluationTabProps) {
const [selectedModel, setSelectedModel] = useState<number | null>(null);
const { data: models } = useProjectModelsQuery(projectId ?? 0);
return (
@ -70,47 +69,62 @@ function ModelEvaluation({ projectId, selectedModel }: ModelEvaluationProps) {
const { data: reportData } = useModelReportsQuery(projectId, selectedModel);
const { data: resultData } = useModelResultsQuery(selectedModel);
if (!reportData || !resultData) {
return null;
}
if (!reportData || !resultData) return null;
const trainingInfoRow = (
<div className="flex justify-between rounded-lg bg-gray-100 p-4">
<div className="flex-1 text-center">
<strong>Epochs</strong>
<p>{resultData[0]?.epochs}</p>
</div>
<div className="flex-1 text-center">
<strong>Batch Size</strong>
<p>{resultData[0]?.batch}</p>
</div>
<div className="flex-1 text-center">
<strong>Learning Rate (Start)</strong>
<p>{resultData[0]?.lr0}</p>
</div>
<div className="flex-1 text-center">
<strong>Learning Rate (End)</strong>
<p>{resultData[0]?.lrf}</p>
</div>
<div className="flex-1 text-center">
<strong>Optimizer</strong>
<p>{resultData[0]?.optimizer}</p>
</div>
</div>
);
return (
<div className="grid gap-8 md:grid-cols-2">
<div className="flex flex-col gap-6">
<ModelBarChart
data={[
{ name: 'precision', value: resultData[0]?.precision, fill: 'var(--color-precision)' },
{ name: 'recall', value: resultData[0]?.recall, fill: 'var(--color-recall)' },
{ name: 'mAP50', value: resultData[0]?.map50, fill: 'var(--color-map50)' },
{ name: 'mAP50_95', value: resultData[0]?.map5095, fill: 'var(--color-map50-95)' },
{ name: 'fitness', value: resultData[0]?.fitness, fill: 'var(--color-fitness)' },
]}
/>
<div>
{trainingInfoRow} {/* 학습 정보 표시 */}
<div className="mt-4 grid h-[400px] gap-8 md:grid-cols-2">
{' '}
{/* grid와 높이 설정 */}
<div className="flex h-full flex-col gap-6">
{' '}
{/* 차트의 높이를 100%로 맞춤 */}
<ModelBarChart
data={[
{ name: 'precision', value: resultData[0]?.precision, fill: 'var(--color-precision)' },
{ name: 'recall', value: resultData[0]?.recall, fill: 'var(--color-recall)' },
{ name: 'mAP50', value: resultData[0]?.map50, fill: 'var(--color-map50)' },
{ name: 'mAP50_95', value: resultData[0]?.map5095, fill: 'var(--color-map50-95)' },
{ name: 'fitness', value: resultData[0]?.fitness, fill: 'var(--color-fitness)' },
]}
className="h-full"
/>
</div>
<div className="flex h-full flex-col gap-6">
{' '}
{/* 차트의 높이를 100%로 맞춤 */}
<ModelLineChart
data={reportData}
className="h-full"
/>
</div>
</div>
<div className="flex flex-col gap-6">
<ModelLineChart
data={reportData.map((report) => ({
epoch: report.epoch.toString(),
boxLoss: report.boxLoss,
classLoss: report.clsLoss,
dflLoss: report.dflLoss,
fitness: report.fitness,
}))}
/>
</div>
{/* <div className="flex flex-col justify-center">
<LabelingPreview />
</div> */}
</div>
);
}
// function LabelingPreview() {
// return (
// <div className="flex items-center justify-center rounded-lg border bg-white p-4">
// <p>레이블링 프리뷰</p>
// </div>
// );
// }

View File

@ -6,9 +6,10 @@ interface InputWithLabelProps {
placeholder: string;
value: number;
onChange: (e: React.ChangeEvent<HTMLInputElement>) => void;
disabled?: boolean;
}
export default function InputWithLabel({ label, id, placeholder, value, onChange }: InputWithLabelProps) {
export default function InputWithLabel({ label, id, placeholder, value, disabled, onChange }: InputWithLabelProps) {
return (
<div className="grid gap-3">
<Label htmlFor={id}>{label}</Label>
@ -18,6 +19,7 @@ export default function InputWithLabel({ label, id, placeholder, value, onChange
placeholder={placeholder}
value={value}
onChange={onChange}
disabled={disabled}
/>
</div>
);

View File

@ -1,9 +1,8 @@
'use client';
import { TrendingUp } from 'lucide-react';
import { Bar, BarChart, CartesianGrid, Rectangle, XAxis } from 'recharts';
import { Card, CardContent, CardDescription, CardFooter, CardHeader, CardTitle } from '@/components/ui/card';
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card';
import { ChartConfig, ChartContainer, ChartTooltip, ChartTooltipContent } from '@/components/ui/chart';
interface MetricData {
@ -14,10 +13,9 @@ interface MetricData {
interface ModelBarChartProps {
data: MetricData[];
className?: string;
}
export const description = 'A bar chart with an active bar';
const chartConfig = {
precision: {
label: 'Precision',
@ -41,9 +39,9 @@ const chartConfig = {
},
} satisfies ChartConfig;
export default function ModelBarChart({ data }: ModelBarChartProps) {
export default function ModelBarChart({ data, className }: ModelBarChartProps) {
return (
<Card>
<Card className={className}>
<CardHeader>
<CardTitle>Model Metrics</CardTitle>
<CardDescription>Performance metrics of the model</CardDescription>
@ -86,12 +84,6 @@ export default function ModelBarChart({ data }: ModelBarChartProps) {
</BarChart>
</ChartContainer>
</CardContent>
<CardFooter className="flex-col items-start gap-2 text-sm">
<div className="flex gap-2 font-medium leading-none">
Model metrics are trending well <TrendingUp className="h-4 w-4" />
</div>
<div className="text-muted-foreground leading-none">Showing current performance metrics</div>
</CardFooter>
</Card>
);
}

View File

@ -3,20 +3,11 @@
import { CartesianGrid, Line, LineChart, XAxis, YAxis, Tooltip, Legend } from 'recharts';
import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card';
import { ChartConfig, ChartContainer } from '@/components/ui/chart';
interface MetricData {
epoch: string;
boxLoss?: number;
classLoss?: number;
dflLoss?: number;
fitness?: number;
}
import { ReportResponse } from '@/types';
interface ModelLineChartProps {
data: MetricData[];
currentEpoch?: number;
totalEpochs?: number;
remainingTime?: number;
data: ReportResponse[];
className?: string;
}
const chartConfig = {
@ -38,8 +29,11 @@ const chartConfig = {
},
} satisfies ChartConfig;
export default function ModelLineChart({ data, currentEpoch, totalEpochs, remainingTime }: ModelLineChartProps) {
const emptyData = Array.from({ length: totalEpochs || 0 }, (_, i) => ({
export default function ModelLineChart({ data, className }: ModelLineChartProps) {
const latestData = data.length > 0 ? data[data.length - 1] : undefined;
const totalEpochs = latestData?.totalEpochs || 0;
const emptyData = Array.from({ length: totalEpochs }, (_, i) => ({
epoch: (i + 1).toString(),
boxLoss: null,
classLoss: null,
@ -53,16 +47,16 @@ export default function ModelLineChart({ data, currentEpoch, totalEpochs, remain
}));
return (
<Card>
<Card className={className}>
<CardHeader>
<CardTitle>Model Training Metrics</CardTitle>
</CardHeader>
<CardContent>
{currentEpoch !== undefined && totalEpochs !== undefined && remainingTime !== undefined && (
{latestData && latestData.totalEpochs !== Number(latestData.epoch) && (
<div className="mb-4 flex justify-between">
<p> : {currentEpoch}</p>
<p> : {totalEpochs}</p>
<p> : {remainingTime}</p>
<p> : {latestData.epoch}</p>
<p> : {latestData.totalEpochs}</p>
<p> : {latestData.leftSecond}</p>
</div>
)}
<ChartContainer config={chartConfig}>

View File

@ -12,16 +12,27 @@ interface SelectWithLabelProps {
options: SelectWithLabelOption[];
placeholder: string;
value: string;
disabled?: boolean;
onChange: (value: string) => void;
}
export default function SelectWithLabel({ label, id, options, placeholder, value, onChange }: SelectWithLabelProps) {
export default function SelectWithLabel({
label,
id,
options,
placeholder,
value,
disabled,
onChange,
}: SelectWithLabelProps) {
return (
<div className="grid gap-3">
<Label htmlFor={id}>{label}</Label>
<Select
value={value}
onValueChange={onChange}
disabled={disabled}
>
<SelectTrigger id={id}>
<SelectValue placeholder={placeholder} />

View File

@ -6,9 +6,10 @@ import useModelStore from '@/stores/useModelStore';
interface TrainingGraphProps {
projectId: number | null;
selectedModel: number | null;
className?: string;
}
export default function TrainingGraph({ projectId, selectedModel }: TrainingGraphProps) {
export default function TrainingGraph({ projectId, selectedModel, className }: TrainingGraphProps) {
const { isTrainingByProject, setIsTraining, saveTrainingData, resetTrainingData, trainingDataByProject } =
useModelStore((state) => ({
isTrainingByProject: state.isTrainingByProject,
@ -48,6 +49,7 @@ export default function TrainingGraph({ projectId, selectedModel }: TrainingGrap
useEffect(() => {
if (latestData.epoch === latestData.totalEpochs && latestData.totalEpochs > 0) {
alert('학습이 완료되었습니다!');
setIsTraining(projectId?.toString() || '', false);
resetTrainingData(projectId?.toString() || '');
}
@ -55,18 +57,8 @@ export default function TrainingGraph({ projectId, selectedModel }: TrainingGrap
return (
<ModelLineChart
data={
trainingDataList?.map((data) => ({
epoch: data.epoch.toString(),
boxLoss: data.boxLoss,
classLoss: data.clsLoss,
dflLoss: data.dflLoss,
fitness: data.fitness,
})) || []
}
currentEpoch={latestData.epoch}
totalEpochs={latestData.totalEpochs}
remainingTime={latestData.leftSecond}
data={trainingDataList}
className={className}
/>
);
}

View File

@ -2,8 +2,10 @@ import SelectWithLabel from './SelectWithLabel';
import InputWithLabel from './InputWithLabel';
import { Button } from '@/components/ui/button';
import useProjectModelsQuery from '@/queries/models/useProjectModelsQuery';
import useModelStore from '@/stores/useModelStore';
import { ModelTrainRequest } from '@/types';
import { useState } from 'react';
import { cn } from '@/lib/utils';
interface TrainingSettingsProps {
projectId: number | null;
@ -11,7 +13,7 @@ interface TrainingSettingsProps {
setSelectedModel: (model: number | null) => void;
handleTrainingStart: (trainData: ModelTrainRequest) => void;
handleTrainingStop: () => void;
isTraining: boolean;
className?: string;
}
export default function TrainingSettings({
@ -20,10 +22,12 @@ export default function TrainingSettings({
setSelectedModel,
handleTrainingStart,
handleTrainingStop,
isTraining,
className,
}: TrainingSettingsProps) {
const { data: models } = useProjectModelsQuery(projectId ?? 0);
const isTraining = useModelStore((state) => state.isTrainingByProject[projectId?.toString() || ''] || false);
const [ratio, setRatio] = useState<number>(0.8);
const [epochs, setEpochs] = useState<number>(50);
const [batchSize, setBatchSize] = useState<number>(32);
@ -49,12 +53,9 @@ export default function TrainingSettings({
};
return (
<fieldset
className="grid gap-6 rounded-lg border p-4"
disabled={isTraining}
>
<fieldset className={cn('grid gap-6 rounded-lg border p-4', className)}>
{' '}
<legend className="-ml-1 px-1 text-sm font-medium"> </legend>
<div className="grid gap-3">
<SelectWithLabel
label="모델 선택"
@ -68,9 +69,9 @@ export default function TrainingSettings({
placeholder="모델을 선택하세요"
value={selectedModel ? selectedModel.toString() : ''}
onChange={(value) => setSelectedModel(parseInt(value, 10))}
disabled={isTraining}
/>
</div>
<div className="grid grid-cols-2 gap-4">
<InputWithLabel
label="훈련/검증 비율"
@ -78,6 +79,7 @@ export default function TrainingSettings({
id="ratio"
value={ratio}
onChange={(e) => setRatio(parseFloat(e.target.value))}
disabled={isTraining}
/>
<InputWithLabel
label="에포크 수"
@ -85,6 +87,7 @@ export default function TrainingSettings({
id="epochs"
value={epochs}
onChange={(e) => setEpochs(parseInt(e.target.value, 10))}
disabled={isTraining}
/>
<InputWithLabel
label="Batch 크기"
@ -92,6 +95,7 @@ export default function TrainingSettings({
id="batch"
value={batchSize}
onChange={(e) => setBatchSize(parseInt(e.target.value, 10))}
disabled={isTraining}
/>
<SelectWithLabel
label="옵티마이저"
@ -108,6 +112,7 @@ export default function TrainingSettings({
placeholder="옵티마이저 선택"
value={optimizer}
onChange={(value) => setOptimizer(value as 'AUTO' | 'SGD' | 'ADAM' | 'ADAMW' | 'NADAM' | 'RADAM' | 'RMSPROP')}
disabled={isTraining} // 학습 중일 때 옵티마이저 선택 비활성화
/>
<InputWithLabel
label="학습률(LR0)"
@ -115,6 +120,7 @@ export default function TrainingSettings({
id="lr0"
value={lr0}
onChange={(e) => setLr0(parseFloat(e.target.value))}
disabled={isTraining}
/>
<InputWithLabel
label="최종 학습률(LRF)"
@ -122,14 +128,14 @@ export default function TrainingSettings({
id="lrf"
value={lrf}
onChange={(e) => setLrf(parseFloat(e.target.value))}
disabled={isTraining}
/>
</div>
<Button
variant="outlinePrimary"
size="lg"
onClick={handleSubmit}
disabled={!selectedModel || isTraining}
disabled={!selectedModel}
>
{isTraining ? '학습 중단' : '학습 시작'}
</Button>

View File

@ -39,19 +39,20 @@ export default function TrainingTab({ projectId }: TrainingTabProps) {
};
return (
<div className="grid gap-8 md:grid-cols-2">
<div className="grid grid-rows-[auto_1fr] gap-8 md:grid-cols-2">
<TrainingSettings
projectId={numericProjectId}
selectedModel={selectedModel}
setSelectedModel={(modelId) => setSelectedModel(numericProjectId?.toString() || '', modelId)}
handleTrainingStart={handleTrainingStart}
handleTrainingStop={handleTrainingStop}
isTraining={isTraining}
className="h-full"
/>
<TrainingGraph
projectId={numericProjectId}
selectedModel={selectedModel}
className="h-full"
/>
</div>
);