# Copyright (c) 2023 Adrian Niemann Dmitry Puzyrev
#
# This file is part of ParticleDetection.
# ParticleDetection is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# ParticleDetection is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with ParticleDetection. If not, see <http://www.gnu.org/licenses/>.
"""
Function(s) to train a new model using the Detectron2 framework.
**Author:** Adrian Niemann (adrian.niemann@ovgu.de)\n
**Date:** 10.08.2022
"""
import os
import cv2
import random
import pickle
from typing import Union, List
import matplotlib.pyplot as plt
import numpy as np
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog
from detectron2.utils.logger import setup_logger
from detectron2.config import CfgNode
import detectron2.data.transforms as T
from ParticleDetection.modelling.datasets import load_custom_data
import ParticleDetection.utils.datasets as ds
import ParticleDetection.modelling.detectron as custom
from ParticleDetection.modelling.configs import write_configs
[docs]
def run_training(train_set: ds.DataSet,
configuration: Union[str, CfgNode],
val_set: ds.DataSet = None,
output_dir: str = "./", log_name: str = "training.log",
resume: bool = True, visualize: bool = False,
img_augmentations: List[T.Augmentation] = None,
freeze_layers: List[str] = None):
"""Runs the training of a model with the given training data.
Runs the training of a model which is defined by the given configuration.
The training can be resumed and further specification of layers to
train/not to train is possible. During training the different model
performance metrics are logged in the Tensorboard format.
Additional COCO metrics are available only if a validation dataset is
given. These metrics are also logged in the Tensorboard format.
Parameters
----------
train_set : DataSet
A DataSet already registered to the Detectron2 framework, that is used
for training the model.
configuration : Union[str, CfgNode]
Configuration for the Detectron2 model with training settings given as
a ``CfgNode`` or path to a ``*.yaml`` file in the Detectron2
configuration format.
val_set : DataSet, optional
A :class:`.DataSet` already registered to the Detectron2 framework,
that is used for testing the model during training.\n
By default ``None``.
output_dir : str, optional
Path to the intended output directory. It's parent directory must exist
prior to running this function.\n
By default ``"./"``.
log_name : str, optional
Filename for logging output in the output directory.\n
By default ``"training.log"``.
resume : bool, optional
Flag to continue with previous training progress in the output
folder.\n
By default ``True``.
visualize : bool, optional
Flag for allowing visualization of one randomly selected image from the
given training dataset with 10 randomly chosen annotations overlaid on
the image.\n
By default ``False``.
img_augmentations : List[Augmentation], optional
Image augmentations to be used during training.\n
By default ``None``.
freeze_layers : List[str], optional
Layers/layer collections to be frozen during training. The model's
layer names are obtained using ``model.named_parameters()``.\n
By default ``None``.
"""
setup_logger(os.path.join(output_dir, log_name))
if visualize:
# visualize annotations of randomly selected samples in the
# training set
meta_data = MetadataCatalog.get(train_set.name)
dataset_dicts = load_custom_data(train_set)
for d in random.sample(dataset_dicts, 1):
img = cv2.imread(d["file_name"])
visualizer = Visualizer(img[:, :, ::-1], metadata=meta_data,
scale=0.5)
d["annotations"] = random.sample(d["annotations"], 10)
out = visualizer.draw_dataset_dict(d)
plt.figure()
plt.imshow(out.get_image()[:, :, ::-1])
plt.show()
resume_with_config = False
# Try to load previously defined *.yaml configuration
if resume:
previous_config = os.path.join(output_dir, "config.yaml")
previous_augment = os.path.join(output_dir, "augmentations.pkl")
if os.path.exists(previous_config):
configuration = CfgNode(CfgNode.load_yaml_with_base(
previous_config))
resume_with_config = True
if os.path.exists(previous_augment):
with open(previous_augment, "rb") as f:
img_augmentations = pickle.load(f)
if not resume_with_config:
# Load configuration, if needed
if isinstance(configuration, str):
configuration = CfgNode(CfgNode.load_yaml_with_base(configuration))
# Adaptation/Double-checking the given configuration
configuration.OUTPUT_DIR = os.path.abspath(output_dir)
configuration.DATASETS.TRAIN = (train_set.name,)
if val_set:
if not configuration.DATASETS.TEST:
configuration.DATASETS.TEST = (val_set.name,)
# Determine the maximum number of instances to predict per
# image
counts = ds.get_object_counts(val_set)
configuration.TEST.DETECTIONS_PER_IMAGE = int(1.5 * np.max(
counts))
elif not configuration.TEST.DETECTIONS_PER_IMAGE:
# Determine the maximum number of instances to predict per
# image
counts = ds.get_object_counts(val_set)
configuration.TEST.DETECTIONS_PER_IMAGE = int(1.5 * np.max(
counts))
# Create output directory and save configuration
os.makedirs(output_dir, exist_ok=True)
write_configs(configuration, output_dir, img_augmentations)
# Training
if img_augmentations:
custom.CustomTrainer.augmentations = img_augmentations
trainer = custom.CustomTrainer(configuration)
# Freeze layers (prevent weight/bias updates, might not work for
# decay/momentum)
if freeze_layers:
for layer, params in trainer.model.named_parameters():
for to_freeze in freeze_layers:
if to_freeze in layer:
params.requires_grad = False
trainer.resume_or_load(resume=resume)
trainer.train()