Merge branch 'fe/refactor/admin-model' into 'fe/develop'
Refactor: 학습, 평가 부분 api 연결 See merge request s11-s-project/S11P21S002!163
This commit is contained in:
commit
d743cf3ce9
@ -1,12 +1,20 @@
|
|||||||
import api from '@/api/axiosConfig';
|
import api from '@/api/axiosConfig';
|
||||||
import { ModelRequest, ModelResponse, ProjectModelsResponse, ModelCategoryResponse } from '@/types';
|
import {
|
||||||
|
ModelRequest,
|
||||||
|
ModelResponse,
|
||||||
|
ProjectModelsResponse,
|
||||||
|
ModelCategoryResponse,
|
||||||
|
ModelTrainRequest,
|
||||||
|
ResultResponse,
|
||||||
|
ReportResponse,
|
||||||
|
} from '@/types';
|
||||||
|
|
||||||
export async function updateModelName(projectId: number, modelId: number, modelData: ModelRequest) {
|
export async function updateModelName(projectId: number, modelId: number, modelData: ModelRequest) {
|
||||||
return api.put<ModelResponse>(`/projects/${projectId}/models/${modelId}`, modelData).then(({ data }) => data);
|
return api.put<ModelResponse>(`/projects/${projectId}/models/${modelId}`, modelData).then(({ data }) => data);
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function trainModel(projectId: number) {
|
export async function trainModel(projectId: number, trainData: ModelTrainRequest) {
|
||||||
return api.post(`/projects/${projectId}/train`).then(({ data }) => data);
|
return api.post(`/projects/${projectId}/train`, trainData).then(({ data }) => data);
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function getProjectModels(projectId: number) {
|
export async function getProjectModels(projectId: number) {
|
||||||
@ -20,3 +28,11 @@ export async function addProjectModel(projectId: number, modelData: ModelRequest
|
|||||||
export async function getModelCategories(modelId: number) {
|
export async function getModelCategories(modelId: number) {
|
||||||
return api.get<ModelCategoryResponse[]>(`/models/${modelId}/categories`).then(({ data }) => data);
|
return api.get<ModelCategoryResponse[]>(`/models/${modelId}/categories`).then(({ data }) => data);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export async function getModelResults(modelId: number) {
|
||||||
|
return api.get<ResultResponse[]>(`/results/model/${modelId}`).then(({ data }) => data);
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function getModelReports(projectId: number, modelId: number) {
|
||||||
|
return api.get<ReportResponse[]>(`/projects/${projectId}/reports/model/${modelId}`).then(({ data }) => data);
|
||||||
|
}
|
||||||
|
@ -1,55 +1,116 @@
|
|||||||
import { Label } from '@/components/ui/label';
|
import { Label } from '@/components/ui/label';
|
||||||
import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from '@/components/ui/select';
|
import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from '@/components/ui/select';
|
||||||
import ModelBarChart from '@/components/ModelBarChart';
|
import { useState } from 'react';
|
||||||
|
import useProjectModelsQuery from '@/queries/models/useProjectModelsQuery';
|
||||||
|
import useModelReportsQuery from '@/queries/models/useModelReportsQuery';
|
||||||
|
import useModelResultsQuery from '@/queries/models/useModelResultsQuery';
|
||||||
|
import ModelBarChart from './ModelBarChart';
|
||||||
|
import ModelLineChart from './ModelLineChart';
|
||||||
|
|
||||||
interface EvaluationTabProps {
|
interface EvaluationTabProps {
|
||||||
selectedModel: string | null;
|
projectId: number | null;
|
||||||
setSelectedModel: (model: string | null) => void;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export default function EvaluationTab({ selectedModel, setSelectedModel }: EvaluationTabProps) {
|
export default function EvaluationTab({ projectId }: EvaluationTabProps) {
|
||||||
|
const [selectedModel, setSelectedModel] = useState<number | null>(null);
|
||||||
|
|
||||||
|
const { data: models } = useProjectModelsQuery(projectId ?? 0);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div>
|
<div>
|
||||||
<div className="mb-4">
|
<ModelSelection
|
||||||
<Label htmlFor="select-model">모델 선택</Label>
|
models={models}
|
||||||
<Select onValueChange={setSelectedModel}>
|
setSelectedModel={setSelectedModel}
|
||||||
<SelectTrigger id="select-model">
|
/>
|
||||||
<SelectValue placeholder="모델을 선택하세요" />
|
|
||||||
</SelectTrigger>
|
|
||||||
<SelectContent>
|
|
||||||
<SelectItem value="genesis">Genesis</SelectItem>
|
|
||||||
<SelectItem value="explorer">Explorer</SelectItem>
|
|
||||||
<SelectItem value="quantum">Quantum</SelectItem>
|
|
||||||
</SelectContent>
|
|
||||||
</Select>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
{selectedModel && (
|
{selectedModel && (
|
||||||
<div className="grid gap-8 md:grid-cols-2">
|
<ModelEvaluation
|
||||||
<div className="flex flex-col gap-6">
|
projectId={projectId as number}
|
||||||
<ModelBarChart
|
selectedModel={selectedModel}
|
||||||
data={[
|
|
||||||
{ name: 'precision', value: 0.734, fill: 'var(--color-precision)' },
|
|
||||||
{ name: 'recall', value: 0.75, fill: 'var(--color-recall)' },
|
|
||||||
{ name: 'mAP50', value: 0.995, fill: 'var(--color-map50)' },
|
|
||||||
{ name: 'mAP50_95', value: 0.97, fill: 'var(--color-map50-95)' },
|
|
||||||
{ name: 'fitness', value: 0.973, fill: 'var(--color-fitness)' },
|
|
||||||
]}
|
|
||||||
/>
|
/>
|
||||||
</div>
|
|
||||||
<div className="flex flex-col justify-center">
|
|
||||||
<LabelingPreview />
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
function LabelingPreview() {
|
interface ModelSelectionProps {
|
||||||
|
models: Array<{ id: number; name: string }> | undefined;
|
||||||
|
setSelectedModel: (modelId: number) => void;
|
||||||
|
}
|
||||||
|
|
||||||
|
function ModelSelection({ models, setSelectedModel }: ModelSelectionProps) {
|
||||||
return (
|
return (
|
||||||
<div className="flex items-center justify-center rounded-lg border bg-white p-4">
|
<div className="mb-4">
|
||||||
<p>레이블링 프리뷰</p>
|
<Label htmlFor="select-model">모델 선택</Label>
|
||||||
|
<Select onValueChange={(value) => setSelectedModel(parseInt(value))}>
|
||||||
|
<SelectTrigger id="select-model">
|
||||||
|
<SelectValue placeholder="모델을 선택하세요" />
|
||||||
|
</SelectTrigger>
|
||||||
|
<SelectContent>
|
||||||
|
{models?.map((model) => (
|
||||||
|
<SelectItem
|
||||||
|
key={model.id}
|
||||||
|
value={model.id.toString()}
|
||||||
|
>
|
||||||
|
{model.name}
|
||||||
|
</SelectItem>
|
||||||
|
))}
|
||||||
|
</SelectContent>
|
||||||
|
</Select>
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
interface ModelEvaluationProps {
|
||||||
|
projectId: number;
|
||||||
|
selectedModel: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
function ModelEvaluation({ projectId, selectedModel }: ModelEvaluationProps) {
|
||||||
|
const { data: reportData } = useModelReportsQuery(projectId, selectedModel);
|
||||||
|
const { data: resultData } = useModelResultsQuery(selectedModel);
|
||||||
|
|
||||||
|
if (!reportData || !resultData) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="grid gap-8 md:grid-cols-2">
|
||||||
|
<div className="flex flex-col gap-6">
|
||||||
|
<ModelBarChart
|
||||||
|
data={[
|
||||||
|
{ name: 'precision', value: resultData[0]?.precision, fill: 'var(--color-precision)' },
|
||||||
|
{ name: 'recall', value: resultData[0]?.recall, fill: 'var(--color-recall)' },
|
||||||
|
{ name: 'mAP50', value: resultData[0]?.map50, fill: 'var(--color-map50)' },
|
||||||
|
{ name: 'mAP50_95', value: resultData[0]?.map5095, fill: 'var(--color-map50-95)' },
|
||||||
|
{ name: 'fitness', value: resultData[0]?.fitness, fill: 'var(--color-fitness)' },
|
||||||
|
]}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="flex flex-col gap-6">
|
||||||
|
<ModelLineChart
|
||||||
|
data={reportData.map((report) => ({
|
||||||
|
epoch: report.epoch.toString(),
|
||||||
|
loss1: report.boxLoss,
|
||||||
|
loss2: report.clsLoss,
|
||||||
|
loss3: report.dflLoss,
|
||||||
|
fitness: report.fitness,
|
||||||
|
}))}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* <div className="flex flex-col justify-center">
|
||||||
|
<LabelingPreview />
|
||||||
|
</div> */}
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// function LabelingPreview() {
|
||||||
|
// return (
|
||||||
|
// <div className="flex items-center justify-center rounded-lg border bg-white p-4">
|
||||||
|
// <p>레이블링 프리뷰</p>
|
||||||
|
// </div>
|
||||||
|
// );
|
||||||
|
// }
|
||||||
|
24
frontend/src/components/ModelManage/InputWithLabel.tsx
Normal file
24
frontend/src/components/ModelManage/InputWithLabel.tsx
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
import { Label } from '@/components/ui/label';
|
||||||
|
import { Input } from '../ui/input';
|
||||||
|
interface InputWithLabelProps {
|
||||||
|
label: string;
|
||||||
|
id: string;
|
||||||
|
placeholder: string;
|
||||||
|
value: number;
|
||||||
|
onChange: (e: React.ChangeEvent<HTMLInputElement>) => void;
|
||||||
|
}
|
||||||
|
|
||||||
|
export default function InputWithLabel({ label, id, placeholder, value, onChange }: InputWithLabelProps) {
|
||||||
|
return (
|
||||||
|
<div className="grid gap-3">
|
||||||
|
<Label htmlFor={id}>{label}</Label>
|
||||||
|
<Input
|
||||||
|
id={id}
|
||||||
|
type="number"
|
||||||
|
placeholder={placeholder}
|
||||||
|
value={value}
|
||||||
|
onChange={onChange}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
@ -1,54 +1,74 @@
|
|||||||
'use client';
|
'use client';
|
||||||
|
|
||||||
import { TrendingUp } from 'lucide-react';
|
import { CartesianGrid, Line, LineChart, XAxis, YAxis, Tooltip, Legend } from 'recharts';
|
||||||
import { CartesianGrid, Line, LineChart, XAxis } from 'recharts';
|
import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card';
|
||||||
|
import { ChartConfig, ChartContainer } from '@/components/ui/chart';
|
||||||
import { Card, CardContent, CardDescription, CardFooter, CardHeader, CardTitle } from '@/components/ui/card';
|
|
||||||
import { ChartConfig, ChartContainer, ChartTooltip, ChartTooltipContent } from '@/components/ui/chart';
|
|
||||||
|
|
||||||
interface MetricData {
|
interface MetricData {
|
||||||
epoch: string;
|
epoch: string;
|
||||||
loss1: number;
|
loss1?: number;
|
||||||
loss2: number;
|
loss2?: number;
|
||||||
loss3: number;
|
loss3?: number;
|
||||||
fitness: number;
|
fitness?: number;
|
||||||
}
|
}
|
||||||
|
|
||||||
interface ModelLineChartProps {
|
interface ModelLineChartProps {
|
||||||
data: MetricData[];
|
data: MetricData[];
|
||||||
|
currentEpoch?: number;
|
||||||
|
totalEpochs?: number;
|
||||||
|
remainingTime?: number;
|
||||||
}
|
}
|
||||||
|
|
||||||
const chartConfig = {
|
const chartConfig = {
|
||||||
loss1: {
|
loss1: {
|
||||||
label: 'Loss 1',
|
label: 'Loss 1',
|
||||||
color: '#FF6347', // 토마토색
|
color: '#FF6347',
|
||||||
},
|
},
|
||||||
loss2: {
|
loss2: {
|
||||||
label: 'Loss 2',
|
label: 'Loss 2',
|
||||||
color: '#1E90FF', // 다저블루색
|
color: '#1E90FF',
|
||||||
},
|
},
|
||||||
loss3: {
|
loss3: {
|
||||||
label: 'Loss 3',
|
label: 'Loss 3',
|
||||||
color: '#32CD32', // 라임색
|
color: '#32CD32',
|
||||||
},
|
},
|
||||||
fitness: {
|
fitness: {
|
||||||
label: 'Fitness',
|
label: 'Fitness',
|
||||||
color: '#FFD700', // 골드색
|
color: '#FFD700',
|
||||||
},
|
},
|
||||||
} satisfies ChartConfig;
|
} satisfies ChartConfig;
|
||||||
|
|
||||||
export default function ModelLineChart({ data }: ModelLineChartProps) {
|
export default function ModelLineChart({ data, currentEpoch, totalEpochs, remainingTime }: ModelLineChartProps) {
|
||||||
|
const emptyData = Array.from({ length: totalEpochs || 0 }, (_, i) => ({
|
||||||
|
epoch: (i + 1).toString(),
|
||||||
|
loss1: null,
|
||||||
|
loss2: null,
|
||||||
|
loss3: null,
|
||||||
|
fitness: null,
|
||||||
|
}));
|
||||||
|
|
||||||
|
const filledData = emptyData.map((item, index) => ({
|
||||||
|
...item,
|
||||||
|
...(data[index] || {}),
|
||||||
|
}));
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Card>
|
<Card>
|
||||||
<CardHeader>
|
<CardHeader>
|
||||||
<CardTitle>Model Training Metrics</CardTitle>
|
<CardTitle>Model Training Metrics</CardTitle>
|
||||||
<CardDescription>Loss and Fitness over Epochs</CardDescription>
|
|
||||||
</CardHeader>
|
</CardHeader>
|
||||||
<CardContent>
|
<CardContent>
|
||||||
|
{currentEpoch !== undefined && totalEpochs !== undefined && remainingTime !== undefined && (
|
||||||
|
<div className="mb-4 flex justify-between">
|
||||||
|
<p>현재 에포크: {currentEpoch}</p>
|
||||||
|
<p>총 에포크: {totalEpochs}</p>
|
||||||
|
<p>예상 남은시간: {remainingTime}</p>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
<ChartContainer config={chartConfig}>
|
<ChartContainer config={chartConfig}>
|
||||||
<LineChart
|
<LineChart
|
||||||
accessibilityLayer
|
accessibilityLayer
|
||||||
data={data}
|
data={filledData}
|
||||||
margin={{
|
margin={{
|
||||||
left: 12,
|
left: 12,
|
||||||
right: 12,
|
right: 12,
|
||||||
@ -62,10 +82,9 @@ export default function ModelLineChart({ data }: ModelLineChartProps) {
|
|||||||
tickMargin={8}
|
tickMargin={8}
|
||||||
tickFormatter={(value) => `Epoch ${value}`}
|
tickFormatter={(value) => `Epoch ${value}`}
|
||||||
/>
|
/>
|
||||||
<ChartTooltip
|
<YAxis />
|
||||||
cursor={false}
|
<Tooltip />
|
||||||
content={<ChartTooltipContent />}
|
<Legend />
|
||||||
/>
|
|
||||||
<Line
|
<Line
|
||||||
dataKey="loss1"
|
dataKey="loss1"
|
||||||
type="monotone"
|
type="monotone"
|
||||||
@ -97,18 +116,6 @@ export default function ModelLineChart({ data }: ModelLineChartProps) {
|
|||||||
</LineChart>
|
</LineChart>
|
||||||
</ChartContainer>
|
</ChartContainer>
|
||||||
</CardContent>
|
</CardContent>
|
||||||
<CardFooter>
|
|
||||||
<div className="flex w-full items-start gap-2 text-sm">
|
|
||||||
<div className="grid gap-2">
|
|
||||||
<div className="flex items-center gap-2 font-medium leading-none">
|
|
||||||
Trending up by 5.2% this epoch <TrendingUp className="h-4 w-4" />
|
|
||||||
</div>
|
|
||||||
<div className="text-muted-foreground flex items-center gap-2 leading-none">
|
|
||||||
Showing training loss and fitness for the current model
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</CardFooter>
|
|
||||||
</Card>
|
</Card>
|
||||||
);
|
);
|
||||||
}
|
}
|
42
frontend/src/components/ModelManage/SelectWithLabel.tsx
Normal file
42
frontend/src/components/ModelManage/SelectWithLabel.tsx
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from '@/components/ui/select';
|
||||||
|
import { Label } from '@/components/ui/label';
|
||||||
|
|
||||||
|
interface SelectWithLabelOption {
|
||||||
|
label: string;
|
||||||
|
value: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
interface SelectWithLabelProps {
|
||||||
|
label: string;
|
||||||
|
id: string;
|
||||||
|
options: SelectWithLabelOption[];
|
||||||
|
placeholder: string;
|
||||||
|
value: string;
|
||||||
|
onChange: (value: string) => void;
|
||||||
|
}
|
||||||
|
|
||||||
|
export default function SelectWithLabel({ label, id, options, placeholder, value, onChange }: SelectWithLabelProps) {
|
||||||
|
return (
|
||||||
|
<div className="grid gap-3">
|
||||||
|
<Label htmlFor={id}>{label}</Label>
|
||||||
|
<Select
|
||||||
|
value={value}
|
||||||
|
onValueChange={onChange}
|
||||||
|
>
|
||||||
|
<SelectTrigger id={id}>
|
||||||
|
<SelectValue placeholder={placeholder} />
|
||||||
|
</SelectTrigger>
|
||||||
|
<SelectContent>
|
||||||
|
{options.map((option) => (
|
||||||
|
<SelectItem
|
||||||
|
key={option.value}
|
||||||
|
value={option.value}
|
||||||
|
>
|
||||||
|
{option.label}
|
||||||
|
</SelectItem>
|
||||||
|
))}
|
||||||
|
</SelectContent>
|
||||||
|
</Select>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
@ -1,189 +0,0 @@
|
|||||||
import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from '@/components/ui/select';
|
|
||||||
import { Input } from '@/components/ui/input';
|
|
||||||
import { Label } from '@/components/ui/label';
|
|
||||||
import useProjectModelsQuery from '@/queries/models/useProjectModelsQuery';
|
|
||||||
import { useState } from 'react';
|
|
||||||
|
|
||||||
interface SettingsFormProps {
|
|
||||||
projectId: string | null; // projectId를 프랍으로 받음
|
|
||||||
onSubmit?: (data: SettingsFormData) => void;
|
|
||||||
}
|
|
||||||
|
|
||||||
export interface SettingsFormData {
|
|
||||||
projectId: number | null;
|
|
||||||
selectedModel: string | null;
|
|
||||||
ratio: number;
|
|
||||||
epochs: number;
|
|
||||||
batchSize: number;
|
|
||||||
optimizer: string;
|
|
||||||
lr0: number;
|
|
||||||
lrf: number;
|
|
||||||
}
|
|
||||||
|
|
||||||
export default function SettingsForm({ projectId, onSubmit }: SettingsFormProps) {
|
|
||||||
const numericProjectId = projectId ? parseInt(projectId, 10) : null;
|
|
||||||
|
|
||||||
const { data: models } = useProjectModelsQuery(numericProjectId ?? 0);
|
|
||||||
const [selectedModel, setSelectedModel] = useState<string | null>(null);
|
|
||||||
const [ratio, setRatio] = useState<number>(0.8);
|
|
||||||
const [epochs, setEpochs] = useState<number>(50);
|
|
||||||
const [batchSize, setBatchSize] = useState<number>(32);
|
|
||||||
const [optimizer, setOptimizer] = useState<string>('SGD');
|
|
||||||
const [lr0, setLr0] = useState<number>(0.01);
|
|
||||||
const [lrf, setLrf] = useState<number>(0.001);
|
|
||||||
|
|
||||||
const handleSubmit = () => {
|
|
||||||
if (onSubmit) {
|
|
||||||
onSubmit({
|
|
||||||
projectId: numericProjectId,
|
|
||||||
selectedModel,
|
|
||||||
ratio,
|
|
||||||
epochs,
|
|
||||||
batchSize,
|
|
||||||
optimizer,
|
|
||||||
lr0,
|
|
||||||
lrf,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
return (
|
|
||||||
<form
|
|
||||||
className="grid w-full gap-6"
|
|
||||||
onSubmit={handleSubmit}
|
|
||||||
>
|
|
||||||
<fieldset className="grid gap-6 rounded-lg border p-4">
|
|
||||||
<legend className="-ml-1 px-1 text-sm font-medium">모델 설정</legend>
|
|
||||||
|
|
||||||
{/* 모델 선택 */}
|
|
||||||
<div className="grid gap-3">
|
|
||||||
<Label htmlFor="model">모델 선택</Label>
|
|
||||||
<Select onValueChange={setSelectedModel}>
|
|
||||||
<SelectTrigger id="model">
|
|
||||||
<SelectValue placeholder="모델을 선택하세요" />
|
|
||||||
</SelectTrigger>
|
|
||||||
<SelectContent>
|
|
||||||
{models?.map((model) => (
|
|
||||||
<SelectItem
|
|
||||||
key={model.id}
|
|
||||||
value={model.name}
|
|
||||||
>
|
|
||||||
{model.name}
|
|
||||||
</SelectItem>
|
|
||||||
))}
|
|
||||||
</SelectContent>
|
|
||||||
</Select>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
{/* 훈련/검증 비율 및 학습 파라미터 */}
|
|
||||||
<div className="grid grid-cols-2 gap-4">
|
|
||||||
<InputWithLabel
|
|
||||||
label="훈련/검증 비율"
|
|
||||||
placeholder="예: 0.8 (80% 훈련, 20% 검증)"
|
|
||||||
id="ratio"
|
|
||||||
value={ratio}
|
|
||||||
onChange={(e) => setRatio(parseFloat(e.target.value))}
|
|
||||||
/>
|
|
||||||
<InputWithLabel
|
|
||||||
label="에포크 수"
|
|
||||||
placeholder="예: 50 (총 반복 횟수)"
|
|
||||||
id="epochs"
|
|
||||||
value={epochs}
|
|
||||||
onChange={(e) => setEpochs(parseInt(e.target.value, 10))}
|
|
||||||
/>
|
|
||||||
<InputWithLabel
|
|
||||||
label="Batch 크기"
|
|
||||||
placeholder="예: 32 (한번에 처리할 샘플 수)"
|
|
||||||
id="batch"
|
|
||||||
value={batchSize}
|
|
||||||
onChange={(e) => setBatchSize(parseInt(e.target.value, 10))}
|
|
||||||
/>
|
|
||||||
<SelectWithLabel
|
|
||||||
label="옵티마이저"
|
|
||||||
id="optimizer"
|
|
||||||
options={['SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp']}
|
|
||||||
value={optimizer}
|
|
||||||
onChange={setOptimizer}
|
|
||||||
placeholder="옵티마이저 선택"
|
|
||||||
/>
|
|
||||||
<InputWithLabel
|
|
||||||
label="학습률(LR0)"
|
|
||||||
placeholder="예: 0.01 (초기 학습률)"
|
|
||||||
id="lr0"
|
|
||||||
value={lr0}
|
|
||||||
onChange={(e) => setLr0(parseFloat(e.target.value))}
|
|
||||||
/>
|
|
||||||
<InputWithLabel
|
|
||||||
label="최종 학습률(LRF)"
|
|
||||||
placeholder="예: 0.001 (최종 학습률)"
|
|
||||||
id="lrf"
|
|
||||||
value={lrf}
|
|
||||||
onChange={(e) => setLrf(parseFloat(e.target.value))}
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<button
|
|
||||||
type="submit"
|
|
||||||
className="btn-primary"
|
|
||||||
>
|
|
||||||
설정 저장
|
|
||||||
</button>
|
|
||||||
</fieldset>
|
|
||||||
</form>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
interface InputWithLabelProps {
|
|
||||||
label: string;
|
|
||||||
id: string;
|
|
||||||
placeholder: string;
|
|
||||||
value: number;
|
|
||||||
onChange: (e: React.ChangeEvent<HTMLInputElement>) => void;
|
|
||||||
}
|
|
||||||
|
|
||||||
function InputWithLabel({ label, id, placeholder, value, onChange }: InputWithLabelProps) {
|
|
||||||
return (
|
|
||||||
<div className="grid gap-3">
|
|
||||||
<Label htmlFor={id}>{label}</Label>
|
|
||||||
<Input
|
|
||||||
id={id}
|
|
||||||
type="number"
|
|
||||||
placeholder={placeholder}
|
|
||||||
value={value}
|
|
||||||
onChange={onChange}
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
interface SelectWithLabelProps {
|
|
||||||
label: string;
|
|
||||||
id: string;
|
|
||||||
options: string[];
|
|
||||||
placeholder: string;
|
|
||||||
value: string;
|
|
||||||
onChange: (value: string) => void;
|
|
||||||
}
|
|
||||||
|
|
||||||
function SelectWithLabel({ label, id, options, placeholder, onChange }: SelectWithLabelProps) {
|
|
||||||
return (
|
|
||||||
<div className="grid gap-3">
|
|
||||||
<Label htmlFor={id}>{label}</Label>
|
|
||||||
<Select onValueChange={onChange}>
|
|
||||||
<SelectTrigger id={id}>
|
|
||||||
<SelectValue placeholder={placeholder} />
|
|
||||||
</SelectTrigger>
|
|
||||||
<SelectContent>
|
|
||||||
{options.map((option) => (
|
|
||||||
<SelectItem
|
|
||||||
key={option}
|
|
||||||
value={option}
|
|
||||||
>
|
|
||||||
{option}
|
|
||||||
</SelectItem>
|
|
||||||
))}
|
|
||||||
</SelectContent>
|
|
||||||
</Select>
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
59
frontend/src/components/ModelManage/TrainingGraph.tsx
Normal file
59
frontend/src/components/ModelManage/TrainingGraph.tsx
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
import { useEffect, useMemo } from 'react';
|
||||||
|
import ModelLineChart from './ModelLineChart';
|
||||||
|
import usePollingModelReportsQuery from '@/queries/models/usePollingModelReportsQuery';
|
||||||
|
import useModelStore from '@/stores/useModelStore';
|
||||||
|
|
||||||
|
interface TrainingGraphProps {
|
||||||
|
projectId: number | null;
|
||||||
|
selectedModel: number | null;
|
||||||
|
}
|
||||||
|
|
||||||
|
export default function TrainingGraph({ projectId, selectedModel }: TrainingGraphProps) {
|
||||||
|
const { isTrainingByProject, setIsTraining, resetTrainingData } = useModelStore((state) => ({
|
||||||
|
isTrainingByProject: state.isTrainingByProject,
|
||||||
|
setIsTraining: state.setIsTraining,
|
||||||
|
resetTrainingData: state.resetTrainingData,
|
||||||
|
}));
|
||||||
|
|
||||||
|
const isTraining = isTrainingByProject[projectId?.toString() || ''] || false;
|
||||||
|
|
||||||
|
const { data: trainingDataList } = usePollingModelReportsQuery(
|
||||||
|
projectId as number,
|
||||||
|
selectedModel ?? 0,
|
||||||
|
isTraining && !!projectId && !!selectedModel
|
||||||
|
);
|
||||||
|
|
||||||
|
const latestData = useMemo(() => {
|
||||||
|
return (
|
||||||
|
trainingDataList?.[trainingDataList.length - 1] || {
|
||||||
|
epoch: 0,
|
||||||
|
totalEpochs: 0,
|
||||||
|
leftSecond: 0,
|
||||||
|
}
|
||||||
|
);
|
||||||
|
}, [trainingDataList]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (latestData.epoch === latestData.totalEpochs && latestData.totalEpochs > 0) {
|
||||||
|
setIsTraining(projectId?.toString() || '', false);
|
||||||
|
resetTrainingData(projectId?.toString() || '');
|
||||||
|
}
|
||||||
|
}, [latestData.epoch, latestData.totalEpochs, setIsTraining, resetTrainingData, projectId]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<ModelLineChart
|
||||||
|
data={
|
||||||
|
trainingDataList?.map((data) => ({
|
||||||
|
epoch: data.epoch.toString(),
|
||||||
|
loss1: data.boxLoss,
|
||||||
|
loss2: data.clsLoss,
|
||||||
|
loss3: data.dflLoss,
|
||||||
|
fitness: data.fitness,
|
||||||
|
})) || []
|
||||||
|
}
|
||||||
|
currentEpoch={latestData.epoch}
|
||||||
|
totalEpochs={latestData.totalEpochs}
|
||||||
|
remainingTime={latestData.leftSecond}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
138
frontend/src/components/ModelManage/TrainingSettings.tsx
Normal file
138
frontend/src/components/ModelManage/TrainingSettings.tsx
Normal file
@ -0,0 +1,138 @@
|
|||||||
|
import SelectWithLabel from './SelectWithLabel';
|
||||||
|
import InputWithLabel from './InputWithLabel';
|
||||||
|
import { Button } from '@/components/ui/button';
|
||||||
|
import useProjectModelsQuery from '@/queries/models/useProjectModelsQuery';
|
||||||
|
import { ModelTrainRequest } from '@/types';
|
||||||
|
import { useState } from 'react';
|
||||||
|
|
||||||
|
interface TrainingSettingsProps {
|
||||||
|
projectId: number | null;
|
||||||
|
selectedModel: number | null;
|
||||||
|
setSelectedModel: (model: number | null) => void;
|
||||||
|
handleTrainingStart: (trainData: ModelTrainRequest) => void;
|
||||||
|
handleTrainingStop: () => void;
|
||||||
|
isTraining: boolean;
|
||||||
|
}
|
||||||
|
|
||||||
|
export default function TrainingSettings({
|
||||||
|
projectId,
|
||||||
|
selectedModel,
|
||||||
|
setSelectedModel,
|
||||||
|
handleTrainingStart,
|
||||||
|
handleTrainingStop,
|
||||||
|
isTraining,
|
||||||
|
}: TrainingSettingsProps) {
|
||||||
|
const { data: models } = useProjectModelsQuery(projectId ?? 0);
|
||||||
|
|
||||||
|
const [ratio, setRatio] = useState<number>(0.8);
|
||||||
|
const [epochs, setEpochs] = useState<number>(50);
|
||||||
|
const [batchSize, setBatchSize] = useState<number>(32);
|
||||||
|
const [optimizer, setOptimizer] = useState<'SGD' | 'AUTO' | 'ADAM' | 'ADAMW' | 'NADAM' | 'RADAM' | 'RMSPROP'>('AUTO');
|
||||||
|
const [lr0, setLr0] = useState<number>(0.01);
|
||||||
|
const [lrf, setLrf] = useState<number>(0.001);
|
||||||
|
|
||||||
|
const handleSubmit = () => {
|
||||||
|
if (isTraining) {
|
||||||
|
handleTrainingStop();
|
||||||
|
} else if (selectedModel !== null) {
|
||||||
|
const trainData: ModelTrainRequest = {
|
||||||
|
modelId: selectedModel,
|
||||||
|
ratio,
|
||||||
|
epochs,
|
||||||
|
batch: batchSize,
|
||||||
|
optimizer,
|
||||||
|
lr0,
|
||||||
|
lrf,
|
||||||
|
};
|
||||||
|
handleTrainingStart(trainData);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<fieldset
|
||||||
|
className="grid gap-6 rounded-lg border p-4"
|
||||||
|
disabled={isTraining}
|
||||||
|
>
|
||||||
|
<legend className="-ml-1 px-1 text-sm font-medium">모델 설정</legend>
|
||||||
|
|
||||||
|
<div className="grid gap-3">
|
||||||
|
<SelectWithLabel
|
||||||
|
label="모델 선택"
|
||||||
|
id="model"
|
||||||
|
options={
|
||||||
|
models?.map((model) => ({
|
||||||
|
label: model.name,
|
||||||
|
value: model.id.toString(),
|
||||||
|
})) || []
|
||||||
|
}
|
||||||
|
placeholder="모델을 선택하세요"
|
||||||
|
value={selectedModel ? selectedModel.toString() : ''}
|
||||||
|
onChange={(value) => setSelectedModel(parseInt(value, 10))}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="grid grid-cols-2 gap-4">
|
||||||
|
<InputWithLabel
|
||||||
|
label="훈련/검증 비율"
|
||||||
|
placeholder="예: 0.8 (80% 훈련, 20% 검증)"
|
||||||
|
id="ratio"
|
||||||
|
value={ratio}
|
||||||
|
onChange={(e) => setRatio(parseFloat(e.target.value))}
|
||||||
|
/>
|
||||||
|
<InputWithLabel
|
||||||
|
label="에포크 수"
|
||||||
|
placeholder="예: 50 (총 반복 횟수)"
|
||||||
|
id="epochs"
|
||||||
|
value={epochs}
|
||||||
|
onChange={(e) => setEpochs(parseInt(e.target.value, 10))}
|
||||||
|
/>
|
||||||
|
<InputWithLabel
|
||||||
|
label="Batch 크기"
|
||||||
|
placeholder="예: 32 (한번에 처리할 샘플 수)"
|
||||||
|
id="batch"
|
||||||
|
value={batchSize}
|
||||||
|
onChange={(e) => setBatchSize(parseInt(e.target.value, 10))}
|
||||||
|
/>
|
||||||
|
<SelectWithLabel
|
||||||
|
label="옵티마이저"
|
||||||
|
id="optimizer"
|
||||||
|
options={[
|
||||||
|
{ label: 'AUTO', value: 'AUTO' },
|
||||||
|
{ label: 'SGD', value: 'SGD' },
|
||||||
|
{ label: 'ADAM', value: 'ADAM' },
|
||||||
|
{ label: 'ADAMW', value: 'ADAMW' },
|
||||||
|
{ label: 'NADAM', value: 'NADAM' },
|
||||||
|
{ label: 'RADAM', value: 'RADAM' },
|
||||||
|
{ label: 'RMSPROP', value: 'RMSPROP' },
|
||||||
|
]}
|
||||||
|
placeholder="옵티마이저 선택"
|
||||||
|
value={optimizer}
|
||||||
|
onChange={(value) => setOptimizer(value as 'AUTO' | 'SGD' | 'ADAM' | 'ADAMW' | 'NADAM' | 'RADAM' | 'RMSPROP')}
|
||||||
|
/>
|
||||||
|
<InputWithLabel
|
||||||
|
label="학습률(LR0)"
|
||||||
|
placeholder="예: 0.01 (초기 학습률)"
|
||||||
|
id="lr0"
|
||||||
|
value={lr0}
|
||||||
|
onChange={(e) => setLr0(parseFloat(e.target.value))}
|
||||||
|
/>
|
||||||
|
<InputWithLabel
|
||||||
|
label="최종 학습률(LRF)"
|
||||||
|
placeholder="예: 0.001 (최종 학습률)"
|
||||||
|
id="lrf"
|
||||||
|
value={lrf}
|
||||||
|
onChange={(e) => setLrf(parseFloat(e.target.value))}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<Button
|
||||||
|
variant="outlinePrimary"
|
||||||
|
size="lg"
|
||||||
|
onClick={handleSubmit}
|
||||||
|
disabled={!selectedModel || isTraining}
|
||||||
|
>
|
||||||
|
{isTraining ? '학습 중단' : '학습 시작'}
|
||||||
|
</Button>
|
||||||
|
</fieldset>
|
||||||
|
);
|
||||||
|
}
|
@ -1,45 +1,58 @@
|
|||||||
import { Button } from '@/components/ui/button';
|
import useTrainModelQuery from '@/queries/models/useTrainModelQuery';
|
||||||
import ModelLineChart from '@/components/ModelLineChart';
|
import useModelStore from '@/stores/useModelStore';
|
||||||
import SettingsForm from './SettingsForm';
|
import TrainingSettings from './TrainingSettings';
|
||||||
|
import TrainingGraph from './TrainingGraph';
|
||||||
|
import { ModelTrainRequest } from '@/types';
|
||||||
|
|
||||||
interface TrainingTabProps {
|
interface TrainingTabProps {
|
||||||
training: boolean;
|
projectId: number | null;
|
||||||
handleTrainingToggle: () => void;
|
|
||||||
trainingDataList: {
|
|
||||||
epoch: number;
|
|
||||||
box_loss: number;
|
|
||||||
cls_loss: number;
|
|
||||||
dfl_loss: number;
|
|
||||||
fitness: number;
|
|
||||||
}[];
|
|
||||||
projectId: string | null; // projectId를 프랍으로 받음
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export default function TrainingTab({ training, handleTrainingToggle, trainingDataList, projectId }: TrainingTabProps) {
|
export default function TrainingTab({ projectId }: TrainingTabProps) {
|
||||||
|
const numericProjectId = projectId ? parseInt(projectId.toString(), 10) : null;
|
||||||
|
const { isTrainingByProject, setIsTraining, selectedModelByProject, setSelectedModel, resetTrainingData } =
|
||||||
|
useModelStore((state) => ({
|
||||||
|
isTrainingByProject: state.isTrainingByProject,
|
||||||
|
setIsTraining: state.setIsTraining,
|
||||||
|
selectedModelByProject: state.selectedModelByProject,
|
||||||
|
setSelectedModel: state.setSelectedModel,
|
||||||
|
resetTrainingData: state.resetTrainingData,
|
||||||
|
}));
|
||||||
|
|
||||||
|
const isTraining = isTrainingByProject[numericProjectId?.toString() || ''] || false;
|
||||||
|
const selectedModel = selectedModelByProject[numericProjectId?.toString() || ''];
|
||||||
|
|
||||||
|
const { mutate: startTraining } = useTrainModelQuery(numericProjectId as number);
|
||||||
|
|
||||||
|
const handleTrainingStart = (trainData: ModelTrainRequest) => {
|
||||||
|
if (!isTraining && selectedModel !== null) {
|
||||||
|
setIsTraining(numericProjectId?.toString() || '', true);
|
||||||
|
startTraining(trainData);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleTrainingStop = () => {
|
||||||
|
if (isTraining) {
|
||||||
|
setIsTraining(numericProjectId?.toString() || '', false);
|
||||||
|
resetTrainingData(numericProjectId?.toString() || '');
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="grid gap-8 md:grid-cols-2">
|
<div className="grid gap-8 md:grid-cols-2">
|
||||||
<div className="flex flex-col gap-6">
|
<TrainingSettings
|
||||||
<SettingsForm projectId={projectId} />
|
projectId={numericProjectId}
|
||||||
<Button
|
selectedModel={selectedModel}
|
||||||
variant={training ? 'destructive' : 'outlinePrimary'}
|
setSelectedModel={(modelId) => setSelectedModel(numericProjectId?.toString() || '', modelId)}
|
||||||
size="lg"
|
handleTrainingStart={handleTrainingStart}
|
||||||
onClick={handleTrainingToggle}
|
handleTrainingStop={handleTrainingStop}
|
||||||
>
|
isTraining={isTraining}
|
||||||
{training ? '학습 중단' : '학습 시작'}
|
/>
|
||||||
</Button>
|
|
||||||
</div>
|
<TrainingGraph
|
||||||
|
projectId={numericProjectId}
|
||||||
<div className="flex flex-col justify-center">
|
selectedModel={selectedModel}
|
||||||
<ModelLineChart
|
|
||||||
data={trainingDataList.map((data) => ({
|
|
||||||
epoch: data.epoch.toString(),
|
|
||||||
loss1: data.box_loss,
|
|
||||||
loss2: data.cls_loss,
|
|
||||||
loss3: data.dfl_loss,
|
|
||||||
fitness: data.fitness,
|
|
||||||
}))}
|
|
||||||
/>
|
/>
|
||||||
</div>
|
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
@ -1,28 +1,11 @@
|
|||||||
import { Tabs, TabsContent, TabsList, TabsTrigger } from '@/components/ui/tabs';
|
import { Tabs, TabsContent, TabsList, TabsTrigger } from '@/components/ui/tabs';
|
||||||
import { useState } from 'react';
|
|
||||||
import { useParams } from 'react-router-dom';
|
import { useParams } from 'react-router-dom';
|
||||||
|
|
||||||
import useTrainWebSocket from '@/hooks/useTrainPolling';
|
|
||||||
import useTrainStore from '@/stores/useTrainStore';
|
|
||||||
import TrainingTab from './TrainingTab';
|
import TrainingTab from './TrainingTab';
|
||||||
import EvaluationTab from './EvaluationTab';
|
import EvaluationTab from './EvaluationTab';
|
||||||
|
|
||||||
export default function ModelManage() {
|
export default function ModelManage() {
|
||||||
const { projectId } = useParams<{ projectId?: string }>();
|
const { projectId } = useParams<{ projectId?: string }>();
|
||||||
const [training, setTraining] = useState(false);
|
const numericProjectId = projectId ? parseInt(projectId, 10) : null;
|
||||||
const [selectedModel, setSelectedModel] = useState<string | null>(null);
|
|
||||||
|
|
||||||
const numericProjectId = projectId ?? null;
|
|
||||||
|
|
||||||
useTrainWebSocket(training, numericProjectId);
|
|
||||||
|
|
||||||
const { trainingDataList } = useTrainStore((state) => ({
|
|
||||||
trainingDataList: numericProjectId ? state.trainingDataByProject[numericProjectId] || [] : [],
|
|
||||||
}));
|
|
||||||
|
|
||||||
const handleTrainingToggle = () => {
|
|
||||||
setTraining((prev) => !prev);
|
|
||||||
};
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="grid h-screen w-full">
|
<div className="grid h-screen w-full">
|
||||||
@ -41,22 +24,12 @@ export default function ModelManage() {
|
|||||||
<TabsTrigger value="results">모델 평가</TabsTrigger>
|
<TabsTrigger value="results">모델 평가</TabsTrigger>
|
||||||
</TabsList>
|
</TabsList>
|
||||||
|
|
||||||
{/* 학습 탭 */}
|
|
||||||
<TabsContent value="train">
|
<TabsContent value="train">
|
||||||
<TrainingTab
|
<TrainingTab projectId={numericProjectId} />
|
||||||
training={training}
|
|
||||||
handleTrainingToggle={handleTrainingToggle}
|
|
||||||
trainingDataList={trainingDataList}
|
|
||||||
projectId={numericProjectId}
|
|
||||||
/>
|
|
||||||
</TabsContent>
|
</TabsContent>
|
||||||
|
|
||||||
{/* 평가 탭 */}
|
|
||||||
<TabsContent value="results">
|
<TabsContent value="results">
|
||||||
<EvaluationTab
|
<EvaluationTab projectId={numericProjectId} />
|
||||||
selectedModel={selectedModel}
|
|
||||||
setSelectedModel={setSelectedModel}
|
|
||||||
/>
|
|
||||||
</TabsContent>
|
</TabsContent>
|
||||||
</Tabs>
|
</Tabs>
|
||||||
</main>
|
</main>
|
||||||
|
@ -1,49 +0,0 @@
|
|||||||
// 임시 가짜 훅
|
|
||||||
import { useEffect, useRef, useCallback } from 'react';
|
|
||||||
import axios from 'axios';
|
|
||||||
import useTrainStore from '@/stores/useTrainStore';
|
|
||||||
|
|
||||||
export default function useTrainPolling(start: boolean, projectId?: string | null) {
|
|
||||||
const { addTrainingData, resetTrainingData } = useTrainStore((state) => ({
|
|
||||||
addTrainingData: state.addTrainingData,
|
|
||||||
resetTrainingData: state.resetTrainingData,
|
|
||||||
}));
|
|
||||||
|
|
||||||
const intervalIdRef = useRef<number | null>(null);
|
|
||||||
// 함수 api 후 교체 예정
|
|
||||||
const fetchTrainingData = useCallback(async () => {
|
|
||||||
if (projectId) {
|
|
||||||
try {
|
|
||||||
const response = await axios.get(`/api/바보=${projectId}`);
|
|
||||||
const data = response.data;
|
|
||||||
|
|
||||||
addTrainingData(projectId, {
|
|
||||||
epoch: data.epoch,
|
|
||||||
total_epochs: data.total_epochs,
|
|
||||||
box_loss: data.box_loss,
|
|
||||||
cls_loss: data.cls_loss,
|
|
||||||
dfl_loss: data.dfl_loss,
|
|
||||||
fitness: data.fitness,
|
|
||||||
epoch_time: data.epoch_time,
|
|
||||||
left_second: data.left_second,
|
|
||||||
});
|
|
||||||
} catch (error) {
|
|
||||||
console.error('Fetching error:', error);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}, [projectId, addTrainingData]);
|
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
if (start && projectId) {
|
|
||||||
resetTrainingData(projectId);
|
|
||||||
intervalIdRef.current = window.setInterval(fetchTrainingData, 5000);
|
|
||||||
}
|
|
||||||
|
|
||||||
return () => {
|
|
||||||
if (intervalIdRef.current) {
|
|
||||||
clearInterval(intervalIdRef.current);
|
|
||||||
intervalIdRef.current = null;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
}, [start, projectId, fetchTrainingData, resetTrainingData]);
|
|
||||||
}
|
|
10
frontend/src/queries/models/useModelReportsQuery.ts
Normal file
10
frontend/src/queries/models/useModelReportsQuery.ts
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
import { useSuspenseQuery } from '@tanstack/react-query';
|
||||||
|
import { getModelReports } from '@/api/modelApi';
|
||||||
|
import { ReportResponse } from '@/types';
|
||||||
|
|
||||||
|
export default function useModelReportsQuery(projectId: number, modelId: number) {
|
||||||
|
return useSuspenseQuery<ReportResponse[]>({
|
||||||
|
queryKey: ['modelReports', projectId, modelId],
|
||||||
|
queryFn: () => getModelReports(projectId, modelId),
|
||||||
|
});
|
||||||
|
}
|
10
frontend/src/queries/models/useModelResultsQuery.ts
Normal file
10
frontend/src/queries/models/useModelResultsQuery.ts
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
import { useSuspenseQuery } from '@tanstack/react-query';
|
||||||
|
import { getModelResults } from '@/api/modelApi';
|
||||||
|
import { ResultResponse } from '@/types';
|
||||||
|
|
||||||
|
export default function useModelResultsQuery(modelId: number) {
|
||||||
|
return useSuspenseQuery<ResultResponse[]>({
|
||||||
|
queryKey: ['modelResults', modelId],
|
||||||
|
queryFn: () => getModelResults(modelId),
|
||||||
|
});
|
||||||
|
}
|
12
frontend/src/queries/models/usePollingModelReportsQuery.ts
Normal file
12
frontend/src/queries/models/usePollingModelReportsQuery.ts
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
import { useQuery } from '@tanstack/react-query';
|
||||||
|
import { getModelReports } from '@/api/modelApi';
|
||||||
|
import { ReportResponse } from '@/types';
|
||||||
|
|
||||||
|
export default function usePollingModelReportsQuery(projectId: number, modelId: number, enabled: boolean) {
|
||||||
|
return useQuery<ReportResponse[]>({
|
||||||
|
queryKey: ['pollingModelReports', projectId, modelId],
|
||||||
|
queryFn: () => getModelReports(projectId, modelId),
|
||||||
|
refetchInterval: 5000,
|
||||||
|
enabled,
|
||||||
|
});
|
||||||
|
}
|
@ -1,8 +1,9 @@
|
|||||||
import { useMutation } from '@tanstack/react-query';
|
import { useMutation } from '@tanstack/react-query';
|
||||||
import { trainModel } from '@/api/modelApi';
|
import { trainModel } from '@/api/modelApi';
|
||||||
|
import { ModelTrainRequest } from '@/types';
|
||||||
|
|
||||||
export default function useTrainModelQuery(projectId: number) {
|
export default function useTrainModelQuery(projectId: number) {
|
||||||
return useMutation({
|
return useMutation({
|
||||||
mutationFn: () => trainModel(projectId),
|
mutationFn: (trainData: ModelTrainRequest) => trainModel(projectId, trainData),
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
56
frontend/src/stores/useModelStore.ts
Normal file
56
frontend/src/stores/useModelStore.ts
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
import { create } from 'zustand';
|
||||||
|
import { ReportResponse } from '@/types';
|
||||||
|
|
||||||
|
interface ModelStoreState {
|
||||||
|
trainingDataByProject: Record<string, ReportResponse[]>;
|
||||||
|
isTrainingByProject: Record<string, boolean>;
|
||||||
|
selectedModelByProject: Record<string, number | null>;
|
||||||
|
setIsTraining: (projectId: string, status: boolean) => void;
|
||||||
|
saveTrainingData: (projectId: string, data: ReportResponse[]) => void;
|
||||||
|
setSelectedModel: (projectId: string, modelId: number | null) => void;
|
||||||
|
resetTrainingData: (projectId: string) => void;
|
||||||
|
}
|
||||||
|
|
||||||
|
const useModelStore = create<ModelStoreState>((set) => ({
|
||||||
|
trainingDataByProject: {},
|
||||||
|
isTrainingByProject: {},
|
||||||
|
selectedModelByProject: {},
|
||||||
|
setIsTraining: (projectId, status) =>
|
||||||
|
set((state) => ({
|
||||||
|
isTrainingByProject: {
|
||||||
|
...state.isTrainingByProject,
|
||||||
|
[projectId]: status,
|
||||||
|
},
|
||||||
|
})),
|
||||||
|
saveTrainingData: (projectId, data) =>
|
||||||
|
set((state) => ({
|
||||||
|
trainingDataByProject: {
|
||||||
|
...state.trainingDataByProject,
|
||||||
|
[projectId]: data,
|
||||||
|
},
|
||||||
|
})),
|
||||||
|
setSelectedModel: (projectId, modelId) =>
|
||||||
|
set((state) => ({
|
||||||
|
selectedModelByProject: {
|
||||||
|
...state.selectedModelByProject,
|
||||||
|
[projectId]: modelId,
|
||||||
|
},
|
||||||
|
})),
|
||||||
|
resetTrainingData: (projectId) =>
|
||||||
|
set((state) => ({
|
||||||
|
trainingDataByProject: {
|
||||||
|
...state.trainingDataByProject,
|
||||||
|
[projectId]: [],
|
||||||
|
},
|
||||||
|
selectedModelByProject: {
|
||||||
|
...state.selectedModelByProject,
|
||||||
|
[projectId]: null,
|
||||||
|
},
|
||||||
|
isTrainingByProject: {
|
||||||
|
...state.isTrainingByProject,
|
||||||
|
[projectId]: false,
|
||||||
|
},
|
||||||
|
})),
|
||||||
|
}));
|
||||||
|
|
||||||
|
export default useModelStore;
|
@ -1,40 +0,0 @@
|
|||||||
import { create } from 'zustand';
|
|
||||||
|
|
||||||
interface TrainingData {
|
|
||||||
epoch: number;
|
|
||||||
total_epochs: number;
|
|
||||||
box_loss: number;
|
|
||||||
cls_loss: number;
|
|
||||||
dfl_loss: number;
|
|
||||||
fitness: number;
|
|
||||||
epoch_time: number;
|
|
||||||
left_second: number;
|
|
||||||
}
|
|
||||||
|
|
||||||
interface StoreState {
|
|
||||||
trainingDataByProject: { [projectId: string]: TrainingData[] };
|
|
||||||
addTrainingData: (projectId: string, data: TrainingData) => void;
|
|
||||||
resetTrainingData: (projectId: string) => void;
|
|
||||||
}
|
|
||||||
|
|
||||||
const useTrainStore = create<StoreState>((set) => ({
|
|
||||||
trainingDataByProject: {},
|
|
||||||
|
|
||||||
addTrainingData: (projectId: string, data: TrainingData) =>
|
|
||||||
set((state) => ({
|
|
||||||
trainingDataByProject: {
|
|
||||||
...state.trainingDataByProject,
|
|
||||||
[projectId]: [...(state.trainingDataByProject[projectId] || []), data],
|
|
||||||
},
|
|
||||||
})),
|
|
||||||
|
|
||||||
resetTrainingData: (projectId: string) =>
|
|
||||||
set((state) => ({
|
|
||||||
trainingDataByProject: {
|
|
||||||
...state.trainingDataByProject,
|
|
||||||
[projectId]: [],
|
|
||||||
},
|
|
||||||
})),
|
|
||||||
}));
|
|
||||||
|
|
||||||
export default useTrainStore;
|
|
@ -277,6 +277,25 @@ export interface ImageFolderRequest {
|
|||||||
parentId: number;
|
parentId: number;
|
||||||
files: File[];
|
files: File[];
|
||||||
}
|
}
|
||||||
|
export interface LabelCategoryResponse {
|
||||||
|
id: number;
|
||||||
|
name: string;
|
||||||
|
}
|
||||||
|
// 카테고리 요청 DTO
|
||||||
|
export interface LabelCategoryRequest {
|
||||||
|
labelCategoryList: number[];
|
||||||
|
}
|
||||||
|
|
||||||
|
// 카테고리 응답 DTO
|
||||||
|
export interface LabelCategoryResponse {
|
||||||
|
id: number;
|
||||||
|
name: string;
|
||||||
|
}
|
||||||
|
// 모델 카테고리 응답 DTO
|
||||||
|
export interface ModelCategoryResponse {
|
||||||
|
id: number;
|
||||||
|
name: string;
|
||||||
|
}
|
||||||
|
|
||||||
// 모델 요청 DTO (API로 전달할 데이터 타입)
|
// 모델 요청 DTO (API로 전달할 데이터 타입)
|
||||||
export interface ModelRequest {
|
export interface ModelRequest {
|
||||||
@ -289,22 +308,41 @@ export interface ModelResponse {
|
|||||||
name: string;
|
name: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
// 모델 카테고리 응답 DTO
|
|
||||||
export interface ModelCategoryResponse {
|
|
||||||
id: number;
|
|
||||||
name: string;
|
|
||||||
}
|
|
||||||
|
|
||||||
// 프로젝트 모델 리스트 응답 DTO
|
// 프로젝트 모델 리스트 응답 DTO
|
||||||
export interface ProjectModelsResponse extends Array<ModelResponse> {}
|
export interface ProjectModelsResponse extends Array<ModelResponse> {}
|
||||||
|
// 모델 훈련 요청 DTO
|
||||||
// 카테고리 요청 DTO
|
export interface ModelTrainRequest {
|
||||||
export interface LabelCategoryRequest {
|
modelId: number;
|
||||||
labelCategoryList: number[];
|
ratio: number;
|
||||||
|
epochs: number;
|
||||||
|
batch: number;
|
||||||
|
lr0: number;
|
||||||
|
lrf: number;
|
||||||
|
optimizer: 'AUTO' | 'SGD' | 'ADAM' | 'ADAMW' | 'NADAM' | 'RADAM' | 'RMSPROP';
|
||||||
}
|
}
|
||||||
|
export interface ResultResponse {
|
||||||
// 카테고리 응답 DTO
|
|
||||||
export interface LabelCategoryResponse {
|
|
||||||
id: number;
|
id: number;
|
||||||
name: string;
|
precision: number;
|
||||||
|
recall: number;
|
||||||
|
fitness: number;
|
||||||
|
ratio: number;
|
||||||
|
epochs: number;
|
||||||
|
batch: number;
|
||||||
|
lr0: number;
|
||||||
|
lrf: number;
|
||||||
|
optimizer: 'AUTO' | 'SGD' | 'ADAM' | 'ADAMW' | 'NADAM' | 'RADAM' | 'RMSPROP';
|
||||||
|
map50: number;
|
||||||
|
map5095: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface ReportResponse {
|
||||||
|
modelId: number;
|
||||||
|
totalEpochs: number;
|
||||||
|
epoch: number;
|
||||||
|
boxLoss: number;
|
||||||
|
clsLoss: number;
|
||||||
|
dflLoss: number;
|
||||||
|
fitness: number;
|
||||||
|
epochTime: number;
|
||||||
|
leftSecond: number;
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user