Merge branch 'fe/refactor/admin-model' into 'fe/develop'

Refactor: 학습, 평가 부분 api 연결

See merge request s11-s-project/S11P21S002!163
This commit is contained in:
홍창기 2024-09-25 08:37:06 +09:00
commit d743cf3ce9
19 changed files with 611 additions and 429 deletions

View File

@ -1,12 +1,20 @@
import api from '@/api/axiosConfig'; import api from '@/api/axiosConfig';
import { ModelRequest, ModelResponse, ProjectModelsResponse, ModelCategoryResponse } from '@/types'; import {
ModelRequest,
ModelResponse,
ProjectModelsResponse,
ModelCategoryResponse,
ModelTrainRequest,
ResultResponse,
ReportResponse,
} from '@/types';
export async function updateModelName(projectId: number, modelId: number, modelData: ModelRequest) { export async function updateModelName(projectId: number, modelId: number, modelData: ModelRequest) {
return api.put<ModelResponse>(`/projects/${projectId}/models/${modelId}`, modelData).then(({ data }) => data); return api.put<ModelResponse>(`/projects/${projectId}/models/${modelId}`, modelData).then(({ data }) => data);
} }
export async function trainModel(projectId: number) { export async function trainModel(projectId: number, trainData: ModelTrainRequest) {
return api.post(`/projects/${projectId}/train`).then(({ data }) => data); return api.post(`/projects/${projectId}/train`, trainData).then(({ data }) => data);
} }
export async function getProjectModels(projectId: number) { export async function getProjectModels(projectId: number) {
@ -20,3 +28,11 @@ export async function addProjectModel(projectId: number, modelData: ModelRequest
export async function getModelCategories(modelId: number) { export async function getModelCategories(modelId: number) {
return api.get<ModelCategoryResponse[]>(`/models/${modelId}/categories`).then(({ data }) => data); return api.get<ModelCategoryResponse[]>(`/models/${modelId}/categories`).then(({ data }) => data);
} }
export async function getModelResults(modelId: number) {
return api.get<ResultResponse[]>(`/results/model/${modelId}`).then(({ data }) => data);
}
export async function getModelReports(projectId: number, modelId: number) {
return api.get<ReportResponse[]>(`/projects/${projectId}/reports/model/${modelId}`).then(({ data }) => data);
}

View File

@ -1,55 +1,116 @@
import { Label } from '@/components/ui/label'; import { Label } from '@/components/ui/label';
import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from '@/components/ui/select'; import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from '@/components/ui/select';
import ModelBarChart from '@/components/ModelBarChart'; import { useState } from 'react';
import useProjectModelsQuery from '@/queries/models/useProjectModelsQuery';
import useModelReportsQuery from '@/queries/models/useModelReportsQuery';
import useModelResultsQuery from '@/queries/models/useModelResultsQuery';
import ModelBarChart from './ModelBarChart';
import ModelLineChart from './ModelLineChart';
interface EvaluationTabProps { interface EvaluationTabProps {
selectedModel: string | null; projectId: number | null;
setSelectedModel: (model: string | null) => void;
} }
export default function EvaluationTab({ selectedModel, setSelectedModel }: EvaluationTabProps) { export default function EvaluationTab({ projectId }: EvaluationTabProps) {
const [selectedModel, setSelectedModel] = useState<number | null>(null);
const { data: models } = useProjectModelsQuery(projectId ?? 0);
return ( return (
<div> <div>
<div className="mb-4"> <ModelSelection
<Label htmlFor="select-model"> </Label> models={models}
<Select onValueChange={setSelectedModel}> setSelectedModel={setSelectedModel}
<SelectTrigger id="select-model"> />
<SelectValue placeholder="모델을 선택하세요" />
</SelectTrigger>
<SelectContent>
<SelectItem value="genesis">Genesis</SelectItem>
<SelectItem value="explorer">Explorer</SelectItem>
<SelectItem value="quantum">Quantum</SelectItem>
</SelectContent>
</Select>
</div>
{selectedModel && ( {selectedModel && (
<div className="grid gap-8 md:grid-cols-2"> <ModelEvaluation
<div className="flex flex-col gap-6"> projectId={projectId as number}
<ModelBarChart selectedModel={selectedModel}
data={[
{ name: 'precision', value: 0.734, fill: 'var(--color-precision)' },
{ name: 'recall', value: 0.75, fill: 'var(--color-recall)' },
{ name: 'mAP50', value: 0.995, fill: 'var(--color-map50)' },
{ name: 'mAP50_95', value: 0.97, fill: 'var(--color-map50-95)' },
{ name: 'fitness', value: 0.973, fill: 'var(--color-fitness)' },
]}
/> />
</div>
<div className="flex flex-col justify-center">
<LabelingPreview />
</div>
</div>
)} )}
</div> </div>
); );
} }
function LabelingPreview() { interface ModelSelectionProps {
models: Array<{ id: number; name: string }> | undefined;
setSelectedModel: (modelId: number) => void;
}
function ModelSelection({ models, setSelectedModel }: ModelSelectionProps) {
return ( return (
<div className="flex items-center justify-center rounded-lg border bg-white p-4"> <div className="mb-4">
<p> </p> <Label htmlFor="select-model"> </Label>
<Select onValueChange={(value) => setSelectedModel(parseInt(value))}>
<SelectTrigger id="select-model">
<SelectValue placeholder="모델을 선택하세요" />
</SelectTrigger>
<SelectContent>
{models?.map((model) => (
<SelectItem
key={model.id}
value={model.id.toString()}
>
{model.name}
</SelectItem>
))}
</SelectContent>
</Select>
</div> </div>
); );
} }
interface ModelEvaluationProps {
projectId: number;
selectedModel: number;
}
function ModelEvaluation({ projectId, selectedModel }: ModelEvaluationProps) {
const { data: reportData } = useModelReportsQuery(projectId, selectedModel);
const { data: resultData } = useModelResultsQuery(selectedModel);
if (!reportData || !resultData) {
return null;
}
return (
<div className="grid gap-8 md:grid-cols-2">
<div className="flex flex-col gap-6">
<ModelBarChart
data={[
{ name: 'precision', value: resultData[0]?.precision, fill: 'var(--color-precision)' },
{ name: 'recall', value: resultData[0]?.recall, fill: 'var(--color-recall)' },
{ name: 'mAP50', value: resultData[0]?.map50, fill: 'var(--color-map50)' },
{ name: 'mAP50_95', value: resultData[0]?.map5095, fill: 'var(--color-map50-95)' },
{ name: 'fitness', value: resultData[0]?.fitness, fill: 'var(--color-fitness)' },
]}
/>
</div>
<div className="flex flex-col gap-6">
<ModelLineChart
data={reportData.map((report) => ({
epoch: report.epoch.toString(),
loss1: report.boxLoss,
loss2: report.clsLoss,
loss3: report.dflLoss,
fitness: report.fitness,
}))}
/>
</div>
{/* <div className="flex flex-col justify-center">
<LabelingPreview />
</div> */}
</div>
);
}
// function LabelingPreview() {
// return (
// <div className="flex items-center justify-center rounded-lg border bg-white p-4">
// <p>레이블링 프리뷰</p>
// </div>
// );
// }

View File

@ -0,0 +1,24 @@
import { Label } from '@/components/ui/label';
import { Input } from '../ui/input';
interface InputWithLabelProps {
label: string;
id: string;
placeholder: string;
value: number;
onChange: (e: React.ChangeEvent<HTMLInputElement>) => void;
}
export default function InputWithLabel({ label, id, placeholder, value, onChange }: InputWithLabelProps) {
return (
<div className="grid gap-3">
<Label htmlFor={id}>{label}</Label>
<Input
id={id}
type="number"
placeholder={placeholder}
value={value}
onChange={onChange}
/>
</div>
);
}

View File

@ -1,54 +1,74 @@
'use client'; 'use client';
import { TrendingUp } from 'lucide-react'; import { CartesianGrid, Line, LineChart, XAxis, YAxis, Tooltip, Legend } from 'recharts';
import { CartesianGrid, Line, LineChart, XAxis } from 'recharts'; import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card';
import { ChartConfig, ChartContainer } from '@/components/ui/chart';
import { Card, CardContent, CardDescription, CardFooter, CardHeader, CardTitle } from '@/components/ui/card';
import { ChartConfig, ChartContainer, ChartTooltip, ChartTooltipContent } from '@/components/ui/chart';
interface MetricData { interface MetricData {
epoch: string; epoch: string;
loss1: number; loss1?: number;
loss2: number; loss2?: number;
loss3: number; loss3?: number;
fitness: number; fitness?: number;
} }
interface ModelLineChartProps { interface ModelLineChartProps {
data: MetricData[]; data: MetricData[];
currentEpoch?: number;
totalEpochs?: number;
remainingTime?: number;
} }
const chartConfig = { const chartConfig = {
loss1: { loss1: {
label: 'Loss 1', label: 'Loss 1',
color: '#FF6347', // 토마토색 color: '#FF6347',
}, },
loss2: { loss2: {
label: 'Loss 2', label: 'Loss 2',
color: '#1E90FF', // 다저블루색 color: '#1E90FF',
}, },
loss3: { loss3: {
label: 'Loss 3', label: 'Loss 3',
color: '#32CD32', // 라임색 color: '#32CD32',
}, },
fitness: { fitness: {
label: 'Fitness', label: 'Fitness',
color: '#FFD700', // 골드색 color: '#FFD700',
}, },
} satisfies ChartConfig; } satisfies ChartConfig;
export default function ModelLineChart({ data }: ModelLineChartProps) { export default function ModelLineChart({ data, currentEpoch, totalEpochs, remainingTime }: ModelLineChartProps) {
const emptyData = Array.from({ length: totalEpochs || 0 }, (_, i) => ({
epoch: (i + 1).toString(),
loss1: null,
loss2: null,
loss3: null,
fitness: null,
}));
const filledData = emptyData.map((item, index) => ({
...item,
...(data[index] || {}),
}));
return ( return (
<Card> <Card>
<CardHeader> <CardHeader>
<CardTitle>Model Training Metrics</CardTitle> <CardTitle>Model Training Metrics</CardTitle>
<CardDescription>Loss and Fitness over Epochs</CardDescription>
</CardHeader> </CardHeader>
<CardContent> <CardContent>
{currentEpoch !== undefined && totalEpochs !== undefined && remainingTime !== undefined && (
<div className="mb-4 flex justify-between">
<p> : {currentEpoch}</p>
<p> : {totalEpochs}</p>
<p> : {remainingTime}</p>
</div>
)}
<ChartContainer config={chartConfig}> <ChartContainer config={chartConfig}>
<LineChart <LineChart
accessibilityLayer accessibilityLayer
data={data} data={filledData}
margin={{ margin={{
left: 12, left: 12,
right: 12, right: 12,
@ -62,10 +82,9 @@ export default function ModelLineChart({ data }: ModelLineChartProps) {
tickMargin={8} tickMargin={8}
tickFormatter={(value) => `Epoch ${value}`} tickFormatter={(value) => `Epoch ${value}`}
/> />
<ChartTooltip <YAxis />
cursor={false} <Tooltip />
content={<ChartTooltipContent />} <Legend />
/>
<Line <Line
dataKey="loss1" dataKey="loss1"
type="monotone" type="monotone"
@ -97,18 +116,6 @@ export default function ModelLineChart({ data }: ModelLineChartProps) {
</LineChart> </LineChart>
</ChartContainer> </ChartContainer>
</CardContent> </CardContent>
<CardFooter>
<div className="flex w-full items-start gap-2 text-sm">
<div className="grid gap-2">
<div className="flex items-center gap-2 font-medium leading-none">
Trending up by 5.2% this epoch <TrendingUp className="h-4 w-4" />
</div>
<div className="text-muted-foreground flex items-center gap-2 leading-none">
Showing training loss and fitness for the current model
</div>
</div>
</div>
</CardFooter>
</Card> </Card>
); );
} }

View File

@ -0,0 +1,42 @@
import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from '@/components/ui/select';
import { Label } from '@/components/ui/label';
interface SelectWithLabelOption {
label: string;
value: string;
}
interface SelectWithLabelProps {
label: string;
id: string;
options: SelectWithLabelOption[];
placeholder: string;
value: string;
onChange: (value: string) => void;
}
export default function SelectWithLabel({ label, id, options, placeholder, value, onChange }: SelectWithLabelProps) {
return (
<div className="grid gap-3">
<Label htmlFor={id}>{label}</Label>
<Select
value={value}
onValueChange={onChange}
>
<SelectTrigger id={id}>
<SelectValue placeholder={placeholder} />
</SelectTrigger>
<SelectContent>
{options.map((option) => (
<SelectItem
key={option.value}
value={option.value}
>
{option.label}
</SelectItem>
))}
</SelectContent>
</Select>
</div>
);
}

View File

@ -1,189 +0,0 @@
import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from '@/components/ui/select';
import { Input } from '@/components/ui/input';
import { Label } from '@/components/ui/label';
import useProjectModelsQuery from '@/queries/models/useProjectModelsQuery';
import { useState } from 'react';
interface SettingsFormProps {
projectId: string | null; // projectId를 프랍으로 받음
onSubmit?: (data: SettingsFormData) => void;
}
export interface SettingsFormData {
projectId: number | null;
selectedModel: string | null;
ratio: number;
epochs: number;
batchSize: number;
optimizer: string;
lr0: number;
lrf: number;
}
export default function SettingsForm({ projectId, onSubmit }: SettingsFormProps) {
const numericProjectId = projectId ? parseInt(projectId, 10) : null;
const { data: models } = useProjectModelsQuery(numericProjectId ?? 0);
const [selectedModel, setSelectedModel] = useState<string | null>(null);
const [ratio, setRatio] = useState<number>(0.8);
const [epochs, setEpochs] = useState<number>(50);
const [batchSize, setBatchSize] = useState<number>(32);
const [optimizer, setOptimizer] = useState<string>('SGD');
const [lr0, setLr0] = useState<number>(0.01);
const [lrf, setLrf] = useState<number>(0.001);
const handleSubmit = () => {
if (onSubmit) {
onSubmit({
projectId: numericProjectId,
selectedModel,
ratio,
epochs,
batchSize,
optimizer,
lr0,
lrf,
});
}
};
return (
<form
className="grid w-full gap-6"
onSubmit={handleSubmit}
>
<fieldset className="grid gap-6 rounded-lg border p-4">
<legend className="-ml-1 px-1 text-sm font-medium"> </legend>
{/* 모델 선택 */}
<div className="grid gap-3">
<Label htmlFor="model"> </Label>
<Select onValueChange={setSelectedModel}>
<SelectTrigger id="model">
<SelectValue placeholder="모델을 선택하세요" />
</SelectTrigger>
<SelectContent>
{models?.map((model) => (
<SelectItem
key={model.id}
value={model.name}
>
{model.name}
</SelectItem>
))}
</SelectContent>
</Select>
</div>
{/* 훈련/검증 비율 및 학습 파라미터 */}
<div className="grid grid-cols-2 gap-4">
<InputWithLabel
label="훈련/검증 비율"
placeholder="예: 0.8 (80% 훈련, 20% 검증)"
id="ratio"
value={ratio}
onChange={(e) => setRatio(parseFloat(e.target.value))}
/>
<InputWithLabel
label="에포크 수"
placeholder="예: 50 (총 반복 횟수)"
id="epochs"
value={epochs}
onChange={(e) => setEpochs(parseInt(e.target.value, 10))}
/>
<InputWithLabel
label="Batch 크기"
placeholder="예: 32 (한번에 처리할 샘플 수)"
id="batch"
value={batchSize}
onChange={(e) => setBatchSize(parseInt(e.target.value, 10))}
/>
<SelectWithLabel
label="옵티마이저"
id="optimizer"
options={['SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp']}
value={optimizer}
onChange={setOptimizer}
placeholder="옵티마이저 선택"
/>
<InputWithLabel
label="학습률(LR0)"
placeholder="예: 0.01 (초기 학습률)"
id="lr0"
value={lr0}
onChange={(e) => setLr0(parseFloat(e.target.value))}
/>
<InputWithLabel
label="최종 학습률(LRF)"
placeholder="예: 0.001 (최종 학습률)"
id="lrf"
value={lrf}
onChange={(e) => setLrf(parseFloat(e.target.value))}
/>
</div>
<button
type="submit"
className="btn-primary"
>
</button>
</fieldset>
</form>
);
}
interface InputWithLabelProps {
label: string;
id: string;
placeholder: string;
value: number;
onChange: (e: React.ChangeEvent<HTMLInputElement>) => void;
}
function InputWithLabel({ label, id, placeholder, value, onChange }: InputWithLabelProps) {
return (
<div className="grid gap-3">
<Label htmlFor={id}>{label}</Label>
<Input
id={id}
type="number"
placeholder={placeholder}
value={value}
onChange={onChange}
/>
</div>
);
}
interface SelectWithLabelProps {
label: string;
id: string;
options: string[];
placeholder: string;
value: string;
onChange: (value: string) => void;
}
function SelectWithLabel({ label, id, options, placeholder, onChange }: SelectWithLabelProps) {
return (
<div className="grid gap-3">
<Label htmlFor={id}>{label}</Label>
<Select onValueChange={onChange}>
<SelectTrigger id={id}>
<SelectValue placeholder={placeholder} />
</SelectTrigger>
<SelectContent>
{options.map((option) => (
<SelectItem
key={option}
value={option}
>
{option}
</SelectItem>
))}
</SelectContent>
</Select>
</div>
);
}

View File

@ -0,0 +1,59 @@
import { useEffect, useMemo } from 'react';
import ModelLineChart from './ModelLineChart';
import usePollingModelReportsQuery from '@/queries/models/usePollingModelReportsQuery';
import useModelStore from '@/stores/useModelStore';
interface TrainingGraphProps {
projectId: number | null;
selectedModel: number | null;
}
export default function TrainingGraph({ projectId, selectedModel }: TrainingGraphProps) {
const { isTrainingByProject, setIsTraining, resetTrainingData } = useModelStore((state) => ({
isTrainingByProject: state.isTrainingByProject,
setIsTraining: state.setIsTraining,
resetTrainingData: state.resetTrainingData,
}));
const isTraining = isTrainingByProject[projectId?.toString() || ''] || false;
const { data: trainingDataList } = usePollingModelReportsQuery(
projectId as number,
selectedModel ?? 0,
isTraining && !!projectId && !!selectedModel
);
const latestData = useMemo(() => {
return (
trainingDataList?.[trainingDataList.length - 1] || {
epoch: 0,
totalEpochs: 0,
leftSecond: 0,
}
);
}, [trainingDataList]);
useEffect(() => {
if (latestData.epoch === latestData.totalEpochs && latestData.totalEpochs > 0) {
setIsTraining(projectId?.toString() || '', false);
resetTrainingData(projectId?.toString() || '');
}
}, [latestData.epoch, latestData.totalEpochs, setIsTraining, resetTrainingData, projectId]);
return (
<ModelLineChart
data={
trainingDataList?.map((data) => ({
epoch: data.epoch.toString(),
loss1: data.boxLoss,
loss2: data.clsLoss,
loss3: data.dflLoss,
fitness: data.fitness,
})) || []
}
currentEpoch={latestData.epoch}
totalEpochs={latestData.totalEpochs}
remainingTime={latestData.leftSecond}
/>
);
}

View File

@ -0,0 +1,138 @@
import SelectWithLabel from './SelectWithLabel';
import InputWithLabel from './InputWithLabel';
import { Button } from '@/components/ui/button';
import useProjectModelsQuery from '@/queries/models/useProjectModelsQuery';
import { ModelTrainRequest } from '@/types';
import { useState } from 'react';
interface TrainingSettingsProps {
projectId: number | null;
selectedModel: number | null;
setSelectedModel: (model: number | null) => void;
handleTrainingStart: (trainData: ModelTrainRequest) => void;
handleTrainingStop: () => void;
isTraining: boolean;
}
export default function TrainingSettings({
projectId,
selectedModel,
setSelectedModel,
handleTrainingStart,
handleTrainingStop,
isTraining,
}: TrainingSettingsProps) {
const { data: models } = useProjectModelsQuery(projectId ?? 0);
const [ratio, setRatio] = useState<number>(0.8);
const [epochs, setEpochs] = useState<number>(50);
const [batchSize, setBatchSize] = useState<number>(32);
const [optimizer, setOptimizer] = useState<'SGD' | 'AUTO' | 'ADAM' | 'ADAMW' | 'NADAM' | 'RADAM' | 'RMSPROP'>('AUTO');
const [lr0, setLr0] = useState<number>(0.01);
const [lrf, setLrf] = useState<number>(0.001);
const handleSubmit = () => {
if (isTraining) {
handleTrainingStop();
} else if (selectedModel !== null) {
const trainData: ModelTrainRequest = {
modelId: selectedModel,
ratio,
epochs,
batch: batchSize,
optimizer,
lr0,
lrf,
};
handleTrainingStart(trainData);
}
};
return (
<fieldset
className="grid gap-6 rounded-lg border p-4"
disabled={isTraining}
>
<legend className="-ml-1 px-1 text-sm font-medium"> </legend>
<div className="grid gap-3">
<SelectWithLabel
label="모델 선택"
id="model"
options={
models?.map((model) => ({
label: model.name,
value: model.id.toString(),
})) || []
}
placeholder="모델을 선택하세요"
value={selectedModel ? selectedModel.toString() : ''}
onChange={(value) => setSelectedModel(parseInt(value, 10))}
/>
</div>
<div className="grid grid-cols-2 gap-4">
<InputWithLabel
label="훈련/검증 비율"
placeholder="예: 0.8 (80% 훈련, 20% 검증)"
id="ratio"
value={ratio}
onChange={(e) => setRatio(parseFloat(e.target.value))}
/>
<InputWithLabel
label="에포크 수"
placeholder="예: 50 (총 반복 횟수)"
id="epochs"
value={epochs}
onChange={(e) => setEpochs(parseInt(e.target.value, 10))}
/>
<InputWithLabel
label="Batch 크기"
placeholder="예: 32 (한번에 처리할 샘플 수)"
id="batch"
value={batchSize}
onChange={(e) => setBatchSize(parseInt(e.target.value, 10))}
/>
<SelectWithLabel
label="옵티마이저"
id="optimizer"
options={[
{ label: 'AUTO', value: 'AUTO' },
{ label: 'SGD', value: 'SGD' },
{ label: 'ADAM', value: 'ADAM' },
{ label: 'ADAMW', value: 'ADAMW' },
{ label: 'NADAM', value: 'NADAM' },
{ label: 'RADAM', value: 'RADAM' },
{ label: 'RMSPROP', value: 'RMSPROP' },
]}
placeholder="옵티마이저 선택"
value={optimizer}
onChange={(value) => setOptimizer(value as 'AUTO' | 'SGD' | 'ADAM' | 'ADAMW' | 'NADAM' | 'RADAM' | 'RMSPROP')}
/>
<InputWithLabel
label="학습률(LR0)"
placeholder="예: 0.01 (초기 학습률)"
id="lr0"
value={lr0}
onChange={(e) => setLr0(parseFloat(e.target.value))}
/>
<InputWithLabel
label="최종 학습률(LRF)"
placeholder="예: 0.001 (최종 학습률)"
id="lrf"
value={lrf}
onChange={(e) => setLrf(parseFloat(e.target.value))}
/>
</div>
<Button
variant="outlinePrimary"
size="lg"
onClick={handleSubmit}
disabled={!selectedModel || isTraining}
>
{isTraining ? '학습 중단' : '학습 시작'}
</Button>
</fieldset>
);
}

View File

@ -1,45 +1,58 @@
import { Button } from '@/components/ui/button'; import useTrainModelQuery from '@/queries/models/useTrainModelQuery';
import ModelLineChart from '@/components/ModelLineChart'; import useModelStore from '@/stores/useModelStore';
import SettingsForm from './SettingsForm'; import TrainingSettings from './TrainingSettings';
import TrainingGraph from './TrainingGraph';
import { ModelTrainRequest } from '@/types';
interface TrainingTabProps { interface TrainingTabProps {
training: boolean; projectId: number | null;
handleTrainingToggle: () => void;
trainingDataList: {
epoch: number;
box_loss: number;
cls_loss: number;
dfl_loss: number;
fitness: number;
}[];
projectId: string | null; // projectId를 프랍으로 받음
} }
export default function TrainingTab({ training, handleTrainingToggle, trainingDataList, projectId }: TrainingTabProps) { export default function TrainingTab({ projectId }: TrainingTabProps) {
const numericProjectId = projectId ? parseInt(projectId.toString(), 10) : null;
const { isTrainingByProject, setIsTraining, selectedModelByProject, setSelectedModel, resetTrainingData } =
useModelStore((state) => ({
isTrainingByProject: state.isTrainingByProject,
setIsTraining: state.setIsTraining,
selectedModelByProject: state.selectedModelByProject,
setSelectedModel: state.setSelectedModel,
resetTrainingData: state.resetTrainingData,
}));
const isTraining = isTrainingByProject[numericProjectId?.toString() || ''] || false;
const selectedModel = selectedModelByProject[numericProjectId?.toString() || ''];
const { mutate: startTraining } = useTrainModelQuery(numericProjectId as number);
const handleTrainingStart = (trainData: ModelTrainRequest) => {
if (!isTraining && selectedModel !== null) {
setIsTraining(numericProjectId?.toString() || '', true);
startTraining(trainData);
}
};
const handleTrainingStop = () => {
if (isTraining) {
setIsTraining(numericProjectId?.toString() || '', false);
resetTrainingData(numericProjectId?.toString() || '');
}
};
return ( return (
<div className="grid gap-8 md:grid-cols-2"> <div className="grid gap-8 md:grid-cols-2">
<div className="flex flex-col gap-6"> <TrainingSettings
<SettingsForm projectId={projectId} /> projectId={numericProjectId}
<Button selectedModel={selectedModel}
variant={training ? 'destructive' : 'outlinePrimary'} setSelectedModel={(modelId) => setSelectedModel(numericProjectId?.toString() || '', modelId)}
size="lg" handleTrainingStart={handleTrainingStart}
onClick={handleTrainingToggle} handleTrainingStop={handleTrainingStop}
> isTraining={isTraining}
{training ? '학습 중단' : '학습 시작'} />
</Button>
</div> <TrainingGraph
projectId={numericProjectId}
<div className="flex flex-col justify-center"> selectedModel={selectedModel}
<ModelLineChart
data={trainingDataList.map((data) => ({
epoch: data.epoch.toString(),
loss1: data.box_loss,
loss2: data.cls_loss,
loss3: data.dfl_loss,
fitness: data.fitness,
}))}
/> />
</div>
</div> </div>
); );
} }

View File

@ -1,28 +1,11 @@
import { Tabs, TabsContent, TabsList, TabsTrigger } from '@/components/ui/tabs'; import { Tabs, TabsContent, TabsList, TabsTrigger } from '@/components/ui/tabs';
import { useState } from 'react';
import { useParams } from 'react-router-dom'; import { useParams } from 'react-router-dom';
import useTrainWebSocket from '@/hooks/useTrainPolling';
import useTrainStore from '@/stores/useTrainStore';
import TrainingTab from './TrainingTab'; import TrainingTab from './TrainingTab';
import EvaluationTab from './EvaluationTab'; import EvaluationTab from './EvaluationTab';
export default function ModelManage() { export default function ModelManage() {
const { projectId } = useParams<{ projectId?: string }>(); const { projectId } = useParams<{ projectId?: string }>();
const [training, setTraining] = useState(false); const numericProjectId = projectId ? parseInt(projectId, 10) : null;
const [selectedModel, setSelectedModel] = useState<string | null>(null);
const numericProjectId = projectId ?? null;
useTrainWebSocket(training, numericProjectId);
const { trainingDataList } = useTrainStore((state) => ({
trainingDataList: numericProjectId ? state.trainingDataByProject[numericProjectId] || [] : [],
}));
const handleTrainingToggle = () => {
setTraining((prev) => !prev);
};
return ( return (
<div className="grid h-screen w-full"> <div className="grid h-screen w-full">
@ -41,22 +24,12 @@ export default function ModelManage() {
<TabsTrigger value="results"> </TabsTrigger> <TabsTrigger value="results"> </TabsTrigger>
</TabsList> </TabsList>
{/* 학습 탭 */}
<TabsContent value="train"> <TabsContent value="train">
<TrainingTab <TrainingTab projectId={numericProjectId} />
training={training}
handleTrainingToggle={handleTrainingToggle}
trainingDataList={trainingDataList}
projectId={numericProjectId}
/>
</TabsContent> </TabsContent>
{/* 평가 탭 */}
<TabsContent value="results"> <TabsContent value="results">
<EvaluationTab <EvaluationTab projectId={numericProjectId} />
selectedModel={selectedModel}
setSelectedModel={setSelectedModel}
/>
</TabsContent> </TabsContent>
</Tabs> </Tabs>
</main> </main>

View File

@ -1,49 +0,0 @@
// 임시 가짜 훅
import { useEffect, useRef, useCallback } from 'react';
import axios from 'axios';
import useTrainStore from '@/stores/useTrainStore';
export default function useTrainPolling(start: boolean, projectId?: string | null) {
const { addTrainingData, resetTrainingData } = useTrainStore((state) => ({
addTrainingData: state.addTrainingData,
resetTrainingData: state.resetTrainingData,
}));
const intervalIdRef = useRef<number | null>(null);
// 함수 api 후 교체 예정
const fetchTrainingData = useCallback(async () => {
if (projectId) {
try {
const response = await axios.get(`/api/바보=${projectId}`);
const data = response.data;
addTrainingData(projectId, {
epoch: data.epoch,
total_epochs: data.total_epochs,
box_loss: data.box_loss,
cls_loss: data.cls_loss,
dfl_loss: data.dfl_loss,
fitness: data.fitness,
epoch_time: data.epoch_time,
left_second: data.left_second,
});
} catch (error) {
console.error('Fetching error:', error);
}
}
}, [projectId, addTrainingData]);
useEffect(() => {
if (start && projectId) {
resetTrainingData(projectId);
intervalIdRef.current = window.setInterval(fetchTrainingData, 5000);
}
return () => {
if (intervalIdRef.current) {
clearInterval(intervalIdRef.current);
intervalIdRef.current = null;
}
};
}, [start, projectId, fetchTrainingData, resetTrainingData]);
}

View File

@ -0,0 +1,10 @@
import { useSuspenseQuery } from '@tanstack/react-query';
import { getModelReports } from '@/api/modelApi';
import { ReportResponse } from '@/types';
export default function useModelReportsQuery(projectId: number, modelId: number) {
return useSuspenseQuery<ReportResponse[]>({
queryKey: ['modelReports', projectId, modelId],
queryFn: () => getModelReports(projectId, modelId),
});
}

View File

@ -0,0 +1,10 @@
import { useSuspenseQuery } from '@tanstack/react-query';
import { getModelResults } from '@/api/modelApi';
import { ResultResponse } from '@/types';
export default function useModelResultsQuery(modelId: number) {
return useSuspenseQuery<ResultResponse[]>({
queryKey: ['modelResults', modelId],
queryFn: () => getModelResults(modelId),
});
}

View File

@ -0,0 +1,12 @@
import { useQuery } from '@tanstack/react-query';
import { getModelReports } from '@/api/modelApi';
import { ReportResponse } from '@/types';
export default function usePollingModelReportsQuery(projectId: number, modelId: number, enabled: boolean) {
return useQuery<ReportResponse[]>({
queryKey: ['pollingModelReports', projectId, modelId],
queryFn: () => getModelReports(projectId, modelId),
refetchInterval: 5000,
enabled,
});
}

View File

@ -1,8 +1,9 @@
import { useMutation } from '@tanstack/react-query'; import { useMutation } from '@tanstack/react-query';
import { trainModel } from '@/api/modelApi'; import { trainModel } from '@/api/modelApi';
import { ModelTrainRequest } from '@/types';
export default function useTrainModelQuery(projectId: number) { export default function useTrainModelQuery(projectId: number) {
return useMutation({ return useMutation({
mutationFn: () => trainModel(projectId), mutationFn: (trainData: ModelTrainRequest) => trainModel(projectId, trainData),
}); });
} }

View File

@ -0,0 +1,56 @@
import { create } from 'zustand';
import { ReportResponse } from '@/types';
interface ModelStoreState {
trainingDataByProject: Record<string, ReportResponse[]>;
isTrainingByProject: Record<string, boolean>;
selectedModelByProject: Record<string, number | null>;
setIsTraining: (projectId: string, status: boolean) => void;
saveTrainingData: (projectId: string, data: ReportResponse[]) => void;
setSelectedModel: (projectId: string, modelId: number | null) => void;
resetTrainingData: (projectId: string) => void;
}
const useModelStore = create<ModelStoreState>((set) => ({
trainingDataByProject: {},
isTrainingByProject: {},
selectedModelByProject: {},
setIsTraining: (projectId, status) =>
set((state) => ({
isTrainingByProject: {
...state.isTrainingByProject,
[projectId]: status,
},
})),
saveTrainingData: (projectId, data) =>
set((state) => ({
trainingDataByProject: {
...state.trainingDataByProject,
[projectId]: data,
},
})),
setSelectedModel: (projectId, modelId) =>
set((state) => ({
selectedModelByProject: {
...state.selectedModelByProject,
[projectId]: modelId,
},
})),
resetTrainingData: (projectId) =>
set((state) => ({
trainingDataByProject: {
...state.trainingDataByProject,
[projectId]: [],
},
selectedModelByProject: {
...state.selectedModelByProject,
[projectId]: null,
},
isTrainingByProject: {
...state.isTrainingByProject,
[projectId]: false,
},
})),
}));
export default useModelStore;

View File

@ -1,40 +0,0 @@
import { create } from 'zustand';
interface TrainingData {
epoch: number;
total_epochs: number;
box_loss: number;
cls_loss: number;
dfl_loss: number;
fitness: number;
epoch_time: number;
left_second: number;
}
interface StoreState {
trainingDataByProject: { [projectId: string]: TrainingData[] };
addTrainingData: (projectId: string, data: TrainingData) => void;
resetTrainingData: (projectId: string) => void;
}
const useTrainStore = create<StoreState>((set) => ({
trainingDataByProject: {},
addTrainingData: (projectId: string, data: TrainingData) =>
set((state) => ({
trainingDataByProject: {
...state.trainingDataByProject,
[projectId]: [...(state.trainingDataByProject[projectId] || []), data],
},
})),
resetTrainingData: (projectId: string) =>
set((state) => ({
trainingDataByProject: {
...state.trainingDataByProject,
[projectId]: [],
},
})),
}));
export default useTrainStore;

View File

@ -277,6 +277,25 @@ export interface ImageFolderRequest {
parentId: number; parentId: number;
files: File[]; files: File[];
} }
export interface LabelCategoryResponse {
id: number;
name: string;
}
// 카테고리 요청 DTO
export interface LabelCategoryRequest {
labelCategoryList: number[];
}
// 카테고리 응답 DTO
export interface LabelCategoryResponse {
id: number;
name: string;
}
// 모델 카테고리 응답 DTO
export interface ModelCategoryResponse {
id: number;
name: string;
}
// 모델 요청 DTO (API로 전달할 데이터 타입) // 모델 요청 DTO (API로 전달할 데이터 타입)
export interface ModelRequest { export interface ModelRequest {
@ -289,22 +308,41 @@ export interface ModelResponse {
name: string; name: string;
} }
// 모델 카테고리 응답 DTO
export interface ModelCategoryResponse {
id: number;
name: string;
}
// 프로젝트 모델 리스트 응답 DTO // 프로젝트 모델 리스트 응답 DTO
export interface ProjectModelsResponse extends Array<ModelResponse> {} export interface ProjectModelsResponse extends Array<ModelResponse> {}
// 모델 훈련 요청 DTO
// 카테고리 요청 DTO export interface ModelTrainRequest {
export interface LabelCategoryRequest { modelId: number;
labelCategoryList: number[]; ratio: number;
epochs: number;
batch: number;
lr0: number;
lrf: number;
optimizer: 'AUTO' | 'SGD' | 'ADAM' | 'ADAMW' | 'NADAM' | 'RADAM' | 'RMSPROP';
} }
export interface ResultResponse {
// 카테고리 응답 DTO
export interface LabelCategoryResponse {
id: number; id: number;
name: string; precision: number;
recall: number;
fitness: number;
ratio: number;
epochs: number;
batch: number;
lr0: number;
lrf: number;
optimizer: 'AUTO' | 'SGD' | 'ADAM' | 'ADAMW' | 'NADAM' | 'RADAM' | 'RMSPROP';
map50: number;
map5095: number;
}
export interface ReportResponse {
modelId: number;
totalEpochs: number;
epoch: number;
boxLoss: number;
clsLoss: number;
dflLoss: number;
fitness: number;
epochTime: number;
leftSecond: number;
} }