Refactor: mocker를 통해 연결한 것 기반 오류 수정

This commit is contained in:
정현조 2024-09-25 17:38:20 +09:00
parent 4e844dd367
commit dc78d000e7
8 changed files with 110 additions and 98 deletions

View File

@ -1,11 +1,11 @@
import { Label } from '@/components/ui/label'; import { Label } from '@/components/ui/label';
import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from '@/components/ui/select'; import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from '@/components/ui/select';
import { useState } from 'react';
import useProjectModelsQuery from '@/queries/models/useProjectModelsQuery'; import useProjectModelsQuery from '@/queries/models/useProjectModelsQuery';
import useModelReportsQuery from '@/queries/models/useModelReportsQuery'; import useModelReportsQuery from '@/queries/models/useModelReportsQuery';
import useModelResultsQuery from '@/queries/models/useModelResultsQuery'; import useModelResultsQuery from '@/queries/models/useModelResultsQuery';
import ModelBarChart from './ModelBarChart'; import ModelBarChart from './ModelBarChart';
import ModelLineChart from './ModelLineChart'; import ModelLineChart from './ModelLineChart';
import { useState } from 'react';
interface EvaluationTabProps { interface EvaluationTabProps {
projectId: number | null; projectId: number | null;
@ -13,7 +13,6 @@ interface EvaluationTabProps {
export default function EvaluationTab({ projectId }: EvaluationTabProps) { export default function EvaluationTab({ projectId }: EvaluationTabProps) {
const [selectedModel, setSelectedModel] = useState<number | null>(null); const [selectedModel, setSelectedModel] = useState<number | null>(null);
const { data: models } = useProjectModelsQuery(projectId ?? 0); const { data: models } = useProjectModelsQuery(projectId ?? 0);
return ( return (
@ -70,47 +69,62 @@ function ModelEvaluation({ projectId, selectedModel }: ModelEvaluationProps) {
const { data: reportData } = useModelReportsQuery(projectId, selectedModel); const { data: reportData } = useModelReportsQuery(projectId, selectedModel);
const { data: resultData } = useModelResultsQuery(selectedModel); const { data: resultData } = useModelResultsQuery(selectedModel);
if (!reportData || !resultData) { if (!reportData || !resultData) return null;
return null;
} const trainingInfoRow = (
<div className="flex justify-between rounded-lg bg-gray-100 p-4">
<div className="flex-1 text-center">
<strong>Epochs</strong>
<p>{resultData[0]?.epochs}</p>
</div>
<div className="flex-1 text-center">
<strong>Batch Size</strong>
<p>{resultData[0]?.batch}</p>
</div>
<div className="flex-1 text-center">
<strong>Learning Rate (Start)</strong>
<p>{resultData[0]?.lr0}</p>
</div>
<div className="flex-1 text-center">
<strong>Learning Rate (End)</strong>
<p>{resultData[0]?.lrf}</p>
</div>
<div className="flex-1 text-center">
<strong>Optimizer</strong>
<p>{resultData[0]?.optimizer}</p>
</div>
</div>
);
return ( return (
<div className="grid gap-8 md:grid-cols-2"> <div>
<div className="flex flex-col gap-6"> {trainingInfoRow} {/* 학습 정보 표시 */}
<ModelBarChart <div className="mt-4 grid h-[400px] gap-8 md:grid-cols-2">
data={[ {' '}
{ name: 'precision', value: resultData[0]?.precision, fill: 'var(--color-precision)' }, {/* grid와 높이 설정 */}
{ name: 'recall', value: resultData[0]?.recall, fill: 'var(--color-recall)' }, <div className="flex h-full flex-col gap-6">
{ name: 'mAP50', value: resultData[0]?.map50, fill: 'var(--color-map50)' }, {' '}
{ name: 'mAP50_95', value: resultData[0]?.map5095, fill: 'var(--color-map50-95)' }, {/* 차트의 높이를 100%로 맞춤 */}
{ name: 'fitness', value: resultData[0]?.fitness, fill: 'var(--color-fitness)' }, <ModelBarChart
]} data={[
/> { name: 'precision', value: resultData[0]?.precision, fill: 'var(--color-precision)' },
{ name: 'recall', value: resultData[0]?.recall, fill: 'var(--color-recall)' },
{ name: 'mAP50', value: resultData[0]?.map50, fill: 'var(--color-map50)' },
{ name: 'mAP50_95', value: resultData[0]?.map5095, fill: 'var(--color-map50-95)' },
{ name: 'fitness', value: resultData[0]?.fitness, fill: 'var(--color-fitness)' },
]}
className="h-full"
/>
</div>
<div className="flex h-full flex-col gap-6">
{' '}
{/* 차트의 높이를 100%로 맞춤 */}
<ModelLineChart
data={reportData}
className="h-full"
/>
</div>
</div> </div>
<div className="flex flex-col gap-6">
<ModelLineChart
data={reportData.map((report) => ({
epoch: report.epoch.toString(),
boxLoss: report.boxLoss,
classLoss: report.clsLoss,
dflLoss: report.dflLoss,
fitness: report.fitness,
}))}
/>
</div>
{/* <div className="flex flex-col justify-center">
<LabelingPreview />
</div> */}
</div> </div>
); );
} }
// function LabelingPreview() {
// return (
// <div className="flex items-center justify-center rounded-lg border bg-white p-4">
// <p>레이블링 프리뷰</p>
// </div>
// );
// }

View File

@ -6,9 +6,10 @@ interface InputWithLabelProps {
placeholder: string; placeholder: string;
value: number; value: number;
onChange: (e: React.ChangeEvent<HTMLInputElement>) => void; onChange: (e: React.ChangeEvent<HTMLInputElement>) => void;
disabled?: boolean;
} }
export default function InputWithLabel({ label, id, placeholder, value, onChange }: InputWithLabelProps) { export default function InputWithLabel({ label, id, placeholder, value, disabled, onChange }: InputWithLabelProps) {
return ( return (
<div className="grid gap-3"> <div className="grid gap-3">
<Label htmlFor={id}>{label}</Label> <Label htmlFor={id}>{label}</Label>
@ -18,6 +19,7 @@ export default function InputWithLabel({ label, id, placeholder, value, onChange
placeholder={placeholder} placeholder={placeholder}
value={value} value={value}
onChange={onChange} onChange={onChange}
disabled={disabled}
/> />
</div> </div>
); );

View File

@ -1,9 +1,8 @@
'use client'; 'use client';
import { TrendingUp } from 'lucide-react';
import { Bar, BarChart, CartesianGrid, Rectangle, XAxis } from 'recharts'; import { Bar, BarChart, CartesianGrid, Rectangle, XAxis } from 'recharts';
import { Card, CardContent, CardDescription, CardFooter, CardHeader, CardTitle } from '@/components/ui/card'; import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card';
import { ChartConfig, ChartContainer, ChartTooltip, ChartTooltipContent } from '@/components/ui/chart'; import { ChartConfig, ChartContainer, ChartTooltip, ChartTooltipContent } from '@/components/ui/chart';
interface MetricData { interface MetricData {
@ -14,10 +13,9 @@ interface MetricData {
interface ModelBarChartProps { interface ModelBarChartProps {
data: MetricData[]; data: MetricData[];
className?: string;
} }
export const description = 'A bar chart with an active bar';
const chartConfig = { const chartConfig = {
precision: { precision: {
label: 'Precision', label: 'Precision',
@ -41,9 +39,9 @@ const chartConfig = {
}, },
} satisfies ChartConfig; } satisfies ChartConfig;
export default function ModelBarChart({ data }: ModelBarChartProps) { export default function ModelBarChart({ data, className }: ModelBarChartProps) {
return ( return (
<Card> <Card className={className}>
<CardHeader> <CardHeader>
<CardTitle>Model Metrics</CardTitle> <CardTitle>Model Metrics</CardTitle>
<CardDescription>Performance metrics of the model</CardDescription> <CardDescription>Performance metrics of the model</CardDescription>
@ -86,12 +84,6 @@ export default function ModelBarChart({ data }: ModelBarChartProps) {
</BarChart> </BarChart>
</ChartContainer> </ChartContainer>
</CardContent> </CardContent>
<CardFooter className="flex-col items-start gap-2 text-sm">
<div className="flex gap-2 font-medium leading-none">
Model metrics are trending well <TrendingUp className="h-4 w-4" />
</div>
<div className="text-muted-foreground leading-none">Showing current performance metrics</div>
</CardFooter>
</Card> </Card>
); );
} }

View File

@ -3,20 +3,11 @@
import { CartesianGrid, Line, LineChart, XAxis, YAxis, Tooltip, Legend } from 'recharts'; import { CartesianGrid, Line, LineChart, XAxis, YAxis, Tooltip, Legend } from 'recharts';
import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card'; import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card';
import { ChartConfig, ChartContainer } from '@/components/ui/chart'; import { ChartConfig, ChartContainer } from '@/components/ui/chart';
import { ReportResponse } from '@/types';
interface MetricData {
epoch: string;
boxLoss?: number;
classLoss?: number;
dflLoss?: number;
fitness?: number;
}
interface ModelLineChartProps { interface ModelLineChartProps {
data: MetricData[]; data: ReportResponse[];
currentEpoch?: number; className?: string;
totalEpochs?: number;
remainingTime?: number;
} }
const chartConfig = { const chartConfig = {
@ -38,8 +29,11 @@ const chartConfig = {
}, },
} satisfies ChartConfig; } satisfies ChartConfig;
export default function ModelLineChart({ data, currentEpoch, totalEpochs, remainingTime }: ModelLineChartProps) { export default function ModelLineChart({ data, className }: ModelLineChartProps) {
const emptyData = Array.from({ length: totalEpochs || 0 }, (_, i) => ({ const latestData = data.length > 0 ? data[data.length - 1] : undefined;
const totalEpochs = latestData?.totalEpochs || 0;
const emptyData = Array.from({ length: totalEpochs }, (_, i) => ({
epoch: (i + 1).toString(), epoch: (i + 1).toString(),
boxLoss: null, boxLoss: null,
classLoss: null, classLoss: null,
@ -53,16 +47,16 @@ export default function ModelLineChart({ data, currentEpoch, totalEpochs, remain
})); }));
return ( return (
<Card> <Card className={className}>
<CardHeader> <CardHeader>
<CardTitle>Model Training Metrics</CardTitle> <CardTitle>Model Training Metrics</CardTitle>
</CardHeader> </CardHeader>
<CardContent> <CardContent>
{currentEpoch !== undefined && totalEpochs !== undefined && remainingTime !== undefined && ( {latestData && latestData.totalEpochs !== Number(latestData.epoch) && (
<div className="mb-4 flex justify-between"> <div className="mb-4 flex justify-between">
<p> : {currentEpoch}</p> <p> : {latestData.epoch}</p>
<p> : {totalEpochs}</p> <p> : {latestData.totalEpochs}</p>
<p> : {remainingTime}</p> <p> : {latestData.leftSecond}</p>
</div> </div>
)} )}
<ChartContainer config={chartConfig}> <ChartContainer config={chartConfig}>

View File

@ -12,16 +12,27 @@ interface SelectWithLabelProps {
options: SelectWithLabelOption[]; options: SelectWithLabelOption[];
placeholder: string; placeholder: string;
value: string; value: string;
disabled?: boolean;
onChange: (value: string) => void; onChange: (value: string) => void;
} }
export default function SelectWithLabel({ label, id, options, placeholder, value, onChange }: SelectWithLabelProps) { export default function SelectWithLabel({
label,
id,
options,
placeholder,
value,
disabled,
onChange,
}: SelectWithLabelProps) {
return ( return (
<div className="grid gap-3"> <div className="grid gap-3">
<Label htmlFor={id}>{label}</Label> <Label htmlFor={id}>{label}</Label>
<Select <Select
value={value} value={value}
onValueChange={onChange} onValueChange={onChange}
disabled={disabled}
> >
<SelectTrigger id={id}> <SelectTrigger id={id}>
<SelectValue placeholder={placeholder} /> <SelectValue placeholder={placeholder} />

View File

@ -6,9 +6,10 @@ import useModelStore from '@/stores/useModelStore';
interface TrainingGraphProps { interface TrainingGraphProps {
projectId: number | null; projectId: number | null;
selectedModel: number | null; selectedModel: number | null;
className?: string;
} }
export default function TrainingGraph({ projectId, selectedModel }: TrainingGraphProps) { export default function TrainingGraph({ projectId, selectedModel, className }: TrainingGraphProps) {
const { isTrainingByProject, setIsTraining, saveTrainingData, resetTrainingData, trainingDataByProject } = const { isTrainingByProject, setIsTraining, saveTrainingData, resetTrainingData, trainingDataByProject } =
useModelStore((state) => ({ useModelStore((state) => ({
isTrainingByProject: state.isTrainingByProject, isTrainingByProject: state.isTrainingByProject,
@ -48,6 +49,7 @@ export default function TrainingGraph({ projectId, selectedModel }: TrainingGrap
useEffect(() => { useEffect(() => {
if (latestData.epoch === latestData.totalEpochs && latestData.totalEpochs > 0) { if (latestData.epoch === latestData.totalEpochs && latestData.totalEpochs > 0) {
alert('학습이 완료되었습니다!');
setIsTraining(projectId?.toString() || '', false); setIsTraining(projectId?.toString() || '', false);
resetTrainingData(projectId?.toString() || ''); resetTrainingData(projectId?.toString() || '');
} }
@ -55,18 +57,8 @@ export default function TrainingGraph({ projectId, selectedModel }: TrainingGrap
return ( return (
<ModelLineChart <ModelLineChart
data={ data={trainingDataList}
trainingDataList?.map((data) => ({ className={className}
epoch: data.epoch.toString(),
boxLoss: data.boxLoss,
classLoss: data.clsLoss,
dflLoss: data.dflLoss,
fitness: data.fitness,
})) || []
}
currentEpoch={latestData.epoch}
totalEpochs={latestData.totalEpochs}
remainingTime={latestData.leftSecond}
/> />
); );
} }

View File

@ -2,8 +2,10 @@ import SelectWithLabel from './SelectWithLabel';
import InputWithLabel from './InputWithLabel'; import InputWithLabel from './InputWithLabel';
import { Button } from '@/components/ui/button'; import { Button } from '@/components/ui/button';
import useProjectModelsQuery from '@/queries/models/useProjectModelsQuery'; import useProjectModelsQuery from '@/queries/models/useProjectModelsQuery';
import useModelStore from '@/stores/useModelStore';
import { ModelTrainRequest } from '@/types'; import { ModelTrainRequest } from '@/types';
import { useState } from 'react'; import { useState } from 'react';
import { cn } from '@/lib/utils';
interface TrainingSettingsProps { interface TrainingSettingsProps {
projectId: number | null; projectId: number | null;
@ -11,7 +13,7 @@ interface TrainingSettingsProps {
setSelectedModel: (model: number | null) => void; setSelectedModel: (model: number | null) => void;
handleTrainingStart: (trainData: ModelTrainRequest) => void; handleTrainingStart: (trainData: ModelTrainRequest) => void;
handleTrainingStop: () => void; handleTrainingStop: () => void;
isTraining: boolean; className?: string;
} }
export default function TrainingSettings({ export default function TrainingSettings({
@ -20,10 +22,12 @@ export default function TrainingSettings({
setSelectedModel, setSelectedModel,
handleTrainingStart, handleTrainingStart,
handleTrainingStop, handleTrainingStop,
isTraining, className,
}: TrainingSettingsProps) { }: TrainingSettingsProps) {
const { data: models } = useProjectModelsQuery(projectId ?? 0); const { data: models } = useProjectModelsQuery(projectId ?? 0);
const isTraining = useModelStore((state) => state.isTrainingByProject[projectId?.toString() || ''] || false);
const [ratio, setRatio] = useState<number>(0.8); const [ratio, setRatio] = useState<number>(0.8);
const [epochs, setEpochs] = useState<number>(50); const [epochs, setEpochs] = useState<number>(50);
const [batchSize, setBatchSize] = useState<number>(32); const [batchSize, setBatchSize] = useState<number>(32);
@ -49,12 +53,9 @@ export default function TrainingSettings({
}; };
return ( return (
<fieldset <fieldset className={cn('grid gap-6 rounded-lg border p-4', className)}>
className="grid gap-6 rounded-lg border p-4" {' '}
disabled={isTraining}
>
<legend className="-ml-1 px-1 text-sm font-medium"> </legend> <legend className="-ml-1 px-1 text-sm font-medium"> </legend>
<div className="grid gap-3"> <div className="grid gap-3">
<SelectWithLabel <SelectWithLabel
label="모델 선택" label="모델 선택"
@ -68,9 +69,9 @@ export default function TrainingSettings({
placeholder="모델을 선택하세요" placeholder="모델을 선택하세요"
value={selectedModel ? selectedModel.toString() : ''} value={selectedModel ? selectedModel.toString() : ''}
onChange={(value) => setSelectedModel(parseInt(value, 10))} onChange={(value) => setSelectedModel(parseInt(value, 10))}
disabled={isTraining}
/> />
</div> </div>
<div className="grid grid-cols-2 gap-4"> <div className="grid grid-cols-2 gap-4">
<InputWithLabel <InputWithLabel
label="훈련/검증 비율" label="훈련/검증 비율"
@ -78,6 +79,7 @@ export default function TrainingSettings({
id="ratio" id="ratio"
value={ratio} value={ratio}
onChange={(e) => setRatio(parseFloat(e.target.value))} onChange={(e) => setRatio(parseFloat(e.target.value))}
disabled={isTraining}
/> />
<InputWithLabel <InputWithLabel
label="에포크 수" label="에포크 수"
@ -85,6 +87,7 @@ export default function TrainingSettings({
id="epochs" id="epochs"
value={epochs} value={epochs}
onChange={(e) => setEpochs(parseInt(e.target.value, 10))} onChange={(e) => setEpochs(parseInt(e.target.value, 10))}
disabled={isTraining}
/> />
<InputWithLabel <InputWithLabel
label="Batch 크기" label="Batch 크기"
@ -92,6 +95,7 @@ export default function TrainingSettings({
id="batch" id="batch"
value={batchSize} value={batchSize}
onChange={(e) => setBatchSize(parseInt(e.target.value, 10))} onChange={(e) => setBatchSize(parseInt(e.target.value, 10))}
disabled={isTraining}
/> />
<SelectWithLabel <SelectWithLabel
label="옵티마이저" label="옵티마이저"
@ -108,6 +112,7 @@ export default function TrainingSettings({
placeholder="옵티마이저 선택" placeholder="옵티마이저 선택"
value={optimizer} value={optimizer}
onChange={(value) => setOptimizer(value as 'AUTO' | 'SGD' | 'ADAM' | 'ADAMW' | 'NADAM' | 'RADAM' | 'RMSPROP')} onChange={(value) => setOptimizer(value as 'AUTO' | 'SGD' | 'ADAM' | 'ADAMW' | 'NADAM' | 'RADAM' | 'RMSPROP')}
disabled={isTraining} // 학습 중일 때 옵티마이저 선택 비활성화
/> />
<InputWithLabel <InputWithLabel
label="학습률(LR0)" label="학습률(LR0)"
@ -115,6 +120,7 @@ export default function TrainingSettings({
id="lr0" id="lr0"
value={lr0} value={lr0}
onChange={(e) => setLr0(parseFloat(e.target.value))} onChange={(e) => setLr0(parseFloat(e.target.value))}
disabled={isTraining}
/> />
<InputWithLabel <InputWithLabel
label="최종 학습률(LRF)" label="최종 학습률(LRF)"
@ -122,14 +128,14 @@ export default function TrainingSettings({
id="lrf" id="lrf"
value={lrf} value={lrf}
onChange={(e) => setLrf(parseFloat(e.target.value))} onChange={(e) => setLrf(parseFloat(e.target.value))}
disabled={isTraining}
/> />
</div> </div>
<Button <Button
variant="outlinePrimary" variant="outlinePrimary"
size="lg" size="lg"
onClick={handleSubmit} onClick={handleSubmit}
disabled={!selectedModel || isTraining} disabled={!selectedModel}
> >
{isTraining ? '학습 중단' : '학습 시작'} {isTraining ? '학습 중단' : '학습 시작'}
</Button> </Button>

View File

@ -39,19 +39,20 @@ export default function TrainingTab({ projectId }: TrainingTabProps) {
}; };
return ( return (
<div className="grid gap-8 md:grid-cols-2"> <div className="grid grid-rows-[auto_1fr] gap-8 md:grid-cols-2">
<TrainingSettings <TrainingSettings
projectId={numericProjectId} projectId={numericProjectId}
selectedModel={selectedModel} selectedModel={selectedModel}
setSelectedModel={(modelId) => setSelectedModel(numericProjectId?.toString() || '', modelId)} setSelectedModel={(modelId) => setSelectedModel(numericProjectId?.toString() || '', modelId)}
handleTrainingStart={handleTrainingStart} handleTrainingStart={handleTrainingStart}
handleTrainingStop={handleTrainingStop} handleTrainingStop={handleTrainingStop}
isTraining={isTraining} className="h-full"
/> />
<TrainingGraph <TrainingGraph
projectId={numericProjectId} projectId={numericProjectId}
selectedModel={selectedModel} selectedModel={selectedModel}
className="h-full"
/> />
</div> </div>
); );