Refactor: mocker를 통해 연결한 것 기반 오류 수정
This commit is contained in:
parent
4e844dd367
commit
dc78d000e7
@ -1,11 +1,11 @@
|
||||
import { Label } from '@/components/ui/label';
|
||||
import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from '@/components/ui/select';
|
||||
import { useState } from 'react';
|
||||
import useProjectModelsQuery from '@/queries/models/useProjectModelsQuery';
|
||||
import useModelReportsQuery from '@/queries/models/useModelReportsQuery';
|
||||
import useModelResultsQuery from '@/queries/models/useModelResultsQuery';
|
||||
import ModelBarChart from './ModelBarChart';
|
||||
import ModelLineChart from './ModelLineChart';
|
||||
import { useState } from 'react';
|
||||
|
||||
interface EvaluationTabProps {
|
||||
projectId: number | null;
|
||||
@ -13,7 +13,6 @@ interface EvaluationTabProps {
|
||||
|
||||
export default function EvaluationTab({ projectId }: EvaluationTabProps) {
|
||||
const [selectedModel, setSelectedModel] = useState<number | null>(null);
|
||||
|
||||
const { data: models } = useProjectModelsQuery(projectId ?? 0);
|
||||
|
||||
return (
|
||||
@ -70,47 +69,62 @@ function ModelEvaluation({ projectId, selectedModel }: ModelEvaluationProps) {
|
||||
const { data: reportData } = useModelReportsQuery(projectId, selectedModel);
|
||||
const { data: resultData } = useModelResultsQuery(selectedModel);
|
||||
|
||||
if (!reportData || !resultData) {
|
||||
return null;
|
||||
}
|
||||
if (!reportData || !resultData) return null;
|
||||
|
||||
const trainingInfoRow = (
|
||||
<div className="flex justify-between rounded-lg bg-gray-100 p-4">
|
||||
<div className="flex-1 text-center">
|
||||
<strong>Epochs</strong>
|
||||
<p>{resultData[0]?.epochs}</p>
|
||||
</div>
|
||||
<div className="flex-1 text-center">
|
||||
<strong>Batch Size</strong>
|
||||
<p>{resultData[0]?.batch}</p>
|
||||
</div>
|
||||
<div className="flex-1 text-center">
|
||||
<strong>Learning Rate (Start)</strong>
|
||||
<p>{resultData[0]?.lr0}</p>
|
||||
</div>
|
||||
<div className="flex-1 text-center">
|
||||
<strong>Learning Rate (End)</strong>
|
||||
<p>{resultData[0]?.lrf}</p>
|
||||
</div>
|
||||
<div className="flex-1 text-center">
|
||||
<strong>Optimizer</strong>
|
||||
<p>{resultData[0]?.optimizer}</p>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
||||
return (
|
||||
<div className="grid gap-8 md:grid-cols-2">
|
||||
<div className="flex flex-col gap-6">
|
||||
<ModelBarChart
|
||||
data={[
|
||||
{ name: 'precision', value: resultData[0]?.precision, fill: 'var(--color-precision)' },
|
||||
{ name: 'recall', value: resultData[0]?.recall, fill: 'var(--color-recall)' },
|
||||
{ name: 'mAP50', value: resultData[0]?.map50, fill: 'var(--color-map50)' },
|
||||
{ name: 'mAP50_95', value: resultData[0]?.map5095, fill: 'var(--color-map50-95)' },
|
||||
{ name: 'fitness', value: resultData[0]?.fitness, fill: 'var(--color-fitness)' },
|
||||
]}
|
||||
/>
|
||||
<div>
|
||||
{trainingInfoRow} {/* 학습 정보 표시 */}
|
||||
<div className="mt-4 grid h-[400px] gap-8 md:grid-cols-2">
|
||||
{' '}
|
||||
{/* grid와 높이 설정 */}
|
||||
<div className="flex h-full flex-col gap-6">
|
||||
{' '}
|
||||
{/* 차트의 높이를 100%로 맞춤 */}
|
||||
<ModelBarChart
|
||||
data={[
|
||||
{ name: 'precision', value: resultData[0]?.precision, fill: 'var(--color-precision)' },
|
||||
{ name: 'recall', value: resultData[0]?.recall, fill: 'var(--color-recall)' },
|
||||
{ name: 'mAP50', value: resultData[0]?.map50, fill: 'var(--color-map50)' },
|
||||
{ name: 'mAP50_95', value: resultData[0]?.map5095, fill: 'var(--color-map50-95)' },
|
||||
{ name: 'fitness', value: resultData[0]?.fitness, fill: 'var(--color-fitness)' },
|
||||
]}
|
||||
className="h-full"
|
||||
/>
|
||||
</div>
|
||||
<div className="flex h-full flex-col gap-6">
|
||||
{' '}
|
||||
{/* 차트의 높이를 100%로 맞춤 */}
|
||||
<ModelLineChart
|
||||
data={reportData}
|
||||
className="h-full"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="flex flex-col gap-6">
|
||||
<ModelLineChart
|
||||
data={reportData.map((report) => ({
|
||||
epoch: report.epoch.toString(),
|
||||
boxLoss: report.boxLoss,
|
||||
classLoss: report.clsLoss,
|
||||
dflLoss: report.dflLoss,
|
||||
fitness: report.fitness,
|
||||
}))}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{/* <div className="flex flex-col justify-center">
|
||||
<LabelingPreview />
|
||||
</div> */}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// function LabelingPreview() {
|
||||
// return (
|
||||
// <div className="flex items-center justify-center rounded-lg border bg-white p-4">
|
||||
// <p>레이블링 프리뷰</p>
|
||||
// </div>
|
||||
// );
|
||||
// }
|
||||
|
@ -6,9 +6,10 @@ interface InputWithLabelProps {
|
||||
placeholder: string;
|
||||
value: number;
|
||||
onChange: (e: React.ChangeEvent<HTMLInputElement>) => void;
|
||||
disabled?: boolean;
|
||||
}
|
||||
|
||||
export default function InputWithLabel({ label, id, placeholder, value, onChange }: InputWithLabelProps) {
|
||||
export default function InputWithLabel({ label, id, placeholder, value, disabled, onChange }: InputWithLabelProps) {
|
||||
return (
|
||||
<div className="grid gap-3">
|
||||
<Label htmlFor={id}>{label}</Label>
|
||||
@ -18,6 +19,7 @@ export default function InputWithLabel({ label, id, placeholder, value, onChange
|
||||
placeholder={placeholder}
|
||||
value={value}
|
||||
onChange={onChange}
|
||||
disabled={disabled}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
|
@ -1,9 +1,8 @@
|
||||
'use client';
|
||||
|
||||
import { TrendingUp } from 'lucide-react';
|
||||
import { Bar, BarChart, CartesianGrid, Rectangle, XAxis } from 'recharts';
|
||||
|
||||
import { Card, CardContent, CardDescription, CardFooter, CardHeader, CardTitle } from '@/components/ui/card';
|
||||
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card';
|
||||
import { ChartConfig, ChartContainer, ChartTooltip, ChartTooltipContent } from '@/components/ui/chart';
|
||||
|
||||
interface MetricData {
|
||||
@ -14,10 +13,9 @@ interface MetricData {
|
||||
|
||||
interface ModelBarChartProps {
|
||||
data: MetricData[];
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export const description = 'A bar chart with an active bar';
|
||||
|
||||
const chartConfig = {
|
||||
precision: {
|
||||
label: 'Precision',
|
||||
@ -41,9 +39,9 @@ const chartConfig = {
|
||||
},
|
||||
} satisfies ChartConfig;
|
||||
|
||||
export default function ModelBarChart({ data }: ModelBarChartProps) {
|
||||
export default function ModelBarChart({ data, className }: ModelBarChartProps) {
|
||||
return (
|
||||
<Card>
|
||||
<Card className={className}>
|
||||
<CardHeader>
|
||||
<CardTitle>Model Metrics</CardTitle>
|
||||
<CardDescription>Performance metrics of the model</CardDescription>
|
||||
@ -86,12 +84,6 @@ export default function ModelBarChart({ data }: ModelBarChartProps) {
|
||||
</BarChart>
|
||||
</ChartContainer>
|
||||
</CardContent>
|
||||
<CardFooter className="flex-col items-start gap-2 text-sm">
|
||||
<div className="flex gap-2 font-medium leading-none">
|
||||
Model metrics are trending well <TrendingUp className="h-4 w-4" />
|
||||
</div>
|
||||
<div className="text-muted-foreground leading-none">Showing current performance metrics</div>
|
||||
</CardFooter>
|
||||
</Card>
|
||||
);
|
||||
}
|
||||
|
@ -3,20 +3,11 @@
|
||||
import { CartesianGrid, Line, LineChart, XAxis, YAxis, Tooltip, Legend } from 'recharts';
|
||||
import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card';
|
||||
import { ChartConfig, ChartContainer } from '@/components/ui/chart';
|
||||
|
||||
interface MetricData {
|
||||
epoch: string;
|
||||
boxLoss?: number;
|
||||
classLoss?: number;
|
||||
dflLoss?: number;
|
||||
fitness?: number;
|
||||
}
|
||||
import { ReportResponse } from '@/types';
|
||||
|
||||
interface ModelLineChartProps {
|
||||
data: MetricData[];
|
||||
currentEpoch?: number;
|
||||
totalEpochs?: number;
|
||||
remainingTime?: number;
|
||||
data: ReportResponse[];
|
||||
className?: string;
|
||||
}
|
||||
|
||||
const chartConfig = {
|
||||
@ -38,8 +29,11 @@ const chartConfig = {
|
||||
},
|
||||
} satisfies ChartConfig;
|
||||
|
||||
export default function ModelLineChart({ data, currentEpoch, totalEpochs, remainingTime }: ModelLineChartProps) {
|
||||
const emptyData = Array.from({ length: totalEpochs || 0 }, (_, i) => ({
|
||||
export default function ModelLineChart({ data, className }: ModelLineChartProps) {
|
||||
const latestData = data.length > 0 ? data[data.length - 1] : undefined;
|
||||
|
||||
const totalEpochs = latestData?.totalEpochs || 0;
|
||||
const emptyData = Array.from({ length: totalEpochs }, (_, i) => ({
|
||||
epoch: (i + 1).toString(),
|
||||
boxLoss: null,
|
||||
classLoss: null,
|
||||
@ -53,16 +47,16 @@ export default function ModelLineChart({ data, currentEpoch, totalEpochs, remain
|
||||
}));
|
||||
|
||||
return (
|
||||
<Card>
|
||||
<Card className={className}>
|
||||
<CardHeader>
|
||||
<CardTitle>Model Training Metrics</CardTitle>
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
{currentEpoch !== undefined && totalEpochs !== undefined && remainingTime !== undefined && (
|
||||
{latestData && latestData.totalEpochs !== Number(latestData.epoch) && (
|
||||
<div className="mb-4 flex justify-between">
|
||||
<p>현재 에포크: {currentEpoch}</p>
|
||||
<p>총 에포크: {totalEpochs}</p>
|
||||
<p>예상 남은시간: {remainingTime}</p>
|
||||
<p>현재 에포크: {latestData.epoch}</p>
|
||||
<p>총 에포크: {latestData.totalEpochs}</p>
|
||||
<p>예상 남은시간: {latestData.leftSecond}초</p>
|
||||
</div>
|
||||
)}
|
||||
<ChartContainer config={chartConfig}>
|
||||
|
@ -12,16 +12,27 @@ interface SelectWithLabelProps {
|
||||
options: SelectWithLabelOption[];
|
||||
placeholder: string;
|
||||
value: string;
|
||||
disabled?: boolean;
|
||||
|
||||
onChange: (value: string) => void;
|
||||
}
|
||||
|
||||
export default function SelectWithLabel({ label, id, options, placeholder, value, onChange }: SelectWithLabelProps) {
|
||||
export default function SelectWithLabel({
|
||||
label,
|
||||
id,
|
||||
options,
|
||||
placeholder,
|
||||
value,
|
||||
disabled,
|
||||
onChange,
|
||||
}: SelectWithLabelProps) {
|
||||
return (
|
||||
<div className="grid gap-3">
|
||||
<Label htmlFor={id}>{label}</Label>
|
||||
<Select
|
||||
value={value}
|
||||
onValueChange={onChange}
|
||||
disabled={disabled}
|
||||
>
|
||||
<SelectTrigger id={id}>
|
||||
<SelectValue placeholder={placeholder} />
|
||||
|
@ -6,9 +6,10 @@ import useModelStore from '@/stores/useModelStore';
|
||||
interface TrainingGraphProps {
|
||||
projectId: number | null;
|
||||
selectedModel: number | null;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export default function TrainingGraph({ projectId, selectedModel }: TrainingGraphProps) {
|
||||
export default function TrainingGraph({ projectId, selectedModel, className }: TrainingGraphProps) {
|
||||
const { isTrainingByProject, setIsTraining, saveTrainingData, resetTrainingData, trainingDataByProject } =
|
||||
useModelStore((state) => ({
|
||||
isTrainingByProject: state.isTrainingByProject,
|
||||
@ -48,6 +49,7 @@ export default function TrainingGraph({ projectId, selectedModel }: TrainingGrap
|
||||
|
||||
useEffect(() => {
|
||||
if (latestData.epoch === latestData.totalEpochs && latestData.totalEpochs > 0) {
|
||||
alert('학습이 완료되었습니다!');
|
||||
setIsTraining(projectId?.toString() || '', false);
|
||||
resetTrainingData(projectId?.toString() || '');
|
||||
}
|
||||
@ -55,18 +57,8 @@ export default function TrainingGraph({ projectId, selectedModel }: TrainingGrap
|
||||
|
||||
return (
|
||||
<ModelLineChart
|
||||
data={
|
||||
trainingDataList?.map((data) => ({
|
||||
epoch: data.epoch.toString(),
|
||||
boxLoss: data.boxLoss,
|
||||
classLoss: data.clsLoss,
|
||||
dflLoss: data.dflLoss,
|
||||
fitness: data.fitness,
|
||||
})) || []
|
||||
}
|
||||
currentEpoch={latestData.epoch}
|
||||
totalEpochs={latestData.totalEpochs}
|
||||
remainingTime={latestData.leftSecond}
|
||||
data={trainingDataList}
|
||||
className={className}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
@ -2,8 +2,10 @@ import SelectWithLabel from './SelectWithLabel';
|
||||
import InputWithLabel from './InputWithLabel';
|
||||
import { Button } from '@/components/ui/button';
|
||||
import useProjectModelsQuery from '@/queries/models/useProjectModelsQuery';
|
||||
import useModelStore from '@/stores/useModelStore';
|
||||
import { ModelTrainRequest } from '@/types';
|
||||
import { useState } from 'react';
|
||||
import { cn } from '@/lib/utils';
|
||||
|
||||
interface TrainingSettingsProps {
|
||||
projectId: number | null;
|
||||
@ -11,7 +13,7 @@ interface TrainingSettingsProps {
|
||||
setSelectedModel: (model: number | null) => void;
|
||||
handleTrainingStart: (trainData: ModelTrainRequest) => void;
|
||||
handleTrainingStop: () => void;
|
||||
isTraining: boolean;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export default function TrainingSettings({
|
||||
@ -20,10 +22,12 @@ export default function TrainingSettings({
|
||||
setSelectedModel,
|
||||
handleTrainingStart,
|
||||
handleTrainingStop,
|
||||
isTraining,
|
||||
className,
|
||||
}: TrainingSettingsProps) {
|
||||
const { data: models } = useProjectModelsQuery(projectId ?? 0);
|
||||
|
||||
const isTraining = useModelStore((state) => state.isTrainingByProject[projectId?.toString() || ''] || false);
|
||||
|
||||
const [ratio, setRatio] = useState<number>(0.8);
|
||||
const [epochs, setEpochs] = useState<number>(50);
|
||||
const [batchSize, setBatchSize] = useState<number>(32);
|
||||
@ -49,12 +53,9 @@ export default function TrainingSettings({
|
||||
};
|
||||
|
||||
return (
|
||||
<fieldset
|
||||
className="grid gap-6 rounded-lg border p-4"
|
||||
disabled={isTraining}
|
||||
>
|
||||
<fieldset className={cn('grid gap-6 rounded-lg border p-4', className)}>
|
||||
{' '}
|
||||
<legend className="-ml-1 px-1 text-sm font-medium">모델 설정</legend>
|
||||
|
||||
<div className="grid gap-3">
|
||||
<SelectWithLabel
|
||||
label="모델 선택"
|
||||
@ -68,9 +69,9 @@ export default function TrainingSettings({
|
||||
placeholder="모델을 선택하세요"
|
||||
value={selectedModel ? selectedModel.toString() : ''}
|
||||
onChange={(value) => setSelectedModel(parseInt(value, 10))}
|
||||
disabled={isTraining}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="grid grid-cols-2 gap-4">
|
||||
<InputWithLabel
|
||||
label="훈련/검증 비율"
|
||||
@ -78,6 +79,7 @@ export default function TrainingSettings({
|
||||
id="ratio"
|
||||
value={ratio}
|
||||
onChange={(e) => setRatio(parseFloat(e.target.value))}
|
||||
disabled={isTraining}
|
||||
/>
|
||||
<InputWithLabel
|
||||
label="에포크 수"
|
||||
@ -85,6 +87,7 @@ export default function TrainingSettings({
|
||||
id="epochs"
|
||||
value={epochs}
|
||||
onChange={(e) => setEpochs(parseInt(e.target.value, 10))}
|
||||
disabled={isTraining}
|
||||
/>
|
||||
<InputWithLabel
|
||||
label="Batch 크기"
|
||||
@ -92,6 +95,7 @@ export default function TrainingSettings({
|
||||
id="batch"
|
||||
value={batchSize}
|
||||
onChange={(e) => setBatchSize(parseInt(e.target.value, 10))}
|
||||
disabled={isTraining}
|
||||
/>
|
||||
<SelectWithLabel
|
||||
label="옵티마이저"
|
||||
@ -108,6 +112,7 @@ export default function TrainingSettings({
|
||||
placeholder="옵티마이저 선택"
|
||||
value={optimizer}
|
||||
onChange={(value) => setOptimizer(value as 'AUTO' | 'SGD' | 'ADAM' | 'ADAMW' | 'NADAM' | 'RADAM' | 'RMSPROP')}
|
||||
disabled={isTraining} // 학습 중일 때 옵티마이저 선택 비활성화
|
||||
/>
|
||||
<InputWithLabel
|
||||
label="학습률(LR0)"
|
||||
@ -115,6 +120,7 @@ export default function TrainingSettings({
|
||||
id="lr0"
|
||||
value={lr0}
|
||||
onChange={(e) => setLr0(parseFloat(e.target.value))}
|
||||
disabled={isTraining}
|
||||
/>
|
||||
<InputWithLabel
|
||||
label="최종 학습률(LRF)"
|
||||
@ -122,14 +128,14 @@ export default function TrainingSettings({
|
||||
id="lrf"
|
||||
value={lrf}
|
||||
onChange={(e) => setLrf(parseFloat(e.target.value))}
|
||||
disabled={isTraining}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<Button
|
||||
variant="outlinePrimary"
|
||||
size="lg"
|
||||
onClick={handleSubmit}
|
||||
disabled={!selectedModel || isTraining}
|
||||
disabled={!selectedModel}
|
||||
>
|
||||
{isTraining ? '학습 중단' : '학습 시작'}
|
||||
</Button>
|
||||
|
@ -39,19 +39,20 @@ export default function TrainingTab({ projectId }: TrainingTabProps) {
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="grid gap-8 md:grid-cols-2">
|
||||
<div className="grid grid-rows-[auto_1fr] gap-8 md:grid-cols-2">
|
||||
<TrainingSettings
|
||||
projectId={numericProjectId}
|
||||
selectedModel={selectedModel}
|
||||
setSelectedModel={(modelId) => setSelectedModel(numericProjectId?.toString() || '', modelId)}
|
||||
handleTrainingStart={handleTrainingStart}
|
||||
handleTrainingStop={handleTrainingStop}
|
||||
isTraining={isTraining}
|
||||
className="h-full"
|
||||
/>
|
||||
|
||||
<TrainingGraph
|
||||
projectId={numericProjectId}
|
||||
selectedModel={selectedModel}
|
||||
className="h-full"
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
|
Loading…
Reference in New Issue
Block a user