Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • empyre/empyre
  • weber/empyre
  • wessels/empyre
  • bryan/empyre
4 results
Show changes
Showing
with 2882 additions and 10 deletions
# -*- coding: utf-8 -*-
# Copyright 2020 by Forschungszentrum Juelich GmbH
# Author: J. Caron
#
"""IO plugin for simple text format."""
import logging
import ast
import numpy as np
from ...fields.field import Field
_log = logging.getLogger(__name__)
file_extensions = ('.txt',) # Recognised file extensions
def reader(filename, scale=None, vector=None, **kwargs):
_log.debug('Call reader')
if vector is None:
vector = False
assert vector is False, 'Only scalar 2D fields can currently be read with this file reader!'
with open(filename, 'r') as load_file: # Read data:
empyre_format = load_file.readline().startswith('EMPYRE-FORMAT')
if empyre_format: # File has EMPyRe structure:
scale = ast.literal_eval(load_file.readline()[8:-4]) # [8:-4] takes just the scale string!
data = np.loadtxt(filename, delimiter='\t', skiprows=2) # skips header!
else: # Try default with provided kwargs:
scale = 1.0 if scale is None else scale # Set default if not provided!
data = np.loadtxt(filename, **kwargs)
return Field(data, scale=scale, vector=False)
def writer(filename, field, with_header=True, **kwargs):
_log.debug('Call writer')
assert not field.vector, 'Vector fields can currently not be saved to text!'
assert len(field.dim) == 2, 'Only 2D fields can currenty be saved to text!'
if with_header: # write header:
with open(filename, 'w') as save_file:
save_file.write('EMPYRE-FORMAT\n')
save_file.write(f'scale = {field.scale} nm\n')
save_kwargs = {'fmt': '%7.6e', 'delimiter': '\t'}
else:
save_kwargs = kwargs
with open(filename, 'ba') as save_file: # the 'a' is for append!
np.savetxt(save_file, field.data, **save_kwargs)
# -*- coding: utf-8 -*-
# Copyright 2020 by Forschungszentrum Juelich GmbH
# Author: J. Caron
#
"""IO plugin for simple text format."""
import logging
from numbers import Number
import numpy as np
from ...fields.field import Field
from ...utils.misc import interp_to_regular_grid, restrict_points
from ...vis import colors
_log = logging.getLogger(__name__)
file_extensions = ('.vtk',) # Recognised file extensions
def reader(filename, scale=None, vector=None, bounds=None, **kwargs):
"""More infos at:
overview: https://docs.enthought.com/mayavi/mayavi/data.html
writing: https://vtk.org/Wiki/VTK/Writing_VTK_files_using_python
format: https://www.vtk.org/wp-content/uploads/2015/04/file-formats.pdf
"""
_log.debug('Calling reader')
try:
from tvtk.api import tvtk
except ImportError:
_log.error('This extension recquires the tvtk package!')
return
if vector is None:
vector = True
# Setting up reader:
reader = tvtk.DataSetReader(file_name=filename, read_all_scalars=True, read_all_vectors=vector)
reader.update()
# Getting output:
output = reader.output
assert output is not None, 'File reader could not find data or file "{}"!'.format(filename)
# Reading points and vectors:
if isinstance(output, tvtk.ImageData): # tvtk.StructuredPoints is a subclass of tvtk.ImageData!
# Connectivity: implicit; described by: 3D data array and spacing along each axis!
_log.info('geometry: ImageData')
# Load relevant information from output (reverse to get typical Python order z,y,x):
dim = output.dimensions[::-1]
origin = output.origin[::-1]
spacing = output.spacing[::-1]
_log.info(f'dim: {dim}, origin: {origin}, spacing: {spacing}')
assert len(dim) == 3, 'Data has to be three-dimensional!'
if scale is None:
scale = tuple(spacing)
if vector: # Extract vector compontents and create magnitude array:
vector_array = np.asarray(output.point_data.vectors, dtype=np.float)
x_mag, y_mag, z_mag = vector_array.T
data = np.stack((x_mag.reshape(dim), y_mag.reshape(dim), z_mag.reshape(dim)), axis=-1)
else: # Extract scalar data and create magnitude array:
scalar_array = np.asarray(output.point_data.scalars, dtype=np.float)
data = scalar_array.reshape(dim)
elif isinstance(output, tvtk.RectilinearGrid):
# Connectivity: implicit; described by: 3D data array and 1D array of spacing for each axis!
_log.info('geometry: RectilinearGrid')
raise NotImplementedError('RectilinearGrid is currently not supported!')
elif isinstance(output, tvtk.StructuredGrid):
# Connectivity: implicit; described by: 3D data array and 3D position arrays for each axis!
_log.info('geometry: StructuredGrid')
raise NotImplementedError('StructuredGrid is currently not supported!')
elif isinstance(output, tvtk.PolyData):
# Connectivity: explicit; described by: x, y, z positions of vertices and arrays of surface cells!
_log.info('geometry: PolyData')
raise NotImplementedError('PolyData is currently not supported!')
elif isinstance(output, tvtk.UnstructuredGrid):
# Connectivity: explicit; described by: x, y, z positions of vertices and arrays of volume cells!
_log.info('geometry: UnstructuredGrid')
# Load relevant information from output:
point_array = np.asarray(output.points, dtype=np.float)
if vector:
data_array = np.asarray(output.point_data.vectors, dtype=np.float)
else:
data_array = np.asarray(output.point_data.scalars, dtype=np.float)
if scale is None:
raise ValueError('For the interpolation of unstructured grids, the `scale` parameter is required!')
elif isinstance(scale, Number): # Scale is the same for each dimension x, y, z!
scale = (scale,) * 3
elif isinstance(scale, tuple):
assert len(scale) == 3, f'Each dimension (z, y, x) needs a scale, but {scale} was given!'
# Crop data to required range, if necessary
if bounds is not None:
_log.info('Restrict data')
point_array, data_array = restrict_points(point_array, data_array, bounds)
data = interp_to_regular_grid(point_array, data_array, scale, **kwargs)
else:
raise TypeError('Data type of {} not understood!'.format(output))
return Field(data, scale, vector=True)
def writer(filename, field, **kwargs):
_log.debug('Call writer')
assert len(field.dim) == 3, 'Currently only 3D fields can be saved to vtk!'
try:
from tvtk.api import tvtk, write_data
except ImportError:
_log.error('This extension recquires the tvtk package!')
return
# Create dataset:
origin = (0, 0, 0)
spacing = (field.scale[2], field.scale[1], field.scale[0])
dimensions = (field.dim[2], field.dim[1], field.dim[0])
sp = tvtk.StructuredPoints(origin=origin, spacing=spacing, dimensions=dimensions)
# Fill with data from field:
if field.vector: # Handle vector fields:
# Put vector components in corresponding array:
vectors = field.data.reshape(-1, 3)
sp.point_data.vectors = vectors
sp.point_data.vectors.name = 'vectors'
# Calculate colors:
x_mag, y_mag, z_mag = field.comp
magvec = np.asarray((x_mag.data.ravel(), y_mag.data.ravel(), z_mag.data.ravel()))
cmap = kwargs.pop('cmap', None)
if cmap is None:
cmap = colors.cmaps.cyclic_cubehelix
rgb = cmap.rgb_from_vector(magvec)
point_colors = tvtk.UnsignedIntArray()
point_colors.number_of_components = 3
point_colors.name = 'colors'
point_colors.from_array(rgb)
sp.point_data.scalars = point_colors
sp.point_data.scalars.name = 'colors'
else: # Handle scalar fields:
scalars = field.data.ravel()
sp.point_data.scalars = scalars
sp.point_data.scalars.name = 'scalars'
# Write the data to file:
write_data(sp, filename)
# -*- coding: utf-8 -*-
# Copyright 2020 by Forschungszentrum Juelich GmbH
# Author: J. Caron
#
"""IO functionality for Field objects."""
import logging
import os
from ..fields.field import Field
from .field_plugins import plugin_list
__all__ = ['load_field', 'save_field']
_log = logging.getLogger(__name__)
def load_field(filename, scale=None, vector=None, **kwargs):
"""Load supported file into a :class:`~..fields.Field` instance.
The function loads the file according to the extension:
SCALAR???
- hdf5 for HDF5. # TODO: You can use comp_pos here!!!
- EMD Electron Microscopy Dataset format (also HDF5).
- npy or npz for numpy formats.
PHASEMAP
- hdf5 for HDF5.
- rpl for Ripple (useful to export to Digital Micrograph).
- dm3 and dm4 for Digital Micrograph files.
- unf for SEMPER unf binary format.
- txt format.
- npy or npz for numpy formats.
- Many image formats such as png, tiff, jpeg...
VECTOR
- hdf5 for HDF5.
- EMD Electron Microscopy Dataset format (also HDF5).
- llg format.
- ovf format.
- npy or npz for numpy formats.
Any extra keyword is passed to the corresponsing reader. For available options see their individual documentation.
Parameters
----------
filename: str
The filename to be loaded.
scale: tuple of float, optional
Scaling along the dimensions of the underlying data. Default is 1.
vector: bool, optional
True if the field should be a vector field, False if it should be interpreted as a scalar field (default).
Returns
-------
field: `Field`
A `Field` object containing the loaded data.
Notes
-----
Falls back to HyperSpy routines for loading data, make sure it is installed if you need the full capabilities.
"""
_log.debug('Calling load_field')
extension = os.path.splitext(filename)[1]
for plugin in plugin_list: # Iterate over all plugins:
if extension in plugin.file_extensions: # Check if extension is recognised:
return plugin.reader(filename, scale=scale, vector=vector, **kwargs)
# If nothing was found, try HyperSpy
_log.debug('Using HyperSpy')
try:
import hyperspy.api as hs
except ImportError:
_log.error('This extension recquires the hyperspy package!')
return
comp_pos = kwargs.pop('comp_pos', -1)
return Field.from_signal(hs.load(filename, **kwargs), scale=scale, vector=vector, comp_pos=comp_pos)
def save_field(filename, field, **kwargs):
"""Saves the Field in the specified format.
The function gets the format from the extension:
- hdf5 for HDF5.
- EMD Electron Microscopy Dataset format (also HDF5).
- npy or npz for numpy formats.
If no extension is provided, 'hdf5' is used. Most formats are saved with the HyperSpy package (internally the field
is first converted to a HyperSpy Signal.
Each format accepts a different set of parameters. For details see the specific format documentation.
Parameters
----------
filename : str, optional
Name of the file which the Field is saved into. The extension determines the saving procedure.
"""
"""Saves the phasemap in the specified format.
The function gets the format from the extension:
- hdf5 for HDF5.
- rpl for Ripple (useful to export to Digital Micrograph).
- unf for SEMPER unf binary format.
- txt format.
- Many image formats such as png, tiff, jpeg...
If no extension is provided, 'hdf5' is used. Most formats are
saved with the HyperSpy package (internally the phasemap is first
converted to a HyperSpy Signal.
Each format accepts a different set of parameters. For details
see the specific format documentation.
Parameters
----------
filename: str, optional
Name of the file which the phasemap is saved into. The extension
determines the saving procedure.
save_mask: boolean, optional
If True, the `mask` is saved, too. For all formats, except HDF5, a separate file will
be created. HDF5 always saves the `mask` in the metadata, independent of this flag. The
default is False.
save_conf: boolean, optional
If True, the `confidence` is saved, too. For all formats, except HDF5, a separate file
will be created. HDF5 always saves the `confidence` in the metadata, independent of
this flag. The default is False
pyramid_format: boolean, optional
Only used for saving to '.txt' files. If this is True, the grid spacing is saved
in an appropriate header. Otherwise just the phase is written with the
corresponding `kwargs`.
"""
_log.debug('Calling save_field')
extension = os.path.splitext(filename)[1]
for plugin in plugin_list: # Iterate over all plugins:
if extension in plugin.file_extensions: # Check if extension is recognised:
plugin.writer(filename, field, **kwargs)
return
# If nothing was found, try HyperSpy:
_log.debug('Using HyperSpy')
field.to_signal().save(filename, **kwargs)
# -*- coding: utf-8 -*-
# Copyright 2020 by Forschungszentrum Juelich GmbH
# Author: J. Caron
#
"""Subpackage containing utility functionality."""
from .quaternion import *
__all__ = []
__all__.extend(quaternion.__all__)
del quaternion
# -*- coding: utf-8 -*-
# Copyright 2020 by Forschungszentrum Juelich GmbH
# Author: J. Caron
#
"""This module provides the miscellaneous helper functions."""
import logging
import itertools
from time import time
import numpy as np
from tqdm import tqdm
from scipy.spatial import cKDTree, qhull
from scipy.interpolate import LinearNDInterpolator
__all__ = ['levi_civita', 'interp_to_regular_grid', 'restrict_points']
_log = logging.getLogger(__name__)
def levi_civita(i, j, k):
_log.debug('Calling levi_civita')
return (j-i) * (k-i) * (k-j) / (np.abs(j-i) * np.abs(k-i) * np.abs(k-j))
def interp_to_regular_grid(points, values, scale, scale_factor=1, step=1, convex=True, distance_upper_bound=None):
"""Interpolate values on points to regular grid.
Parameters
----------
points : np.ndarray, (N, 3)
Array of points, describing the location of the values that should be interpolated. Three columns x, y, z!
values : np.ndarray, (N, c)
Array of values that should be interpolated to the new grid. `c` is the number of components (`1` for scalar
fields, `3` for normal 3D vector fields).
scale : tuple of 3 ints
Scale along each of the 3 spatial dimensions. Usually given in nm.
scale_factor : float, optional
Additional scaling factor that should be used if the original points are not described on a nm-scale. Use this
to convert from nm of `scale` to the unit of the points. By default 1.
step : int, optional
If this is bigger than 1 (the default), only every `step` point is taken into account. Can speed up calculation,
but you'll lose accuracy of your interpolation.
convex : bool, optional
By default True. If this is set to False, additional measures are taken to find holes in the point cloud.
WARNING: this is an experimental feature that should be used with caution!
distance_upper_bound: float, optional
Only used if `convex=False`. Set the upper bound, determining if a point of the new (interpolated) grid is too
far away from any original point. They are assumed to be in a "hole" and their values are set to zero. Set this
value in nm, it will be converted to the local unit of the original points internally. If not set and
`convex=True`, double of the the mean of `scale` is calculated and used (can be troublesome if the scales vary
drastically).
Returns
-------
interpolation: np.ndarray
Interpolated grid with shape `(zdim, ydim, xdim)` for scalar and `(zdim, ydim, xdim, ncomp)` for vector field.
"""
_log.debug('Calling interpolate_to_regular_grid')
z_uniq = np.unique(points[:, 2])
_log.info(f'unique positions along z: {len(z_uniq)}')
# Translate scale in local units (not necessarily nm), taken care of with `scale_factor`:
scale = tuple([s * scale_factor for s in scale])
# Determine the size of the point cloud of irregular coordinates:
x_min, x_max = points[:, 0].min(), points[:, 0].max()
y_min, y_max = points[:, 1].min(), points[:, 1].max()
z_min, z_max = points[:, 2].min(), points[:, 2].max()
x_diff, y_diff, z_diff = np.ptp(points[:, 0]), np.ptp(points[:, 1]), np.ptp(points[:, 2])
_log.info(f'x-range: {x_min:.2g} <-> {x_max:.2g} ({x_diff:.2g})')
_log.info(f'y-range: {y_min:.2g} <-> {y_max:.2g} ({y_diff:.2g})')
_log.info(f'z-range: {z_min:.2g} <-> {z_max:.2g} ({z_diff:.2g})')
# Determine dimensions from given grid spacing a:
dim = tuple(np.round(np.asarray((z_diff/scale[0], y_diff/scale[1], x_diff/scale[2]), dtype=int)))
assert all(dim) > 0, f'All dimensions of dim={dim} need to be > 0, please adjust the scale accordingly!'
z = z_min + scale[0] * (np.arange(dim[0]) + 0.5) # +0.5: shift to pixel center!
y = y_min + scale[1] * (np.arange(dim[1]) + 0.5) # +0.5: shift to pixel center!
x = x_min + scale[2] * (np.arange(dim[2]) + 0.5) # +0.5: shift to pixel center!
# Create points for new Euclidian grid; fliplr for (x, y, z) order:
points_euc = np.fliplr(np.asarray(list(itertools.product(z, y, x))))
# Make values 2D (if not already); double .T so that a new axis is added at the END (n, 1):
values = np.atleast_2d(values.T).T
# Prepare interpolated grid:
interpolation = np.empty(dim+(values.shape[-1],), dtype=np.float)
_log.info(f'Dimensions of new grid: {interpolation.shape}')
# Calculate the Delaunay triangulation (same for every component of multidim./vector fields):
_log.info('Start Delaunay triangulation...')
tick = time()
triangulation = qhull.Delaunay(points[::step])
tock = time()
_log.info(f'Delaunay triangulation complete (took {tock-tick:.2f} s)!')
# Perform the interpolation for each column of `values`:
for i in tqdm(range(values.shape[-1])):
# Create interpolator for the given triangulation and the values of the current column:
interpolator = LinearNDInterpolator(triangulation, values[::step, i], fill_value=0)
# Interpolate:
interpolation[..., i] = interpolator(points_euc).reshape(dim)
# If NOT convex, we have to check for additional holes in the structure (EXPERIMENTAL):
if not convex: # Only necessary if the user expects holes in the (-> nonconvex) distribution:
# Create k-dimensional tree for queries:
_log.info('Create cKDTree...')
tick = time()
tree = cKDTree(points)
tock = time()
_log.info(f'cKDTree creation complete (took {tock-tick:.2f} s)!')
# Query the tree for nearest neighbors, x: points to query, k: number of neighbors, p: norm
# to use (here: 2 - Euclidean), distance_upper_bound: maximum distance that is searched!
if distance_upper_bound is None: # Take the mean of the scale for the upper bound:
distance_upper_bound = 2 * np.mean(scale) # NOTE: could be problematic for wildly varying scale numbers.
else: # Convert to local scale:
distance_upper_bound *= scale_factor
_log.info('Start cKDTree neighbour query...')
tick = time()
data, leafsize = tree.query(x=points_euc, k=1, p=2, distance_upper_bound=distance_upper_bound)
tock = time()
_log.info(f'cKDTree neighbour query complete (took {tock-tick:.2f} s)!')
# Create boolean mask that determines which interpolation points have no neighbor near enough:
mask = np.isinf(data).reshape(dim) # Points further away than upper bound were marked 'inf'!
_log.info(f'{np.sum(mask)} of {points_euc.shape[0]} points were assumed to be in holes of the point cloud!')
# Set these points to zero (NOTE: This can take a looooong time):
interpolation[mask, :] = 0
return np.squeeze(interpolation)
def restrict_points(point_array, data_array, bounds):
"""Restrict range of point_array and data_array
Parameters
----------
points_array : np.ndarray, (N, 3)
Array of points, describing the location of the values that should be interpolated. Three columns x, y, z!
data_array : np.ndarray, (N, c)
Array of values that should be interpolated to the new grid. `c` is the number of components (`1` for scalar
fields, `3` for normal 3D vector fields).
bounds : tuple of 6 values
Restrict data range to given bounds, x0, x1, y0, y1, z0, z1.
Returns
-------
point_restricted: np.ndarray
Cut out of the array of points inside the bounds, describing the location of the values that should be
interpolated. Three columns x, y, z!
value_restricted: np.ndarray
Cut out of the array of values inside the bounds, describing the location of the values that should be
interpolated. Three columns x, y, z!
"""
point_restricted = []
data_restricted = []
for i, pos in enumerate(point_array):
if bounds[0] <= pos[0] <= bounds[1]:
if bounds[2] <= pos[1] <= bounds[3]:
if bounds[4] <= pos[2] <= bounds[5]:
point_restricted.append(pos)
data_restricted.append(data_array[i])
point_restricted = np.array(point_restricted)
data_restricted = np.array(data_restricted)
return point_restricted, data_restricted
# -*- coding: utf-8 -*-
# Copyright 2014 by Forschungszentrum Juelich GmbH
# Copyright 2020 by Forschungszentrum Juelich GmbH
# Author: J. Caron
#
"""This module provides the :class:`~.Quaternion` class which can be used for rotations."""
import logging
import numpy as np
__all__ = ['Quaternion']
class Quaternion(object):
"""Class representing a rotation expressed by a quaternion.
R"""Class representing a rotation expressed by a quaternion.
A quaternion is a four-dimensional description of a rotation which can also be described by
a rotation vector (`v1`, `v2`, `v3`) and a rotation angle :math:`\theta`. The four components
......@@ -64,11 +66,18 @@ class Quaternion(object):
self._log.debug('Calling __mul__')
if isinstance(other, Quaternion): # Quaternion multiplication
return self.dot_quat(self, other)
elif len(other) == 3: # vector multiplication
elif len(other) == 3: # vector multiplication (Caution: normalises!)
q_vec = Quaternion((0,) + tuple(other))
q = self.dot_quat(self.dot_quat(self, q_vec), self.conj)
return q.values[1:]
def _normalize(self):
self._log.debug('Calling _normalize')
mag2 = np.sum([n ** 2 for n in self.values])
if abs(mag2 - 1.0) > self.NORM_TOLERANCE:
mag = np.sqrt(mag2)
self.values = tuple(n / mag for n in self.values)
def dot_quat(self, q1, q2):
"""Multiply two :class:`~.Quaternion` objects to create a new one (always normalized).
......@@ -92,13 +101,6 @@ class Quaternion(object):
z = w1 * z2 + z1 * w2 + x1 * y2 - y1 * x2
return Quaternion((w, x, y, z))
def _normalize(self):
self._log.debug('Calling _normalize')
mag2 = np.sum(n ** 2 for n in self.values)
if abs(mag2 - 1.0) > self.NORM_TOLERANCE:
mag = np.sqrt(mag2)
self.values = tuple(n / mag for n in self.values)
@classmethod
def from_axisangle(cls, vector, theta):
"""Create a quaternion from an axis-angle representation
......
# -*- coding: utf-8 -*-
# Copyright 2020 by Forschungszentrum Juelich GmbH
# Author: J. Caron
#
"""Subpackage containing functionality for visualisation of multidimensional fields."""
from . import colors
from .plot2d import *
from .decorators import *
from .tools import *
from .plot3d import *
__all__ = ['colors']
__all__.extend(plot2d.__all__)
__all__.extend(decorators.__all__)
__all__.extend(tools.__all__)
__all__.extend(plot3d.__all__)
del plot2d
del decorators
del tools
del plot3d
# -*- coding: utf-8 -*-
# Copyright 2014 by Forschungszentrum Juelich GmbH
# Copyright 2020 by Forschungszentrum Juelich GmbH
# Author: J. Caron
#
# TODO: Own small package? Use viscm (with colorspacious)?
# TODO: Also add cmoceaon "phase" colormap? Make optional (try importing, fall back to RdBu!)
"""This module provides a number of custom colormaps, which also have capabilities for 3D plotting.
If this is the case, the :class:`~.Colormap3D` colormap class is a parent class. In `cmaps`, a
number of specialised colormaps is available for convenience. If the default for circular colormaps
(used for 3D plotting) should be changed, set it via `CMAP_CMAP_ANGULAR_DEFAULT`.
number of specialised colormaps is available for convenience.
For general questions about colors see:
http://www.poynton.com/PDFs/GammaFAQ.pdf
http://www.poynton.com/PDFs/ColorFAQ.pdf
"""
import logging
import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter as FuncForm
from matplotlib.ticker import MaxNLocator, IndexLocator, FixedLocator
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.axes_grid1 import ImageGrid
from matplotlib import gridspec
from matplotlib.patches import Circle
import colorsys
import abc
import numpy as np
from PIL import Image
from matplotlib import colors
import matplotlib.pyplot as plt
from matplotlib.ticker import FixedLocator
from mpl_toolkits.mplot3d import Axes3D # noqa: F401
from matplotlib.patches import Circle
from skimage import color as skcolor
import colorsys
import abc
from .tools import use_style
from . import plottools
__all__ = ['Colormap3D', 'ColormapCubehelix', 'ColormapPerception', 'ColormapHLS', 'ColormapClassic',
'ColormapTransparent', 'cmaps', 'interpolate_color']
__all__ = ['Colormap3D', 'ColormapCubehelix', 'ColormapPerception', 'ColormapHLS',
'ColormapClassic', 'ColormapTransparent', 'cmaps', 'CMAP_CIRCULAR_DEFAULT',
'ColorspaceCIELab', 'ColorspaceCIELuv', 'ColorspaceCIExyY', 'ColorspaceYPbPr',
'interpolate_color', 'rgb_to_brightness', 'colormap_brightness_comparison']
_log = logging.getLogger(__name__)
# TODO: DOCSTRINGS!!!
class Colormap3D(colors.Colormap, metaclass=abc.ABCMeta):
"""Colormap subclass for encoding directions with colors.
......@@ -64,7 +48,7 @@ class Colormap3D(colors.Colormap, metaclass=abc.ABCMeta):
_log = logging.getLogger(__name__ + '.Colormap3D')
def rgb_from_vector(self, vector):
def rgb_from_vector(self, vector, vmax=None):
"""Construct a hls tuple from three coordinates representing a 3D direction.
Parameters
......@@ -80,54 +64,56 @@ class Colormap3D(colors.Colormap, metaclass=abc.ABCMeta):
"""
self._log.debug('Calling rgb_from_vector')
x, y, z = np.asarray(vector)
# Calculate spherical coordinates:
r = np.sqrt(x ** 2 + y ** 2 + z ** 2)
x, y, z = vector
R = np.sqrt(x ** 2 + y ** 2 + z ** 2)
R_max = vmax if vmax is not None else R.max() + 1E-30
# FIRST color dimension: HUE (1D ring/angular direction)
phi = np.asarray(np.arctan2(y, x))
phi[phi < 0] += 2 * np.pi
theta = np.arccos(z / (r + 1E-30))
# Calculate color deterministics:
hue = phi / (2 * np.pi)
lum = 1 - theta / np.pi
sat = r / (r.max() + 1E-30)
# Calculate RGB from hue with colormap:
rgba = np.asarray(self(hue))
r, g, b = rgba[..., 0], rgba[..., 1], rgba[..., 2]
# Interpolate saturation:
# SECOND color dimension: SATURATION (2D, in-plane)
rho = np.sqrt(x ** 2 + y ** 2)
sat = rho / R_max
r, g, b = interpolate_color(sat, (0.5, 0.5, 0.5), np.stack((r, g, b), axis=-1))
# Interpolate luminance:
lum_target = np.where(lum < 0.5, 0, 1)
lum_target = np.stack([lum_target] * 3, axis=-1)
fraction = np.where(lum < 0.5, 1 - 2 * lum, 2 * (lum - 0.5))
# THIRD color dimension: LUMINANCE (3D, color sphere)
theta = np.arccos(z / R_max)
lum = 1 - theta / np.pi # goes from 0 (black) over 0.5 (grey) to 1 (white)!
lum_target = np.where(lum < 0.5, 0, 1) # Separate upper(white)/lower(black) hemispheres!
lum_target = np.stack([lum_target] * 3, axis=-1) # [0, 0, 0] -> black / [1, 1, 1] -> white!
fraction = 2 * np.abs(lum - 0.5) # 0.5: difference from grey, 2: scale to range (0, 1)!
r, g, b = interpolate_color(fraction, np.stack((r, g, b), axis=-1), lum_target)
# Return RGB:
return np.asarray(255 * np.stack((r, g, b), axis=-1), dtype=np.uint8)
def make_colorwheel(self, size=256, alpha=1, bgcolor=None):
# TODO: Strange arrows are not straight...
def make_colorwheel(self, size=64):
"""Construct a color wheel as an :class:`~PIL.Image` object.
Parameters
----------
size : int, optional
Diameter of the color wheel along both axes in pixels, by default 64.
Returns
-------
img: :class:`~PIL.Image`
The resulting image.
"""
self._log.debug('Calling make_colorwheel')
# Construct the colorwheel:
yy, xx = (np.indices((size, size)) - size/2 + 0.5)
rr = np.hypot(xx, yy)
xx = np.where(rr <= size/2-2, xx, 0)
yy = np.where(rr <= size/2-2, yy, 0)
zz = np.where(rr <= size/2-2, 0, -1) # color inside, black outside
aa = np.where(rr >= size/2-2, 255*alpha, 255).astype(dtype=np.uint8)
xx = np.where(rr <= size/2-3, xx, 0)
yy = np.where(rr <= size/2-3, yy, 0)
zz = np.zeros((size, size))
aa = np.where(rr >= size/2-3, 0, 255).astype(dtype=np.uint8)
rgba = np.dstack((self.rgb_from_vector(np.asarray((xx, yy, zz))), aa))
if bgcolor:
if bgcolor == 'w': # TODO: Matplotlib get color tuples from string?
bgcolor = (1, 1, 1)
if len(bgcolor) == 3 and not isinstance(bgcolor, str): # Only you have tuple!
r, g, b = rgba[..., 0], rgba[..., 1], rgba[..., 2]
r = np.where(rr <= size / 2 - 2, r, 255*bgcolor[0]).astype(dtype=np.uint8)
g = np.where(rr <= size / 2 - 2, g, 255*bgcolor[1]).astype(dtype=np.uint8)
b = np.where(rr <= size / 2 - 2, b, 255*bgcolor[2]).astype(dtype=np.uint8)
rgba[..., 0], rgba[..., 1], rgba[..., 2] = r, g, b
# Create color wheel:
return Image.fromarray(rgba)
def plot_colorwheel(self, axis=None, size=512, alpha=1, arrows=False, greyscale=False,
figsize=(4, 4), bgcolor=None, **kwargs):
def plot_colorwheel(self, axis=None, size=64, arrows=True, grayscale=False, **kwargs):
"""Display a color wheel to illustrate the color coding of vector gradient directions.
Parameters
......@@ -137,42 +123,30 @@ class Colormap3D(colors.Colormap, metaclass=abc.ABCMeta):
Returns
-------
None
img: :class:`matplotlib.image.AxesImage`
The resulting colorwheel.
"""
self._log.debug('Calling plot_colorwheel')
# Construct the colorwheel:
color_wheel = self.make_colorwheel(size=size, alpha=alpha, bgcolor=bgcolor)
if greyscale:
color_wheel = color_wheel.convert('L')
color_wheel = self.make_colorwheel(size=size)
if grayscale:
color_wheel = color_wheel.convert('LA')
# Plot the color wheel:
if axis is None:
fig = plt.figure(figsize=figsize)
axis = fig.add_subplot(1, 1, 1, aspect='equal')
axis.imshow(color_wheel, origin='lower', **kwargs)
axis.add_artist(Circle(xy=(size/2-0.5, size/2-0.5), radius=size/2-2, linewidth=2,
edgecolor='k', facecolor='none'))
if arrows:
plt.tick_params(axis='both', which='both', labelleft='off', labelbottom='off',
left='off', right='off', top='off', bottom='off')
axis.arrow(size/2, size/2, 0, 0.15*size, head_width=9, head_length=20,
fc='k', ec='k', lw=1, width=2)
axis.arrow(size/2, size/2, 0, -0.15*size, head_width=9, head_length=20,
fc='k', ec='k', lw=1, width=2)
axis.arrow(size/2, size/2, 0.15*size, 0, head_width=9, head_length=20,
fc='k', ec='k', lw=1, width=2)
axis.arrow(size/2, size/2, -0.15*size, 0, head_width=9, head_length=20,
fc='k', ec='k', lw=1, width=2)
# Return axis:
axis.xaxis.set_visible(False)
axis.yaxis.set_visible(False)
for tic in axis.xaxis.get_major_ticks():
tic.tick1On = tic.tick2On = False
tic.label1On = tic.label2On = False
for tic in axis.yaxis.get_major_ticks():
tic.tick1On = tic.tick2On = False
tic.label1On = tic.label2On = False
return axis
with use_style('empyre-image'): # Only works on axes created WITHIN context!
if axis is None: # If no axis is set, find the current or create a new one:
fig = plt.figure()
axis = fig.add_subplot(1, 1, 1, aspect='equal')
# Plot:
im = axis.imshow(color_wheel, **kwargs)
xy = size/2 - 0.5, size/2 - 0.5
axis.add_patch(Circle(xy=xy, radius=size/2-2.5, linewidth=2, edgecolor='k', facecolor='none'))
if arrows:
axis.arrow(size/2, size/2+5, 0, 0.1*size, fc='k', ec='k', lw=1, width=2, alpha=0.15)
axis.arrow(size/2, size/2-5, 0, -0.1*size, fc='k', ec='k', lw=1, width=2, alpha=0.15)
axis.arrow(size/2+5, size/2, 0.1*size, 0, fc='k', ec='k', lw=1, width=2, alpha=0.15)
axis.arrow(size/2-5, size/2, -0.1*size, 0, fc='k', ec='k', lw=1, width=2, alpha=0.15)
return im
class ColormapCubehelix(colors.LinearSegmentedColormap, Colormap3D):
......@@ -241,6 +215,7 @@ class ColormapCubehelix(colors.LinearSegmentedColormap, Colormap3D):
changed name of `hue` parameter to `sat`.
2016-11 (@jan.caron) Added support for isoluminant cubehelices while making sure
`rot` works as intended. Decoded the plane-vectors a bit.
"""
_log = logging.getLogger(__name__ + '.ColormapCubehelix')
......@@ -264,15 +239,15 @@ class ColormapCubehelix(colors.LinearSegmentedColormap, Colormap3D):
self.fract = self.fract**gamma
satar = np.linspace(minSat, maxSat, nlev)
amp = np.asarray(satar * self.fract * (1. - self.fract) / 2)
# Set RGB color coefficients (Luma is calculated in RGB Rec.601, so choose those),
# the original version of Dave green used (0.30, 0.59, 0.11) and REc.709 is
# c709 = (0.2126, 0.7152, 0.0722) but would not produce correct YPbPr Luma.
# Set RGB color coefficients (Luma is calculated in RGB Rec.601, so choose those), the original version of
# Dave Green used (0.30, 0.59, 0.11) and Rec.709 is c709 = (0.2126, 0.7152, 0.0722) but with eihter of those,
# this function would not produce the correct YPbPr Luma.
c601 = (0.299, 0.587, 0.114)
cr, cg, cb = c601
cw = -0.90649 # Chosen to comply with Dave Greens implementation.
k = -1.6158 / cr / cw # k has to balance out cw so nothing gets out of RGB gamut (> 1).
# Calculate the vectors v and w spanning the plane of constant perceived intensity.
# v and w have to solve v x w = k(cr, cg, cb) (normal vector of the described plane) and
# Calculate the vectors v and w spanning the plane of constant perceived intensity. v and w have to solve
# v x w = k(cr, cg, cb) (normal vector of the described plane) and
# v * w = 0 (scalar product, v and w have to be perpendicular).
# 6 unknown and 4 equations --> Chose wb = 0 and wg = cw (constant).
v = np.array((k * cr ** 2 * cb / (cw * (cr ** 2 + cg ** 2)),
......@@ -319,33 +294,31 @@ class ColormapCubehelix(colors.LinearSegmentedColormap, Colormap3D):
"""
self._log.debug('Calling plot_helix')
if figsize is None:
figsize = plottools.FIGSIZE_DEFAULT
plt.figure(figsize=figsize)
gs = gridspec.GridSpec(2, 1, height_ratios=[8, 1])
# Main plot:
axis = plt.subplot(gs[0])
axis.plot(self.fract, 'k', linewidth=2)
axis.plot(self.red, 'r', linewidth=2)
axis.plot(self.grn, 'g', linewidth=2)
axis.plot(self.blu, 'b', linewidth=2)
axis.set_xlim(0, self.nlev)
axis.set_ylim(0, 1)
axis.set_title('Cubehelix', fontsize=18)
axis.set_xlabel('Color index', fontsize=15)
axis.set_ylabel('Brightness / RGB', fontsize=15)
axis.xaxis.set_major_locator(FixedLocator(locs=np.linspace(0, self.nlev, 5)))
axis.yaxis.set_major_locator(FixedLocator(locs=[0, 0.5, 1]))
# Colorbar horizontal:
caxis = plt.subplot(gs[1], sharex=axis)
rgb = self(np.linspace(0, 1, 256))[None, ...]
rgb = np.asarray(255.9999 * rgb, dtype=np.uint8)
rgb = np.repeat(rgb, 30, axis=0)
im = Image.fromarray(rgb)
caxis.imshow(im)
plt.tick_params(axis='both', which='both', labelleft='off', labelbottom='off',
left='off', right='off', top='on', bottom='on')
return plottools.format_axis(axis, scalebar=False, keep_labels=True, **kwargs)
with use_style('empyre-plot'):
fig = plt.figure(figsize=figsize, constrained_layout=True)
gs = fig.add_gridspec(2, 1, height_ratios=[8, 1])
# Main plot:
axis = plt.subplot(gs[0])
axis.plot(self.fract, 'k')
axis.plot(self.red, 'r')
axis.plot(self.grn, 'g')
axis.plot(self.blu, 'b')
axis.set_xlim(0, self.nlev)
axis.set_ylim(0, 1)
axis.set_title('Cubehelix')
axis.set_xlabel('Color index')
axis.set_ylabel('Brightness / RGB')
axis.xaxis.set_major_locator(FixedLocator(locs=np.linspace(0, self.nlev, 5)))
axis.yaxis.set_major_locator(FixedLocator(locs=[0, 0.5, 1]))
# Colorbar horizontal:
caxis = plt.subplot(gs[1])
rgb = self(np.linspace(0, 1, 256))[None, ...]
rgb = np.asarray(255.9999 * rgb, dtype=np.uint8)
rgb = np.repeat(rgb, 30, axis=0)
im = Image.fromarray(rgb)
caxis.imshow(im, aspect='auto')
caxis.tick_params(axis='both', which='both', labelleft=False, labelbottom=False,
left=False, right=False, top=False, bottom=False)
class ColormapPerception(colors.LinearSegmentedColormap, Colormap3D):
......@@ -435,23 +408,23 @@ class ColormapClassic(colors.LinearSegmentedColormap, Colormap3D):
_log = logging.getLogger(__name__ + '.ColormapClassic')
CDICT = {'red': [(0.00, 1.0, 1.0),
(0.25, 0.0, 0.0),
(0.50, 0.0, 0.0),
(0.75, 1.0, 1.0),
(1.00, 1.0, 1.0)],
CDICT = {'red': [(0/4, 1.0, 1.0),
(1/4, 0.0, 0.0),
(2/4, 0.0, 0.0),
(3/4, 1.0, 1.0),
(4/4, 1.0, 1.0)],
'green': [(0.00, 0.0, 0.0),
(0.25, 0.0, 0.0),
(0.50, 1.0, 1.0),
(0.75, 1.0, 1.0),
(1.00, 0.0, 0.0)],
'green': [(0/4, 0.0, 0.0),
(1/4, 0.0, 0.0),
(2/4, 1.0, 1.0),
(3/4, 1.0, 1.0),
(4/4, 0.0, 0.0)],
'blue': [(0.00, 0.0, 0.0),
(0.25, 1.0, 1.0),
(0.50, 0.0, 0.0),
(0.75, 0.0, 0.0),
(1.00, 0.0, 0.0)]}
'blue': [(0/4, 0.0, 0.0),
(1/4, 1.0, 1.0),
(2/4, 0.0, 0.0),
(3/4, 0.0, 0.0),
(4/4, 0.0, 0.0)]}
def __init__(self):
self._log.debug('Calling __init__')
......@@ -494,609 +467,6 @@ class ColormapTransparent(colors.LinearSegmentedColormap):
self._log.debug('Created ' + str(self))
class ColorspaceCIELab(object): # TODO: Superclass?
"""Class representing the CIELab colorspace."""
_log = logging.getLogger(__name__ + '.ColorspaceCIELab')
def __init__(self, dim=(500, 500), extent=(-100, 100, -100, 100), cut_gamut=False, clip=True):
self._log.debug('Calling __init__')
self.dim = dim
self.extent = extent
self.cut_out_gamut = cut_gamut
self.clip = clip
self._log.debug('Created ' + str(self))
def plot(self, L=53.4, axis=None, figsize=None, **kwargs):
self._log.debug('Calling plot')
if figsize is None:
figsize = plottools.FIGSIZE_DEFAULT
dim, ext = self.dim, self.extent
# Create Lab colorspace:
a = np.linspace(ext[0], ext[1], dim[1])
b = np.linspace(ext[2], ext[3], dim[0])
aa, bb = np.meshgrid(a, b)
LL = np.full(dim, L, dtype=int)
Lab = np.stack((LL, aa, bb), axis=-1)
# Convert to XYZ colorspace:
XYZ = skcolor.lab2xyz(Lab)
# Convert to RGB colorspace following algorithm from http://www.easyrgb.com/index.php:
rgb = skcolor.colorconv._convert(skcolor.colorconv.rgb_from_xyz, XYZ)
# Gamma correction (gamma encoding) rgb are now nonlinear (R'G'B')!:
mask = rgb > 0.0031308
rgb[mask] = 1.055 * np.power(rgb[mask], 1 / 2.4) - 0.055
rgb[~mask] *= 12.92
# Determine gamut:
gamut = np.logical_or(rgb < 0, rgb > 1)
gamut = np.sum(gamut, axis=-1, dtype=bool)
gamut_mask = np.stack((gamut, gamut, gamut), axis=-1)
# Cut out gamut (set out of bound colors to gray) if necessary:
if self.cut_out_gamut:
rgb[gamut_mask] = 0.5
# Clip out of gamut colors:
if self.clip:
rgb[rgb < 0] = 0
rgb[rgb > 1] = 1
# Plot colorspace:
if axis is None:
fig = plt.figure(figsize=figsize)
axis = fig.add_subplot(1, 1, 1, aspect='equal')
axis.imshow(rgb, origin='lower', interpolation='none', extent=(0, dim[0], 0, dim[1]))
axis.contour(gamut, levels=[0], colors='k', linewidths=1.5)
axis.set_xlabel('a', fontsize=15)
axis.set_ylabel('b', fontsize=15)
axis.set_title('CIELab (L = {:g})'.format(L), fontsize=18)
axis.xaxis.set_major_locator(FixedLocator(np.linspace(0, dim[1], 5)))
axis.yaxis.set_major_locator(FixedLocator(np.linspace(0, dim[0], 5)))
fx = FuncForm(lambda x, pos: '{:.3g}'.format(x / dim[1] * (ext[1] - ext[0]) + ext[0]))
axis.xaxis.set_major_formatter(fx)
fy = FuncForm(lambda y, pos: '{:.3g}'.format(y / dim[0] * (ext[3] - ext[2]) + ext[2]))
axis.yaxis.set_major_formatter(fy)
plottools.format_axis(axis, scalebar=False, keep_labels=True, **kwargs)
def plot_colormap(self, cmap, N=256, L='auto', figsize=None, cbar_lim=None, brightness=True,
input_rec=None):
self._log.debug('Calling plot_colormap')
dim, ext = self.dim, self.extent
# Calculate rgb values:
rgb = cmap(np.linspace(0, 1, N))[None, :, :3] # These are R'G'B' values!
if input_rec == 601:
rgb = RGBConverter('Rec601', 'Rec709')(rgb)
# Convert to Lab space:
Lab = np.squeeze(skcolor.rgb2lab(rgb))
LL, aa, bb = Lab.T
aa = (aa - ext[0]) / (ext[1] - ext[0]) * dim[1]
bb = (bb - ext[2]) / (ext[3] - ext[2]) * dim[0]
# Determine number of images / luma levels:
LL_min, LL_max = np.round(np.min(LL), 1), np.round(np.max(LL), 1)
if L == 'auto':
if LL_max - LL_min < 0.1: # Just one image:
L = LL_min
else: # Two images:
L = np.asarray((LL_max, np.mean(LL), LL_min))
L_list = np.atleast_1d(L)
# Determine colorbar limits:
if cbar_lim is not None: # Overwrite limits!
LL_min, LL_max = cbar_lim
elif not brightness or LL_max - LL_min < 0.1: # Just one value, full range for colormap:
LL_min, LL_max = 0, 1
# Creat grid:
if figsize is None:
figsize = (len(L_list) * 5 + 2, 7)
fig = plt.figure(figsize=figsize)
grid = ImageGrid(fig, 111, nrows_ncols=(1, len(L_list)), axes_pad=0.4, share_all=False,
cbar_location="right", cbar_mode="single", cbar_size="5%", cbar_pad=0.25)
# Plot:
if brightness:
c = LL
cmap = 'gray'
else:
c = np.linspace(0, 1, N)
for i, axis in enumerate(grid):
self.plot(L=L_list[i], axis=axis)
im = axis.scatter(aa, bb, c=c, cmap=cmap, edgecolors='none',
vmin=LL_min, vmax=LL_max)
axis.set_xlim(0, self.dim[1])
axis.set_ylim(0, self.dim[0])
axis.cax.colorbar(im, ticks=np.linspace(LL_min, LL_max, 9))
def plot3d(self, N=9):
self._log.debug('Calling plot3d')
dim, ext = self.dim, self.extent
# Create Lab colorspace:
a = np.linspace(ext[0], ext[1], dim[1])
b = np.linspace(ext[2], ext[3], dim[0])
aa, bb = np.meshgrid(a, b)
import visvis # TODO: If VisPy is ever ready, switch every plot to that!
for i in range(1, N):
LL = np.full(dim, i * 100 / N, dtype=int)
Lab = np.stack((LL, aa, bb), axis=-1)
# Convert to XYZ colorspace:
XYZ = skcolor.lab2xyz(Lab)
# Convert to RGB colorspace following algorithm from http://www.easyrgb.com/index.php:
rgb = skcolor.colorconv._convert(skcolor.colorconv.rgb_from_xyz, XYZ)
# Gamma correction (gamma encoding) rgb are now nonlinear (R'G'B')!:
mask = rgb > 0.0031308
rgb[mask] = 1.055 * np.power(rgb[mask], 1 / 2.4) - 0.055
rgb[~mask] *= 12.92
# Determine gamut:
gamut = np.logical_or(rgb < 0, rgb > 1)
gamut = np.sum(gamut, axis=-1, dtype=bool)
# Alpha:
alpha = 1.
a = np.full(dim + (1,), alpha)
a *= np.logical_not(gamut[..., None])
rgba = np.asarray(255 * np.dstack((rgb, a)), dtype=np.uint8)
# Visvis plot:
obj = visvis.functions.surf(aa, bb, i * 100. / N * np.ones_like(aa), rgba, aa=0)
obj.parent.light0.ambient = 1.
obj.parent.light0.diffuse = 0.
class ColorspaceCIELuv(object):
"""Class representing the CIELuv colorspace."""
_log = logging.getLogger(__name__ + '.ColorspaceCIELuv')
def __init__(self, dim=(500, 500), extent=(-100, 100, -100, 100), cut_gamut=False, clip=True):
self._log.debug('Calling __init__')
self.dim = dim
self.extent = extent
self.cut_out_gamut = cut_gamut
self.clip = clip
self._log.debug('Created ' + str(self))
def plot(self, L=53.4, axis=None, figsize=None, **kwargs):
self._log.debug('Calling plot')
if figsize is None:
figsize = plottools.FIGSIZE_DEFAULT
dim, ext = self.dim, self.extent
# Create Lab colorspace:
u = np.linspace(ext[0], ext[1], dim[1])
v = np.linspace(ext[2], ext[3], dim[0])
uu, vv = np.meshgrid(u, v)
LL = np.full(dim, L, dtype=int)
Luv = np.stack((LL, uu, vv), axis=-1)
# Convert to XYZ colorspace:
XYZ = skcolor.luv2xyz(Luv)
# Convert to RGB colorspace following algorithm from http://www.easyrgb.com/index.php:
rgb = skcolor.colorconv._convert(skcolor.colorconv.rgb_from_xyz, XYZ)
# Gamma correction (gamma encoding) rgb are now nonlinear (R'G'B')!:
mask = rgb > 0.0031308
rgb[mask] = 1.055 * np.power(rgb[mask], 1 / 2.4) - 0.055
rgb[~mask] *= 12.92
# Determine gamut:
gamut = np.logical_or(rgb < 0, rgb > 1)
gamut = np.sum(gamut, axis=-1, dtype=bool)
gamut_mask = np.stack((gamut, gamut, gamut), axis=-1)
# Cut out gamut (set out of bound colors to gray) if necessary:
if self.cut_out_gamut:
rgb[gamut_mask] = 0.5
# Clip out of gamut colors:
if self.clip:
rgb[rgb < 0] = 0
rgb[rgb > 1] = 1
# Plot colorspace:
if axis is None:
fig = plt.figure(figsize=figsize)
axis = fig.add_subplot(1, 1, 1, aspect='equal')
axis.imshow(rgb, origin='lower', interpolation='none', extent=(0, dim[0], 0, dim[1]))
axis.contour(gamut, levels=[0], colors='k', linewidths=1.5)
axis.set_xlabel('u', fontsize=15)
axis.set_ylabel('v', fontsize=15)
axis.set_title('CIELuv (L = {:g})'.format(L), fontsize=18)
axis.xaxis.set_major_locator(FixedLocator(np.linspace(0, dim[1], 5)))
axis.yaxis.set_major_locator(FixedLocator(np.linspace(0, dim[0], 5)))
fx = FuncForm(lambda x, pos: '{:.3g}'.format(x / dim[1] * (ext[1] - ext[0]) + ext[0]))
axis.xaxis.set_major_formatter(fx)
fy = FuncForm(lambda y, pos: '{:.3g}'.format(y / dim[0] * (ext[3] - ext[2]) + ext[2]))
axis.yaxis.set_major_formatter(fy)
plottools.format_axis(axis, scalebar=False, keep_labels=True, **kwargs)
def plot_colormap(self, cmap, N=256, L='auto', figsize=None, cbar_lim=None, brightness=True,
input_rec=None):
self._log.debug('Calling plot_colormap')
dim, ext = self.dim, self.extent
# Calculate rgb values:
rgb = cmap(np.linspace(0, 1, N))[None, :, :3]
if input_rec == 601:
rgb = RGBConverter('Rec601', 'Rec709')(rgb)
# Convert to Lab space:
Luv = np.squeeze(skcolor.rgb2luv(rgb))
LL, uu, vv = Luv.T
uu = (uu - ext[0]) / (ext[1] - ext[0]) * dim[1]
vv = (vv - ext[2]) / (ext[3] - ext[2]) * dim[0]
# Determine number of images / luma levels:
LL_min, LL_max = np.round(np.min(LL), 1), np.round(np.max(LL), 1)
if L == 'auto':
if LL_max - LL_min < 0.1: # Just one image:
L = LL_min
else: # Two images:
L = np.asarray((LL_max, np.mean(LL), LL_min))
L_list = np.atleast_1d(L)
# Determine colorbar limits:
if cbar_lim is not None: # Overwrite limits!
LL_min, LL_max = cbar_lim
elif not brightness or LL_max - LL_min < 0.1: # Just one value, full range for colormap:
LL_min, LL_max = 0, 1
# Creat grid:
if figsize is None:
figsize = (len(L_list) * 5 + 2, 7)
fig = plt.figure(figsize=figsize)
grid = ImageGrid(fig, 111, nrows_ncols=(1, len(L_list)), axes_pad=0.4, share_all=False,
cbar_location="right", cbar_mode="single", cbar_size="5%", cbar_pad=0.25)
# Plot:
if brightness:
c = LL
cmap = 'gray'
else:
c = np.linspace(0, 1, N)
for i, axis in enumerate(grid):
self.plot(L=L_list[i], axis=axis)
im = axis.scatter(uu, vv, c=c, cmap=cmap, edgecolors='none',
vmin=LL_min, vmax=LL_max)
axis.set_xlim(0, self.dim[1])
axis.set_ylim(0, self.dim[0])
axis.cax.colorbar(im, ticks=np.linspace(LL_min, LL_max, 9))
def plot3d(self, N=9):
self._log.debug('Calling plot3d')
dim, ext = self.dim, self.extent
# Create Lab colorspace:
u = np.linspace(ext[0], ext[1], dim[1])
v = np.linspace(ext[2], ext[3], dim[0])
uu, vv = np.meshgrid(u, v)
import visvis
for i in range(1, N):
LL = np.full(dim, i * 100 / N, dtype=int)
Luv = np.stack((LL, uu, vv), axis=-1)
# Convert to XYZ colorspace:
XYZ = skcolor.luv2xyz(Luv)
# Convert to RGB colorspace following algorithm from http://www.easyrgb.com/index.php:
rgb = skcolor.colorconv._convert(skcolor.colorconv.rgb_from_xyz, XYZ)
# Gamma correction (gamma encoding) rgb are now nonlinear (R'G'B')!:
mask = rgb > 0.0031308
rgb[mask] = 1.055 * np.power(rgb[mask], 1 / 2.4) - 0.055
rgb[~mask] *= 12.92
# Determine gamut:
gamut = np.logical_or(rgb < 0, rgb > 1)
gamut = np.sum(gamut, axis=-1, dtype=bool)
# Alpha:
alpha = 1.
a = np.full(dim + (1,), alpha)
a *= np.logical_not(gamut[..., None])
rgba = np.asarray(255 * np.dstack((rgb, a)), dtype=np.uint8)
# Visvis plot:
obj = visvis.functions.surf(uu, vv, i * 100. / N * np.ones_like(uu), rgba, aa=0)
obj.parent.light0.ambient = 1.
obj.parent.light0.diffuse = 0.
class ColorspaceCIExyY(object):
"""Class representing the CIExyY colorspace."""
_log = logging.getLogger(__name__ + '.ColorspaceCIExyY')
def __init__(self, dim=(500, 500), extent=(0, 0.8, 0, 0.8), cut_gamut=False, clip=True):
self._log.debug('Calling __init__')
self.dim = dim
self.extent = extent
self.cut_out_gamut = cut_gamut
self.clip = clip
self._log.debug('Created ' + str(self))
def plot(self, Y=0.214, axis=None, figsize=None, **kwargs):
self._log.debug('Calling plot')
if figsize is None:
figsize = plottools.FIGSIZE_DEFAULT
dim, ext = self.dim, self.extent
# Create Lab colorspace:
x = np.linspace(ext[0], ext[1], dim[1])
y = np.linspace(ext[2], ext[3], dim[0])
xx, yy = np.meshgrid(x, y)
YY = np.full(dim, Y)
# Convert to XYZ:
XX = YY / (yy + 1e-30) * xx
ZZ = YY / (yy + 1e-30) * (1 - xx - yy)
XYZ = np.stack((XX, YY, ZZ), axis=-1)
# Convert to RGB colorspace following algorithm from http://www.easyrgb.com/index.php:
rgb = skcolor.colorconv._convert(skcolor.colorconv.rgb_from_xyz, XYZ)
# Gamma correction (gamma encoding) rgb are now nonlinear (R'G'B')!:
mask = rgb > 0.0031308
rgb[mask] = 1.055 * np.power(rgb[mask], 1 / 2.4) - 0.055
rgb[~mask] *= 12.92
# Determine gamut:
gamut = np.logical_or(rgb < 0, rgb > 1)
gamut = np.sum(gamut, axis=-1, dtype=bool)
gamut_mask = np.stack((gamut, gamut, gamut), axis=-1)
# Cut out gamut (set out of bound colors to gray) if necessary:
if self.cut_out_gamut:
rgb[gamut_mask] = 0.5
# Clip out of gamut colors:
if self.clip:
rgb[rgb < 0] = 0
rgb[rgb > 1] = 1
# Plot colorspace:
if axis is None:
fig = plt.figure(figsize=figsize)
axis = fig.add_subplot(1, 1, 1, aspect='equal')
axis.imshow(rgb, origin='lower', interpolation='none', extent=(0, dim[0], 0, dim[1]))
axis.contour(gamut, levels=[0], colors='k', linewidths=1.5)
axis.set_xlabel('x', fontsize=15)
axis.set_ylabel('y', fontsize=15)
axis.set_title('CIExyY (Y = {:g})'.format(Y), fontsize=18)
axis.xaxis.set_major_locator(FixedLocator(np.linspace(0, dim[1], 5)))
axis.yaxis.set_major_locator(FixedLocator(np.linspace(0, dim[0], 5)))
fx = FuncForm(lambda x, pos: '{:.3g}'.format(x / dim[1] * (ext[1] - ext[0]) + ext[0]))
axis.xaxis.set_major_formatter(fx)
fy = FuncForm(lambda y, pos: '{:.3g}'.format(y / dim[0] * (ext[3] - ext[2]) + ext[2]))
axis.yaxis.set_major_formatter(fy)
plottools.format_axis(axis, scalebar=False, keep_labels=True, **kwargs)
def plot_colormap(self, cmap, N=256, Y='auto', figsize=None, cbar_lim=None, brightness=True,
input_rec=None):
self._log.debug('Calling plot_colormap')
dim, ext = self.dim, self.extent
# Calculate rgb values:
rgb = cmap(np.linspace(0, 1, N))[None, :, :3]
if input_rec == 601:
rgb = RGBConverter('Rec601', 'Rec709')(rgb)
# Convert to XYZ space:
XYZ = np.squeeze(skcolor.rgb2xyz(rgb))
XX, YY, ZZ = XYZ.T
# Convert to xyY space:
xx = XX / (XX + YY + ZZ)
yy = YY / (XX + YY + ZZ)
xx = (xx - ext[0]) / (ext[1] - ext[0]) * dim[1]
yy = (yy - ext[2]) / (ext[3] - ext[2]) * dim[0]
# Determine number of images / luma levels:
YY_min, YY_max = np.round(np.min(YY), 2), np.round(np.max(YY), 2)
if Y == 'auto':
if YY_max - YY_min < 0.01: # Just one image:
Y = YY_min
else: # Two images:
Y = np.asarray((YY_max, np.mean(YY), YY_min))
Y_list = np.atleast_1d(Y)
# Determine colorbar limits:
if cbar_lim is not None: # Overwrite limits!
YY_min, YY_max = cbar_lim
elif not brightness or YY_max - YY_min < 0.01: # Just one value, full range for colormap:
YY_min, YY_max = 0, 1
# Creat grid:
if figsize is None:
figsize = (len(Y_list) * 5 + 2, 7)
fig = plt.figure(figsize=figsize)
grid = ImageGrid(fig, 111, nrows_ncols=(1, len(Y_list)), axes_pad=0.4, share_all=False,
cbar_location="right", cbar_mode="single", cbar_size="5%", cbar_pad=0.25)
# Plot:
if brightness:
c = YY
cmap = 'gray'
else:
c = np.linspace(0, 1, N)
for i, axis in enumerate(grid):
self.plot(Y=Y_list[i], axis=axis)
im = axis.scatter(xx, yy, c=c, cmap=cmap, edgecolors='none',
vmin=YY_min, vmax=YY_max)
axis.set_xlim(0, self.dim[1])
axis.set_ylim(0, self.dim[0])
axis.cax.colorbar(im, ticks=np.linspace(YY_min, YY_max, 9))
def plot3d(self, N=9):
self._log.debug('Calling plot3d')
dim, ext = self.dim, self.extent
# Create Lab colorspace:
x = np.linspace(ext[0], ext[1], dim[1])
y = np.linspace(ext[2], ext[3], dim[0])
xx, yy = np.meshgrid(x, y)
import visvis
for i in range(1, N):
YY = np.full(dim, i * 1. / N)
# Convert to XYZ:
XX = YY / (yy + 1e-30) * xx
ZZ = YY / (yy + 1e-30) * (1 - xx - yy)
XYZ = np.stack((XX, YY, ZZ), axis=-1)
# Convert to RGB colorspace following algorithm from http://www.easyrgb.com/index.php:
rgb = skcolor.colorconv._convert(skcolor.colorconv.rgb_from_xyz, XYZ)
# Gamma correction (gamma encoding) rgb are now nonlinear (R'G'B')!:
mask = rgb > 0.0031308
rgb[mask] = 1.055 * np.power(rgb[mask], 1 / 2.4) - 0.055
rgb[~mask] *= 12.92
# Determine gamut:
gamut = np.logical_or(rgb < 0, rgb > 1)
gamut = np.sum(gamut, axis=-1, dtype=bool)
# Alpha:
alpha = 1.
a = np.full(dim + (1,), alpha)
a *= np.logical_not(gamut[..., None])
rgba = np.asarray(255 * np.dstack((rgb, a)), dtype=np.uint8)
# Visvis plot:
obj = visvis.functions.surf(xx, yy, i / N * np.ones_like(xx), rgba, aa=0)
obj.parent.light0.ambient = 1.
obj.parent.light0.diffuse = 0.
class ColorspaceYPbPr(object):
"""Class representing the YPbPr colorspace."""
_log = logging.getLogger(__name__ + '.ColorspaceYPbPr')
def __init__(self, dim=(500, 500), extent=(-0.8, 0.8, -0.8, 0.8), cut_gamut=False, clip=True):
self._log.debug('Calling __init__')
self.dim = dim
self.extent = extent
self.cut_out_gamut = cut_gamut
self.clip = clip
self._log.debug('Created ' + str(self))
def plot(self, Y=0.5, axis=None, figsize=None, **kwargs):
self._log.debug('Calling plot')
if figsize is None:
figsize = plottools.FIGSIZE_DEFAULT
dim, ext = self.dim, self.extent
# Create YPbPr colorspace:
pb = np.linspace(ext[0], ext[1], dim[1])
pr = np.linspace(ext[2], ext[3], dim[0])
ppb, ppr = np.meshgrid(pb, pr)
YY = np.full(dim, Y) # This is luma, not relative luminance (Y', not Y)!
# Convert to RGB colorspace (this is the nonlinear R'G'B' space!):
rr = YY + 1.402 * ppr
gg = YY - 0.344136 * ppb - 0.7141136 * ppr
bb = YY + 1.772 * ppb
rgb = np.stack((rr, gg, bb), axis=-1)
# Determine gamut:
gamut = np.logical_or(rgb < 0, rgb > 1)
gamut = np.sum(gamut, axis=-1, dtype=bool)
gamut_mask = np.stack((gamut, gamut, gamut), axis=-1)
# Cut out gamut (set out of bound colors to gray) if necessary:
if self.cut_out_gamut:
rgb[gamut_mask] = 0.5
# Clip out of gamut colors:
if self.clip:
rgb[rgb < 0] = 0
rgb[rgb > 1] = 1
# Plot colorspace:
if axis is None:
fig = plt.figure(figsize=figsize)
axis = fig.add_subplot(1, 1, 1, aspect='equal')
axis.imshow(rgb, origin='lower', interpolation='none',
extent=(0, dim[0], 0, dim[1]))
axis.contour(gamut, levels=[0], colors='k', linewidths=1.5)
axis.set_xlabel('Pb', fontsize=15)
axis.set_ylabel('Pr', fontsize=15)
axis.set_title("Y'PbPr (Y' = {:g})".format(Y), fontsize=18)
axis.xaxis.set_major_locator(FixedLocator(np.linspace(0, dim[1], 5)))
axis.yaxis.set_major_locator(FixedLocator(np.linspace(0, dim[0], 5)))
fx = FuncForm(lambda x, pos: '{:.3g}'.format(x / dim[1] * (ext[1] - ext[0]) + ext[0]))
axis.xaxis.set_major_formatter(fx)
fy = FuncForm(lambda y, pos: '{:.3g}'.format(y / dim[0] * (ext[3] - ext[2]) + ext[2]))
axis.yaxis.set_major_formatter(fy)
plottools.format_axis(axis, scalebar=False, keep_labels=True, **kwargs)
def plot_colormap(self, cmap, N=256, Y='auto', figsize=None, cbar_lim=None, brightness=True,
input_rec=None):
self._log.debug('Calling plot_colormap')
dim, ext = self.dim, self.extent
# Calculate rgb values:
rgb = cmap(np.linspace(0, 1, N))[None, :, :3]
if input_rec == 709:
rgb = RGBConverter('Rec709', 'Rec601')(rgb)
rr, gg, bb = rgb.T
# Convert to YPbPr space:
k_r, k_g, k_b = 0.299, 0.587, 0.114 # Constants Rec.601!
YY = k_r * rr + k_g * gg + k_b * bb
ppb = (bb - YY) / (2 * (1 - k_b))
ppr = (rr - YY) / (2 * (1 - k_r))
ppb = (ppb - ext[0]) / (ext[1] - ext[0]) * dim[1]
ppr = (ppr - ext[2]) / (ext[3] - ext[2]) * dim[0]
# Determine number of images / luma levels:
YY_min, YY_max = np.round(np.min(YY), 2), np.round(np.max(YY), 2)
if Y == 'auto':
if YY_max - YY_min < 0.01: # Just one image:
Y = YY_min
else: # Two images:
Y = np.asarray((YY_max, np.mean(YY), YY_min))
Y_list = np.atleast_1d(Y)
# Determine colorbar limits:
if cbar_lim is not None: # Overwrite limits!
YY_min, YY_max = cbar_lim
elif not brightness or YY_max - YY_min < 0.01: # Just one value, full range for colormap:
YY_min, YY_max = 0, 1
# Creat grid:
if figsize is None:
figsize = (len(Y_list) * 5 + 2, 7)
fig = plt.figure(figsize=figsize)
grid = ImageGrid(fig, 111, nrows_ncols=(1, len(Y_list)), axes_pad=0.4, share_all=False,
cbar_location="right", cbar_mode="single", cbar_size="5%", cbar_pad=0.25)
# Plot:
if brightness:
c = YY
cmap = 'gray'
else:
c = np.linspace(0, 1, N)
for i, axis in enumerate(grid):
self.plot(Y=Y_list[i], axis=axis)
im = axis.scatter(ppb, ppr, c=c, cmap=cmap, edgecolors='none',
vmin=YY_min, vmax=YY_max)
axis.set_xlim(0, self.dim[1])
axis.set_ylim(0, self.dim[0])
axis.cax.colorbar(im, ticks=np.linspace(YY_min, YY_max, 9))
def plot3d(self, N=9):
self._log.debug('Calling plot3d')
dim, ext = self.dim, self.extent
# Create YPbPr colorspace:
pb = np.linspace(ext[0], ext[1], dim[1])
pr = np.linspace(ext[2], ext[3], dim[0])
ppb, ppr = np.meshgrid(pb, pr)
import visvis
for i in range(1, N):
YY = np.full(dim, i * 1. / N) # This is luma, not relative luminance (Y', not Y)!
# Convert to RGB colorspace (this is the nonlinear R'G'B' space!):
rr = YY + 1.402 * ppr
gg = YY - 0.344136 * ppb - 0.7141136 * ppr
bb = YY + 1.772 * ppb
rgb = np.stack((rr, gg, bb), axis=-1)
# Determine gamut:
gamut = np.logical_or(rgb < 0, rgb > 1)
gamut = np.sum(gamut, axis=-1, dtype=bool)
# Alpha:
alpha = 1.
a = np.full(dim + (1,), alpha)
a *= np.logical_not(gamut[..., None])
rgba = np.asarray(255 * np.dstack((rgb, a)), dtype=np.uint8)
# Visvis plot:
obj = visvis.functions.surf(ppb, ppr, i / N * np.ones_like(ppb), rgba, aa=0)
obj.parent.light0.ambient = 1.
obj.parent.light0.diffuse = 0.
class RGBConverter(object):
"""Class for the conversion of RGB values from one RGB space to another.
Notes
-----
This operates only on NONLINEAR R'G'B' values, normalised to a range of [0, 1]!
Convert from linear RGB values beforehand, if necessary!
"""
rgb601_to_ypbpr = np.array([[+0.299000, +0.587000, +0.114000],
[-0.168736, -0.331264, +0.500000],
[+0.500000, -0.418688, -0.081312]])
ypbr_to_rgb709 = np.array([[1, +0.0000, +1.5701],
[1, -0.1870, -0.4664],
[1, +1.8556, +0.0000]])
rgb601_to_rgb709 = ypbr_to_rgb709.dot(rgb601_to_ypbpr)
rgb709_to_rgb601 = np.linalg.inv(rgb601_to_rgb709)
def __init__(self, source='Rec601', target='Rec709'):
if source == 'Rec601' and target == 'Rec709':
self.convert_matrix = self.rgb601_to_rgb709
elif source == 'Rec709' and target == 'Rec601':
self.convert_matrix = self.rgb709_to_rgb601
else:
raise KeyError('Conversion from {} to {} not found!'.format(source, target))
def __call__(self, rgb):
"""Convert from one RGB space to another.
Parameters
----------
rgb: :class:`~numpy.ndarray`
Numpy array containing the RGB source values (last dimension: 3).
Returns
-------
rgb_result: :class:`~numpy.ndarray`
The resulting RGB values in the target space.
"""
rgb_out = rgb.reshape((-1, 3)).T
rgb_out = self.convert_matrix.dot(rgb_out)
return rgb_out.T.reshape(rgb.shape)
def interpolate_color(fraction, start, end):
"""Interpolate linearly between two color tuples (e.g. RGB).
......@@ -1127,104 +497,24 @@ def interpolate_color(fraction, start, end):
return r1, r2, r3
def rgb_to_brightness(rgb, mode="Y'", input_rec=None):
import colorspacious # TODO: Use for everything!
c = {601: [0.299, 0.587, 0.114], 709: [0.2125, 0.7154, 0.0721]} # Image.convert('L') uses 601!
if input_rec is None: # Not specified, use in all cases:
rgbp601 = rgb
rgbp709 = rgb
elif input_rec == 601:
rgbp601 = rgb
rgbp709 = RGBConverter('Rec601', 'Rec709')(rgb)
elif input_rec == 709:
rgbp601 = RGBConverter('Rec601', 'Rec709')(rgb)
rgbp709 = rgb
else:
raise KeyError('Input RGB type {} not understood!'.format(input_rec))
if mode in ("Y'", 'Luma'):
rp601, gp601, bp601 = rgbp601.T
brightness = c[601][0] * rp601 + c[601][1] * gp601 + c[601][2] * bp601
elif mode in ('Y', 'Luminance'):
rgb709 = colorspacious.cspace_converter('sRGB1', 'sRGB1-linear')(rgbp709)
r709, g709, b709 = rgb709.T
brightness = c[709][0] * r709 + c[709][1] * g709 + c[709][2] * b709
elif mode in ('L*', 'LightnessLab'):
lab = colorspacious.cspace_converter('sRGB1', 'CIELab')(rgbp709)
brightness = lab[0, :, 0]
elif mode in ('I', 'Intensity', 'Average'):
brightness = np.mean(rgb, axis=-1)
elif mode in ('V', 'Value', 'Maximum'):
brightness = np.max(rgb, axis=-1)
elif mode in ('L', 'LightnessHSL'):
brightness = (np.max(rgb, axis=-1) + np.min(rgb, axis=-1)) / 2
else:
raise KeyError('Brightness request {} not understood!'.format(mode))
return brightness
def colormap_brightness_comparison(cmap, input_rec=None, figsize=(18, 8)):
# Create R'G'B' values from colormap:
x = np.linspace(0, 1, 1000)
rgbp = cmap(x)[None, :, :3]
# Calculate different brightness measures:
luma = rgb_to_brightness(rgbp, mode="Y'", input_rec=input_rec)
luminance = rgb_to_brightness(rgbp, mode='Y', input_rec=input_rec)
lightness_lab = rgb_to_brightness(rgbp, mode='L*', input_rec=input_rec)
intensity = rgb_to_brightness(rgbp, mode='I', input_rec=input_rec)
value = rgb_to_brightness(rgbp, mode='V', input_rec=input_rec)
lightness_hls = rgb_to_brightness(rgbp, mode='L', input_rec=input_rec)
# Plot:
fig, grid = plt.subplots(2, 3, figsize=figsize)
plt.title(cmap.name)
axis = grid[0, 0]
axis.scatter(x, luma, c=x, cmap=cmap, s=200, linewidths=0.)
axis.axhline(y=0.5, color='k', ls='--')
axis.set_xlim(0, 1)
axis.set_ylim(0, 1)
axis.set_title("Luma $Y$ '")
axis = grid[0, 1]
axis.scatter(x, luminance, c=x, cmap=cmap, s=200, linewidths=0.)
axis.axhline(y=0.214, color='k', ls='--')
axis.set_xlim(0, 1)
axis.set_ylim(0, 1)
axis.set_title('Relative Luminance $Y$')
axis = grid[0, 2]
axis.scatter(x, lightness_lab, c=x, cmap=cmap, s=200, linewidths=0.)
axis.axhline(y=53.39, color='k', ls='--')
axis.set_xlim(0, 1)
axis.set_ylim(0, 100)
axis.set_title('Lightness $L^*$ (CIELab)')
axis = grid[1, 0]
axis.scatter(x, intensity, c=x, cmap=cmap, s=200, linewidths=0.)
axis.axhline(y=53.39, color='k', ls='--')
axis.set_xlim(0, 1)
axis.set_ylim(0, 1)
axis.set_title('Intensity $I$ (HSI Component Average)')
axis = grid[1, 1]
axis.scatter(x, value, c=x, cmap=cmap, s=200, linewidths=0.)
axis.axhline(y=53.39, color='k', ls='--')
axis.set_xlim(0, 1)
axis.set_ylim(0, 1)
axis.set_title('Value $V$ (HSV Component Maximum)')
axis = grid[1, 2]
axis.scatter(x, lightness_hls, c=x, cmap=cmap, s=200, linewidths=0.)
axis.axhline(y=53.39, color='k', ls='--')
axis.set_xlim(0, 1)
axis.set_ylim(0, 1)
axis.set_title('Lightness $L$ (HSL Min-Max-Average)')
cmaps = {'cubehelix_standard': ColormapCubehelix(),
'cubehelix_reverse': ColormapCubehelix(reverse=True),
'cubehelix_circular': ColormapCubehelix(start=1, rot=1,
minLight=0.5, maxLight=0.5, sat=2),
'perception_circular': ColormapPerception(),
'hls_circular': ColormapHLS(),
'classic_circular': ColormapClassic(),
'transparent_black': ColormapTransparent(0, 0, 0, [0, 1.]),
'transparent_white': ColormapTransparent(1, 1, 1, [0, 1.]),
'transparent_confidence': ColormapTransparent(0.2, 0.3, 0.2, [0.75, 0.])}
CMAP_CIRCULAR_DEFAULT = cmaps['cubehelix_circular']
class CMapNamespace(object):
def __init__(self):
self.cubehelix = ColormapCubehelix()
self.cubehelix_r = ColormapCubehelix(reverse=True)
self.cyclic_cubehelix = ColormapCubehelix(start=1, rot=1, minLight=0.5, maxLight=0.5, sat=2)
self.cyclic_perception = ColormapPerception()
self.cyclic_hls = ColormapHLS()
self.cyclic_classic = ColormapClassic()
self.transparent_black = ColormapTransparent(0, 0, 0, [0, 1.])
self.transparent_white = ColormapTransparent(1, 1, 1, [0, 1.])
self.transparent_confidence = ColormapTransparent(0.2, 0.3, 0.2, [0.75, 0.])
def __getitem__(self, key):
return self.__dict__[key]
def add_cmap_dict(self, **kwargs):
self.__dict__.update(kwargs)
cmaps = CMapNamespace()
# -*- coding: utf-8 -*-
# Copyright 2020 by Forschungszentrum Juelich GmbH
# Author: J. Caron
#
"""This module provides functions that decorate exisiting `matplotlib` plots."""
import logging
from collections.abc import Iterable
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import patheffects
from matplotlib.patches import Circle
from matplotlib.offsetbox import TextArea, AnchoredOffsetbox
from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from . import colors
from .tools import use_style
from ..fields.field import Field
__all__ = ['scalebar', 'colorwheel', 'annotate', 'quiverkey', 'coords', 'colorbar']
_log = logging.getLogger(__name__)
def scalebar(axis=None, unit='nm', loc='lower left', **kwargs):
"""Add a scalebar to the axis.
Parameters
----------
axis : :class:`~matplotlib.axes.AxesSubplot`, optional
Axis to which the scalebar is added, by default None, which will pick the last used axis via `gca`.
unit: str, optional
String that determines the unit of the scalebar, defaults to 'nm'.
loc : str or pair of floats, optional
The location of the scalebar, defaults to 'lower left'. See `matplotlib.legend` for possible settings.
Returns
-------
aoffbox : :class:`~matplotlib.offsetbox.AnchoredOffsetbox`
The box containing the scalebar.
Notes
-----
Additional kwargs are passed to `mpl_toolkits.axes_grid1.anchored_artists.AnchoredSizeBar`.
"""
_log.debug('Calling scalebar')
if axis is None: # If no axis is set, find the current or create a new one:
axis = plt.gca()
# Transform axis borders (1, 1) to data borders to get number of pixels in y and x:
transform = axis.transData
bb0 = axis.transLimits.inverted().transform((0, 0))
bb1 = axis.transLimits.inverted().transform((1, 1))
data_extent = (int(abs(bb1[1] - bb0[1])), int(abs(bb1[0] - bb0[0])))
# Calculate bar length:
bar_length = data_extent[1] / 4 # 25% of the data width!
thresholds = [1, 5, 10, 50, 100, 500, 1000]
for t in thresholds: # For larger grids (real images), multiples of threshold look better!
if bar_length > t: # Round down to the next lowest multiple of t:
bar_length = (bar_length // t) * t
# Set parameters for scale bar:
label = f'{bar_length:g} {unit}'
# Set defaults:
kwargs.setdefault('borderpad', 0.2)
kwargs.setdefault('pad', 0.2)
kwargs.setdefault('sep', 5)
kwargs.setdefault('color', 'w')
kwargs.setdefault('size_vertical', data_extent[0]*0.01)
kwargs.setdefault('frameon', False)
kwargs.setdefault('label_top', True)
kwargs.setdefault('fill_bar', True)
# Create scale bar:
scalebar = AnchoredSizeBar(transform, bar_length, label, loc, **kwargs)
scalebar.txt_label._text._color = 'w' # Overwrite AnchoredSizeBar color!
# Set stroke patheffect:
effect_txt = [patheffects.withStroke(linewidth=2, foreground='k')]
scalebar.txt_label._text.set_path_effects(effect_txt)
effect_bar = [patheffects.withStroke(linewidth=3, foreground='k')]
scalebar.size_bar._children[0].set_path_effects(effect_bar)
# Add scale bar to axis and return:
axis.add_artist(scalebar)
return scalebar
def colorwheel(axis=None, cmap=None, ax_size='20%', loc='upper right', **kwargs):
"""Add a colorwheel to the axis on the upper right corner.
Parameters
----------
axis : :class:`~matplotlib.axes.Axes`, optional
Axis to which the colorwheel is added, by default None, which will pick the last used axis via `gca`.
cmap : str or `matplotlib.colors.Colormap`, optional
The Colormap that should be used for the colorwheel, defaults to `None`, which chooses the
`.colors.cmaps.cyclic_cubehelix` colormap. Needs to be a :class:`~.colors.Colormap3D` to work correctly.
ax_size : str or float, optional
String or float determining the size of the inset axis used, defaults to `20%`.
loc : str or pair of floats, optional
The location of the colorwheel, defaults to 'upper right'. See `matplotlib.legend` for possible settings.
Returns
-------
axis : :class:`~matplotlib.image.AxesImage`
The colorwheel image that was created.
Notes
-----
Additional kwargs are passed to :class:`~.colors.Colormap3D.plot_colorwheel` of the :class:`~.colors.Colormap3D`.
"""
_log.debug('Calling colorwheel')
if axis is None: # If no axis is set, find the current or create a new one:
axis = plt.gca()
ins_axes = inset_axes(axis, width=ax_size, height=ax_size, loc=loc)
ins_axes.axis('off')
if cmap is None:
cmap = colors.cmaps.cyclic_cubehelix
plt.sca(axis) # Set focus back to parent axis!
return cmap.plot_colorwheel(axis=ins_axes, **kwargs)
def annotate(label, axis=None, loc='upper left'):
"""Add an annotation to the axis on the upper left corner.
Parameters
----------
label : string
The text of the annotation.
axis : :class:`~matplotlib.axes.AxesSubplot`, optional
Axis to which the annotation is added, by default None, which will pick the last used axis via `gca`.
loc : str or pair of floats, optional
The location of the annotation, defaults to 'upper left'. See `matplotlib.legend` for possible settings.
Returns
-------
aoffbox : :class:`~matplotlib.offsetbox.AnchoredOffsetbox`
The box containing the annotation.
"""
_log.debug('Calling annotate')
if axis is None: # If no axis is set, find the current or create a new one:
axis = plt.gca()
# Create text:
txt = TextArea(label, textprops={'color': 'w'})
txt.set_clip_box(axis.bbox)
txt._text.set_path_effects([patheffects.withStroke(linewidth=2, foreground='k')])
# Pack into and add AnchoredOffsetBox:
aoffbox = AnchoredOffsetbox(loc=loc, pad=0.5, borderpad=0.1, child=txt, frameon=False)
axis.add_artist(aoffbox)
return aoffbox
def quiverkey(quiv, field, axis=None, unit='', loc='lower right', **kwargs):
"""Add a quiver key to an axis.
Parameters
----------
quiv : Quiver instance
The quiver instance returned by a call to quiver.
field : `Field` or ndarray
The vector data as a `Field` or a numpy array (in the latter case, `vector=True` and `scale=1.0` are assumed).
axis : :class:`~matplotlib.axes.AxesSubplot`, optional
Axis to which the quiverkey is added, by default None, which will pick the last used axis via `gca`.
unit: str, optional
String that determines the unit of the quiverkey, defaults to ''.
loc : str or pair of floats, optional
The location of the quiverkey, defaults to 'lower right'. See `matplotlib.legend` for possible settings.
Returns
-------
qk: Quiverkey
The generated quiverkey.
Notes
-----
Additional kwargs are passed to `matplotlib.pyplot.quiverkey`.
"""
_log.debug('Calling quiverkey')
if axis is None: # If no axis is set, find the current or create a new one:
axis = plt.gca()
if not isinstance(field, Field): # Try to convert input to Field if it is not already one:
field = Field(data=np.asarray(field), scale=1.0, vector=True)
length = field.amp.data.max()
shift = 1 / field.squeeze().dim[1] # equivalent to one pixel distance in axis coords!
label = f'{length:.3g} {unit}'
if loc in ('upper right', 1):
X, Y, labelpos = 0.95-shift, 0.95-shift/4, 'W'
elif loc in ('upper left', 2):
X, Y, labelpos = 0.05+shift, 0.95-shift/4, 'E'
elif loc in ('lower left', 3):
X, Y, labelpos = 0.05+shift, 0.05+shift/4, 'E'
elif loc in ('lower right', 4):
X, Y, labelpos = 0.95-shift, 0.05+shift/4, 'W'
else:
raise ValueError('Quiverkey can only be placed in one of the corners (number 1 - 4 or associated strings)!')
# Set defaults:
kwargs.setdefault('coordinates', 'axes')
kwargs.setdefault('facecolor', 'w')
kwargs.setdefault('edgecolor', 'k')
kwargs.setdefault('labelcolor', 'w')
kwargs.setdefault('linewidth', 1)
kwargs.setdefault('clip_box', axis.bbox)
kwargs.setdefault('clip_on', True)
# Plot:
qk = axis.quiverkey(quiv, X, Y, U=1, label=label, labelpos=labelpos, **kwargs)
qk.text.set_path_effects([patheffects.withStroke(linewidth=2, foreground='k')])
return qk
def coords(axis=None, coords=('x', 'y'), loc='lower left', **kwargs):
"""Add coordinate arrows to an axis.
Parameters
----------
axis : :class:`~matplotlib.axes.AxesSubplot`, optional
Axis to which the coordinates are added, by default None, which will pick the last used axis via `gca`.
coords : tuple or int, optional
Tuple of strings determining the labels, by default ('x', 'y'). Can also be `2` or `3` which expands to
('x', 'y') or ('x', 'y', 'z'). The length of `coords` determines the number of arrows (2 or 3).
loc : str, optional
[description], by default 'lower left'
Returns
-------
ins_axes : :class:`~matplotlib.axes.Axes`
The created inset axes containing the coordinates.
"""
_log.debug('Calling coords')
if axis is None: # If no axis is set, find the current or create a new one:
axis = plt.gca()
ins_ax = inset_axes(axis, width='5%', height='5%', loc=loc, borderpad=2.2)
if coords == 3:
coords = ('x', 'y', 'z')
elif coords == 2:
coords = ('x', 'y')
effects = [patheffects.withStroke(linewidth=2, foreground='k')]
kwargs.setdefault('fc', 'w')
kwargs.setdefault('ec', 'k')
kwargs.setdefault('head_width', 0.6)
kwargs.setdefault('head_length', 0.7)
kwargs.setdefault('linewidth', 1)
kwargs.setdefault('width', 0.2)
if len(coords) == 3:
ins_ax.arrow(x=0.5, y=0.5, dx=-1.05, dy=-0.75, clip_on=False, **kwargs)
ins_ax.arrow(x=0.5, y=0.5, dx=0.96, dy=-0.75, clip_on=False, **kwargs)
ins_ax.arrow(x=0.5, y=0.5, dx=0, dy=1.35, clip_on=False, **kwargs)
ins_ax.annotate(coords[0], xy=(0, 0), xytext=(-1.0, 0.3), path_effects=effects, color='w')
ins_ax.annotate(coords[1], xy=(0, 0), xytext=(1.7, 0.3), path_effects=effects, color='w')
ins_ax.annotate(coords[2], xy=(0, 0), xytext=(0.8, 1.5), path_effects=effects, color='w')
ins_ax.add_artist(Circle((0.5, 0.5), 0.12, fc='w', ec='k', linewidth=1, clip_on=False))
elif len(coords) == 2:
ins_ax.arrow(x=-0.5, y=-0.5, dx=1.5, dy=0, clip_on=False, **kwargs)
ins_ax.arrow(x=-0.5, y=-0.5, dx=0, dy=1.5, clip_on=False, **kwargs)
ins_ax.annotate(coords[0], xy=(0, 0), xytext=(1.3, -0.05), path_effects=effects, color='w')
ins_ax.annotate(coords[1], xy=(0, 0), xytext=(-0.1, 1.1), path_effects=effects, color='w')
ins_ax.add_artist(Circle((-0.5, -0.5), 0.12, fc='w', ec='k', linewidth=1, clip_on=False))
ins_ax.axis('off')
plt.sca(axis)
return coords
def colorbar(im, fig=None, cbar_axis=None, axes=None, position='right', pad=0.02, thickness=0.03, label=None,
constrain_ticklabels=True, ticks=None, ticklabels=None):
"""Creates a colorbar, aligned with figure axes.
Parameters
----------
im : matplotlib object, mappable
Mappable matplotlib object.
fig : matplotlib.figure object, optional
The figure object that contains the matplotlib axes and artists, by default None, which will pick the last used
figure via `gcf`.
axes : matplotlib.axes or list of matplotlib.axes
The axes object(s), where the colorbar is drawn, by default None, which will pick the last used axis via `gca`.
Only provide those axes, which the colorbar should span over.
position : str, optional
The position defines the location of the colorbar. One of 'top', 'bottom', 'left' or 'right' (default).
pad : float, optional
Defines the spacing between the axes and colorbar axis. Is given in figure fraction.
thickness : float, optional
Thickness of the colorbar given in figure fraction.
label : string, optional
Colorbar label, defaults to None.
constrain_ticklabels : bool, optional
Allows to slightly shift the outermost ticklabels, such that they do not exceed the cbar axis, defaults to True.
ticks : list, np.ndarray, optional
List of cbar ticks, defaults to None.
ticklabels : list, np.ndarray, optional
List of cbar ticklabels, defaults to None.
Returns
-------
cbar : :class:`~matplotlib.Colorbar`
The created colorbar.
Notes
-----
Based on a modified snippet by Florian Winkler. Note that this function TURNS OFF constrained layout, therefore it
should be the final command before finishing or saving a figure. The colorbar will be outside the original bounds
of your constructed figure. If you set the size, e.g. with `~empyre.vis.tools.new`, make sure to account for the
additional space by setting the `width_scale` to something smaller than 1 (e.g. 0.9).
"""
_log.debug('Calling colorbar')
assert position in ('left', 'right', 'top', 'bottom'), "position has to be 'left', 'right', 'top' or 'bottom'!"
if fig is None: # If no figure is set, find the current or create a new one:
fig = plt.gcf()
fig.canvas.draw() # Trigger a draw so that a potential constrained_layout is executed once!
fig.set_constrained_layout(False) # we don't want the layout to change after this point!
if axes is None: # If no axis is set, find the current or create a new one:
axes = plt.gca()
if not isinstance(axes, Iterable):
axes = (axes,) # Make sure it is an iterable (e.g. tuple)!
# Save previously active axis for later:
previous_axis = plt.gca()
if cbar_axis is None: # Construct a new cbar_axis:
x_coords, y_coords = [], []
# Find bounds of all individual axes:
for ax in np.ravel(axes): # ravel needed for arrays of axes:
points = ax.get_position().get_points()
x_coords.extend(points[:, 0])
y_coords.extend(points[:, 1])
# Find outer bounds of plotting area:
left, right = min(x_coords), max(x_coords)
bottom, top = min(y_coords), max(y_coords)
# Determine where the colorbar will be placed:
if position == 'right':
bounds = [right+pad, bottom, thickness, top-bottom]
elif position == 'left':
bounds = [left-pad-thickness, bottom, thickness, top-bottom]
if position == 'top':
bounds = [left, top+pad, right-left, thickness]
elif position == 'bottom':
bounds = [left, bottom-pad-thickness, right-left, thickness]
cbar_axis = fig.add_axes(bounds)
# Create the colorbar:
with use_style('empyre-image'):
if position in ('left', 'right'):
cb = plt.colorbar(im, cax=cbar_axis, orientation='vertical')
cb.ax.yaxis.set_ticks_position(position)
cb.ax.yaxis.set_label_position(position)
elif position in ('top', 'bottom'):
cb = plt.colorbar(im, cax=cbar_axis, orientation='horizontal')
cb.ax.xaxis.set_ticks_position(position)
cb.ax.xaxis.set_label_position(position)
# Colorbar label
if label is not None:
cb.set_label(f'{label}')
# Set ticks and ticklabels (if specified):
if ticks:
cb.set_ticks(ticks)
if ticklabels:
cb.set_ticklabels(ticklabels)
# Constrain tick labels (if wanted):
if constrain_ticklabels:
if position == 'top' or position == 'bottom':
t = cb.ax.get_xticklabels()
t[0].set_horizontalalignment('left')
t[-1].set_horizontalalignment('right')
elif position == 'left' or position == 'right':
t = cb.ax.get_yticklabels()
t[0].set_verticalalignment('bottom')
t[-1].set_verticalalignment('top')
# Set focus back from colorbar to previous axis and return colorbar:
plt.sca(previous_axis)
return cb
### MATPLOTLIB STYLESHEET FOR EMPYRE IMAGES
font.family : serif ## default font family (use serifs)
font.serif : cm ## Computer Modern (LaTeX font)
xtick.top : False ## draw ticks on the top side
xtick.bottom : False ## draw ticks on the bottom side
xtick.labeltop : False ## draw label on the top
xtick.labelbottom : False ## draw label on the bottom
ytick.left : False ## draw ticks on the left side
ytick.right : False ## draw ticks on the right side
ytick.labelleft : False ## draw tick labels on the left side
ytick.labelright : False ## draw tick labels on the right side
figure.figsize : 3.6, 3.6 ## figure size in inches
figure.dpi : 200 ## figure dots per inch
figure.constrained_layout.use : True ## use constrained layout
image.origin : lower ## lower | upper
### MATPLOTLIB STYLESHEET FOR EMPYRE PLOTS
font.family : serif ## default font family (use serifs)
font.serif : cm ## Computer Modern (LaTeX font)
figure.figsize : 3.6, 2.2 ## figure size in inches
figure.dpi : 200 ## figure dots per inch
figure.constrained_layout.use : True ## use constrained layout
### MATPLOTLIB STYLESHEET FOR SAVING EMPYRE IMAGES AND PLOTS
font.family : serif ## default font family (use serifs)
font.serif : cm ## Computer Modern (LaTeX font)
savefig.dpi : 200 ## figure dots per inch or 'figure'
savefig.facecolor : white ## figure facecolor when saving
savefig.edgecolor : white ## figure edgecolor when saving
savefig.format : pdf ## png, ps, pdf, svg
savefig.bbox : tight ## 'tight' or 'standard'.
savefig.pad_inches : 0.01 ## Padding to be used when bbox is set to 'tight'
savefig.jpeg_quality : 95 ## when a jpeg is saved, the default quality parameter.
savefig.directory : ~ ## default directory in savefig dialog box, leave empty to always use cwd
savefig.transparent : False ## controls whether figures are saved with transp. background by default
savefig.orientation : portrait ## Orientation of saved figure
# -*- coding: utf-8 -*-
# Copyright 2019 by Forschungszentrum Juelich GmbH
# Author: J. Caron
#
"""This module provides functions for 2D plots that often wrap functions from `maptlotlib.pyplot`."""
import logging
import warnings
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import TwoSlopeNorm
from PIL import Image
from . import colors
from .tools import use_style
from ..fields.field import Field
__all__ = ['imshow', 'contour', 'colorvec', 'cosine_contours', 'quiver']
_log = logging.getLogger(__name__)
DIVERGING_CMAPS = ['PiYG', 'PRGn', 'BrBG', 'PuOr', 'RdGy', 'RdBu', 'RdYlBu', 'RdYlGn',
'Spectral', 'coolwarm', 'bwr', 'seismic', # all divergent maps from matplotlib!
'balance', 'delta', 'curl', 'diff', 'tarn'] # all divergent maps from cmocean!
# TODO: add seaborn and more?
def imshow(field, axis=None, cmap=None, **kwargs):
"""Display an image on a 2D regular raster. Wrapper for `matplotlib.pyplot.imshow`.
Parameters
----------
field : `Field` or ndarray
The image data as a `Field` or a numpy array (in the latter case, `vector=False` and `scale=1.0` are assumed).
axis : `matplotlib.axes.Axes` object, optional
The axis to which the image should be added, by default None, which will pick the last use axis via `gca`.
cmap : str or `matplotlib.colors.Colormap`, optional
The Colormap that should be used for the display, either as a string or object, by default None, which will pick
`cmocean.cm.balance` if available. `imshow` will automatically detect if a divergent colormap is used and will
make sure that zero is pinned to the symmetry point of the colormap (this is done by creating a new colormap
with custom range under the hood).
Returns
-------
axis : `matplotlib.axes.Axes`
The plotting axis.
Notes
-----
Additional kwargs are passed to :meth:`~matplotlib.pyplot.imshow`.
Note that the y-axis of the plot is flipped in comparison to :meth:`~matplotlib.pyplot.imshow`, i.e. that the
origin is `'lower'` in this case instead of `'upper'`.
Uses the `empyre-image` stylesheet settings for plotting (and axis creation if none exists, yet).
Fields are squeezed before plotting, so non-2D fields work as long as their superfluous dimensions have length 1.
"""
_log.debug('Calling imshow')
if not isinstance(field, Field): # Try to convert input to Field if it is not already one:
field = Field(data=np.asarray(field), scale=1.0, vector=False)
assert not field.vector, 'Can only plot scalar fields!'
# Get squeezed data and make sure it's 2D scalar:
squeezed_field = field.squeeze()
assert len(squeezed_field.dim) == 2, 'Cannot plot more than 2 dimensions!'
# Determine colormap and related important properties and flags:
if cmap is None:
try:
import cmocean
cmap = cmocean.cm.balance
except ImportError:
_log.info('cmocean.balance not found, fallback to rRdBu!')
cmap = plt.get_cmap('RdBu_r') # '_r' for reverse!
elif isinstance(cmap, str): # make sure we have a Colormap object (and not a string):
cmap = plt.get_cmap(cmap)
if cmap.name.replace('_r', '') in DIVERGING_CMAPS: # 'replace' also matches reverted cmaps!
kwargs.setdefault('norm', TwoSlopeNorm(0)) # Diverging colormap should have zero at the symmetry point!
# Set extent in data coordinates (left, right, bottom, top) to kwargs (if not set explicitely):
dim_v, dim_u = squeezed_field.dim
s_v, s_u = squeezed_field.scale
kwargs.setdefault('extent', (0, dim_u * s_u, 0, dim_v * s_v))
# Plot with the empyre style context:
with use_style('empyre-image'): # Only works on axes created WITHIN context!
if axis is None: # If no axis is set, find the current or create a new one:
axis = plt.gca()
return axis.imshow(squeezed_field, cmap=cmap, **kwargs)
def contour(field, axis=None, **kwargs):
"""Plot contours. Wrapper for `matplotlib.pyplot.contour`.
Parameters
----------
field : `Field` or ndarray
The contour data as a `Field` or a numpy array (in the latter case, `vector=False` and `scale=1.0` are assumed).
axis : `matplotlib.axes.Axes` object, optional
The axis to which the contour should be added, by default None, which will pick the last use axis via `gca`.
Returns
-------
axis : `matplotlib.axes.Axes`
The plotting axis.
Notes
-----
Additional kwargs are passed to `matplotlib.pyplot.contour`.
Note that the y-axis of the plot is flipped in comparison to :meth:`~matplotlib.pyplot.imshow`, i.e. that the
origin is `'lower'` in this case instead of `'upper'`.
Uses the `empyre-image` stylesheet settings for plotting (and axis creation if none exists, yet).
Fields are squeezed before plotting, so non-2D fields work as long as their superfluous dimensions have length 1.
"""
_log.debug('Calling contour')
if not isinstance(field, Field): # Try to convert input to Field if it is not already one:
field = Field(data=np.asarray(field), scale=1.0, vector=False)
assert not field.vector, 'Can only plot scalar fields!'
# Get squeezed data and make sure it's 2D scalar:
squeezed_field = field.squeeze()
assert len(squeezed_field.dim) == 2, 'Cannot plot more than 2 dimensions!'
# Create coordinates (respecting the field scale, +0.5: pixel center!):
vv, uu = (np.indices(squeezed_field.dim) + 0.5) * np.asarray(squeezed_field.scale)[:, None, None]
# Set kwargs defaults without overriding possible user input:
kwargs.setdefault('levels', [0.5])
kwargs.setdefault('colors', 'k')
kwargs.setdefault('linestyles', 'dotted')
kwargs.setdefault('linewidths', 2)
# Plot with the empyre style context:
with use_style('empyre-image'): # Only works on axes created WITHIN context!
if axis is None: # If no axis is set, find the current or create a new one:
axis = plt.gca()
axis.set_aspect('equal')
return axis.contour(uu, vv, squeezed_field.data, **kwargs)
def colorvec(field, axis=None, **kwargs):
"""Plot an image of a 2D vector field with up to 3 components by color encoding the vector direction.
In-plane directions are encoded via hue ("color wheel"), making sure that all in-plane directions are isoluminant
(i.e. a greyscale image would result a homogeneously medium grey image). Out-of-plane directions are encoded via
brightness with upwards pointing vectors being white and downward pointing vectors being black. The length of the
vectors are encoded via saturation, with full saturation being fully chromatic (in-plane) or fully white/black
(up/down). The center of the "color sphere" desaturated in a medium gray and encodes vectors with length zero.
Parameters
----------
field : `Field` or ndarray
The image data as a `Field` or a numpy array (in the latter case, `vector=True` and `scale=1.0` are assumed).
axis : `matplotlib.axes.Axes` object, optional
The axis to which the image should be added, by default None, which will pick the last use axis via `gca`.
Returns
-------
axis : `matplotlib.axes.Axes`
The plotting axis.
Notes
-----
Additional kwargs are passed to `matplotlib.pyplot.imshow`.
Note that the y-axis of the plot is flipped in comparison to :meth:`~matplotlib.pyplot.imshow`, i.e. that the
origin is `'lower'` in this case instead of `'upper'`.
Uses the `empyre-image` stylesheet settings for plotting (and axis creation if none exists, yet).
Fields are squeezed before plotting, so non-2D fields work as long as their superfluous dimensions have length 1.
Even though squeezing takes place, `colorvec` "remembers" the original orientation of the slice! This is important
if you want to plot a slice that should not represent the xy-plane. The colors chosen will respect the original
orientation of your slice, e.g. a vortex in the xz-plane will include black and white colors (up/down) if the
`Field` object given as the `field` parameter has `dim=(128, 1, 128)`. If you want to plot a slice of a 3D vector
with 3 components and make use of this functionality, make sure to not use an integer as an index, as that will
drop the dimension BEFORE it is passed to `colorvec`, which will have no way of knowing which dimension was dropped.
Instead, make sure to use a slice of length one (example with `dim=(128, 128, 128)`):
>>> colorvec(field[:, 15, :]) # Wrong! Shape: (128, 128), interpreted as xy-plane!
>>> colorvec(field[:, 15:16, :]) # Right! Shape: (128, 1, 128), passed as 3D to `colorvec`, squeezed internally!
"""
_log.debug('Calling colorvec')
if not isinstance(field, Field): # Try to convert input to Field if it is not already one:
field = Field(data=np.asarray(field), scale=1.0, vector=True)
assert field.vector, 'Can only plot vector fields!'
assert len(field.dim) <= 3, 'Unusable for vector fields with dimension higher than 3!'
# Get squeezed data and make sure it's 2D scalar:
squeezed_field = field.squeeze()
assert len(squeezed_field.dim) == 2, 'Cannot plot more than 2 dimensions!'
# Extract vector components (fill 3rd component with zeros if field.comp is only 2):
comp = squeezed_field.comp
x_comp = comp[0]
y_comp = comp[1]
z_comp = comp[2] if (squeezed_field.ncomp == 3) else np.zeros(squeezed_field.dim)
# Calculate image with color encoded directions:
cmap = kwargs.pop('cmap', None)
if cmap is None:
cmap = colors.cmaps.cyclic_cubehelix
rgb = cmap.rgb_from_vector(np.stack((x_comp, y_comp, z_comp), axis=0))
# Set extent in data coordinates (left, right, bottom, top) to kwargs (if not set explicitely):
dim_v, dim_u = squeezed_field.dim
s_v, s_u = squeezed_field.scale
kwargs.setdefault('extent', (0, dim_u * s_u, 0, dim_v * s_v))
# Plot with the empyre style context:
with use_style('empyre-image'): # Only works on axes created WITHIN context!
if axis is None: # If no axis is set, find the current or create a new one:
axis = plt.gca()
return axis.imshow(Image.fromarray(rgb), **kwargs)
def cosine_contours(field, axis=None, gain='auto', cmap=None, **kwargs):
"""Plots the cosine of the (amplified) field. Wrapper for `matplotlib.pyplot.imshow`.
Parameters
----------
field : `Field` or ndarray
The contour data as a `Field` or a numpy array (in the latter case, `vector=False` and `scale=1.0` are assumed).
axis : `matplotlib.axes.Axes` object, optional
The axis to which the contour should be added, by default None, which will pick the last use axis via `gca`.
gain : float or 'auto', optional
Gain factor with which the `Field` is amplified before taking the cosine, by default 'auto', which calculates a
gain factor that would produce roughly 4 cosine contours.
cmap : str or `matplotlib.colors.Colormap`, optional
The Colormap that should be used for the display, either as a string or object, by default None, which will pick
`colors.cmaps['transparent_black']` that will alternate between regions with alpha=0, showing layers below and
black contours.
Returns
-------
axis : `matplotlib.axes.Axes`
The plotting axis.
Notes
-----
Additional kwargs are passed to `matplotlib.pyplot.imshow`.
Uses the `empyre-image` stylesheet settings for plotting (and axis creation if none exists, yet).
Fields are squeezed before plotting, so non-2D fields work as long as their superfluous dimensions have length 1.
"""
_log.debug('Calling cosine_contours')
if not isinstance(field, Field): # Try to convert input to Field if it is not already one:
field = Field(data=np.asarray(field), scale=1.0, vector=False)
assert not field.vector, 'Can only plot scalar fields!'
# Get squeezed data and make sure it's 2D scalar:
squeezed_field = field.squeeze()
assert len(squeezed_field.dim) == 2, 'Cannot plot more than 2 dimensions (Squeezing did not help)!'
# Determine colormap and related important properties and flags:
if cmap is None:
cmap = colors.cmaps['transparent_black']
# Calculate gain if 'auto' is selected:
if gain == 'auto':
gain = 4 * 2*np.pi / (squeezed_field.amp.data.max() + 1E-30) # 4: roughly 4 contours!
gain = round(gain, -int(np.floor(np.log10(abs(gain))))) # Round to last significant digit!
_log.info(f'Automatically calculated a gain of: {gain}')
# Calculate the contours:
contours = np.cos(gain * squeezed_field) # Range: [-1, 1]
contours += 1 # Shift to positive values
contours /= 2 # Rescale to [0, 1]
# Set extent in data coordinates (left, right, bottom, top) to kwargs (if not set explicitely):
dim_v, dim_u = squeezed_field.dim
s_v, s_u = squeezed_field.scale
kwargs.setdefault('extent', (0, dim_u * s_u, 0, dim_v * s_v))
# Plot with the empyre style context:
with use_style('empyre-image'): # Only works on axes created WITHIN context!
if axis is None: # If no axis is set, find the current or create a new one:
axis = plt.gca()
return axis.imshow(contours, cmap=cmap, **kwargs)
def quiver(field, axis=None, color_angles=False, cmap=None, n_bin='auto', bin_with_mask=True, **kwargs):
"""Plot a 2D field of arrows. Wrapper for `matplotlib.pyplot.imshow`.
Parameters
----------
field : `Field` or ndarray
The vector data as a `Field` or a numpy array (in the latter case, `vector=True` and `scale=1.0` are assumed).
axis : `matplotlib.axes.Axes` object, optional
The axis to which the image should be added, by default None, which will pick the last use axis via `gca`.
color_angles : bool, optional
Switch that turns on color encoding of the arrows, by default False. Encoding works the same as for the
`colorvec` function (see for details). If False, arrows are uniformly colored white with black border. In both
cases, the amplitude is encoded via the transparency of the arrow.
cmap : str or `matplotlib.colors.Colormap`, optional
The Colormap that should be used for the arrows, either as a string or object, by default None. Will only be
used if `color_angles=True`.
n_bin : float or 'auto', optional
Number of entries along each axis over which the average is taken, by default 'auto', which automatically
determines a bin size resulting in roughly 16 arrows along the largest dimension. Usually sensible to leave
this on to not clutter the image with too many arrows (also due to performance). Can be turned off by setting
`n_bin=1`. Uses the `..fields.field.Field.bin` method.
bin_with_mask : bool, optional
If True (default) and if `n_bin>1`, entries of the constructed binned `Field` that averaged over regions that
were outside the `..fields.field.Field.mask` will not be assigned an arrow and stay empty instead. This prevents
errouneous "fade-out" effects of the arrows that would occur even for homogeneous objects.
Returns
-------
quiv : Quiver instance
The quiver instance that was created.
Notes
-----
Additional kwargs are passed to `matplotlib.pyplot.quiver`.
Uses the `empyre-image` stylesheet settings for plotting (and axis creation if none exists, yet).
Fields are squeezed before plotting, so non-2D fields work as long as their superfluous dimensions have length 1.
Even though squeezing takes place, `quiver` "remembers" the original orientation of the slice and which dimensions
were squeezed! See `colorvec` for more information and an example (the same principles apply here, too).
The transparency of the arrows denotes the 3D(!) amplitude, if you see dots in the plot, that means the amplitude
is not zero, but simply out of the current plane!
"""
_log.debug('Calling quiver')
if not isinstance(field, Field): # Try to convert input to Field if it is not already one:
field = Field(data=np.asarray(field), scale=1.0, vector=True)
assert field.vector, 'Can only plot vector fields!'
assert len(field.dim) <= 3, 'Unusable for vector fields with dimension higher than 3!'
if len(field.dim) < field.ncomp:
warnings.warn('Assignment of vector components to dimensions is ambiguous!'
f'`ncomp` ({field.ncomp}) should match `len(dim)` ({len(field.dim)})!'
'If you want to plot a slice of a 3D volume, make sure to use `from:to` notation!')
# Get squeezed data and make sure it's 2D scalar:
squeezed_field = field.squeeze()
assert len(squeezed_field.dim) == 2, 'Cannot plot more than 2 dimensions (Squeezing did not help)!'
# Determine binning size if necessary:
if n_bin == 'auto':
n_bin = int(np.max((1, np.max(squeezed_field.dim) / 16)))
# Save old limits in case binning has to use padding:
u_lim = squeezed_field.dim[1] * squeezed_field.scale[1]
v_lim = squeezed_field.dim[0] * squeezed_field.scale[0]
# Bin if necessary:
if n_bin > 1:
field_mask = squeezed_field.mask # Get mask BEFORE binning!
squeezed_field = squeezed_field.bin(n_bin)
if bin_with_mask: # Excludes regions where in and outside are binned together!
mask = (field_mask.bin(n_bin) == 1)
squeezed_field *= mask
# Extract normalized vector components (fill 3rd component with zeros if field.comp is only 2):
normalised_comp = (squeezed_field / squeezed_field.amp.data.max()).comp
amplitude = squeezed_field.amp.data / squeezed_field.amp.data.max()
x_comp = normalised_comp[0].data
y_comp = normalised_comp[1].data
z_comp = normalised_comp[2].data if (field.ncomp == 3) else np.zeros(squeezed_field.dim)
# Create coordinates (respecting the field scale, +0.5: pixel center!):
vv, uu = (np.indices(squeezed_field.dim) + 0.5) * np.asarray(squeezed_field.scale)[:, None, None]
# Calculate the arrow colors:
if color_angles: # Color angles according to calculated RGB values (only with circular colormaps):
_log.debug('Encoding angles')
if cmap is None:
cmap = colors.cmaps.cyclic_cubehelix
rgb = cmap.rgb_from_vector(np.asarray((x_comp, y_comp, z_comp))) / 255
rgba = np.concatenate((rgb, amplitude[..., None]), axis=-1)
kwargs.setdefault('color', rgba.reshape(-1, 4))
else: # Color amplitude with numeric values, according to cmap, overrides 'color':
_log.debug('Encoding amplitudes')
if cmap is None:
cmap = colors.cmaps['transparent_white']
C = amplitude # Numeric values, used with cmap!
# Check which (if any) indices were squeezed to find out which components are passed to quiver: # TODO: TEST!!!
squeezed_indices = np.flatnonzero(np.asarray(field.dim) == 1)
if not squeezed_indices: # Separate check, because in this case squeezed_indices == []:
u_comp = x_comp
v_comp = y_comp
elif squeezed_indices[0] == 0: # Slice of the xy-plane with z squeezed:
u_comp = x_comp
v_comp = y_comp
elif squeezed_indices[0] == 1: # Slice of the xz-plane with y squeezed:
u_comp = x_comp
v_comp = z_comp
elif squeezed_indices[0] == 2: # Slice of the zy-plane with x squeezed:
u_comp = y_comp
v_comp = z_comp
# Set specific defaults for quiver kwargs:
kwargs.setdefault('edgecolor', colors.cmaps['transparent_black'](amplitude).reshape(-1, 4))
kwargs.setdefault('scale', 1/np.max(squeezed_field.scale))
kwargs.setdefault('width', np.max(squeezed_field.scale))
kwargs.setdefault('clim', (0, 1))
kwargs.setdefault('pivot', 'middle')
kwargs.setdefault('units', 'xy')
kwargs.setdefault('scale_units', 'xy')
kwargs.setdefault('minlength', 0.05)
kwargs.setdefault('headlength', 2)
kwargs.setdefault('headaxislength', 2)
kwargs.setdefault('headwidth', 2)
kwargs.setdefault('minshaft', 2)
kwargs.setdefault('linewidths', 1)
# Plot with the empyre style context:
with use_style('empyre-image'): # Only works on axes created WITHIN context!
if axis is None: # If no axis is set, find the current or create a new one:
axis = plt.gca()
axis.set_xlim(0, u_lim)
axis.set_ylim(0, v_lim)
axis.set_aspect('equal')
if color_angles:
return axis.quiver(uu, vv, np.asarray(u_comp), np.asarray(v_comp), cmap=cmap, **kwargs)
else:
return axis.quiver(uu, vv, np.asarray(u_comp), np.asarray(v_comp), C, cmap=cmap, **kwargs)
# -*- coding: utf-8 -*-
# Copyright 2020 by Forschungszentrum Juelich GmbH
# Author: J. Caron
#
"""This module provides functions for 3D plots based on the `mayavi` library."""
import logging
import numpy as np
from . import colors
__all__ = ['contour3d', 'mask3d', 'quiver3d']
_log = logging.getLogger(__name__)
# TODO: Docstrings and signature!
def contour3d(field, title='Field Distribution', contours=10, opacity=0.25, size=None, new_fig=True, **kwargs):
"""Plot a field as a 3D-contour plot.
Parameters
----------
title: string, optional
The title for the plot.
contours: int, optional
Number of contours which should be plotted.
opacity: float, optional
Defines the opacity of the contours. Default is 0.25.
Returns
-------
plot : :class:`mayavi.modules.vectors.Vectors`
The plot object.
"""
_log.debug('Calling contour3d')
try:
from mayavi import mlab
except ImportError:
_log.error('This extension recquires the mayavi package!')
return
if size is None:
size = (750, 700)
if new_fig:
mlab.figure(size=size, bgcolor=(0.5, 0.5, 0.5), fgcolor=(0., 0., 0.))
zzz, yyy, xxx = np.indices(field.dim) + np.reshape(field.scale, (3, 1, 1, 1)) / 2 # shifted by half of scale!
zzz, yyy, xxx = zzz.T, yyy.T, xxx.T # Transpose because of VTK order!
field_amp = field.amp.data.T # Transpose because of VTK order!
if not isinstance(contours, (list, tuple, np.ndarray)): # Calculate the contours:
contours = list(np.linspace(field_amp.min(), field_amp.max(), contours))
extent = np.ravel(list(zip((0, 0, 0), field_amp.shape)))
cont = mlab.contour3d(xxx, yyy, zzz, field_amp, contours=contours, opacity=opacity, **kwargs)
mlab.outline(cont, extent=extent)
mlab.axes(cont, extent=extent)
mlab.title(title, height=0.95, size=0.35)
mlab.orientation_axes()
cont.scene.isometric_view()
return cont
def mask3d(field, title='Mask', threshold=0, grid=True, labels=True,
orientation=True, size=None, new_fig=True, **kwargs):
"""Plot the mask as a 3D-contour plot.
Parameters
----------
title: string, optional
The title for the plot.
threshold : float, optional
A pixel only gets masked, if it lies above this threshold . The default is 0.
Returns
-------
plot : :class:`mayavi.modules.vectors.Vectors`
The plot object.
"""
_log.debug('Calling mask3d')
try:
from mayavi import mlab
except ImportError:
_log.error('This extension recquires the mayavi package!')
return
if size is None:
size = (750, 700)
if new_fig:
mlab.figure(size=size, bgcolor=(0.5, 0.5, 0.5), fgcolor=(0., 0., 0.))
zzz, yyy, xxx = np.indices(field.dim) + np.reshape(field.scale, (3, 1, 1, 1)) / 2 # shifted by half of scale!
zzz, yyy, xxx = zzz.T, yyy.T, xxx.T # Transpose because of VTK order!
mask = field.mask.data.T.astype(int) # Transpose because of VTK order!
extent = np.ravel(list(zip((0, 0, 0), mask.shape)))
cont = mlab.contour3d(xxx, yyy, zzz, mask, contours=[1], **kwargs)
if grid:
mlab.outline(cont, extent=extent)
if labels:
mlab.axes(cont, extent=extent)
mlab.title(title, height=0.95, size=0.35)
if orientation:
oa = mlab.orientation_axes()
oa.marker.set_viewport(0, 0, 0.4, 0.4)
mlab.draw()
engine = mlab.get_engine()
scene = engine.scenes[0]
scene.scene.isometric_view()
return cont
def quiver3d(field, title='Vector Field', limit=None, cmap=None, mode='2darrow',
coloring='angle', ar_dens=1, opacity=1.0, grid=True, labels=True,
orientation=True, size=(700, 750), new_fig=True, view='isometric',
position=None, bgcolor=(0.5, 0.5, 0.5)):
"""Plot the vector field as 3D-vectors in a quiverplot.
Parameters
----------
title : string, optional
The title for the plot.
limit : float, optional
Plotlimit for the vector field arrow length used to scale the colormap.
cmap : string, optional
String describing the colormap which is used for color encoding (uses `~.colors.cmaps.cyclic_cubehelix` if
left on the `None` default) or amplitude encoding (uses 'jet' if left on the `None` default).
ar_dens: int, optional
Number defining the arrow density which is plotted. A higher ar_dens number skips more
arrows (a number of 2 plots every second arrow). Default is 1.
mode: string, optional
Mode, determining the glyphs used in the 3D plot. Default is '2darrow', which
corresponds to 2D arrows. For smaller amounts of arrows, 'arrow' (3D) is prettier.
coloring : {'angle', 'amplitude'}, optional
Color coding mode of the arrows. Use 'angle' (default) or 'amplitude'.
opacity: float, optional
Defines the opacity of the arrows. Default is 1.0 (completely opaque).
Returns
-------
plot : :class:`mayavi.modules.vectors.Vectors`
The plot object.
"""
_log.debug('Calling quiver_plot3D')
try:
from mayavi import mlab
except ImportError:
_log.error('This extension recquires the mayavi package!')
return
if limit is None:
limit = np.max(np.nan_to_num(field.amp))
ad = ar_dens
# Create points and vector components as lists:
zzz, yyy, xxx = (np.indices(field.dim) + 1 / 2)
zzz = zzz[::ad, ::ad, ::ad].ravel()
yyy = yyy[::ad, ::ad, ::ad].ravel()
xxx = xxx[::ad, ::ad, ::ad].ravel()
x_mag = field.data[::ad, ::ad, ::ad, 0].ravel()
y_mag = field.data[::ad, ::ad, ::ad, 1].ravel()
z_mag = field.data[::ad, ::ad, ::ad, 2].ravel()
# Plot them as vectors:
if new_fig:
mlab.figure(size=size, bgcolor=(0.5, 0.5, 0.5), fgcolor=(0., 0., 0.))
if coloring == 'angle': # Encodes the full angle via colorwheel and saturation:
_log.debug('Encoding full 3D angles')
vecs = mlab.quiver3d(xxx, yyy, zzz, x_mag, y_mag, z_mag, mode=mode, opacity=opacity,
scalars=np.arange(len(xxx)), line_width=2)
vector = np.asarray((x_mag.ravel(), y_mag.ravel(), z_mag.ravel()))
if cmap is None:
cmap = colors.cmaps.cyclic_cubehelix
rgb = cmap.rgb_from_vector(vector)
rgba = np.hstack((rgb, 255 * np.ones((len(xxx), 1), dtype=np.uint8)))
vecs.glyph.color_mode = 'color_by_scalar'
vecs.module_manager.scalar_lut_manager.lut.table = rgba
mlab.draw()
elif coloring == 'amplitude': # Encodes the amplitude of the arrows with the jet colormap:
_log.debug('Encoding amplitude')
if cmap is None:
cmap = 'jet'
vecs = mlab.quiver3d(xxx, yyy, zzz, x_mag, y_mag, z_mag,
mode=mode, colormap=cmap, opacity=opacity, line_width=2)
mlab.colorbar(label_fmt='%.2f')
mlab.colorbar(orientation='vertical')
else:
raise AttributeError('Coloring mode not supported!')
vecs.glyph.glyph_source.glyph_position = 'center'
vecs.module_manager.vector_lut_manager.data_range = np.array([0, limit])
extent = np.ravel(list(zip((0, 0, 0), (field.dim[2], field.dim[1], field.dim[0]))))
if grid:
mlab.outline(vecs, extent=extent)
if labels:
mlab.axes(vecs, extent=extent)
mlab.title(title, height=0.95, size=0.35)
if orientation:
oa = mlab.orientation_axes()
oa.marker.set_viewport(0, 0, 0.4, 0.4)
mlab.draw()
engine = mlab.get_engine()
scene = engine.scenes[0]
if view == 'isometric':
scene.scene.isometric_view()
elif view == 'x_plus_view':
scene.scene.x_plus_view()
elif view == 'y_plus_view':
scene.scene.y_plus_view()
if position:
scene.scene.camera.position = position
return vecs
# -*- coding: utf-8 -*-
# Copyright 2020 by Forschungszentrum Juelich GmbH
# Author: J. Caron
#
"""This module provides helper functions to the vis module."""
import os
import glob
import shutil
import logging
from numbers import Number
from contextlib import contextmanager
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from ..fields.field import Field
__all__ = ['new', 'savefig', 'calc_figsize', 'use_style', 'copy_mpl_stylesheets']
_log = logging.getLogger(__name__)
def new(nrows=1, ncols=1, mode='image', figsize=None, textwidth=None, width_scale=1.0, aspect=None, **kwargs):
R"""Convenience function for the creation of a new subplot grid (wraps `~matplotlib.pyplot.subplots`).
If you use the `textwidth` parameter, plot sizes are fitting into publications with LaTeX. Requires two stylesheets
`empyre-image` and `empyre-plot` corresponding to its two `mode` settings. Those stylesheets use
`constrained_layout=True` to achieve well behaving plots without much whitespace around. This function should work
fine for a small number of images (e.g. 1, 2x2, etc.), for more fine grained control, the contexts can be used
directly if they are installed corretly, or use width_scale to build the images separately (e.g. 2 adjacent with
width=0.5). For images, it is assumed that most images are square (and therefore `aspect=1`).
Parameters
----------
nrows : int, optional
Number of rows of the subplot grid, by default 1
ncols : int, optional
Number of columns of the subplot grid, by default 1
mode : {'image', 'plot'}, optional
Mode of the new subplot grid, by default 'image'. Both modes have dedicated matplotlib styles which are used
and which are installed together with EMPyRe. The 'image' mode disables axis labels and ticks, mainly intended
to be used with `~matplotlib.pyplot.imshow` with `~empyre.vis.decorators.scalebar`, while the 'plot'
mode should be used for traditional plots like with `~matplotlib.pyplot.plot` or `~matplotlib.pyplot.scatter`.
figsize : (float, float), optional
Width and height of the figure in inches, defaults to rcParams["figure.figsize"], which depends on the chosen
stylesheet. If set, this will overwrite all other following parameters.
textwidth : float, optional
The textwidth of your LaTeX document in points, which you can get by using :math:`\the\textwidth`. If this is
not None (the default), this will be used to define the figure size if it is not set explicitely.
width_scale : float, optional
Only meaningful if `textwidth` is set. If it is, `width_scale` will be a scaling factor for the figure width.
Example: if you set this to 0.5, your figure will span half of the textwidth. Default is 1.
aspect : float, optional
Aspect ratio of the figure height relative to the figure width. If None (default), the aspect is set to be 1
for `mode=image` and to 'golden' for `mode=plot`, which adjusts the aspect to represent the golden ratio of
0.6180... If `ncols!=nrows`, it often makes sense to use `aspect=nrows/ncols` here.
Returns
-------
fig : :class:`~matplotlib.figure.Figure`
The constructed figure.
axes : axes.Axes object or array of Axes objects.
axes can be either a single Axes object or an array of Axes objects if more than one subplot was created.
The dimensions of the resulting array can be controlled with the squeeze keyword argument.
Notes
-----
additional kwargs are passed to `~matplotlib.pyplot.subplots`.
"""
_log.debug('Calling new')
assert mode in ('image', 'plot'), "mode has to be 'image', or 'plot'!"
with use_style(f'empyre-{mode}'):
if figsize is None:
if aspect is None:
aspect = 'golden' if mode == 'plot' else 1 # Both image modes have 'same' as default'!
elif isinstance(aspect, Field):
dim_uv = [d for d in aspect.dim if d != 1]
assert len(dim_uv) == 2, f"Couldn't find field aspect ({len(dim_uv)} squeezed dimensions, has to be 2)!"
aspect = dim_uv[0]/dim_uv[1] # height/width
else:
assert isinstance(aspect, Number), 'aspect has to be None, a number or field instance squeezable to 2D!'
figsize = calc_figsize(textwidth=textwidth, width_scale=width_scale, aspect=aspect)
return plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize, **kwargs)
def savefig(fname, **kwargs):
"""Utility wrapper around :func:`~matplotlib.pyplot.savefig` to save the current figure.
Parameters
----------
fname : str or PathLike or file-like object
Path to the file wherein the figure should be saved.
Notes
-----
Uses the 'empyre-save' stylesheet (installed together with EMPyRe to control the saving behaviour. Any kwargs are
passed to :func:`~matplotlib.pyplot.savefig`.
"""
_log.debug('Calling savefig')
with use_style('empyre-save'):
plt.savefig(fname, **kwargs)
def calc_figsize(textwidth=None, width_scale=1.0, aspect=1):
R"""Helper function to calculate the figure size from various parameters. Useful for publications via LaTeX.
Parameters
----------
textwidth : float, optional
The textwidth of your LaTeX document in points, which you can get by using :math:`\the\textwidth`. If this is
None (default), the standard width in inches from the current stylesheet is used.
width_scale : float, optional
Scaling factor for the figure width. Example: if you set this to 0.5, your figure will span half of the
textwidth. Default is 1.
aspect : float, optional
Aspect ratio of the figure height relative to the figure width. If None (default), the aspect is set to be 1
for `mode=image` and to 'golden' for `mode=plot`, which adjusts the aspect to represent the golden ratio of
0.6180...
Returns
-------
figsize: (float, float)
The determined figure size
Notes
-----
Based on snippet by Florian Winkler.
"""
_log.debug('Calling calc_figsize')
GOLDEN_RATIO = (1 + np.sqrt(5)) / 2 # Aesthetic ratio!
INCHES_PER_POINT = 1.0 / 72.27 # Convert points to inch, LaTeX constant, apparently...
if textwidth is not None:
textwidth_in = textwidth * INCHES_PER_POINT # Width of the text in inches
else: # If textwidth is not given, use the default from rcParams:
textwidth_in = mpl.rcParams["figure.figsize"][0]
fig_width = textwidth_in * width_scale # Width in inches
if aspect == 'golden':
fig_height = fig_width / GOLDEN_RATIO
elif isinstance(aspect, Number):
fig_height = textwidth_in * aspect
else:
raise ValueError(f"aspect has to be either a number, or 'golden'! Was {aspect}!")
fig_size = [fig_width, fig_height] # Both in inches
return fig_size
@contextmanager
def use_style(stylename):
"""Context that uses a matplotlib stylesheet. Can fall back to local mpl stylesheets if necessary!
Parameters
----------
stylename : str
A style specification.
Yields
-------
context
Context manager for using style settings temporarily.
"""
try: # Try to load the style directly (works if it is installed somewhere mpl looks for it):
with plt.style.context(stylename) as context:
yield context
except OSError: # Stylesheet not found, use local ones:
mplstyle_path = os.path.join(os.path.dirname(__file__), 'mplstyles', f'{stylename}.mplstyle')
with plt.style.context(mplstyle_path) as context:
yield context
def copy_mpl_stylesheets():
"""Copy matplotlib styles to the users matplotlib config directory. Useful if you want to utilize them elsewhere.
Notes
-----
You might need to restart your Python session for the stylesheets to be recognized/found!
"""
# Find matplotlib styles:
user_stylelib_path = os.path.join(mpl.get_configdir(), 'stylelib')
vis_dir = os.path.dirname(__file__)
style_files = glob.glob(os.path.join(vis_dir, 'mplstyles', '*.mplstyle'))
# Copy them to the local matplotlib styles folder:
if not os.path.exists(user_stylelib_path):
os.makedirs(user_stylelib_path)
for style_path in style_files:
_, fname = os.path.split(style_path)
dest = os.path.join(user_stylelib_path, fname)
shutil.copy(style_path, dest)
import os
import pytest
import numpy as np
from empyre.fields import Field
@pytest.fixture
def fielddata_path():
return os.path.join(os.path.dirname(os.path.realpath(__file__)), 'test_fielddata')
@pytest.fixture
def vector_data():
magnitude = np.zeros((4, 4, 4, 3))
magnitude[1:-1, 1:-1, 1:-1] = 1
return Field(magnitude, 10.0, vector=True)
@pytest.fixture
def vector_data_asymm():
shape = (5, 7, 11, 3)
data = np.linspace(0, 1, np.prod(shape))
return Field(data.reshape(shape), 10.0, vector=True)
@pytest.fixture
def vector_data_asymm_2d():
shape = (5, 7, 2)
data = np.linspace(0, 1, np.prod(shape))
return Field(data.reshape(shape), 10.0, vector=True)
@pytest.fixture
def vector_data_asymmcube():
shape = (3, 3, 3, 3)
data = np.linspace(0, 1, np.prod(shape))
return Field(data.reshape(shape), 10.0, vector=True)
@pytest.fixture
def scalar_data():
magnitude = np.zeros((4, 4, 4))
magnitude[1:-1, 1:-1, 1:-1] = 1
return Field(magnitude, 10.0, vector=False)
@pytest.fixture
def scalar_data_asymm():
shape = (5, 7, 2)
data = np.linspace(0, 1, np.prod(shape))
return Field(data.reshape(shape), 10.0, vector=False)
# -*- coding: utf-8 -*-
"""Testcase for the magdata module."""
import pytest
from numbers import Number
import numpy as np
import numpy.testing
from empyre.fields import Field
from utils import assert_allclose
def test_copy(vector_data):
vector_data = vector_data.copy()
# Make sure it is a new object
assert vector_data != vector_data, 'Unexpected behaviour in copy()!'
assert np.allclose(vector_data, vector_data)
def test_bin(vector_data):
binned_data = vector_data.bin(2)
reference = 1 / 8. * np.ones((2, 2, 2, 3))
assert_allclose(binned_data, reference,
err_msg='Unexpected behavior in scale_down()!')
assert_allclose(binned_data.scale, (20, 20, 20),
err_msg='Unexpected behavior in scale_down()!')
def test_zoom(vector_data):
zoomed_test = vector_data.zoom(2, order=0)
reference = np.zeros((8, 8, 8, 3))
reference[2:6, 2:6, 2:6] = 1
assert_allclose(zoomed_test, reference,
err_msg='Unexpected behavior in zoom()!')
assert_allclose(zoomed_test.scale, (5, 5, 5),
err_msg='Unexpected behavior in zoom()!')
@pytest.mark.parametrize(
'mode', [
'constant',
'edge',
'wrap'
]
)
@pytest.mark.parametrize(
'pad_width,np_pad', [
(1, ((1, 1), (1, 1), (1, 1), (0, 0))),
((1, 2, 3), ((1, 1), (2, 2), (3, 3), (0, 0))),
(((1, 2), (3, 4), (5, 6)), ((1, 2), (3, 4), (5, 6), (0, 0)))
]
)
def test_pad(vector_data, mode, pad_width, np_pad):
magdata_test = vector_data.pad(pad_width, mode=mode)
reference = np.pad(vector_data, np_pad, mode=mode)
assert_allclose(magdata_test, reference,
err_msg='Unexpected behavior in pad()!')
@pytest.mark.parametrize(
'axis', [-1, 3]
)
def test_component_reduction(vector_data, axis):
# axis=-1 is supposed to reduce over the component dimension, if it exists. axis=3 should do the same here!
res = np.sum(vector_data, axis=axis)
ref = np.zeros((4, 4, 4))
ref[1:-1, 1:-1, 1:-1] = 3
assert res.shape == ref.shape, 'Shape mismatch!'
assert_allclose(res, ref, err_msg="Unexpected behavior of axis keyword")
assert isinstance(res, Field), 'Result is not a Field object!'
assert not res.vector, 'Result is a vector field, but should be reduced to a scalar!'
@pytest.mark.parametrize(
'axis', [(0, 1, 2), (2, 1, 0), None, (-4, -3, -2)]
)
def test_full_reduction(vector_data, axis):
res = np.sum(vector_data, axis=axis)
ref = np.zeros((3,))
ref[:] = 8
assert res.shape == ref.shape
assert_allclose(res, ref, err_msg="Unexpected behavior of full or default reduction")
assert isinstance(res, np.ndarray)
@pytest.mark.parametrize(
'axis', [-1, 2]
)
def test_last_reduction_scalar(scalar_data, axis):
# axis=-1 is supposed to reduce over the component dimension if it exists.
# In this case it doesn't!
res = np.sum(scalar_data, axis=axis)
ref = np.zeros((4, 4))
ref[1:-1, 1:-1] = 2
assert res.shape == ref.shape
assert_allclose(res, ref, err_msg="Unexpected behavior of axis keyword")
assert isinstance(res, Field)
assert not res.vector
@pytest.mark.parametrize(
'axis', [(0, 1, 2), (2, 1, 0), None, (-1, -2, -3)]
)
def test_full_reduction_scalar(scalar_data, axis):
res = np.sum(scalar_data, axis=axis)
ref = 8
assert res.shape == ()
assert_allclose(res, ref, err_msg="Unexpected behavior of full or default reduction")
assert isinstance(res, Number)
def test_binary_operator_vector_number(vector_data):
res = vector_data + 1
ref = np.ones((4, 4, 4, 3))
ref[1:-1, 1:-1, 1:-1] = 2
assert res.shape == ref.shape
assert_allclose(res, ref, err_msg="Unexpected behavior of addition")
assert isinstance(res, Field)
assert res.vector
def test_binary_operator_vector_scalar(vector_data, scalar_data):
res = vector_data + scalar_data
ref = np.zeros((4, 4, 4, 3))
ref[1:-1, 1:-1, 1:-1] = 2
assert res.shape == ref.shape
assert_allclose(res, ref, err_msg="Unexpected behavior of addition")
assert isinstance(res, Field)
assert res.vector
def test_binary_operator_vector_vector(vector_data):
res = vector_data + vector_data
ref = np.zeros((4, 4, 4, 3))
ref[1:-1, 1:-1, 1:-1] = 2
assert res.shape == ref.shape
assert_allclose(res, ref, err_msg="Unexpected behavior of addition")
assert isinstance(res, Field)
assert res.vector
@pytest.mark.xfail
def test_binary_operator_vector_broadcast(vector_data):
# Broadcasting between vector fields is currently not implemented
second = np.zeros((4, 4, 3))
second[1:-1, 1:-1] = 1
second = Field(second, 10.0, vector=True)
res = vector_data + second
ref = np.zeros((4, 4, 4, 3))
ref[1:-1, 1:-1, 1:-1] = 1
ref[:, 1:-1, 1:-1] += 1
assert res.shape == ref.shape
assert_allclose(res, ref, err_msg="Unexpected behavior of addition")
assert isinstance(res, Field)
assert res.vector
def test_mask(vector_data):
mask = vector_data.mask
reference = np.zeros((4, 4, 4))
reference[1:-1, 1:-1, 1:-1] = True
assert_allclose(mask, reference,
err_msg='Unexpected behavior in mask attribute!')
def test_get_vector(vector_data):
mask = vector_data.mask
vector = vector_data.get_vector(mask)
reference = np.ones(np.sum(mask) * 3)
assert_allclose(vector, reference,
err_msg='Unexpected behavior in get_vector()!')
def test_set_vector(vector_data):
mask = vector_data.mask
vector = 2 * np.ones(np.sum(mask) * 3)
vector_data.set_vector(vector, mask)
reference = np.zeros((4, 4, 4, 3))
reference[1:-1, 1:-1, 1:-1] = 2
assert_allclose(vector_data, reference,
err_msg='Unexpected behavior in set_vector()!')
def test_flip(vector_data_asymm):
field_flipz = vector_data_asymm.flip(0)
field_flipy = vector_data_asymm.flip(1)
field_flipx = vector_data_asymm.flip(2)
field_flipxy = vector_data_asymm.flip((1, 2))
field_flipdefault = vector_data_asymm.flip()
field_flipcomp = vector_data_asymm.flip(-1)
assert_allclose(np.flip(vector_data_asymm.data, axis=0) * [1, 1, -1], field_flipz.data,
err_msg='Unexpected behavior in flip()! (z)')
assert_allclose(np.flip(vector_data_asymm.data, axis=1) * [1, -1, 1], field_flipy.data,
err_msg='Unexpected behavior in flip()! (y)')
assert_allclose(np.flip(vector_data_asymm.data, axis=2) * [-1, 1, 1], field_flipx.data,
err_msg='Unexpected behavior in flip()! (x)')
assert_allclose(np.flip(vector_data_asymm.data, axis=(1, 2)) * [-1, -1, 1], field_flipxy.data,
err_msg='Unexpected behavior in flip()! (xy)')
assert_allclose(np.flip(vector_data_asymm.data, axis=(0, 1, 2)) * [-1, -1, -1], field_flipdefault.data,
err_msg='Unexpected behavior in flip()! (default)')
assert_allclose(np.flip(vector_data_asymm.data, axis=-1) * [1, 1, 1], field_flipcomp.data,
err_msg='Unexpected behavior in flip()! (components)')
def test_unknown_num_of_components():
shape = (5, 7, 7)
data = np.linspace(0, 1, np.prod(shape))
with pytest.raises(AssertionError):
Field(data.reshape(shape), 10.0, vector=True)
def test_repr(vector_data_asymm):
string_repr = repr(vector_data_asymm)
data_str = str(vector_data_asymm.data)
string_ref = f'Field(data={data_str}, scale=(10.0, 10.0, 10.0), vector=True)'
print(f'reference: {string_ref}')
print(f'repr output: {string_repr}')
assert string_repr == string_ref, 'Unexpected behavior in __repr__()!'
def test_str(vector_data_asymm):
string_str = str(vector_data_asymm)
string_ref = 'Field(dim=(5, 7, 11), scale=(10.0, 10.0, 10.0), vector=True, ncomp=3)'
print(f'reference: {string_str}')
print(f'str output: {string_str}')
assert string_str == string_ref, 'Unexpected behavior in __str__()!'
@pytest.mark.parametrize(
"index,t,scale", [
((0, 1, 2), tuple, None),
((0, ), Field, (2., 3.)),
(0, Field, (2., 3.)),
((0, 1, 2, 0), float, None),
((0, 1, 2, 0), float, None),
((..., 0), Field, (1., 2., 3.)),
((0, slice(1, 3), 2), Field, (2.,)),
]
)
def test_getitem(vector_data, index, t, scale):
vector_data.scale = (1., 2., 3.)
data_index = index
res = vector_data[index]
assert_allclose(res, vector_data.data[data_index])
assert isinstance(res, t)
if t is Field:
assert res.scale == scale
def test_from_scalar_field(scalar_data):
sca_x, sca_y, sca_z = [i * scalar_data for i in range(1, 4)]
field_comb = Field.from_scalar_fields([sca_x, sca_y, sca_z])
assert field_comb.vector
assert field_comb.scale == scalar_data.scale
assert_allclose(sca_x, field_comb.comp[0])
assert_allclose(sca_y, field_comb.comp[1])
assert_allclose(sca_z, field_comb.comp[2])
def test_squeeze():
magnitude = np.zeros((4, 1, 4, 3))
field = Field(magnitude, (1., 2., 3.), vector=True)
sq = field.squeeze()
assert sq.shape == (4, 4, 3)
assert sq.dim == (4, 4)
assert sq.scale == (1., 3.)
def test_gradient():
pass
def test_gradient_1d():
pass
def test_curl():
pass
def test_curl_2d():
pass
def test_clip_scalar_noop():
shape = (3, 3, 3)
data = np.linspace(-2, 1, np.prod(shape)).reshape(shape)
field = Field(data, (1., 2., 3.), vector=False)
assert_allclose(field, field.clip())
def test_clip_scalar_minmax():
shape = (3, 3, 3)
data = np.linspace(-2, 1, np.prod(shape)).reshape(shape)
field = Field(data, (1., 2., 3.), vector=False)
assert_allclose(np.clip(data, -1, 0.1), field.clip(vmin=-1, vmax=0.1))
def test_clip_scalar_sigma():
shape = (3, 3, 3)
data = np.linspace(-2, 1, np.prod(shape)).reshape(shape)
data[0, 0, 0] = 1e6
field = Field(data, (1., 2., 3.), vector=False)
# We clip off the one outlier
assert_allclose(np.clip(data, -2, 1), field.clip(sigma=5))
assert field.clip(sigma=5)[0, 0, 0] == 1
def test_clip_scalar_mask():
shape = (3, 3, 3)
data = np.linspace(-2, 1, np.prod(shape)).reshape(shape)
mask = np.zeros(shape, dtype=bool)
mask[0, 0, 0] = True
mask[0, 0, 1] = True
field = Field(data, (1., 2., 3.), vector=False)
assert_allclose(np.clip(data, data[0, 0, 0], data[0, 0, 1]), field.clip(mask=mask))
def test_clip_vector_noop():
shape = (3, 3, 3, 3)
data = np.linspace(-2, 1, np.prod(shape)).reshape(shape)
field = Field(data, (1., 2., 3.), vector=True)
assert_allclose(field, field.clip())
def test_clip_vector_max():
shape = (3, 3, 3, 3)
data = np.linspace(-2, 1, np.prod(shape)).reshape(shape)
field = Field(data, (1., 2., 3.), vector=True)
res = field.clip(vmax=0.1)
assert_allclose(np.max(res.amp), 0.1)
def test_clip_vector_sigma():
shape = (3, 3, 3, 3)
data = np.linspace(-2, 1, np.prod(shape)).reshape(shape)
data[0, 0, 0] = (1e6, 1e6, 1e6)
field = Field(data, (1., 2., 3.), vector=True)
# We clip off the one outlier
res = field.clip(sigma=5)
assert np.max(res.amp) < 1e3
# TODO: HyperSpy would need to be installed for the following tests (slow...):
# def test_from_signal()
# raise NotImplementedError()
#
# def test_to_signal()
# raise NotImplementedError()
import pytest
from utils import assert_allclose
from empyre.fields import Field
import numpy as np
@pytest.mark.parametrize(
'axis', ['x', 'y', 'z']
)
def test_rot90_360(vector_data_asymm, axis):
assert_allclose(vector_data_asymm.rot90(axis=axis).rot90(axis=axis).rot90(axis=axis).rot90(axis=axis),
vector_data_asymm,
err_msg=f'Unexpected behavior in rot90()! {axis}')
@pytest.mark.parametrize(
'rot_axis,flip_axes', [
('x', (0, 1)),
('y', (0, 2)),
('z', (1, 2))
]
)
def test_rot90_180(vector_data_asymm, rot_axis, flip_axes):
res = vector_data_asymm.rot90(axis=rot_axis).rot90(axis=rot_axis)
ref = vector_data_asymm.flip(axis=flip_axes)
assert_allclose(res, ref, err_msg=f'Unexpected behavior in rot90()! {rot_axis}')
@pytest.mark.parametrize(
'rot_axis', [
'x',
'y',
'z',
]
)
def test_rotate_compare_rot90_1(vector_data_asymmcube, rot_axis):
res = vector_data_asymmcube.rotate(angle=90, axis=rot_axis)
ref = vector_data_asymmcube.rot90(axis=rot_axis)
print("input", vector_data_asymmcube.data)
print("ref", res.data)
print("res", ref.data)
assert_allclose(res, ref, err_msg=f'Unexpected behavior in rotate()! {rot_axis}')
def test_rot90_manual():
data = np.zeros((3, 3, 3, 3))
diag = np.array((1, 1, 1))
diag_unity = diag / np.sqrt(np.sum(diag**2))
data[0, 0, 0] = diag_unity
data = Field(data, 10, vector=True)
print("data", data.data)
rot90_x = np.zeros((3, 3, 3, 3))
# Axis order z, y, x; vector components x, y, z
rot90_x[0, 2, 0] = diag_unity * (1, -1, 1)
rot90_x = Field(rot90_x, 10, vector=True)
print("rot90_x", rot90_x.data)
print("data rot90 x", data.rot90(axis='x').data)
rot90_y = np.zeros((3, 3, 3, 3))
# Axis order z, y, x; vector components x, y, z
rot90_y[2, 0, 0] = diag_unity * (1, 1, -1)
rot90_y = Field(rot90_y, 10, vector=True)
print("rot90_y", rot90_y.data)
print("data rot90 y", data.rot90(axis='y').data)
rot90_z = np.zeros((3, 3, 3, 3))
# Axis order z, y, x; vector components x, y, z
rot90_z[0, 0, 2] = diag_unity * (-1, 1, 1)
rot90_z = Field(rot90_z, 10, vector=True)
print("rot90_z", rot90_z.data)
print("data rot90 z", data.rot90(axis='z').data)
assert_allclose(rot90_x, data.rot90(axis='x'), err_msg='Unexpected behavior in rot90("x")!')
assert_allclose(rot90_y, data.rot90(axis='y'), err_msg='Unexpected behavior in rot90("y")!')
assert_allclose(rot90_z, data.rot90(axis='z'), err_msg='Unexpected behavior in rot90("z")!')
def test_rot45_manual():
data = np.zeros((3, 3, 3, 3))
data[0, 0, 0] = (1, 1, 1)
data = Field(data, 10, vector=True)
print("data", data.data)
rot45_x = np.zeros((3, 3, 3, 3))
# Axis order z, y, x; vector components x, y, z
rot45_x[0, 1, 0] = (1, 0, np.sqrt(2))
rot45_x = Field(rot45_x, 10, vector=True)
print("rot45_x", rot45_x.data)
# Disable spline interpolation, use nearest instead
res_rot45_x = data.rotate(45, axis='x', order=0)
print("data rot45 x", res_rot45_x.data)
rot45_y = np.zeros((3, 3, 3, 3))
# Axis order z, y, x; vector components x, y, z
rot45_y[1, 0, 0] = (np.sqrt(2), 1, 0)
rot45_y = Field(rot45_y, 10, vector=True)
print("rot45_y", rot45_y.data)
# Disable spline interpolation, use nearest instead
res_rot45_y = data.rotate(45, axis='y', order=0)
print("data rot45 y", res_rot45_y.data)
rot45_z = np.zeros((3, 3, 3, 3))
# Axis order z, y, x; vector components x, y, z
rot45_z[0, 0, 1] = (0, np.sqrt(2), 1)
rot45_z = Field(rot45_z, 10, vector=True)
print("rot45_z", rot45_z.data)
# Disable spline interpolation, use nearest instead
res_rot45_z = data.rotate(45, axis='z', order=0)
print("data rot45 z", res_rot45_z.data)
assert_allclose(rot45_x, res_rot45_x, err_msg='Unexpected behavior in rotate(45, "x")!')
assert_allclose(rot45_y, res_rot45_y, err_msg='Unexpected behavior in rotate(45, "y")!')
assert_allclose(rot45_z, res_rot45_z, err_msg='Unexpected behavior in rotate(45, "z")!')
def test_rot90_2d_360(vector_data_asymm_2d):
assert_allclose(vector_data_asymm_2d.rot90().rot90().rot90().rot90(), vector_data_asymm_2d,
err_msg='Unexpected behavior in 2D rot90()!')
def test_rot90_2d_180(vector_data_asymm_2d):
res = vector_data_asymm_2d.rot90().rot90()
ref = vector_data_asymm_2d.flip()
assert_allclose(res, ref, err_msg='Unexpected behavior in 2D rot90()!')
@pytest.mark.parametrize(
'k', [0, 1, 2, 3, 4]
)
def test_rot90_comp_2d_with_3d(vector_data_asymm_2d, k):
data_x, data_y = [comp.data[np.newaxis, :, :] for comp in vector_data_asymm_2d.comp]
data_z = np.zeros_like(data_x)
data_3d = np.stack([data_x, data_y, data_z], axis=-1)
vector_data_asymm_3d = Field(data_3d, scale=10, vector=True)
print(f'2D shape, scale: {vector_data_asymm_2d.shape, vector_data_asymm_2d.scale}')
print(f'3D shape, scale: {vector_data_asymm_3d.shape, vector_data_asymm_3d.scale}')
vector_data_rot_2d = vector_data_asymm_2d.rot90(k=k)
vector_data_rot_3d = vector_data_asymm_3d.rot90(k=k, axis='z')
print(f'2D shape after rot: {vector_data_rot_2d.shape}')
print(f'3D shape after rot: {vector_data_rot_3d.shape}')
assert_allclose(vector_data_rot_2d, vector_data_rot_3d[0, :, :, :2], err_msg='Unexpected behavior in 2D rot90()!')
@pytest.mark.parametrize(
'angle', [90, 45, 23, 11.5]
)
def test_rotate_comp_2d_with_3d(vector_data_asymm_2d, angle):
data_x, data_y = [comp.data[np.newaxis, :, :] for comp in vector_data_asymm_2d.comp]
data_z = np.zeros_like(data_x)
data_3d = np.stack([data_x, data_y, data_z], axis=-1)
vector_data_asymm_3d = Field(data_3d, scale=10, vector=True)
print(f'2D shape, scale: {vector_data_asymm_2d.shape, vector_data_asymm_2d.scale}')
print(f'3D shape, scale: {vector_data_asymm_3d.shape, vector_data_asymm_3d.scale}')
r2d = vector_data_asymm_2d.rotate(angle)
r3d = vector_data_asymm_3d.rotate(angle, axis='z')
print(f'2D shape after rot: {r2d.shape}')
print(f'3D shape after rot: {r3d.shape}')
assert_allclose(r2d, r3d[0, :, :, :2], err_msg='Unexpected behavior in 2D rotate()!')
@pytest.mark.parametrize(
'angle', [180, 360, 90, 45, 23, 11.5],
)
@pytest.mark.parametrize(
'axis', ['x', 'y', 'z'],
)
def test_rotate_scalar(vector_data_asymm, angle, axis):
data = np.zeros((1, 2, 2, 3))
data[0, 0, 0] = 1
field = Field(data, scale=10., vector=True)
print(field)
print(field.amp)
assert_allclose(
field.rotate(angle, axis=axis).amp,
field.amp.rotate(angle, axis=axis)
)
@pytest.mark.parametrize(
'angle,order', [(180, 3), (360, 3), (90, 3), (45, 0), (23, 0), (11.5, 0)],
)
@pytest.mark.parametrize(
'axis', ['x', 'y', 'z'],
)
@pytest.mark.parametrize(
'reshape', [True, False],
)
def test_rotate_scalar_asymm(vector_data_asymm, angle, axis, order, reshape):
assert_allclose(
vector_data_asymm.rotate(angle, axis=axis, reshape=reshape, order=order).amp,
vector_data_asymm.amp.rotate(angle, axis=axis, reshape=reshape, order=order)
)
@pytest.mark.parametrize(
'axis', ['x', 'y', 'z'],
)
@pytest.mark.parametrize(
'k', [0, 1, 2, 3, 4],
)
def test_rot90_scalar(vector_data_asymm, axis, k):
assert_allclose(
vector_data_asymm.amp.rot90(k=k, axis=axis),
vector_data_asymm.rot90(k=k, axis=axis).amp
)