# Copyright (c) 2023 Adrian Niemann Dmitry Puzyrev
#
# This file is part of RodTracker.
# RodTracker is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# RodTracker is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with RodTracker. If not, see <http://www.gnu.org/licenses/>.
"""**TBD**"""
import os
import logging
from typing import List
from functools import partial
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
import matplotlib.backends.backend_qtagg as b_qt
from PyQt5 import QtWidgets, QtCore
import ParticleDetection.utils.data_loading as dl
from RodTracker.backend.reconstruction import Plotter, Tracker, Reconstructor
from RodTracker.backend import reconstruction
import RodTracker.ui.mainwindow_layout as mw_l
import RodTracker.backend.logger as lg
from RodTracker.ui.dialogs import show_warning
_logger = logging.getLogger(__name__)
[docs]def init_reconstruction(ui: mw_l.Ui_MainWindow):
"""Initialize the functionality of reconstructing 3D particle positions.
Parameters
----------
ui : Ui_MainWindow
UI object of the main window of the application, i.e. also containing
the UI tab/objects for 3D reconstruction tasks.
Returns
-------
None | ReconstructorUI
Returns ``None``, if the system requirements for 3D particle position
reconstruction are not met. Otherwise the ``ReconstructorUI`` object
handling particle reconstructions is returned.
"""
return ReconstructorUI(ui.tab_reconstruct)
[docs]class ReconstructorUI(QtWidgets.QWidget):
"""A custom ``QWidget`` to provide access to the reconstruction of 3D
particle coordinates.
This widget interfaces with the :mod:`ParticleDetection.reconstruct_3D`
library and provides these functionalities to the GUI, i.e. reconstruction
of 3D particle coordinates and tracking of particles over multiple frames.
Parameters
----------
ui : QWidget
Widget containing the tab that is the GUI for the reconstruction and
tracking functionality.
*args : iterable
Positional arguments for the ``QWidget`` superclass.
**kwargs: dict
Keyword arguments for the ``QWidget`` superclass.
Attributes
----------
used_colors : List[str]
Selected colors for reconstruction/tracking/plotting.\n
By default ``[]``.
start_frame : int
Lower bound of the frame range selected for
reconstruction/tracking/plotting. The bound is inclusive.\n
By default ``0``.
end_frame : int
Upper bound of the frame range selected for
reconstruction/tracking/plotting. The bound is inclusive.\n
By default ``0``.
.. admonition:: Signals
- :attr:`request_data`
- :attr:`updated_data`
.. admonition:: Slots
- :meth:`add_plot`
- :meth:`data_loaded`
- :meth:`data_update`
- :meth:`set_calibration`
- :meth:`set_cam_ids`
- :meth:`set_transformation`
- :meth:`switch_plot_page`
- :meth:`update_frames`
- :meth:`update_settings`
"""
position_scaling: float = 1.0
"""float : Scale factor to scale the loaded data for display (is usually
kept as ``1.0``).
Default is ``1.0``.
"""
request_data = QtCore.pyqtSignal([list, list])
"""pyqtSignal(list, list) : Request a portion of the *main* dataset defined
by
| [0]: a list of frames, and
| [1]: a list of colors.
"""
updated_data = QtCore.pyqtSignal(pd.DataFrame)
"""pyqtSignal(DataFrame) : Sends an updated slice of the *main* dataset,
that has been (re-)tracked or its 3D coordinates updated.
This signal is emitted once for every color during the
reconstruction/tracking process. The ``DataFrame`` in the payload is
effectively an updated slice of the *main* dataset and does not contain
new rows.
"""
is_busy = QtCore.pyqtSignal(bool)
"""pyqtSignal(bool) : Notifies when a background task is started/finished.
"""
data: pd.DataFrame = None
"""DataFrame : Slice of the *main* ``DataFrame`` that is used for
reconstruction/tracking.
Default is ``None``.
"""
cam_ids: List[str] = ["", ""]
"""List[str] : IDs of the two cameras intended for reconstruction of 3D
coordinates.
The IDs are used to identify the 2D data columns during the reconstruction
process. If at least one of them is an empty string, the process of
reconstruction or tracking will be immediatly aborted, because the there
either is not enough data or the data is not identifiable.
Default is ``["", ""]``.
"""
_calibration = QtCore.pyqtSignal([dict])
_progress_val: float = 0.
_colors_to_solve: int = 0
def __init__(self, ui: QtWidgets.QWidget, *args, **kwargs):
super().__init__(*args, **kwargs)
self.ui = ui
self._threads = QtCore.QThreadPool()
self._calibration = None
self._transformation = None
self._custom_solver = None
self.reprojection_errs = None
self.used_colors = []
self.start_frame = 0
self.end_frame = 0
self.first_update = False
# Signal connections
tb_calibration = ui.findChild(QtWidgets.QToolButton, "tb_calibration")
le_calibration = ui.findChild(QtWidgets.QLineEdit, "le_calibration")
tb_calibration.clicked.connect(
lambda: choose_calibration(le_calibration, self.set_calibration)
)
tb_transformation = ui.findChild(QtWidgets.QToolButton,
"tb_transformation")
le_transformation = ui.findChild(QtWidgets.QLineEdit,
"le_transformation")
tb_transformation.clicked.connect(
lambda: choose_calibration(le_transformation,
self.set_transformation)
)
start_f = ui.findChild(QtWidgets.QSpinBox, "start_frame")
start_f.valueChanged.connect(self._change_start_frame)
end_f = ui.findChild(QtWidgets.QSpinBox, "end_frame")
end_f.valueChanged.connect(self._change_end_frame)
for cb in ui.findChildren(QtWidgets.QCheckBox):
if "tracking" in cb.objectName():
continue
cb.stateChanged.connect(self._toggle_color)
self.stacked_plots = ui.findChild(QtWidgets.QStackedWidget,
"stacked_plots")
self.lbl_current_plot = ui.findChild(QtWidgets.QLabel,
"lbl_current_plot")
tb_left = self.ui.findChild(QtWidgets.QToolButton, "tb_plots_left")
tb_left.clicked.connect(lambda: self.switch_plot_page(-1))
tb_right = self.ui.findChild(QtWidgets.QToolButton, "tb_plots_right")
tb_right.clicked.connect(lambda: self.switch_plot_page(1))
self.pb_plots = self.ui.findChild(QtWidgets.QPushButton,
"pb_update_plots")
self.pb_plots.clicked.connect(self.update_plots)
self.pb_solve = ui.findChild(QtWidgets.QPushButton, "pb_solve")
self.pb_solve.clicked.connect(self.solve)
# Disable not implemented features
ui.findChild(QtWidgets.QToolButton, "tb_solver").setEnabled(False)
ui.findChild(QtWidgets.QLineEdit, "le_solver").setEnabled(False)
ui.findChild(QtWidgets.QLabel, "lbl_solver").setEnabled(False)
self.pb_solve.setEnabled(False)
self.progress = ui.findChild(QtWidgets.QProgressBar,
"progress_reconstruction")
self.progress.setValue(100)
[docs] @QtCore.pyqtSlot(str)
def set_calibration(self, path: str):
"""Attempts to load a new set of stereo camera calibration data.
Attempt to load calibration data from the file given in ``path`` and
activates the **Solve** button if both, calibration and transformation
data, have been loaded. Additionally, the updating of plots is
(re-)enabled upon successful loading of the calibration data.
Parameters
----------
path : str
Path to the stereo camera calibration data that shall be loaded
here.
Returns
-------
None
"""
self._calibration = dl.load_camera_calibration(path)
if self._calibration and self._transformation:
self.ui.findChild(
QtWidgets.QPushButton, "pb_solve").setEnabled(True)
if self.data is not None:
self.pb_plots.setEnabled(True)
def _change_start_frame(self, new_val: int):
"""Callback for the ``QSpinBox`` handling the start frame selection."""
self.start_frame = new_val
self.select_data()
self.pb_plots.setEnabled(True)
def _change_end_frame(self, new_val: int):
"""Callback for the ``QSpinBox`` handling the end frame selection."""
self.end_frame = new_val
self.select_data()
self.pb_plots.setEnabled(True)
@property
def solver(self):
"""**Not Implemented.**"""
raise NotImplementedError
@solver.setter
def solver(self, path: str):
raise NotImplementedError
[docs] @QtCore.pyqtSlot(pd.DataFrame)
def data_update(self, data: pd.DataFrame):
"""Accepts a new dataset that shall be used for
reconstruction/tracking.
Accepts a new dataset and, depending on whether this is the first time
data is given here, updates the plots or (re-)enables the button for
updating the plots in the UI.
Parameters
----------
data : pd.DataFrame
New data that shall be used for plotting, 3D coordinate
reconstruction and tracking of particles.
"""
self.data = data
if any([cam == "" for cam in self.cam_ids]):
# insufficient data for plotting given
_logger.info("Insufficient data for plotting given. Skipping...")
return
if self.first_update:
# The automatic updating of plots has been disabled due to a memory
# issue with not garbage collected plot residues. See the
# "Known issues" section in the docs to learn more about the
# problem.
# self.update_plots()
# self.pb_plots.setEnabled(False)
self.pb_plots.setEnabled(True)
self.first_update = False
else:
self.pb_plots.setEnabled(True)
[docs] def select_data(self):
"""Request data defined by the selections in the UI.
Requests a portion of the *main* data that is defined by the selections
of the user in the reconstruction tab, i.e. start/end frame and colors
of particles to include.
Returns
-------
None
.. hint::
**Emits**:
- :attr:`request_data`
"""
self.request_data.emit(list(range(self.start_frame,
self.end_frame + 1)), self.used_colors)
[docs] def solve(self):
"""(Re-)Starts the reconstruction/tracking of particles.
Starts either the reconstruction or tracking of particles, depending on
the state of a ``QCheckBox``. One process/thread for every selected
color (:attr:`used_colors`) is started, that will (re-)calculate the
3D values (and particle IDs) for all frames in between
:attr:`start_frame` and :attr:`end_frame`.
Returns
-------
None
"""
track = self.ui.findChild(
QtWidgets.QCheckBox, "cb_tracking").isChecked()
if (self.data is None or len(self.data) == 0 or
any([cam == "" for cam in self.cam_ids])):
# insufficient data for 3D reconstruction given
_logger.info("Insufficient data for 3D reconstruction given.")
return
frames = list(range(self.start_frame, self.end_frame + 1))
self._progress_val = 0.
self.progress.setValue(0)
num_colors = len(self.used_colors)
self._colors_to_solve = num_colors
self.is_busy.emit(True)
for i in range(num_colors):
color = self.used_colors[i]
tmp = self.data.loc[self.data.color == color]
if track:
tracker = Tracker(tmp, frames, self._calibration,
self._transformation, self.cam_ids, color)
else:
tracker = Reconstructor(tmp, frames, self._calibration,
self._transformation, self.cam_ids,
color)
tracker.signals.progress.connect(
lambda val: self._progress_update(val / num_colors))
tracker.signals.error.connect(
lambda ret: lg.exception_logger(*ret))
tracker.signals.error.connect(
lambda ret: self._solver_result(None))
tracker.signals.error.connect(
lambda ret: partial(self._notify_error, color))
tracker.signals.error.connect(partial(self._notify_error, color))
tracker.signals.result.connect(self._solver_result)
self._threads.start(tracker)
self.pb_solve.setText("Abort")
self.pb_solve.clicked.disconnect()
self.pb_solve.clicked.connect(self._abort_reconstruction)
def _abort_reconstruction(self):
self.pb_solve.setEnabled(False)
reconstruction.lock.lockForWrite()
reconstruction.abort_reconstruction = True
reconstruction.lock.unlock()
def _notify_error(self, color: str):
show_warning(f"Something went wrong during 3D reconstruction of "
f"'{color}' particles.\n"
f"Please consult the logs for more information.")
[docs] def _solver_result(self, result: pd.DataFrame):
"""Hook to handle the result of each reconstruction process.
Updates the count of active reconstruction/tracking processes/threads
and resets the UI for further reconstruction/tracking tasks after all
have finished.
Propagates the results of the finished task.
Parameters
----------
result : pd.DataFrame
The ``DataFrame`` containing the result of the process, usually
updated data of one color only but for all frames used during
the finished process.
.. hint::
**Emits:**
- :attr:`updated_data`
"""
self._colors_to_solve -= 1
if self._colors_to_solve == 0:
if self._threads.activeThreadCount() == 0:
self.is_busy.emit(False)
self.pb_solve.setText("Solve")
self.pb_solve.clicked.disconnect()
self.pb_solve.clicked.connect(self.solve)
self.pb_solve.setEnabled(True)
self.progress.setValue(100)
reconstruction.lock.lockForWrite()
reconstruction.abort_reconstruction = False
reconstruction.lock.unlock()
if result is None:
return
self.data.update(result)
self.updated_data.emit(result)
self.select_data()
def _progress_update(self, update: float):
"""Update the progressbar during the reconstruction/tracking process.
Parameters
----------
update : float
Value to add upon the combined progress :math:`\\in [0, 1]`.
Returns
-------
None
"""
self._progress_val += 100 * update
self.progress.setValue(int(self._progress_val))
[docs] @QtCore.pyqtSlot(int)
def switch_plot_page(self, direction: int):
"""Switch the displayed plot page relative to the currently displayed
one.
Parameters
----------
direction : int
Direction of the plot to display next. Its the index relative to
the currently displayed plot.\n
a) ``direction = 3``-> displays the plot three positions
further\n
b) ``direction = -1``-> displays the previous plot\n
c) ``direction = 0``-> stays on the current plot
"""
idx_max = self.stacked_plots.count() - 1
idx_new = self.stacked_plots.currentIndex() + direction
if idx_new > idx_max:
idx_new = 0
elif idx_new < 0:
idx_new = idx_max
self.stacked_plots.setCurrentIndex(idx_new)
self.lbl_current_plot.setText(f"({self.stacked_plots.currentIndex()+1}"
f"/{self.stacked_plots.count()})")
[docs] @QtCore.pyqtSlot(str, str)
def set_cam_ids(self, cam1: str, cam2: str):
"""Setter function for :attr:`cam_ids`.
Parameters
----------
cam1 : str
ID for the first camera of the stereo camera setup.
cam2 : str
ID for the second camera of the stereo camera setup.
Returns
-------
None
"""
self.cam_ids = [cam1, cam2]
[docs] @QtCore.pyqtSlot(int, int, list)
def data_loaded(self, f_min: int, f_max: int, colors: List[str]):
"""Hook to updated the available frame range and colors for
reconstruction/tracking.
This function is intended as a slot for the
:attr:`~.RodData.data_loaded` signal. The range of available frames, as
well as the available colors in the loaded dataset is updated and
presented to users in the UI.
Parameters
----------
f_min : int
Lowest frame currently available in the particle position dataset.
f_max : int
Highest frame currently available in the particle position dataset.
colors : List[str]
Colors currently available in the particle position dataset.
Returns
-------
None
"""
start = self.ui.findChild(QtWidgets.QSpinBox, "start_frame")
start.setRange(f_min, f_max)
start.setValue(f_min)
end = self.ui.findChild(QtWidgets.QSpinBox, "end_frame")
end.setRange(f_min, f_max)
end.setValue(f_max)
self._update_colors(colors)
self.first_update = True
self.select_data()
def _update_colors(self, colors: List[str]):
"""Update the checkable colors displayed in the UI."""
color_group = self.ui.findChildren(QtWidgets.QGroupBox)[0]
color_cbs = color_group.findChildren(QtWidgets.QCheckBox)
old_colors = [cb.text().lower() for cb in color_cbs]
group_layout: QtWidgets.QGridLayout = color_group.layout()
for cb in color_cbs:
group_layout.removeWidget(cb)
if cb.text().lower() not in colors:
cb.hide()
cb.deleteLater()
row = 0
col = 0
for color in colors:
try:
cb = color_cbs[old_colors.index(color)]
except ValueError:
cb = QtWidgets.QCheckBox(text=color.lower())
cb.setObjectName(f"cb_{color}")
cb.stateChanged.connect(self._toggle_color)
cb.setChecked(True)
group_layout.addWidget(cb, row, col)
if col == 1:
col = 0
row += 1
else:
col = 1
def _toggle_color(self, _: int):
"""Update whether to use an available color.
Parameters
----------
_ : int
"""
self.used_colors = []
for cb in self.ui.findChildren(QtWidgets.QCheckBox):
if "tracking" in cb.objectName():
continue
if cb.checkState():
self.used_colors.append(cb.objectName().split("_")[1])
self.pb_plots.setEnabled(True)
[docs] @QtCore.pyqtSlot(int, int)
def update_frames(self, start: int, end: int):
"""Update the selected frame range for reconstruction/tracking.
Parameters
----------
start : int
Lowest selected frame.
end : int
Highest selected frame.
"""
spb_start = self.ui.findChild(QtWidgets.QSpinBox, "start_frame")
spb_end = self.ui.findChild(QtWidgets.QSpinBox, "end_frame")
spb_start.setValue(start)
spb_end.setValue(end)
[docs] def update_plots(self):
"""(Re-)Generate plots for evaluation of the 3D data in the dataset.
Starts a thread that generates the evaluation plots for the data
selected by the state of UI, i.e. selected colors
(:attr:`self.used_colors`) and frame range (:attr:`start_frame`,
:attr:`end_frame`).
Returns
-------
None
"""
while self.stacked_plots.count():
self.stacked_plots.removeWidget(self.stacked_plots.currentWidget())
plt.close('all')
if self.data is None or len(self.data) == 0:
return
data_plt = self.data.loc[(self.data["frame"] >= self.start_frame) &
(self.data["frame"] <= self.end_frame) &
(self.data["color"].isin(self.used_colors))]
self.is_busy.emit(True)
plotter = Plotter(
data_plt.copy(), colors=self.used_colors,
start_frame=self.start_frame, end_frame=self.end_frame,
position_scaling=self.position_scaling, cam_ids=self.cam_ids,
calibration=self._calibration, transformation=self._transformation
)
plotter.signals.result_plot.connect(self.add_plot)
plotter.signals.error.connect(lambda ret: lg.exception_logger(*ret))
self._threads.start(plotter)
self.pb_plots.setEnabled(False)
[docs] @QtCore.pyqtSlot(Figure)
def add_plot(self, fig: Figure):
"""Add a figure to the display section.
Attempts to add the given ``Figure`` as a new *page* for display in the
UI.
Parameters
----------
fig : Figure
``Figure`` to be added to the stacked plots for display in the UI.
Returns
-------
None
"""
canvas = b_qt.FigureCanvasQTAgg(fig)
nav_bar = b_qt.NavigationToolbar2QT(canvas, None)
widget = QtWidgets.QWidget()
widget.setLayout(QtWidgets.QVBoxLayout())
widget.layout().setContentsMargins(0, 0, 0, 0)
widget.layout().addWidget(canvas)
widget.layout().addWidget(nav_bar)
self.stacked_plots.insertWidget(self.stacked_plots.count(), widget)
fig.tight_layout()
self.lbl_current_plot.setText(f"({self.stacked_plots.currentIndex()+1}"
f"/{self.stacked_plots.count()})")
if self._threads.activeThreadCount() == 0:
self.is_busy.emit(False)
return
[docs] @QtCore.pyqtSlot(dict)
def update_settings(self, settings: dict) -> None:
"""Catches updates of the settings from a :class:`.Settings` class.
Checks for the keys relevant to itself and updates the corresponding
attributes. Redraws itself with the new settings in place.
Parameters
----------
settings : dict
Returns
-------
None
"""
settings_changed = False
if "position_scaling" in settings:
settings_changed = True
self.position_scaling = settings["position_scaling"]
if settings_changed and self.data is not None:
self.pb_plots.setEnabled(True)
[docs]def choose_calibration(line_edit: QtWidgets.QLineEdit,
destination_func: callable):
"""Let a user select a calibration/transformation file and load it.
Lets a user select a ``*.json`` file that should contain one kind of
calibration, i.e. stereo camera calibration or transformtion to
world/experiment coordinates. The chosen file is then passed to the given
loading function for further processing.
Parameters
----------
line_edit : QLineEdit
Display object for the desired calibration file.
destination_func : callable
Loading function for the desired calibration file.
Returns
-------
None
"""
# check for a directory
ui_dir = line_edit.text()
# opens directory to select image
kwargs = {}
# handle file path issue when running on linux as a snap
if 'SNAP' in os.environ:
kwargs["options"] = QtWidgets.QFileDialog.DontUseNativeDialog
chosen_file, _ = QtWidgets.QFileDialog.getOpenFileName(
line_edit, 'Open a calibration', ui_dir, '*.json',
**kwargs)
if chosen_file == "":
# File selection was aborted
return None
else:
destination_func(chosen_file)
line_edit.setText(chosen_file)