Source code for mumott.output_handling.projection_viewer

""" Container for the class ProjectionViewer. """

import logging
from dataclasses import dataclass

import numpy as np
import matplotlib.pyplot as plt
import colorcet  # noqa

from colorspacious import cspace_converter
from matplotlib import cm
from matplotlib.colors import Colormap
from matplotlib.gridspec import GridSpec
from numpy.typing import NDArray
from typing import Union, Tuple

from .orientation_image_mapper import OrientationImageMapper
from mumott.data_handling import DataContainer


@dataclass
class ProjectionImage:
    """
    A dataclass for storing images for instances of
    class:`ProjectionViewer <mumott.output_handling.ProjectionViewer>`.

    Attributes
    ----------
    mean
        The mean value of the data.
    std
        The standard deviation of the data.
    angle_color_quad
        A quadruplet of RGBA colors mapping phase angles
        that indicate the orientation of the image.
    """
    mean: NDArray[float]
    std: NDArray[float]
    angle_color_quad: NDArray[float]


[docs]class ProjectionViewer: """ Class for viewing data and synthetic projections in a scrollable plot. The viewed projection can be changed with either the arrow keys or with the mouse scrolling wheel. Parameters ---------- data_container : DataContainer A :class:`DataContainer <mumott.data_handling.data_container.DataContainer>` object, containing the necessary dimension and rotation information to show orientations correctly. data : numpy.ndarray(dtype=numpy.float64), optional A flat array of data from e.g. reconstruction output. If not given, data will be loaded from the provided ``data_container``. Provide if you e.g. want to look at synthetic projections from a reconstruction. orientation_symmetry : ``'longitudinal'``, ``'transversal'``, optional Specifies the assumed three-dimensional symmetry of the reciprocal space map, which determines whether the minimum or the maximum of the ``cos(phi) ** 2`` component of the of the intensity is interpreted as the orientation angle. The default is ``'longitudinal'``, which interprets the angle of maximum scattering as the orientation, corresponding to an assumed three-dimensional symmetry of polar caps of intensity. The other option is ``transversal``, which corresponds to a symmetry of a great circle of intensity (also known as fiber symmetry), corresponding to a phase shift of ``90`` degrees. mean_colormap : str or matplotlib.colors.Colormap Specifies the colormap of the mean intensity across all angles. The default is ``'cet_linear_bmy_10_95_c71'``, a linear high-contrast with blue, magenta and yellow. std_colormap : str or matplotlib.colors.Colormap Specifies the colormap of the standar deviation of the data across all angles. The default is ``'cet_gouldian'``, a linear high-contrast map with blue, green and yellow. phase_colormap : str or matplotlib.colors.Colormap Specifies the colormap of the phase of the ``cos(detector_angles) ** 2`` fit of the data. The default is ``'cet_CET_C10'``, an isoluminant cyclic map. mean_range : tuple(float, float) Specifies the range of the colormap of the mean. std_range : tuple(float, float) Specifies the range of the colormap of the standard deviation. """ def __init__(self, data_container: DataContainer, data: NDArray[np.float64] = None, orientation_symmetry: str = 'longitudinal', mean_colormap: Union[str, Colormap] = 'cet_linear_bmy_10_95_c71', std_colormap: Union[str, Colormap] = 'cet_gouldian', phase_colormap: Union[str, Colormap] = 'cet_CET_C10', mean_range: Tuple[float, float] = None, std_range: Tuple[float, float] = None): detector_angles = data_container.geometry.detector_angles det_vector = \ (np.cos(detector_angles[np.newaxis, :]) * data_container.geometry.detector_direction_origin[:, np.newaxis] + np.sin(detector_angles[np.newaxis, :]) * data_container.geometry.detector_direction_positive_90[:, np.newaxis]) distance_j = np.dot(data_container.geometry.j_direction_0[np.newaxis, :], det_vector).squeeze() distance_k = np.dot(data_container.geometry.k_direction_0[np.newaxis, :], det_vector).squeeze() if not np.allclose(np.diff(detector_angles).reshape(1, -1), np.diff(detector_angles).reshape(-1, 1), rtol=2e-01, atol=8e-02): logging.warning('The provided `detector_angles` do not appear to be ' 'sorted and equally spaced to within an absolute tolerance ' 'of 5 degrees and a relative tolerance of twenty percent of the ' 'segment width.' '\nThe orientation angle may not behave as expected. Please be cautious' 'in interpreting the figure.') detector_angles = np.arctan2(distance_k, distance_j) if data is None: data = data_container.data if orientation_symmetry == 'transversal': detector_angles = np.arctan2(np.sin(detector_angles), np.cos(detector_angles)) elif orientation_symmetry == 'longitudinal': detector_angles = np.arctan2(-np.cos(detector_angles), np.sin(detector_angles)) else: logging.warning( 'Unknown `orientation_symmetry` option: "' + orientation_symmetry + '"\nUsing default value: "longitudinal".\n' 'Valid options are "transversal" and "longitudinal".') detector_angles = np.arctan2(-np.cos(detector_angles), np.sin(detector_angles)) direction = np.arctan2(np.sin(detector_angles[1] - detector_angles[0]), np.cos(detector_angles[1] - detector_angles[0])) a = float(np.round(direction >= 0)) b = np.exp((-2j) * detector_angles[0]) self._mean_colormap = mean_colormap self._std_colormap = std_colormap self._phase_colormap = phase_colormap self._cielab_converter = cspace_converter('sRGB255', 'JCh') self._srgb_converter = cspace_converter('JCh', 'sRGB255') self._orientation_image_mapper = OrientationImageMapper(phase_colormap) self._wheel_properties = self._orientation_image_mapper.wheel_properties self._phase_colormap_lut = cm.get_cmap(self._phase_colormap, 256) self._projection_number = len(data) self._clims_mean = np.percentile( a=data.mean(axis=2), q=(1., 99.9)) if mean_range is not None: if mean_range[0] is not None: self._clims_mean[0] = mean_range[0] if mean_range[1] is not None: self._clims_mean[1] = mean_range[1] self._clims_std = np.percentile( a=data.std(axis=2), q=(1., 99.9)) if std_range is not None: if std_range[0] is not None: self._clims_std[0] = std_range[0] if std_range[1] is not None: self._clims_std[1] = std_range[1] self._projection_images = [] for i, frame in enumerate(data): temp_mean = frame[:, :, :detector_angles.size].mean(axis=2).T temp_std = frame[:, :, :detector_angles.size].std(axis=2).T temp_fft = np.fft.rfft(frame[:, :, :detector_angles.size] ** 2, axis=2, norm='ortho') if a: temp_angle = (np.sqrt(temp_fft[:, :, 1] * b)).T else: temp_angle = (np.sqrt(np.conj(temp_fft[:, :, 1]) * b)).T temp_angle = np.angle(np.sign(temp_angle.imag) * temp_angle) temp_mask = np.round( np.sqrt( abs(temp_fft[:, :, 0])) > 0.5 * np.sqrt(abs(temp_fft[:, :, 0])).mean()).astype(float).T temp_aniso_factor = np.sqrt( abs(temp_fft[:, :, 1]) / (np.finfo(float).tiny + abs(temp_fft[:, :, 0]))).T temp_shader = temp_aniso_factor.clip(0., 1) phase_rgba = self._phase_colormap_lut(temp_angle / np.pi) phase_cielab = self._cielab_converter(phase_rgba[:, :, :-1]) phase_cielab[:, :, :2] *= 1.1 * ((1 - np.exp(-temp_shader)) / (1 - np.exp(-1))).reshape(temp_shader.shape + (1,)) phase_rgba[:, :, :-1] = self._srgb_converter(phase_cielab).clip(0., 1.) phase_rgba[:, :, -1] = temp_mask temp_image = ProjectionImage( mean=temp_mean, std=temp_std, angle_color_quad=phase_rgba) self._projection_images.append(temp_image) self._fig = plt.figure() self._grid_spec = GridSpec(6, 9, figure=self._fig) self._axes = [] self._axes.append(self._fig.add_subplot(self._grid_spec[0:4, 0:3])) self._axes.append(self._fig.add_subplot(self._grid_spec[0:4, 3:6])) self._axes.append(self._fig.add_subplot(self._grid_spec[0:4, 6:9])) self._axes.append(self._fig.add_subplot(self._grid_spec[4:6, 6:9], projection='polar')) self._colorbar_axes = [] self._colorbar_axes.append(self._fig.add_subplot(self._grid_spec[5, 0:3])) self._colorbar_axes.append(self._fig.add_subplot(self._grid_spec[5, 3:6])) for a in self._axes: a.grid(False) a.set_yticks([]) a.set_xticks([]) a.set_yticklabels([]) a.set_xticklabels([]) self._im1 = self._axes[0].imshow( self._projection_images[0].mean, origin='lower', cmap=self._mean_colormap) self._axes[0].set_title('Mean') self._fig.colorbar( self._im1, cax=self._colorbar_axes[0], orientation='horizontal') self._im2 = self._axes[1].imshow( self._projection_images[0].std, origin='lower', cmap=self._std_colormap) self._axes[1].set_title('Standard deviation') self._fig.colorbar(self._im2, cax=self._colorbar_axes[1], orientation='horizontal') self._im3 = self._axes[2].imshow(self._projection_images[0].angle_color_quad, origin='lower',) self._axes[2].set_title('Orientation phase') self._axes[2].set_facecolor((254 / 255, 254 / 255, 254 / 255)) self._wheel_rgba = self._wheel_properties[2] self._axes[3].set_facecolor((254 / 255, 254 / 255, 254 / 255)) wheel_shader = self._wheel_properties[1][:-1, :-1] ** 2 wheel_cielab = self._cielab_converter(self._wheel_rgba[:, :, :-1]) wheel_cielab[:, :, :2] *= 1.1 * ((1 - np.exp(-wheel_shader)) / (1 - np.exp(-1))).reshape(wheel_shader.shape + (1,)) wheel_rgb = self._srgb_converter(wheel_cielab).clip(0., 1.) self._wheel_rgba[:, :, :-1] = wheel_rgb self._axes[3].grid(False) self._color_wheel = self._axes[3].pcolormesh(self._wheel_properties[0], 1 - np.arccos(self._wheel_properties[1]) * 2 / np.pi, wheel_shader[:, :], facecolors=self._wheel_rgba.reshape(-1, 4), edgecolors=self._wheel_rgba.reshape(-1, 4)) self._axes[3].set_rmax(1.) self._axes[3].set_rmin(0.) self._projection_ind = 0
[docs] def change_projection(self, ind: int) -> None: """ Method for updating the projection shown when the projection slider changes. Parameters ---------- ind The index of the projection to be shown. """ if ind != self._projection_ind: self._projection_ind = ind ind = int(ind) self._im1.set_data(self._projection_images[ind].mean) self._im1.set_clim(*self._clims_mean) self._im2.set_data(self._projection_images[ind].std) self._im2.set_clim(*self._clims_std) self._im3.set_data(self._projection_images[ind].angle_color_quad) self._im3.autoscale() self._axes[0].relim() self._axes[1].relim() self._axes[2].relim() self._axes[0].autoscale_view() self._axes[1].autoscale_view() self._axes[2].autoscale_view() self._axes[0].set_aspect('equal') self._axes[1].set_aspect('equal') self._axes[2].set_aspect('equal') self._fig.canvas.draw_idle() self._fig.canvas.flush_events()