# Copyright (c) 2023 Adrian Niemann Dmitry Puzyrev
#
# This file is part of ParticleDetection.
# ParticleDetection is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# ParticleDetection is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with ParticleDetection. If not, see <http://www.gnu.org/licenses/>.
"""
Collection of custom Detectron2 objects to provide a customized training
process with more sophisticated outputs.
**Author:** Adrian Niemann (adrian.niemann@ovgu.de)\n
**Date:** 31.10.2022
"""
import time
import datetime
import logging
import copy
from typing import List
import numpy as np
import torch
from detectron2.engine.hooks import HookBase
from detectron2.config.config import CfgNode
from detectron2.engine.defaults import DefaultTrainer
from detectron2.evaluation import COCOEvaluator, DatasetEvaluators
import detectron2.data.detection_utils as utils
import detectron2.data.transforms as T
from detectron2.utils.events import EventWriter, get_event_storage, \
CommonMetricPrinter
from detectron2.utils.logger import log_every_n_seconds
from detectron2.data import DatasetMapper, build_detection_test_loader, \
build_detection_train_loader
import detectron2.utils.comm as comm
[docs]
class CustomTrainer(DefaultTrainer):
augmentations: List[T.Augmentation] = []
[docs]
@classmethod
def build_evaluator(cls, cfg: CfgNode, dataset_name: str):
"""Build a custom evaluator depending on the detection task.
Parameters
----------
cfg : CfgNode
Detectron2 network configuration with allowed TASK field of:\n
``"None"`` -> will be changed to ``"segm"``\n
``"segm"`` -> results in an evaluator for a segmentation task\n
``"keypoints"`` -> results in an evaluator for a keypoint detection
task
dataset_name : str
Name of a dataset that is registered in the Detectron2 framework,
that is used as the ``test`` dataset of the constructed evaluator.
Returns
-------
DatasetEvaluators
"""
# TODO: What exactly is max_dets_per_image controlling/ how does it
# influence the metrics?
try:
tasks = [cfg.TASKS]
if "None" in tasks:
return DatasetEvaluators([])
except AttributeError:
tasks = ["segm"]
if "keypoints" in tasks:
sigmas = (np.array([0.25, 0.25]) / 10.0)
sigmas = sigmas.tolist()
else:
sigmas = []
dataset_evaluators = [
COCOEvaluator(dataset_name, output_dir=cfg.OUTPUT_DIR,
max_dets_per_image=cfg.TEST.DETECTIONS_PER_IMAGE,
tasks=tuple(tasks), kpt_oks_sigmas=sigmas)
]
return DatasetEvaluators(dataset_evaluators)
[docs]
def build_writers(self):
"""Builds additional/custom writers for use during training."""
return [CustomTensorboardWriter(self.cfg.OUTPUT_DIR, window_size=1),
CommonMetricPrinter(self.cfg.SOLVER.MAX_ITER)]
[docs]
def build_hooks(self):
"""
Build a list of hooks, including the ``DefaultTrainer`` default hooks
and a custom loss hook used during evaluation.
"""
hooks = super().build_hooks()
hooks.insert(-1, EvalLossHook(
self.cfg.TEST.EVAL_PERIOD, # 1,
self.model,
build_detection_test_loader(
self.cfg,
self.cfg.DATASETS.TEST[0],
DatasetMapper(self.cfg, True) # TODO: might need replacement
)
))
return hooks
[docs]
@classmethod
def build_train_loader(cls, cfg):
"""Custom loader for training data.
Parameters
----------
cfg : CfgNode
Detectron2 network configuration.
Returns
-------
Iterable
"""
# Code from DatasetMapper.from_config(cls, cfg, is_train: bool = True)
# to properly load settings from the configuration
is_train = True
import detectron2.data.detection_utils as du
augs = du.build_augmentation(cfg, is_train)
if cfg.INPUT.CROP.ENABLED:
augs.insert(0, T.RandomCrop(cfg.INPUT.CROP.TYPE,
cfg.INPUT.CROP.SIZE))
recompute_boxes = cfg.MODEL.MASK_ON
else:
recompute_boxes = False
mapper_conf = {
"is_train": is_train,
"augmentations": augs,
"image_format": cfg.INPUT.FORMAT,
"use_instance_mask": cfg.MODEL.MASK_ON,
"instance_mask_format": cfg.INPUT.MASK_FORMAT,
"use_keypoint": cfg.MODEL.KEYPOINT_ON,
"recompute_boxes": recompute_boxes,
}
if cfg.MODEL.KEYPOINT_ON:
mapper_conf["keypoint_hflip_indices"] = \
utils.create_keypoint_hflip_indices(cfg.DATASETS.TRAIN)
if cfg.MODEL.LOAD_PROPOSALS:
mapper_conf["precomputed_proposal_topk"] = (
cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TRAIN
if is_train
else cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TEST
)
# Additional custom augmentations
if cls.augmentations:
mapper_conf["augmentations"].extend(cls.augmentations)
return build_detection_train_loader(
cfg, mapper=DatasetMapper(**mapper_conf))
# Currently not used
[docs]
class CompleteMapper(DatasetMapper): # pragma: no cover
"""Provides annotation data in training and testing context."""
[docs]
def __call__(self, dataset_dict):
"""
Args:
dataset_dict (dict): Metadata of one image, in Detectron2 dataset
format.
Returns:
dict: a format that builtin models in detectron2 accept
"""
dataset_dict = copy.deepcopy(
dataset_dict) # it will be modified by code below
# USER: Write your own image loading if it's not from a file
image = utils.read_image(dataset_dict["file_name"],
format=self.image_format)
utils.check_image_size(dataset_dict, image)
# USER: Remove if you don't do semantic/panoptic segmentation.
if "sem_seg_file_name" in dataset_dict:
sem_seg_gt = utils.read_image(
dataset_dict.pop("sem_seg_file_name"), "L").squeeze(2)
else:
sem_seg_gt = None
aug_input = T.AugInput(image, sem_seg=sem_seg_gt)
transforms = self.augmentations(aug_input)
image, sem_seg_gt = aug_input.image, aug_input.sem_seg
image_shape = image.shape[:2] # h, w
# Pytorch's dataloader is efficient on torch.Tensor due to
# shared-memory, but not efficient on large generic data structures due
# to the use of pickle & mp.Queue. Therefore it's important to use
# torch.Tensor.
dataset_dict["image"] = torch.as_tensor(
np.ascontiguousarray(image.transpose(2, 0, 1)))
if sem_seg_gt is not None:
dataset_dict["sem_seg"] = torch.as_tensor(
sem_seg_gt.astype("long"))
# USER: Remove if you don't use pre-computed proposals.
# Most users would not need this feature.
if self.proposal_topk is not None:
utils.transform_proposals(
dataset_dict, image_shape, transforms,
proposal_topk=self.proposal_topk
)
if "annotations" in dataset_dict:
self._transform_annotations(dataset_dict, transforms, image_shape)
return dataset_dict
[docs]
class EvalLossHook(HookBase):
"""Hook to compute different losses in the training process of a network.
This is the *copy* of the loss hook used by Detectron2 for the training
dataset. This hook is intended for evaluating the loss on the test dataset
during the training process.
"""
def __init__(self, eval_period, model, data_loader):
self._model = model
self._period = eval_period
self._data_loader = data_loader
def _do_loss_eval(self):
# Copying inference_on_dataset from evaluator.py
total = len(self._data_loader)
num_warmup = min(5, total - 1)
start_time = time.perf_counter()
total_compute_time = 0
losses = {}
for idx, inputs in enumerate(self._data_loader):
if idx == num_warmup:
start_time = time.perf_counter()
total_compute_time = 0
start_compute_time = time.perf_counter()
if torch.cuda.is_available():
torch.cuda.synchronize()
total_compute_time += time.perf_counter() - start_compute_time
iters_after_start = idx + 1 - num_warmup * int(
idx >= num_warmup)
seconds_per_img = total_compute_time / iters_after_start
if idx >= num_warmup * 2 or seconds_per_img > 5:
total_seconds_per_img = (
time.perf_counter() - start_time) / iters_after_start
eta = datetime.timedelta(
seconds=int(total_seconds_per_img * (total - idx - 1)))
log_every_n_seconds(
logging.INFO,
"Loss on Validation done {}/{}. {:.4f} s / img. ETA={}".format( # noqa: E501
idx + 1, total, seconds_per_img, str(eta)
),
n=5,
)
losses = self._get_loss(inputs, losses)
total_loss = 0
for loss, values in losses.items():
next_loss = np.mean(values)
self.trainer.storage.put_scalar("test/" + loss, next_loss)
total_loss += next_loss
self.trainer.storage.put_scalar('test/total_loss', total_loss)
comm.synchronize()
return losses
def _get_loss(self, data, losses: dict):
# How loss is calculated on train_loop
metrics_dict = self._model(data)
for k, v in metrics_dict.items():
float_v = 0.
if isinstance(v, torch.Tensor):
float_v = v.detach().cpu().item()
else:
float_v = float(v)
if k in losses.keys():
losses[k].append(float_v)
else:
losses[k] = [float_v]
return losses
[docs]
def after_step(self):
next_iter = self.trainer.iter + 1
is_final = next_iter == self.trainer.max_iter
if is_final or (self._period > 0 and next_iter % self._period == 0):
self._do_loss_eval()
self.trainer.storage.put_scalars(timetest=12)
[docs]
class CustomTensorboardWriter(EventWriter):
"""
Write all scalars to a tensorboard file.
"""
def __init__(self, log_dir: str, window_size: int = 20, **kwargs):
"""
Args:
log_dir (str):
The directory to save the output events
window_size (int):
The scalars will be median-smoothed by this window size
kwargs:
Other arguments passed to
``torch.utils.tensorboard.SummaryWriter()``
"""
self._window_size = window_size
from torch.utils.tensorboard import SummaryWriter
self._writer = SummaryWriter(log_dir + "/test", **kwargs)
self._writers = {
"test": SummaryWriter(log_dir + "/test"),
"train": SummaryWriter(log_dir + "/train")}
self._last_write = -1
self.train_kw = ["fast_rcnn"]
[docs]
def write(self):
storage = get_event_storage()
# custom stuff
new_last_write = self._last_write
for k, (v, iter) in storage.latest_with_smoothing_hint(
self._window_size).items():
if iter > self._last_write:
for id, writer in self._writers.items():
if id == "test":
if "test" in k:
if k == "timetest":
writer.add_scalar("time/time", v, iter)
else:
new_k = k.split("/")[-1]
if "loss" in new_k:
writer.add_scalar("loss/" + new_k, v, iter)
else:
writer.add_scalar(new_k, v, iter)
elif "segm/" in k:
writer.add_scalar(k, v, iter)
elif "keypoints/" in k:
writer.add_scalar(k, v, iter)
elif id == "train":
if "test" not in k:
if "segm/" in k:
continue
if "keypoints/" in k:
continue
writer.add_scalar("all/" + k, v, iter)
if "loss" in k:
writer.add_scalar("loss/" + k, v, iter)
elif "time" in k:
writer.add_scalar("time/" + k, v, iter)
new_last_write = max(new_last_write, iter)
self._last_write = new_last_write
# storage.put_{image,histogram} is only meant to be used by
# tensorboard writer. So we access its internal fields directly from
# here.
if len(storage._vis_data) >= 1:
for img_name, img, step_num in storage._vis_data:
self._writer.add_image(img_name, img, step_num)
# Storage stores all image data and rely on this writer to clear
# them. As a result it assumes only one writer will use its image
# data. An alternative design is to let storage store limited
# recent data (e.g. only the most recent image) that all writers
# can access. In that case a writer may not see all image data if
# its period is long.
storage.clear_images()
if len(storage._histograms) >= 1:
for params in storage._histograms:
self._writer.add_histogram_raw(**params)
storage.clear_histograms()
[docs]
def close(self):
# doesn't exist when the code fails at import
if hasattr(self, "_writer"):
self._writer.close()