Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • empyre/empyre
  • weber/empyre
  • wessels/empyre
  • bryan/empyre
4 results
Show changes
# -*- coding: utf-8 -*-
# Copyright 2020 by Forschungszentrum Juelich GmbH
# Author: J. Caron
#
"""Subpackage containing functionality for visualisation of multidimensional fields."""
from . import colors
from .plot2d import *
from .decorators import *
from .tools import *
from .plot3d import *
__all__ = ['colors']
__all__.extend(plot2d.__all__)
__all__.extend(decorators.__all__)
__all__.extend(tools.__all__)
__all__.extend(plot3d.__all__)
del plot2d
del decorators
del tools
del plot3d
# -*- coding: utf-8 -*-
# Copyright 2020 by Forschungszentrum Juelich GmbH
# Author: J. Caron
#
"""This module provides a number of custom colormaps, which also have capabilities for 3D plotting.
If this is the case, the :class:`~.Colormap3D` colormap class is a parent class. In `cmaps`, a
number of specialised colormaps is available for convenience.
For general questions about colors see:
http://www.poynton.com/PDFs/GammaFAQ.pdf
http://www.poynton.com/PDFs/ColorFAQ.pdf
"""
import logging
import colorsys
import abc
import numpy as np
from PIL import Image
from matplotlib import colors
import matplotlib.pyplot as plt
from matplotlib.ticker import FixedLocator
from mpl_toolkits.mplot3d import Axes3D # noqa: F401
from matplotlib.patches import Circle
from .tools import use_style
__all__ = ['Colormap3D', 'ColormapCubehelix', 'ColormapPerception', 'ColormapHLS', 'ColormapClassic',
'ColormapTransparent', 'cmaps', 'interpolate_color']
_log = logging.getLogger(__name__)
class Colormap3D(colors.Colormap, metaclass=abc.ABCMeta):
"""Colormap subclass for encoding directions with colors.
This abstract class is used as a superclass/interface for 3D vector plotting capabilities.
In general, a circular colormap should be used to encode the in-plane angle (hue). The
perpendicular angle is encoded via luminance variation (up: white, down: black). Finally,
the length of a vector is encoded via saturation. Decreasing vector length causes a desaturated
color. Subclassing colormaps get access to routines to plot a colorwheel (which should
ideally be located in the 50% luminance plane, which depends strongly on the underlying map),
a convenience function to interpolate color tuples and a function to return rgb triples for a
given vector. The :class:`~.Colormap3D` class itself subclasses the matplotlib base colormap.
"""
_log = logging.getLogger(__name__ + '.Colormap3D')
def rgb_from_vector(self, vector, vmax=None):
"""Construct a hls tuple from three coordinates representing a 3D direction.
Parameters
----------
vector: tuple (N=3) or :class:`~numpy.ndarray`
Vector containing the x, y and z component, or a numpy array encompassing the
components as three lists.
Returns
-------
rgb: :class:`~numpy.ndarray`
Numpy array containing the calculated color tuples.
"""
self._log.debug('Calling rgb_from_vector')
x, y, z = vector
R = np.sqrt(x ** 2 + y ** 2 + z ** 2)
R_max = vmax if vmax is not None else R.max() + 1E-30
# FIRST color dimension: HUE (1D ring/angular direction)
phi = np.asarray(np.arctan2(y, x))
phi[phi < 0] += 2 * np.pi
hue = phi / (2 * np.pi)
rgba = np.asarray(self(hue))
r, g, b = rgba[..., 0], rgba[..., 1], rgba[..., 2]
# SECOND color dimension: SATURATION (2D, in-plane)
rho = np.sqrt(x ** 2 + y ** 2)
sat = rho / R_max
r, g, b = interpolate_color(sat, (0.5, 0.5, 0.5), np.stack((r, g, b), axis=-1))
# THIRD color dimension: LUMINANCE (3D, color sphere)
theta = np.arccos(z / R_max)
lum = 1 - theta / np.pi # goes from 0 (black) over 0.5 (grey) to 1 (white)!
lum_target = np.where(lum < 0.5, 0, 1) # Separate upper(white)/lower(black) hemispheres!
lum_target = np.stack([lum_target] * 3, axis=-1) # [0, 0, 0] -> black / [1, 1, 1] -> white!
fraction = 2 * np.abs(lum - 0.5) # 0.5: difference from grey, 2: scale to range (0, 1)!
r, g, b = interpolate_color(fraction, np.stack((r, g, b), axis=-1), lum_target)
# Return RGB:
return np.asarray(255 * np.stack((r, g, b), axis=-1), dtype=np.uint8)
def make_colorwheel(self, size=64):
"""Construct a color wheel as an :class:`~PIL.Image` object.
Parameters
----------
size : int, optional
Diameter of the color wheel along both axes in pixels, by default 64.
Returns
-------
img: :class:`~PIL.Image`
The resulting image.
"""
self._log.debug('Calling make_colorwheel')
# Construct the colorwheel:
yy, xx = (np.indices((size, size)) - size/2 + 0.5)
rr = np.hypot(xx, yy)
xx = np.where(rr <= size/2-3, xx, 0)
yy = np.where(rr <= size/2-3, yy, 0)
zz = np.zeros((size, size))
aa = np.where(rr >= size/2-3, 0, 255).astype(dtype=np.uint8)
rgba = np.dstack((self.rgb_from_vector(np.asarray((xx, yy, zz))), aa))
# Create color wheel:
return Image.fromarray(rgba)
def plot_colorwheel(self, axis=None, size=64, arrows=True, grayscale=False, **kwargs):
"""Display a color wheel to illustrate the color coding of vector gradient directions.
Parameters
----------
figsize : tuple of floats (N=2)
Size of the plot figure.
Returns
-------
img: :class:`matplotlib.image.AxesImage`
The resulting colorwheel.
"""
self._log.debug('Calling plot_colorwheel')
# Construct the colorwheel:
color_wheel = self.make_colorwheel(size=size)
if grayscale:
color_wheel = color_wheel.convert('LA')
# Plot the color wheel:
with use_style('empyre-image'): # Only works on axes created WITHIN context!
if axis is None: # If no axis is set, find the current or create a new one:
fig = plt.figure()
axis = fig.add_subplot(1, 1, 1, aspect='equal')
# Plot:
im = axis.imshow(color_wheel, **kwargs)
xy = size/2 - 0.5, size/2 - 0.5
axis.add_patch(Circle(xy=xy, radius=size/2-2.5, linewidth=2, edgecolor='k', facecolor='none'))
if arrows:
axis.arrow(size/2, size/2+5, 0, 0.1*size, fc='k', ec='k', lw=1, width=2, alpha=0.15)
axis.arrow(size/2, size/2-5, 0, -0.1*size, fc='k', ec='k', lw=1, width=2, alpha=0.15)
axis.arrow(size/2+5, size/2, 0.1*size, 0, fc='k', ec='k', lw=1, width=2, alpha=0.15)
axis.arrow(size/2-5, size/2, -0.1*size, 0, fc='k', ec='k', lw=1, width=2, alpha=0.15)
return im
class ColormapCubehelix(colors.LinearSegmentedColormap, Colormap3D):
"""A full implementation of Dave Green's "cubehelix" for Matplotlib.
Based on the FORTRAN 77 code provided in D.A. Green, 2011, BASI, 39, 289.
http://adsabs.harvard.edu/abs/2011arXiv1108.5083G
Also see:
http://www.mrao.cam.ac.uk/~dag/CUBEHELIX/
http://davidjohnstone.net/pages/cubehelix-gradient-picker
User can adjust all parameters of the cubehelix algorithm. This enables much greater
flexibility in choosing color maps. Default color map settings produce the standard cubehelix.
Create color map in only blues by setting rot=0 and start=0. Create reverse (white to black)
backwards through the rainbow once by setting rot=1 and reverse=True, etc. Furthermore, the
algorithm was tuned, so that constant luminance values can be used (e.g. to create a truly
isoluminant colorwheel). The `rot` parameter is also tuned to hold true for these cases.
Of the here presented colorwheels, only this one manages to solely navigate through the L*=50
plane, which can be seen here:
https://upload.wikimedia.org/wikipedia/commons/2/21/Lab_color_space.png
Parameters
----------
start : scalar, optional
Sets the starting position in the color space. 0=blue, 1=red,
2=green. Defaults to 0.5.
rot : scalar, optional
The number of rotations through the rainbow. Can be positive
or negative, indicating direction of rainbow. Negative values
correspond to Blue->Red direction. Defaults to -1.5.
gamma : scalar, optional
The gamma correction for intensity. Defaults to 1.0.
reverse : boolean, optional
Set to True to reverse the color map. Will go from black to
white. Good for density plots where shade~density. Defaults to False.
nlev : scalar, optional
Defines the number of discrete levels to render colors at.
Defaults to 256.
sat : scalar, optional
The saturation intensity factor. Defaults to 1.2
NOTE: this was formerly known as `hue` parameter
minSat : scalar, optional
Sets the minimum-level saturation. Defaults to 1.2.
maxSat : scalar, optional
Sets the maximum-level saturation. Defaults to 1.2.
startHue : scalar, optional
Sets the starting color, ranging from [0, 360], as in
D3 version by @mbostock.
NOTE: overrides values in start parameter.
endHue : scalar, optional
Sets the ending color, ranging from [0, 360], as in
D3 version by @mbostock
NOTE: overrides values in rot parameter.
minLight : scalar, optional
Sets the minimum lightness value. Defaults to 0.
maxLight : scalar, optional
Sets the maximum lightness value. Defaults to 1.
Returns
-------
matplotlib.colors.LinearSegmentedColormap object
Revisions
---------
2014-04 (@jradavenport) Ported from IDL version
2014-04 (@jradavenport) Added kwargs to enable similar to D3 version,
changed name of `hue` parameter to `sat`.
2016-11 (@jan.caron) Added support for isoluminant cubehelices while making sure
`rot` works as intended. Decoded the plane-vectors a bit.
"""
_log = logging.getLogger(__name__ + '.ColormapCubehelix')
def __init__(self, start=0.5, rot=-1.5, gamma=1.0, reverse=False, nlev=256,
minSat=1.2, maxSat=1.2, minLight=0., maxLight=1., **kwargs):
self._log.debug('Calling __init__')
# Override start and rot if startHue and endHue are set:
if kwargs is not None:
if 'startHue' in kwargs:
start = (kwargs.get('startHue') / 360. - 1.) * 3.
if 'endHue' in kwargs:
rot = kwargs.get('endHue') / 360. - start / 3. - 1.
if 'sat' in kwargs:
minSat = kwargs.get('sat')
maxSat = kwargs.get('sat')
self.nlev = nlev
# Set up the parameters:
self.fract = np.linspace(minLight, maxLight, nlev)
angle = 2.0 * np.pi * (start / 3.0 + rot * np.linspace(0, 1, nlev))
self.fract = self.fract**gamma
satar = np.linspace(minSat, maxSat, nlev)
amp = np.asarray(satar * self.fract * (1. - self.fract) / 2)
# Set RGB color coefficients (Luma is calculated in RGB Rec.601, so choose those), the original version of
# Dave Green used (0.30, 0.59, 0.11) and Rec.709 is c709 = (0.2126, 0.7152, 0.0722) but with eihter of those,
# this function would not produce the correct YPbPr Luma.
c601 = (0.299, 0.587, 0.114)
cr, cg, cb = c601
cw = -0.90649 # Chosen to comply with Dave Greens implementation.
k = -1.6158 / cr / cw # k has to balance out cw so nothing gets out of RGB gamut (> 1).
# Calculate the vectors v and w spanning the plane of constant perceived intensity. v and w have to solve
# v x w = k(cr, cg, cb) (normal vector of the described plane) and
# v * w = 0 (scalar product, v and w have to be perpendicular).
# 6 unknown and 4 equations --> Chose wb = 0 and wg = cw (constant).
v = np.array((k * cr ** 2 * cb / (cw * (cr ** 2 + cg ** 2)),
k * cr * cg * cb / (cw * (cr ** 2 + cg ** 2)), -k * cr / cw))
w = np.array((-cw * cg / cr, cw, 0))
# Calculate components:
self.red = self.fract + amp * (v[0] * np.cos(angle) + w[0] * np.sin(angle))
self.grn = self.fract + amp * (v[1] * np.cos(angle) + w[1] * np.sin(angle))
self.blu = self.fract + amp * (v[2] * np.cos(angle) + w[2] * np.sin(angle))
# Original formulas with original v and w:
# self.red = self.fract + amp * (-0.14861 * np.cos(angle) + 1.78277 * np.sin(angle))
# self.grn = self.fract + amp * (-0.29227 * np.cos(angle) - 0.90649 * np.sin(angle))
# self.blu = self.fract + amp * (1.97294 * np.cos(angle))
# Find where RBG are outside the range [0,1], clip:
self.red = np.clip(self.red, 0, 1)
self.grn = np.clip(self.grn, 0, 1)
self.blu = np.clip(self.blu, 0, 1)
# Optional color reverse:
if reverse is True:
self.red = self.red[::-1]
self.blu = self.blu[::-1]
self.grn = self.grn[::-1]
# Put in to tuple & dictionary structures needed:
rr, bb, gg = [], [], []
for k in range(0, int(nlev)):
rr.append((float(k) / (nlev - 1), self.red[k], self.red[k]))
bb.append((float(k) / (nlev - 1), self.blu[k], self.blu[k]))
gg.append((float(k) / (nlev - 1), self.grn[k], self.grn[k]))
cdict = {'red': rr, 'blue': bb, 'green': gg}
super().__init__('cubehelix', cdict, N=256)
self._log.debug('Created ' + str(self))
def plot_helix(self, figsize=None, **kwargs):
"""Display the RGB and luminance plots for the chosen cubehelix.
Parameters
----------
figsize : tuple of floats (N=2)
Size of the plot figure.
Returns
-------
None
"""
self._log.debug('Calling plot_helix')
with use_style('empyre-plot'):
fig = plt.figure(figsize=figsize, constrained_layout=True)
gs = fig.add_gridspec(2, 1, height_ratios=[8, 1])
# Main plot:
axis = plt.subplot(gs[0])
axis.plot(self.fract, 'k')
axis.plot(self.red, 'r')
axis.plot(self.grn, 'g')
axis.plot(self.blu, 'b')
axis.set_xlim(0, self.nlev)
axis.set_ylim(0, 1)
axis.set_title('Cubehelix')
axis.set_xlabel('Color index')
axis.set_ylabel('Brightness / RGB')
axis.xaxis.set_major_locator(FixedLocator(locs=np.linspace(0, self.nlev, 5)))
axis.yaxis.set_major_locator(FixedLocator(locs=[0, 0.5, 1]))
# Colorbar horizontal:
caxis = plt.subplot(gs[1])
rgb = self(np.linspace(0, 1, 256))[None, ...]
rgb = np.asarray(255.9999 * rgb, dtype=np.uint8)
rgb = np.repeat(rgb, 30, axis=0)
im = Image.fromarray(rgb)
caxis.imshow(im, aspect='auto')
caxis.tick_params(axis='both', which='both', labelleft=False, labelbottom=False,
left=False, right=False, top=False, bottom=False)
class ColormapPerception(colors.LinearSegmentedColormap, Colormap3D):
"""A perceptual colormap based on face-based luminance matching.
Based on a publication by Kindlmann et. al.
http://www.cs.utah.edu/~gk/papers/vis02/FaceLumin.pdf
This colormap tries to achieve an isoluminant perception by using a list of colors acquired
through face recognition studies. It is a lot better than the HLS colormap, but still not
completely isoluminant (despite its name). Also it appears a bit dark.
"""
_log = logging.getLogger(__name__ + '.ColormapPerception')
CDICT = {'red': [(0/6, 0.847, 0.847),
(1/6, 0.527, 0.527),
(2/6, 0.000, 0.000),
(3/6, 0.000, 0.000),
(4/6, 0.316, 0.316),
(5/6, 0.718, 0.718),
(6/6, 0.847, 0.847)],
'green': [(0/6, 0.057, 0.057),
(1/6, 0.527, 0.527),
(2/6, 0.592, 0.592),
(3/6, 0.559, 0.559),
(4/6, 0.316, 0.316),
(5/6, 0.000, 0.000),
(6/6, 0.057, 0.057)],
'blue': [(0/6, 0.057, 0.057),
(1/6, 0.000, 0.000),
(2/6, 0.000, 0.000),
(3/6, 0.559, 0.559),
(4/6, 0.991, 0.991),
(5/6, 0.718, 0.718),
(6/6, 0.057, 0.057)]}
def __init__(self):
self._log.debug('Calling __init__')
super().__init__('perception', self.CDICT, N=256)
self._log.debug('Created ' + str(self))
class ColormapHLS(colors.ListedColormap, Colormap3D):
"""Colormap subclass for encoding directions with colors.
This class is a subclass of the :class:`~matplotlib.pyplot.colors.ListedColormap`
class. The class follows the HSL ('hue', 'saturation', 'lightness') 'Double Hexcone' Model
with the saturation always set to 1 (moving on the surface of the color
cylinder) with a lightness of 0.5 (full color). The three prime colors (`rgb`) are spaced
equidistant with 120° space in between, according to a triadic arrangement.
Even though the lightness is constant in the plane, the luminance (which is a weighted sum
of the RGB components which encompasses human perception) is not, which can lead to
artifacts like reliefs. Converting the map to a grayscale show spokes at the secondary colors.
For more information see:
https://vis4.net/blog/posts/avoid-equidistant-hsv-colors/
http://www.workwithcolor.com/color-luminance-2233.htm
http://blog.asmartbear.com/color-wheels.html
"""
_log = logging.getLogger(__name__ + '.ColormapHLS')
def __init__(self):
self._log.debug('Calling __init__')
h = np.linspace(0, 1, 256)
l = 0.5 * np.ones_like(h)
s = np.ones_like(h)
r, g, b = np.vectorize(colorsys.hls_to_rgb)(h, l, s)
colors = [(r[i], g[i], b[i]) for i in range(len(r))]
super().__init__(colors, 'hls', N=256)
self._log.debug('Created ' + str(self))
class ColormapClassic(colors.LinearSegmentedColormap, Colormap3D):
"""Colormap subclass for encoding directions with colors.
This class is a subclass of the :class:`~matplotlib.pyplot.colors.LinearSegmentedColormap`
class. The class follows the HSL ('hue', 'saturation', 'lightness') 'Double
Hexcone' Model with the saturation always set to 1 (moving on the surface of the color
cylinder) with a luminance of 0.5 (full color). The colors follow a tetradic arrangement with
four colors (red, green, blue and yellow) arranged with 90° spacing in between.
"""
_log = logging.getLogger(__name__ + '.ColormapClassic')
CDICT = {'red': [(0/4, 1.0, 1.0),
(1/4, 0.0, 0.0),
(2/4, 0.0, 0.0),
(3/4, 1.0, 1.0),
(4/4, 1.0, 1.0)],
'green': [(0/4, 0.0, 0.0),
(1/4, 0.0, 0.0),
(2/4, 1.0, 1.0),
(3/4, 1.0, 1.0),
(4/4, 0.0, 0.0)],
'blue': [(0/4, 0.0, 0.0),
(1/4, 1.0, 1.0),
(2/4, 0.0, 0.0),
(3/4, 0.0, 0.0),
(4/4, 0.0, 0.0)]}
def __init__(self):
self._log.debug('Calling __init__')
super().__init__('classic', self.CDICT, N=256)
self._log.debug('Created ' + str(self))
class ColormapTransparent(colors.LinearSegmentedColormap):
"""Colormap subclass for including transparency.
This class is a subclass of the :class:`~matplotlib.pyplot.colors.LinearSegmentedColormap`
class with integrated support for transparency. The colormap is unicolor and varies only in
transparency.
Attributes
----------
r: float, optional
Intensity of red in the colormap. Has to be between 0. and 1.
g: float, optional
Intensity of green in the colormap. Has to be between 0. and 1.
b: float, optional
Intensity of blue in the colormap. Has to be between 0. and 1.
alpha_range : list (N=2) of float, optional
Start and end alpha value. Has to be between 0. and 1.
"""
_log = logging.getLogger(__name__ + '.ColormapTransparent')
def __init__(self, r=0., g=0., b=0., alpha_range=None):
self._log.debug('Calling __init__')
if alpha_range is None:
alpha_range = [0., 1.]
red = [(0., 0., r), (1., r, 1.)]
green = [(0., 0., g), (1., g, 1.)]
blue = [(0., 0., b), (1., b, 1.)]
alpha = [(0., 0., alpha_range[0]), (1., alpha_range[1], 1.)]
cdict = {'red': red, 'green': green, 'blue': blue, 'alpha': alpha}
super().__init__('transparent', cdict, N=256)
self._log.debug('Created ' + str(self))
def interpolate_color(fraction, start, end):
"""Interpolate linearly between two color tuples (e.g. RGB).
Parameters
----------
fraction: float or :class:`~numpy.ndarray`
Interpolation fraction between 0 and 1, which determines the position of the
interpolation between `start` and `end`.
start: tuple (N=3) or :class:`~numpy.ndarray`
Start of the interpolation as a tuple of three numbers or a numpy array, where the last
dimension should have length 3 and contain the color tuples.
end: tuple (N=3) or :class:`~numpy.ndarray`
End of the interpolation as a tuple of three numbers or a numpy array, where the last
dimension should have length 3 and contain the color tuples.
Returns
-------
result: tuple (N=3) or :class:`~numpy.ndarray`
Result of the interpolation as a tuple of three numbers or a numpy array, where the
last dimension should has length 3 and contains the color tuples.
"""
_log.debug('Calling interpolate_color')
start, end = np.asarray(start), np.asarray(end)
r1 = start[..., 0] + (end[..., 0] - start[..., 0]) * fraction
r2 = start[..., 1] + (end[..., 1] - start[..., 1]) * fraction
r3 = start[..., 2] + (end[..., 2] - start[..., 2]) * fraction
return r1, r2, r3
class CMapNamespace(object):
def __init__(self):
self.cubehelix = ColormapCubehelix()
self.cubehelix_r = ColormapCubehelix(reverse=True)
self.cyclic_cubehelix = ColormapCubehelix(start=1, rot=1, minLight=0.5, maxLight=0.5, sat=2)
self.cyclic_perception = ColormapPerception()
self.cyclic_hls = ColormapHLS()
self.cyclic_classic = ColormapClassic()
self.transparent_black = ColormapTransparent(0, 0, 0, [0, 1.])
self.transparent_white = ColormapTransparent(1, 1, 1, [0, 1.])
self.transparent_confidence = ColormapTransparent(0.2, 0.3, 0.2, [0.75, 0.])
def __getitem__(self, key):
return self.__dict__[key]
def add_cmap_dict(self, **kwargs):
self.__dict__.update(kwargs)
cmaps = CMapNamespace()
# -*- coding: utf-8 -*-
# Copyright 2020 by Forschungszentrum Juelich GmbH
# Author: J. Caron
#
"""This module provides functions that decorate exisiting `matplotlib` plots."""
import logging
from collections.abc import Iterable
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import patheffects
from matplotlib.patches import Circle
from matplotlib.offsetbox import TextArea, AnchoredOffsetbox
from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from . import colors
from .tools import use_style
from ..fields.field import Field
__all__ = ['scalebar', 'colorwheel', 'annotate', 'quiverkey', 'coords', 'colorbar']
_log = logging.getLogger(__name__)
def scalebar(axis=None, unit='nm', loc='lower left', **kwargs):
"""Add a scalebar to the axis.
Parameters
----------
axis : :class:`~matplotlib.axes.AxesSubplot`, optional
Axis to which the scalebar is added, by default None, which will pick the last used axis via `gca`.
unit: str, optional
String that determines the unit of the scalebar, defaults to 'nm'.
loc : str or pair of floats, optional
The location of the scalebar, defaults to 'lower left'. See `matplotlib.legend` for possible settings.
Returns
-------
aoffbox : :class:`~matplotlib.offsetbox.AnchoredOffsetbox`
The box containing the scalebar.
Notes
-----
Additional kwargs are passed to `mpl_toolkits.axes_grid1.anchored_artists.AnchoredSizeBar`.
"""
_log.debug('Calling scalebar')
if axis is None: # If no axis is set, find the current or create a new one:
axis = plt.gca()
# Transform axis borders (1, 1) to data borders to get number of pixels in y and x:
transform = axis.transData
bb0 = axis.transLimits.inverted().transform((0, 0))
bb1 = axis.transLimits.inverted().transform((1, 1))
data_extent = (int(abs(bb1[1] - bb0[1])), int(abs(bb1[0] - bb0[0])))
# Calculate bar length:
bar_length = data_extent[1] / 4 # 25% of the data width!
thresholds = [1, 5, 10, 50, 100, 500, 1000]
for t in thresholds: # For larger grids (real images), multiples of threshold look better!
if bar_length > t: # Round down to the next lowest multiple of t:
bar_length = (bar_length // t) * t
# Set parameters for scale bar:
label = f'{bar_length:g} {unit}'
# Set defaults:
kwargs.setdefault('borderpad', 0.2)
kwargs.setdefault('pad', 0.2)
kwargs.setdefault('sep', 5)
kwargs.setdefault('color', 'w')
kwargs.setdefault('size_vertical', data_extent[0]*0.01)
kwargs.setdefault('frameon', False)
kwargs.setdefault('label_top', True)
kwargs.setdefault('fill_bar', True)
# Create scale bar:
scalebar = AnchoredSizeBar(transform, bar_length, label, loc, **kwargs)
scalebar.txt_label._text._color = 'w' # Overwrite AnchoredSizeBar color!
# Set stroke patheffect:
effect_txt = [patheffects.withStroke(linewidth=2, foreground='k')]
scalebar.txt_label._text.set_path_effects(effect_txt)
effect_bar = [patheffects.withStroke(linewidth=3, foreground='k')]
scalebar.size_bar._children[0].set_path_effects(effect_bar)
# Add scale bar to axis and return:
axis.add_artist(scalebar)
return scalebar
def colorwheel(axis=None, cmap=None, ax_size='20%', loc='upper right', **kwargs):
"""Add a colorwheel to the axis on the upper right corner.
Parameters
----------
axis : :class:`~matplotlib.axes.Axes`, optional
Axis to which the colorwheel is added, by default None, which will pick the last used axis via `gca`.
cmap : str or `matplotlib.colors.Colormap`, optional
The Colormap that should be used for the colorwheel, defaults to `None`, which chooses the
`.colors.cmaps.cyclic_cubehelix` colormap. Needs to be a :class:`~.colors.Colormap3D` to work correctly.
ax_size : str or float, optional
String or float determining the size of the inset axis used, defaults to `20%`.
loc : str or pair of floats, optional
The location of the colorwheel, defaults to 'upper right'. See `matplotlib.legend` for possible settings.
Returns
-------
axis : :class:`~matplotlib.image.AxesImage`
The colorwheel image that was created.
Notes
-----
Additional kwargs are passed to :class:`~.colors.Colormap3D.plot_colorwheel` of the :class:`~.colors.Colormap3D`.
"""
_log.debug('Calling colorwheel')
if axis is None: # If no axis is set, find the current or create a new one:
axis = plt.gca()
ins_axes = inset_axes(axis, width=ax_size, height=ax_size, loc=loc)
ins_axes.axis('off')
if cmap is None:
cmap = colors.cmaps.cyclic_cubehelix
plt.sca(axis) # Set focus back to parent axis!
return cmap.plot_colorwheel(axis=ins_axes, **kwargs)
def annotate(label, axis=None, loc='upper left'):
"""Add an annotation to the axis on the upper left corner.
Parameters
----------
label : string
The text of the annotation.
axis : :class:`~matplotlib.axes.AxesSubplot`, optional
Axis to which the annotation is added, by default None, which will pick the last used axis via `gca`.
loc : str or pair of floats, optional
The location of the annotation, defaults to 'upper left'. See `matplotlib.legend` for possible settings.
Returns
-------
aoffbox : :class:`~matplotlib.offsetbox.AnchoredOffsetbox`
The box containing the annotation.
"""
_log.debug('Calling annotate')
if axis is None: # If no axis is set, find the current or create a new one:
axis = plt.gca()
# Create text:
txt = TextArea(label, textprops={'color': 'w'})
txt.set_clip_box(axis.bbox)
txt._text.set_path_effects([patheffects.withStroke(linewidth=2, foreground='k')])
# Pack into and add AnchoredOffsetBox:
aoffbox = AnchoredOffsetbox(loc=loc, pad=0.5, borderpad=0.1, child=txt, frameon=False)
axis.add_artist(aoffbox)
return aoffbox
def quiverkey(quiv, field, axis=None, unit='', loc='lower right', **kwargs):
"""Add a quiver key to an axis.
Parameters
----------
quiv : Quiver instance
The quiver instance returned by a call to quiver.
field : `Field` or ndarray
The vector data as a `Field` or a numpy array (in the latter case, `vector=True` and `scale=1.0` are assumed).
axis : :class:`~matplotlib.axes.AxesSubplot`, optional
Axis to which the quiverkey is added, by default None, which will pick the last used axis via `gca`.
unit: str, optional
String that determines the unit of the quiverkey, defaults to ''.
loc : str or pair of floats, optional
The location of the quiverkey, defaults to 'lower right'. See `matplotlib.legend` for possible settings.
Returns
-------
qk: Quiverkey
The generated quiverkey.
Notes
-----
Additional kwargs are passed to `matplotlib.pyplot.quiverkey`.
"""
_log.debug('Calling quiverkey')
if axis is None: # If no axis is set, find the current or create a new one:
axis = plt.gca()
if not isinstance(field, Field): # Try to convert input to Field if it is not already one:
field = Field(data=np.asarray(field), scale=1.0, vector=True)
length = field.amp.data.max()
shift = 1 / field.squeeze().dim[1] # equivalent to one pixel distance in axis coords!
label = f'{length:.3g} {unit}'
if loc in ('upper right', 1):
X, Y, labelpos = 0.95-shift, 0.95-shift/4, 'W'
elif loc in ('upper left', 2):
X, Y, labelpos = 0.05+shift, 0.95-shift/4, 'E'
elif loc in ('lower left', 3):
X, Y, labelpos = 0.05+shift, 0.05+shift/4, 'E'
elif loc in ('lower right', 4):
X, Y, labelpos = 0.95-shift, 0.05+shift/4, 'W'
else:
raise ValueError('Quiverkey can only be placed in one of the corners (number 1 - 4 or associated strings)!')
# Set defaults:
kwargs.setdefault('coordinates', 'axes')
kwargs.setdefault('facecolor', 'w')
kwargs.setdefault('edgecolor', 'k')
kwargs.setdefault('labelcolor', 'w')
kwargs.setdefault('linewidth', 1)
kwargs.setdefault('clip_box', axis.bbox)
kwargs.setdefault('clip_on', True)
# Plot:
qk = axis.quiverkey(quiv, X, Y, U=1, label=label, labelpos=labelpos, **kwargs)
qk.text.set_path_effects([patheffects.withStroke(linewidth=2, foreground='k')])
return qk
def coords(axis=None, coords=('x', 'y'), loc='lower left', **kwargs):
"""Add coordinate arrows to an axis.
Parameters
----------
axis : :class:`~matplotlib.axes.AxesSubplot`, optional
Axis to which the coordinates are added, by default None, which will pick the last used axis via `gca`.
coords : tuple or int, optional
Tuple of strings determining the labels, by default ('x', 'y'). Can also be `2` or `3` which expands to
('x', 'y') or ('x', 'y', 'z'). The length of `coords` determines the number of arrows (2 or 3).
loc : str, optional
[description], by default 'lower left'
Returns
-------
ins_axes : :class:`~matplotlib.axes.Axes`
The created inset axes containing the coordinates.
"""
_log.debug('Calling coords')
if axis is None: # If no axis is set, find the current or create a new one:
axis = plt.gca()
ins_ax = inset_axes(axis, width='5%', height='5%', loc=loc, borderpad=2.2)
if coords == 3:
coords = ('x', 'y', 'z')
elif coords == 2:
coords = ('x', 'y')
effects = [patheffects.withStroke(linewidth=2, foreground='k')]
kwargs.setdefault('fc', 'w')
kwargs.setdefault('ec', 'k')
kwargs.setdefault('head_width', 0.6)
kwargs.setdefault('head_length', 0.7)
kwargs.setdefault('linewidth', 1)
kwargs.setdefault('width', 0.2)
if len(coords) == 3:
ins_ax.arrow(x=0.5, y=0.5, dx=-1.05, dy=-0.75, clip_on=False, **kwargs)
ins_ax.arrow(x=0.5, y=0.5, dx=0.96, dy=-0.75, clip_on=False, **kwargs)
ins_ax.arrow(x=0.5, y=0.5, dx=0, dy=1.35, clip_on=False, **kwargs)
ins_ax.annotate(coords[0], xy=(0, 0), xytext=(-1.0, 0.3), path_effects=effects, color='w')
ins_ax.annotate(coords[1], xy=(0, 0), xytext=(1.7, 0.3), path_effects=effects, color='w')
ins_ax.annotate(coords[2], xy=(0, 0), xytext=(0.8, 1.5), path_effects=effects, color='w')
ins_ax.add_artist(Circle((0.5, 0.5), 0.12, fc='w', ec='k', linewidth=1, clip_on=False))
elif len(coords) == 2:
ins_ax.arrow(x=-0.5, y=-0.5, dx=1.5, dy=0, clip_on=False, **kwargs)
ins_ax.arrow(x=-0.5, y=-0.5, dx=0, dy=1.5, clip_on=False, **kwargs)
ins_ax.annotate(coords[0], xy=(0, 0), xytext=(1.3, -0.05), path_effects=effects, color='w')
ins_ax.annotate(coords[1], xy=(0, 0), xytext=(-0.1, 1.1), path_effects=effects, color='w')
ins_ax.add_artist(Circle((-0.5, -0.5), 0.12, fc='w', ec='k', linewidth=1, clip_on=False))
ins_ax.axis('off')
plt.sca(axis)
return coords
def colorbar(im, fig=None, cbar_axis=None, axes=None, position='right', pad=0.02, thickness=0.03, label=None,
constrain_ticklabels=True, ticks=None, ticklabels=None):
"""Creates a colorbar, aligned with figure axes.
Parameters
----------
im : matplotlib object, mappable
Mappable matplotlib object.
fig : matplotlib.figure object, optional
The figure object that contains the matplotlib axes and artists, by default None, which will pick the last used
figure via `gcf`.
axes : matplotlib.axes or list of matplotlib.axes
The axes object(s), where the colorbar is drawn, by default None, which will pick the last used axis via `gca`.
Only provide those axes, which the colorbar should span over.
position : str, optional
The position defines the location of the colorbar. One of 'top', 'bottom', 'left' or 'right' (default).
pad : float, optional
Defines the spacing between the axes and colorbar axis. Is given in figure fraction.
thickness : float, optional
Thickness of the colorbar given in figure fraction.
label : string, optional
Colorbar label, defaults to None.
constrain_ticklabels : bool, optional
Allows to slightly shift the outermost ticklabels, such that they do not exceed the cbar axis, defaults to True.
ticks : list, np.ndarray, optional
List of cbar ticks, defaults to None.
ticklabels : list, np.ndarray, optional
List of cbar ticklabels, defaults to None.
Returns
-------
cbar : :class:`~matplotlib.Colorbar`
The created colorbar.
Notes
-----
Based on a modified snippet by Florian Winkler. Note that this function TURNS OFF constrained layout, therefore it
should be the final command before finishing or saving a figure. The colorbar will be outside the original bounds
of your constructed figure. If you set the size, e.g. with `~empyre.vis.tools.new`, make sure to account for the
additional space by setting the `width_scale` to something smaller than 1 (e.g. 0.9).
"""
_log.debug('Calling colorbar')
assert position in ('left', 'right', 'top', 'bottom'), "position has to be 'left', 'right', 'top' or 'bottom'!"
if fig is None: # If no figure is set, find the current or create a new one:
fig = plt.gcf()
fig.canvas.draw() # Trigger a draw so that a potential constrained_layout is executed once!
fig.set_constrained_layout(False) # we don't want the layout to change after this point!
if axes is None: # If no axis is set, find the current or create a new one:
axes = plt.gca()
if not isinstance(axes, Iterable):
axes = (axes,) # Make sure it is an iterable (e.g. tuple)!
# Save previously active axis for later:
previous_axis = plt.gca()
if cbar_axis is None: # Construct a new cbar_axis:
x_coords, y_coords = [], []
# Find bounds of all individual axes:
for ax in np.ravel(axes): # ravel needed for arrays of axes:
points = ax.get_position().get_points()
x_coords.extend(points[:, 0])
y_coords.extend(points[:, 1])
# Find outer bounds of plotting area:
left, right = min(x_coords), max(x_coords)
bottom, top = min(y_coords), max(y_coords)
# Determine where the colorbar will be placed:
if position == 'right':
bounds = [right+pad, bottom, thickness, top-bottom]
elif position == 'left':
bounds = [left-pad-thickness, bottom, thickness, top-bottom]
if position == 'top':
bounds = [left, top+pad, right-left, thickness]
elif position == 'bottom':
bounds = [left, bottom-pad-thickness, right-left, thickness]
cbar_axis = fig.add_axes(bounds)
# Create the colorbar:
with use_style('empyre-image'):
if position in ('left', 'right'):
cb = plt.colorbar(im, cax=cbar_axis, orientation='vertical')
cb.ax.yaxis.set_ticks_position(position)
cb.ax.yaxis.set_label_position(position)
elif position in ('top', 'bottom'):
cb = plt.colorbar(im, cax=cbar_axis, orientation='horizontal')
cb.ax.xaxis.set_ticks_position(position)
cb.ax.xaxis.set_label_position(position)
# Colorbar label
if label is not None:
cb.set_label(f'{label}')
# Set ticks and ticklabels (if specified):
if ticks:
cb.set_ticks(ticks)
if ticklabels:
cb.set_ticklabels(ticklabels)
# Constrain tick labels (if wanted):
if constrain_ticklabels:
if position == 'top' or position == 'bottom':
t = cb.ax.get_xticklabels()
t[0].set_horizontalalignment('left')
t[-1].set_horizontalalignment('right')
elif position == 'left' or position == 'right':
t = cb.ax.get_yticklabels()
t[0].set_verticalalignment('bottom')
t[-1].set_verticalalignment('top')
# Set focus back from colorbar to previous axis and return colorbar:
plt.sca(previous_axis)
return cb
### MATPLOTLIB STYLESHEET FOR EMPYRE IMAGES
font.family : serif ## default font family (use serifs)
font.serif : cm ## Computer Modern (LaTeX font)
xtick.top : False ## draw ticks on the top side
xtick.bottom : False ## draw ticks on the bottom side
xtick.labeltop : False ## draw label on the top
xtick.labelbottom : False ## draw label on the bottom
ytick.left : False ## draw ticks on the left side
ytick.right : False ## draw ticks on the right side
ytick.labelleft : False ## draw tick labels on the left side
ytick.labelright : False ## draw tick labels on the right side
figure.figsize : 3.6, 3.6 ## figure size in inches
figure.dpi : 200 ## figure dots per inch
figure.constrained_layout.use : True ## use constrained layout
image.origin : lower ## lower | upper
### MATPLOTLIB STYLESHEET FOR EMPYRE PLOTS
font.family : serif ## default font family (use serifs)
font.serif : cm ## Computer Modern (LaTeX font)
figure.figsize : 3.6, 2.2 ## figure size in inches
figure.dpi : 200 ## figure dots per inch
figure.constrained_layout.use : True ## use constrained layout
### MATPLOTLIB STYLESHEET FOR SAVING EMPYRE IMAGES AND PLOTS
font.family : serif ## default font family (use serifs)
font.serif : cm ## Computer Modern (LaTeX font)
savefig.dpi : 200 ## figure dots per inch or 'figure'
savefig.facecolor : white ## figure facecolor when saving
savefig.edgecolor : white ## figure edgecolor when saving
savefig.format : pdf ## png, ps, pdf, svg
savefig.bbox : tight ## 'tight' or 'standard'.
savefig.pad_inches : 0.01 ## Padding to be used when bbox is set to 'tight'
savefig.jpeg_quality : 95 ## when a jpeg is saved, the default quality parameter.
savefig.directory : ~ ## default directory in savefig dialog box, leave empty to always use cwd
savefig.transparent : False ## controls whether figures are saved with transp. background by default
savefig.orientation : portrait ## Orientation of saved figure
# -*- coding: utf-8 -*-
# Copyright 2019 by Forschungszentrum Juelich GmbH
# Author: J. Caron
#
"""This module provides functions for 2D plots that often wrap functions from `maptlotlib.pyplot`."""
import logging
import warnings
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import TwoSlopeNorm
from PIL import Image
from . import colors
from .tools import use_style
from ..fields.field import Field
__all__ = ['imshow', 'contour', 'colorvec', 'cosine_contours', 'quiver']
_log = logging.getLogger(__name__)
DIVERGING_CMAPS = ['PiYG', 'PRGn', 'BrBG', 'PuOr', 'RdGy', 'RdBu', 'RdYlBu', 'RdYlGn',
'Spectral', 'coolwarm', 'bwr', 'seismic', # all divergent maps from matplotlib!
'balance', 'delta', 'curl', 'diff', 'tarn'] # all divergent maps from cmocean!
# TODO: add seaborn and more?
def imshow(field, axis=None, cmap=None, **kwargs):
"""Display an image on a 2D regular raster. Wrapper for `matplotlib.pyplot.imshow`.
Parameters
----------
field : `Field` or ndarray
The image data as a `Field` or a numpy array (in the latter case, `vector=False` and `scale=1.0` are assumed).
axis : `matplotlib.axes.Axes` object, optional
The axis to which the image should be added, by default None, which will pick the last use axis via `gca`.
cmap : str or `matplotlib.colors.Colormap`, optional
The Colormap that should be used for the display, either as a string or object, by default None, which will pick
`cmocean.cm.balance` if available. `imshow` will automatically detect if a divergent colormap is used and will
make sure that zero is pinned to the symmetry point of the colormap (this is done by creating a new colormap
with custom range under the hood).
Returns
-------
axis : `matplotlib.axes.Axes`
The plotting axis.
Notes
-----
Additional kwargs are passed to :meth:`~matplotlib.pyplot.imshow`.
Note that the y-axis of the plot is flipped in comparison to :meth:`~matplotlib.pyplot.imshow`, i.e. that the
origin is `'lower'` in this case instead of `'upper'`.
Uses the `empyre-image` stylesheet settings for plotting (and axis creation if none exists, yet).
Fields are squeezed before plotting, so non-2D fields work as long as their superfluous dimensions have length 1.
"""
_log.debug('Calling imshow')
if not isinstance(field, Field): # Try to convert input to Field if it is not already one:
field = Field(data=np.asarray(field), scale=1.0, vector=False)
assert not field.vector, 'Can only plot scalar fields!'
# Get squeezed data and make sure it's 2D scalar:
squeezed_field = field.squeeze()
assert len(squeezed_field.dim) == 2, 'Cannot plot more than 2 dimensions!'
# Determine colormap and related important properties and flags:
if cmap is None:
try:
import cmocean
cmap = cmocean.cm.balance
except ImportError:
_log.info('cmocean.balance not found, fallback to rRdBu!')
cmap = plt.get_cmap('RdBu_r') # '_r' for reverse!
elif isinstance(cmap, str): # make sure we have a Colormap object (and not a string):
cmap = plt.get_cmap(cmap)
if cmap.name.replace('_r', '') in DIVERGING_CMAPS: # 'replace' also matches reverted cmaps!
kwargs.setdefault('norm', TwoSlopeNorm(0)) # Diverging colormap should have zero at the symmetry point!
# Set extent in data coordinates (left, right, bottom, top) to kwargs (if not set explicitely):
dim_v, dim_u = squeezed_field.dim
s_v, s_u = squeezed_field.scale
kwargs.setdefault('extent', (0, dim_u * s_u, 0, dim_v * s_v))
# Plot with the empyre style context:
with use_style('empyre-image'): # Only works on axes created WITHIN context!
if axis is None: # If no axis is set, find the current or create a new one:
axis = plt.gca()
return axis.imshow(squeezed_field, cmap=cmap, **kwargs)
def contour(field, axis=None, **kwargs):
"""Plot contours. Wrapper for `matplotlib.pyplot.contour`.
Parameters
----------
field : `Field` or ndarray
The contour data as a `Field` or a numpy array (in the latter case, `vector=False` and `scale=1.0` are assumed).
axis : `matplotlib.axes.Axes` object, optional
The axis to which the contour should be added, by default None, which will pick the last use axis via `gca`.
Returns
-------
axis : `matplotlib.axes.Axes`
The plotting axis.
Notes
-----
Additional kwargs are passed to `matplotlib.pyplot.contour`.
Note that the y-axis of the plot is flipped in comparison to :meth:`~matplotlib.pyplot.imshow`, i.e. that the
origin is `'lower'` in this case instead of `'upper'`.
Uses the `empyre-image` stylesheet settings for plotting (and axis creation if none exists, yet).
Fields are squeezed before plotting, so non-2D fields work as long as their superfluous dimensions have length 1.
"""
_log.debug('Calling contour')
if not isinstance(field, Field): # Try to convert input to Field if it is not already one:
field = Field(data=np.asarray(field), scale=1.0, vector=False)
assert not field.vector, 'Can only plot scalar fields!'
# Get squeezed data and make sure it's 2D scalar:
squeezed_field = field.squeeze()
assert len(squeezed_field.dim) == 2, 'Cannot plot more than 2 dimensions!'
# Create coordinates (respecting the field scale, +0.5: pixel center!):
vv, uu = (np.indices(squeezed_field.dim) + 0.5) * np.asarray(squeezed_field.scale)[:, None, None]
# Set kwargs defaults without overriding possible user input:
kwargs.setdefault('levels', [0.5])
kwargs.setdefault('colors', 'k')
kwargs.setdefault('linestyles', 'dotted')
kwargs.setdefault('linewidths', 2)
# Plot with the empyre style context:
with use_style('empyre-image'): # Only works on axes created WITHIN context!
if axis is None: # If no axis is set, find the current or create a new one:
axis = plt.gca()
axis.set_aspect('equal')
return axis.contour(uu, vv, squeezed_field.data, **kwargs)
def colorvec(field, axis=None, **kwargs):
"""Plot an image of a 2D vector field with up to 3 components by color encoding the vector direction.
In-plane directions are encoded via hue ("color wheel"), making sure that all in-plane directions are isoluminant
(i.e. a greyscale image would result a homogeneously medium grey image). Out-of-plane directions are encoded via
brightness with upwards pointing vectors being white and downward pointing vectors being black. The length of the
vectors are encoded via saturation, with full saturation being fully chromatic (in-plane) or fully white/black
(up/down). The center of the "color sphere" desaturated in a medium gray and encodes vectors with length zero.
Parameters
----------
field : `Field` or ndarray
The image data as a `Field` or a numpy array (in the latter case, `vector=True` and `scale=1.0` are assumed).
axis : `matplotlib.axes.Axes` object, optional
The axis to which the image should be added, by default None, which will pick the last use axis via `gca`.
Returns
-------
axis : `matplotlib.axes.Axes`
The plotting axis.
Notes
-----
Additional kwargs are passed to `matplotlib.pyplot.imshow`.
Note that the y-axis of the plot is flipped in comparison to :meth:`~matplotlib.pyplot.imshow`, i.e. that the
origin is `'lower'` in this case instead of `'upper'`.
Uses the `empyre-image` stylesheet settings for plotting (and axis creation if none exists, yet).
Fields are squeezed before plotting, so non-2D fields work as long as their superfluous dimensions have length 1.
Even though squeezing takes place, `colorvec` "remembers" the original orientation of the slice! This is important
if you want to plot a slice that should not represent the xy-plane. The colors chosen will respect the original
orientation of your slice, e.g. a vortex in the xz-plane will include black and white colors (up/down) if the
`Field` object given as the `field` parameter has `dim=(128, 1, 128)`. If you want to plot a slice of a 3D vector
with 3 components and make use of this functionality, make sure to not use an integer as an index, as that will
drop the dimension BEFORE it is passed to `colorvec`, which will have no way of knowing which dimension was dropped.
Instead, make sure to use a slice of length one (example with `dim=(128, 128, 128)`):
>>> colorvec(field[:, 15, :]) # Wrong! Shape: (128, 128), interpreted as xy-plane!
>>> colorvec(field[:, 15:16, :]) # Right! Shape: (128, 1, 128), passed as 3D to `colorvec`, squeezed internally!
"""
_log.debug('Calling colorvec')
if not isinstance(field, Field): # Try to convert input to Field if it is not already one:
field = Field(data=np.asarray(field), scale=1.0, vector=True)
assert field.vector, 'Can only plot vector fields!'
assert len(field.dim) <= 3, 'Unusable for vector fields with dimension higher than 3!'
# Get squeezed data and make sure it's 2D scalar:
squeezed_field = field.squeeze()
assert len(squeezed_field.dim) == 2, 'Cannot plot more than 2 dimensions!'
# Extract vector components (fill 3rd component with zeros if field.comp is only 2):
comp = squeezed_field.comp
x_comp = comp[0]
y_comp = comp[1]
z_comp = comp[2] if (squeezed_field.ncomp == 3) else np.zeros(squeezed_field.dim)
# Calculate image with color encoded directions:
cmap = kwargs.pop('cmap', None)
if cmap is None:
cmap = colors.cmaps.cyclic_cubehelix
rgb = cmap.rgb_from_vector(np.stack((x_comp, y_comp, z_comp), axis=0))
# Set extent in data coordinates (left, right, bottom, top) to kwargs (if not set explicitely):
dim_v, dim_u = squeezed_field.dim
s_v, s_u = squeezed_field.scale
kwargs.setdefault('extent', (0, dim_u * s_u, 0, dim_v * s_v))
# Plot with the empyre style context:
with use_style('empyre-image'): # Only works on axes created WITHIN context!
if axis is None: # If no axis is set, find the current or create a new one:
axis = plt.gca()
return axis.imshow(Image.fromarray(rgb), **kwargs)
def cosine_contours(field, axis=None, gain='auto', cmap=None, **kwargs):
"""Plots the cosine of the (amplified) field. Wrapper for `matplotlib.pyplot.imshow`.
Parameters
----------
field : `Field` or ndarray
The contour data as a `Field` or a numpy array (in the latter case, `vector=False` and `scale=1.0` are assumed).
axis : `matplotlib.axes.Axes` object, optional
The axis to which the contour should be added, by default None, which will pick the last use axis via `gca`.
gain : float or 'auto', optional
Gain factor with which the `Field` is amplified before taking the cosine, by default 'auto', which calculates a
gain factor that would produce roughly 4 cosine contours.
cmap : str or `matplotlib.colors.Colormap`, optional
The Colormap that should be used for the display, either as a string or object, by default None, which will pick
`colors.cmaps['transparent_black']` that will alternate between regions with alpha=0, showing layers below and
black contours.
Returns
-------
axis : `matplotlib.axes.Axes`
The plotting axis.
Notes
-----
Additional kwargs are passed to `matplotlib.pyplot.imshow`.
Uses the `empyre-image` stylesheet settings for plotting (and axis creation if none exists, yet).
Fields are squeezed before plotting, so non-2D fields work as long as their superfluous dimensions have length 1.
"""
_log.debug('Calling cosine_contours')
if not isinstance(field, Field): # Try to convert input to Field if it is not already one:
field = Field(data=np.asarray(field), scale=1.0, vector=False)
assert not field.vector, 'Can only plot scalar fields!'
# Get squeezed data and make sure it's 2D scalar:
squeezed_field = field.squeeze()
assert len(squeezed_field.dim) == 2, 'Cannot plot more than 2 dimensions (Squeezing did not help)!'
# Determine colormap and related important properties and flags:
if cmap is None:
cmap = colors.cmaps['transparent_black']
# Calculate gain if 'auto' is selected:
if gain == 'auto':
gain = 4 * 2*np.pi / (squeezed_field.amp.data.max() + 1E-30) # 4: roughly 4 contours!
gain = round(gain, -int(np.floor(np.log10(abs(gain))))) # Round to last significant digit!
_log.info(f'Automatically calculated a gain of: {gain}')
# Calculate the contours:
contours = np.cos(gain * squeezed_field) # Range: [-1, 1]
contours += 1 # Shift to positive values
contours /= 2 # Rescale to [0, 1]
# Set extent in data coordinates (left, right, bottom, top) to kwargs (if not set explicitely):
dim_v, dim_u = squeezed_field.dim
s_v, s_u = squeezed_field.scale
kwargs.setdefault('extent', (0, dim_u * s_u, 0, dim_v * s_v))
# Plot with the empyre style context:
with use_style('empyre-image'): # Only works on axes created WITHIN context!
if axis is None: # If no axis is set, find the current or create a new one:
axis = plt.gca()
return axis.imshow(contours, cmap=cmap, **kwargs)
def quiver(field, axis=None, color_angles=False, cmap=None, n_bin='auto', bin_with_mask=True, **kwargs):
"""Plot a 2D field of arrows. Wrapper for `matplotlib.pyplot.imshow`.
Parameters
----------
field : `Field` or ndarray
The vector data as a `Field` or a numpy array (in the latter case, `vector=True` and `scale=1.0` are assumed).
axis : `matplotlib.axes.Axes` object, optional
The axis to which the image should be added, by default None, which will pick the last use axis via `gca`.
color_angles : bool, optional
Switch that turns on color encoding of the arrows, by default False. Encoding works the same as for the
`colorvec` function (see for details). If False, arrows are uniformly colored white with black border. In both
cases, the amplitude is encoded via the transparency of the arrow.
cmap : str or `matplotlib.colors.Colormap`, optional
The Colormap that should be used for the arrows, either as a string or object, by default None. Will only be
used if `color_angles=True`.
n_bin : float or 'auto', optional
Number of entries along each axis over which the average is taken, by default 'auto', which automatically
determines a bin size resulting in roughly 16 arrows along the largest dimension. Usually sensible to leave
this on to not clutter the image with too many arrows (also due to performance). Can be turned off by setting
`n_bin=1`. Uses the `..fields.field.Field.bin` method.
bin_with_mask : bool, optional
If True (default) and if `n_bin>1`, entries of the constructed binned `Field` that averaged over regions that
were outside the `..fields.field.Field.mask` will not be assigned an arrow and stay empty instead. This prevents
errouneous "fade-out" effects of the arrows that would occur even for homogeneous objects.
Returns
-------
quiv : Quiver instance
The quiver instance that was created.
Notes
-----
Additional kwargs are passed to `matplotlib.pyplot.quiver`.
Uses the `empyre-image` stylesheet settings for plotting (and axis creation if none exists, yet).
Fields are squeezed before plotting, so non-2D fields work as long as their superfluous dimensions have length 1.
Even though squeezing takes place, `quiver` "remembers" the original orientation of the slice and which dimensions
were squeezed! See `colorvec` for more information and an example (the same principles apply here, too).
The transparency of the arrows denotes the 3D(!) amplitude, if you see dots in the plot, that means the amplitude
is not zero, but simply out of the current plane!
"""
_log.debug('Calling quiver')
if not isinstance(field, Field): # Try to convert input to Field if it is not already one:
field = Field(data=np.asarray(field), scale=1.0, vector=True)
assert field.vector, 'Can only plot vector fields!'
assert len(field.dim) <= 3, 'Unusable for vector fields with dimension higher than 3!'
if len(field.dim) < field.ncomp:
warnings.warn('Assignment of vector components to dimensions is ambiguous!'
f'`ncomp` ({field.ncomp}) should match `len(dim)` ({len(field.dim)})!'
'If you want to plot a slice of a 3D volume, make sure to use `from:to` notation!')
# Get squeezed data and make sure it's 2D scalar:
squeezed_field = field.squeeze()
assert len(squeezed_field.dim) == 2, 'Cannot plot more than 2 dimensions (Squeezing did not help)!'
# Determine binning size if necessary:
if n_bin == 'auto':
n_bin = int(np.max((1, np.max(squeezed_field.dim) / 16)))
# Save old limits in case binning has to use padding:
u_lim = squeezed_field.dim[1] * squeezed_field.scale[1]
v_lim = squeezed_field.dim[0] * squeezed_field.scale[0]
# Bin if necessary:
if n_bin > 1:
field_mask = squeezed_field.mask # Get mask BEFORE binning!
squeezed_field = squeezed_field.bin(n_bin)
if bin_with_mask: # Excludes regions where in and outside are binned together!
mask = (field_mask.bin(n_bin) == 1)
squeezed_field *= mask
# Extract normalized vector components (fill 3rd component with zeros if field.comp is only 2):
normalised_comp = (squeezed_field / squeezed_field.amp.data.max()).comp
amplitude = squeezed_field.amp.data / squeezed_field.amp.data.max()
x_comp = normalised_comp[0].data
y_comp = normalised_comp[1].data
z_comp = normalised_comp[2].data if (field.ncomp == 3) else np.zeros(squeezed_field.dim)
# Create coordinates (respecting the field scale, +0.5: pixel center!):
vv, uu = (np.indices(squeezed_field.dim) + 0.5) * np.asarray(squeezed_field.scale)[:, None, None]
# Calculate the arrow colors:
if color_angles: # Color angles according to calculated RGB values (only with circular colormaps):
_log.debug('Encoding angles')
if cmap is None:
cmap = colors.cmaps.cyclic_cubehelix
rgb = cmap.rgb_from_vector(np.asarray((x_comp, y_comp, z_comp))) / 255
rgba = np.concatenate((rgb, amplitude[..., None]), axis=-1)
kwargs.setdefault('color', rgba.reshape(-1, 4))
else: # Color amplitude with numeric values, according to cmap, overrides 'color':
_log.debug('Encoding amplitudes')
if cmap is None:
cmap = colors.cmaps['transparent_white']
C = amplitude # Numeric values, used with cmap!
# Check which (if any) indices were squeezed to find out which components are passed to quiver: # TODO: TEST!!!
squeezed_indices = np.flatnonzero(np.asarray(field.dim) == 1)
if not squeezed_indices: # Separate check, because in this case squeezed_indices == []:
u_comp = x_comp
v_comp = y_comp
elif squeezed_indices[0] == 0: # Slice of the xy-plane with z squeezed:
u_comp = x_comp
v_comp = y_comp
elif squeezed_indices[0] == 1: # Slice of the xz-plane with y squeezed:
u_comp = x_comp
v_comp = z_comp
elif squeezed_indices[0] == 2: # Slice of the zy-plane with x squeezed:
u_comp = y_comp
v_comp = z_comp
# Set specific defaults for quiver kwargs:
kwargs.setdefault('edgecolor', colors.cmaps['transparent_black'](amplitude).reshape(-1, 4))
kwargs.setdefault('scale', 1/np.max(squeezed_field.scale))
kwargs.setdefault('width', np.max(squeezed_field.scale))
kwargs.setdefault('clim', (0, 1))
kwargs.setdefault('pivot', 'middle')
kwargs.setdefault('units', 'xy')
kwargs.setdefault('scale_units', 'xy')
kwargs.setdefault('minlength', 0.05)
kwargs.setdefault('headlength', 2)
kwargs.setdefault('headaxislength', 2)
kwargs.setdefault('headwidth', 2)
kwargs.setdefault('minshaft', 2)
kwargs.setdefault('linewidths', 1)
# Plot with the empyre style context:
with use_style('empyre-image'): # Only works on axes created WITHIN context!
if axis is None: # If no axis is set, find the current or create a new one:
axis = plt.gca()
axis.set_xlim(0, u_lim)
axis.set_ylim(0, v_lim)
axis.set_aspect('equal')
if color_angles:
return axis.quiver(uu, vv, np.asarray(u_comp), np.asarray(v_comp), cmap=cmap, **kwargs)
else:
return axis.quiver(uu, vv, np.asarray(u_comp), np.asarray(v_comp), C, cmap=cmap, **kwargs)
# -*- coding: utf-8 -*-
# Copyright 2020 by Forschungszentrum Juelich GmbH
# Author: J. Caron
#
"""This module provides functions for 3D plots based on the `mayavi` library."""
import logging
import numpy as np
from . import colors
__all__ = ['contour3d', 'mask3d', 'quiver3d']
_log = logging.getLogger(__name__)
# TODO: Docstrings and signature!
def contour3d(field, title='Field Distribution', contours=10, opacity=0.25, size=None, new_fig=True, **kwargs):
"""Plot a field as a 3D-contour plot.
Parameters
----------
title: string, optional
The title for the plot.
contours: int, optional
Number of contours which should be plotted.
opacity: float, optional
Defines the opacity of the contours. Default is 0.25.
Returns
-------
plot : :class:`mayavi.modules.vectors.Vectors`
The plot object.
"""
_log.debug('Calling contour3d')
try:
from mayavi import mlab
except ImportError:
_log.error('This extension recquires the mayavi package!')
return
if size is None:
size = (750, 700)
if new_fig:
mlab.figure(size=size, bgcolor=(0.5, 0.5, 0.5), fgcolor=(0., 0., 0.))
zzz, yyy, xxx = np.indices(field.dim) + np.reshape(field.scale, (3, 1, 1, 1)) / 2 # shifted by half of scale!
zzz, yyy, xxx = zzz.T, yyy.T, xxx.T # Transpose because of VTK order!
field_amp = field.amp.data.T # Transpose because of VTK order!
if not isinstance(contours, (list, tuple, np.ndarray)): # Calculate the contours:
contours = list(np.linspace(field_amp.min(), field_amp.max(), contours))
extent = np.ravel(list(zip((0, 0, 0), field_amp.shape)))
cont = mlab.contour3d(xxx, yyy, zzz, field_amp, contours=contours, opacity=opacity, **kwargs)
mlab.outline(cont, extent=extent)
mlab.axes(cont, extent=extent)
mlab.title(title, height=0.95, size=0.35)
mlab.orientation_axes()
cont.scene.isometric_view()
return cont
def mask3d(field, title='Mask', threshold=0, grid=True, labels=True,
orientation=True, size=None, new_fig=True, **kwargs):
"""Plot the mask as a 3D-contour plot.
Parameters
----------
title: string, optional
The title for the plot.
threshold : float, optional
A pixel only gets masked, if it lies above this threshold . The default is 0.
Returns
-------
plot : :class:`mayavi.modules.vectors.Vectors`
The plot object.
"""
_log.debug('Calling mask3d')
try:
from mayavi import mlab
except ImportError:
_log.error('This extension recquires the mayavi package!')
return
if size is None:
size = (750, 700)
if new_fig:
mlab.figure(size=size, bgcolor=(0.5, 0.5, 0.5), fgcolor=(0., 0., 0.))
zzz, yyy, xxx = np.indices(field.dim) + np.reshape(field.scale, (3, 1, 1, 1)) / 2 # shifted by half of scale!
zzz, yyy, xxx = zzz.T, yyy.T, xxx.T # Transpose because of VTK order!
mask = field.mask.data.T.astype(int) # Transpose because of VTK order!
extent = np.ravel(list(zip((0, 0, 0), mask.shape)))
cont = mlab.contour3d(xxx, yyy, zzz, mask, contours=[1], **kwargs)
if grid:
mlab.outline(cont, extent=extent)
if labels:
mlab.axes(cont, extent=extent)
mlab.title(title, height=0.95, size=0.35)
if orientation:
oa = mlab.orientation_axes()
oa.marker.set_viewport(0, 0, 0.4, 0.4)
mlab.draw()
engine = mlab.get_engine()
scene = engine.scenes[0]
scene.scene.isometric_view()
return cont
def quiver3d(field, title='Vector Field', limit=None, cmap=None, mode='2darrow',
coloring='angle', ar_dens=1, opacity=1.0, grid=True, labels=True,
orientation=True, size=(700, 750), new_fig=True, view='isometric',
position=None, bgcolor=(0.5, 0.5, 0.5)):
"""Plot the vector field as 3D-vectors in a quiverplot.
Parameters
----------
title : string, optional
The title for the plot.
limit : float, optional
Plotlimit for the vector field arrow length used to scale the colormap.
cmap : string, optional
String describing the colormap which is used for color encoding (uses `~.colors.cmaps.cyclic_cubehelix` if
left on the `None` default) or amplitude encoding (uses 'jet' if left on the `None` default).
ar_dens: int, optional
Number defining the arrow density which is plotted. A higher ar_dens number skips more
arrows (a number of 2 plots every second arrow). Default is 1.
mode: string, optional
Mode, determining the glyphs used in the 3D plot. Default is '2darrow', which
corresponds to 2D arrows. For smaller amounts of arrows, 'arrow' (3D) is prettier.
coloring : {'angle', 'amplitude'}, optional
Color coding mode of the arrows. Use 'angle' (default) or 'amplitude'.
opacity: float, optional
Defines the opacity of the arrows. Default is 1.0 (completely opaque).
Returns
-------
plot : :class:`mayavi.modules.vectors.Vectors`
The plot object.
"""
_log.debug('Calling quiver_plot3D')
try:
from mayavi import mlab
except ImportError:
_log.error('This extension recquires the mayavi package!')
return
if limit is None:
limit = np.max(np.nan_to_num(field.amp))
ad = ar_dens
# Create points and vector components as lists:
zzz, yyy, xxx = (np.indices(field.dim) + 1 / 2)
zzz = zzz[::ad, ::ad, ::ad].ravel()
yyy = yyy[::ad, ::ad, ::ad].ravel()
xxx = xxx[::ad, ::ad, ::ad].ravel()
x_mag = field.data[::ad, ::ad, ::ad, 0].ravel()
y_mag = field.data[::ad, ::ad, ::ad, 1].ravel()
z_mag = field.data[::ad, ::ad, ::ad, 2].ravel()
# Plot them as vectors:
if new_fig:
mlab.figure(size=size, bgcolor=(0.5, 0.5, 0.5), fgcolor=(0., 0., 0.))
if coloring == 'angle': # Encodes the full angle via colorwheel and saturation:
_log.debug('Encoding full 3D angles')
vecs = mlab.quiver3d(xxx, yyy, zzz, x_mag, y_mag, z_mag, mode=mode, opacity=opacity,
scalars=np.arange(len(xxx)), line_width=2)
vector = np.asarray((x_mag.ravel(), y_mag.ravel(), z_mag.ravel()))
if cmap is None:
cmap = colors.cmaps.cyclic_cubehelix
rgb = cmap.rgb_from_vector(vector)
rgba = np.hstack((rgb, 255 * np.ones((len(xxx), 1), dtype=np.uint8)))
vecs.glyph.color_mode = 'color_by_scalar'
vecs.module_manager.scalar_lut_manager.lut.table = rgba
mlab.draw()
elif coloring == 'amplitude': # Encodes the amplitude of the arrows with the jet colormap:
_log.debug('Encoding amplitude')
if cmap is None:
cmap = 'jet'
vecs = mlab.quiver3d(xxx, yyy, zzz, x_mag, y_mag, z_mag,
mode=mode, colormap=cmap, opacity=opacity, line_width=2)
mlab.colorbar(label_fmt='%.2f')
mlab.colorbar(orientation='vertical')
else:
raise AttributeError('Coloring mode not supported!')
vecs.glyph.glyph_source.glyph_position = 'center'
vecs.module_manager.vector_lut_manager.data_range = np.array([0, limit])
extent = np.ravel(list(zip((0, 0, 0), (field.dim[2], field.dim[1], field.dim[0]))))
if grid:
mlab.outline(vecs, extent=extent)
if labels:
mlab.axes(vecs, extent=extent)
mlab.title(title, height=0.95, size=0.35)
if orientation:
oa = mlab.orientation_axes()
oa.marker.set_viewport(0, 0, 0.4, 0.4)
mlab.draw()
engine = mlab.get_engine()
scene = engine.scenes[0]
if view == 'isometric':
scene.scene.isometric_view()
elif view == 'x_plus_view':
scene.scene.x_plus_view()
elif view == 'y_plus_view':
scene.scene.y_plus_view()
if position:
scene.scene.camera.position = position
return vecs
# -*- coding: utf-8 -*-
# Copyright 2020 by Forschungszentrum Juelich GmbH
# Author: J. Caron
#
"""This module provides helper functions to the vis module."""
import os
import glob
import shutil
import logging
from numbers import Number
from contextlib import contextmanager
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from ..fields.field import Field
__all__ = ['new', 'savefig', 'calc_figsize', 'use_style', 'copy_mpl_stylesheets']
_log = logging.getLogger(__name__)
def new(nrows=1, ncols=1, mode='image', figsize=None, textwidth=None, width_scale=1.0, aspect=None, **kwargs):
R"""Convenience function for the creation of a new subplot grid (wraps `~matplotlib.pyplot.subplots`).
If you use the `textwidth` parameter, plot sizes are fitting into publications with LaTeX. Requires two stylesheets
`empyre-image` and `empyre-plot` corresponding to its two `mode` settings. Those stylesheets use
`constrained_layout=True` to achieve well behaving plots without much whitespace around. This function should work
fine for a small number of images (e.g. 1, 2x2, etc.), for more fine grained control, the contexts can be used
directly if they are installed corretly, or use width_scale to build the images separately (e.g. 2 adjacent with
width=0.5). For images, it is assumed that most images are square (and therefore `aspect=1`).
Parameters
----------
nrows : int, optional
Number of rows of the subplot grid, by default 1
ncols : int, optional
Number of columns of the subplot grid, by default 1
mode : {'image', 'plot'}, optional
Mode of the new subplot grid, by default 'image'. Both modes have dedicated matplotlib styles which are used
and which are installed together with EMPyRe. The 'image' mode disables axis labels and ticks, mainly intended
to be used with `~matplotlib.pyplot.imshow` with `~empyre.vis.decorators.scalebar`, while the 'plot'
mode should be used for traditional plots like with `~matplotlib.pyplot.plot` or `~matplotlib.pyplot.scatter`.
figsize : (float, float), optional
Width and height of the figure in inches, defaults to rcParams["figure.figsize"], which depends on the chosen
stylesheet. If set, this will overwrite all other following parameters.
textwidth : float, optional
The textwidth of your LaTeX document in points, which you can get by using :math:`\the\textwidth`. If this is
not None (the default), this will be used to define the figure size if it is not set explicitely.
width_scale : float, optional
Only meaningful if `textwidth` is set. If it is, `width_scale` will be a scaling factor for the figure width.
Example: if you set this to 0.5, your figure will span half of the textwidth. Default is 1.
aspect : float, optional
Aspect ratio of the figure height relative to the figure width. If None (default), the aspect is set to be 1
for `mode=image` and to 'golden' for `mode=plot`, which adjusts the aspect to represent the golden ratio of
0.6180... If `ncols!=nrows`, it often makes sense to use `aspect=nrows/ncols` here.
Returns
-------
fig : :class:`~matplotlib.figure.Figure`
The constructed figure.
axes : axes.Axes object or array of Axes objects.
axes can be either a single Axes object or an array of Axes objects if more than one subplot was created.
The dimensions of the resulting array can be controlled with the squeeze keyword argument.
Notes
-----
additional kwargs are passed to `~matplotlib.pyplot.subplots`.
"""
_log.debug('Calling new')
assert mode in ('image', 'plot'), "mode has to be 'image', or 'plot'!"
with use_style(f'empyre-{mode}'):
if figsize is None:
if aspect is None:
aspect = 'golden' if mode == 'plot' else 1 # Both image modes have 'same' as default'!
elif isinstance(aspect, Field):
dim_uv = [d for d in aspect.dim if d != 1]
assert len(dim_uv) == 2, f"Couldn't find field aspect ({len(dim_uv)} squeezed dimensions, has to be 2)!"
aspect = dim_uv[0]/dim_uv[1] # height/width
else:
assert isinstance(aspect, Number), 'aspect has to be None, a number or field instance squeezable to 2D!'
figsize = calc_figsize(textwidth=textwidth, width_scale=width_scale, aspect=aspect)
return plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize, **kwargs)
def savefig(fname, **kwargs):
"""Utility wrapper around :func:`~matplotlib.pyplot.savefig` to save the current figure.
Parameters
----------
fname : str or PathLike or file-like object
Path to the file wherein the figure should be saved.
Notes
-----
Uses the 'empyre-save' stylesheet (installed together with EMPyRe to control the saving behaviour. Any kwargs are
passed to :func:`~matplotlib.pyplot.savefig`.
"""
_log.debug('Calling savefig')
with use_style('empyre-save'):
plt.savefig(fname, **kwargs)
def calc_figsize(textwidth=None, width_scale=1.0, aspect=1):
R"""Helper function to calculate the figure size from various parameters. Useful for publications via LaTeX.
Parameters
----------
textwidth : float, optional
The textwidth of your LaTeX document in points, which you can get by using :math:`\the\textwidth`. If this is
None (default), the standard width in inches from the current stylesheet is used.
width_scale : float, optional
Scaling factor for the figure width. Example: if you set this to 0.5, your figure will span half of the
textwidth. Default is 1.
aspect : float, optional
Aspect ratio of the figure height relative to the figure width. If None (default), the aspect is set to be 1
for `mode=image` and to 'golden' for `mode=plot`, which adjusts the aspect to represent the golden ratio of
0.6180...
Returns
-------
figsize: (float, float)
The determined figure size
Notes
-----
Based on snippet by Florian Winkler.
"""
_log.debug('Calling calc_figsize')
GOLDEN_RATIO = (1 + np.sqrt(5)) / 2 # Aesthetic ratio!
INCHES_PER_POINT = 1.0 / 72.27 # Convert points to inch, LaTeX constant, apparently...
if textwidth is not None:
textwidth_in = textwidth * INCHES_PER_POINT # Width of the text in inches
else: # If textwidth is not given, use the default from rcParams:
textwidth_in = mpl.rcParams["figure.figsize"][0]
fig_width = textwidth_in * width_scale # Width in inches
if aspect == 'golden':
fig_height = fig_width / GOLDEN_RATIO
elif isinstance(aspect, Number):
fig_height = textwidth_in * aspect
else:
raise ValueError(f"aspect has to be either a number, or 'golden'! Was {aspect}!")
fig_size = [fig_width, fig_height] # Both in inches
return fig_size
@contextmanager
def use_style(stylename):
"""Context that uses a matplotlib stylesheet. Can fall back to local mpl stylesheets if necessary!
Parameters
----------
stylename : str
A style specification.
Yields
-------
context
Context manager for using style settings temporarily.
"""
try: # Try to load the style directly (works if it is installed somewhere mpl looks for it):
with plt.style.context(stylename) as context:
yield context
except OSError: # Stylesheet not found, use local ones:
mplstyle_path = os.path.join(os.path.dirname(__file__), 'mplstyles', f'{stylename}.mplstyle')
with plt.style.context(mplstyle_path) as context:
yield context
def copy_mpl_stylesheets():
"""Copy matplotlib styles to the users matplotlib config directory. Useful if you want to utilize them elsewhere.
Notes
-----
You might need to restart your Python session for the stylesheets to be recognized/found!
"""
# Find matplotlib styles:
user_stylelib_path = os.path.join(mpl.get_configdir(), 'stylelib')
vis_dir = os.path.dirname(__file__)
style_files = glob.glob(os.path.join(vis_dir, 'mplstyles', '*.mplstyle'))
# Copy them to the local matplotlib styles folder:
if not os.path.exists(user_stylelib_path):
os.makedirs(user_stylelib_path)
for style_path in style_files:
_, fname = os.path.split(style_path)
dest = os.path.join(user_stylelib_path, fname)
shutil.copy(style_path, dest)
import os
import pytest
import numpy as np
from empyre.fields import Field
@pytest.fixture
def fielddata_path():
return os.path.join(os.path.dirname(os.path.realpath(__file__)), 'test_fielddata')
@pytest.fixture
def vector_data():
magnitude = np.zeros((4, 4, 4, 3))
magnitude[1:-1, 1:-1, 1:-1] = 1
return Field(magnitude, 10.0, vector=True)
@pytest.fixture
def vector_data_asymm():
shape = (5, 7, 11, 3)
data = np.linspace(0, 1, np.prod(shape))
return Field(data.reshape(shape), 10.0, vector=True)
@pytest.fixture
def vector_data_asymm_2d():
shape = (5, 7, 2)
data = np.linspace(0, 1, np.prod(shape))
return Field(data.reshape(shape), 10.0, vector=True)
@pytest.fixture
def vector_data_asymmcube():
shape = (3, 3, 3, 3)
data = np.linspace(0, 1, np.prod(shape))
return Field(data.reshape(shape), 10.0, vector=True)
@pytest.fixture
def scalar_data():
magnitude = np.zeros((4, 4, 4))
magnitude[1:-1, 1:-1, 1:-1] = 1
return Field(magnitude, 10.0, vector=False)
@pytest.fixture
def scalar_data_asymm():
shape = (5, 7, 2)
data = np.linspace(0, 1, np.prod(shape))
return Field(data.reshape(shape), 10.0, vector=False)
# -*- coding: utf-8 -*-
"""Testcase for the magdata module."""
import pytest
from numbers import Number
import numpy as np
import numpy.testing
from empyre.fields import Field
from utils import assert_allclose
def test_copy(vector_data):
vector_data = vector_data.copy()
# Make sure it is a new object
assert vector_data != vector_data, 'Unexpected behaviour in copy()!'
assert np.allclose(vector_data, vector_data)
def test_bin(vector_data):
binned_data = vector_data.bin(2)
reference = 1 / 8. * np.ones((2, 2, 2, 3))
assert_allclose(binned_data, reference,
err_msg='Unexpected behavior in scale_down()!')
assert_allclose(binned_data.scale, (20, 20, 20),
err_msg='Unexpected behavior in scale_down()!')
def test_zoom(vector_data):
zoomed_test = vector_data.zoom(2, order=0)
reference = np.zeros((8, 8, 8, 3))
reference[2:6, 2:6, 2:6] = 1
assert_allclose(zoomed_test, reference,
err_msg='Unexpected behavior in zoom()!')
assert_allclose(zoomed_test.scale, (5, 5, 5),
err_msg='Unexpected behavior in zoom()!')
@pytest.mark.parametrize(
'mode', [
'constant',
'edge',
'wrap'
]
)
@pytest.mark.parametrize(
'pad_width,np_pad', [
(1, ((1, 1), (1, 1), (1, 1), (0, 0))),
((1, 2, 3), ((1, 1), (2, 2), (3, 3), (0, 0))),
(((1, 2), (3, 4), (5, 6)), ((1, 2), (3, 4), (5, 6), (0, 0)))
]
)
def test_pad(vector_data, mode, pad_width, np_pad):
magdata_test = vector_data.pad(pad_width, mode=mode)
reference = np.pad(vector_data, np_pad, mode=mode)
assert_allclose(magdata_test, reference,
err_msg='Unexpected behavior in pad()!')
@pytest.mark.parametrize(
'axis', [-1, 3]
)
def test_component_reduction(vector_data, axis):
# axis=-1 is supposed to reduce over the component dimension, if it exists. axis=3 should do the same here!
res = np.sum(vector_data, axis=axis)
ref = np.zeros((4, 4, 4))
ref[1:-1, 1:-1, 1:-1] = 3
assert res.shape == ref.shape, 'Shape mismatch!'
assert_allclose(res, ref, err_msg="Unexpected behavior of axis keyword")
assert isinstance(res, Field), 'Result is not a Field object!'
assert not res.vector, 'Result is a vector field, but should be reduced to a scalar!'
@pytest.mark.parametrize(
'axis', [(0, 1, 2), (2, 1, 0), None, (-4, -3, -2)]
)
def test_full_reduction(vector_data, axis):
res = np.sum(vector_data, axis=axis)
ref = np.zeros((3,))
ref[:] = 8
assert res.shape == ref.shape
assert_allclose(res, ref, err_msg="Unexpected behavior of full or default reduction")
assert isinstance(res, np.ndarray)
@pytest.mark.parametrize(
'axis', [-1, 2]
)
def test_last_reduction_scalar(scalar_data, axis):
# axis=-1 is supposed to reduce over the component dimension if it exists.
# In this case it doesn't!
res = np.sum(scalar_data, axis=axis)
ref = np.zeros((4, 4))
ref[1:-1, 1:-1] = 2
assert res.shape == ref.shape
assert_allclose(res, ref, err_msg="Unexpected behavior of axis keyword")
assert isinstance(res, Field)
assert not res.vector
@pytest.mark.parametrize(
'axis', [(0, 1, 2), (2, 1, 0), None, (-1, -2, -3)]
)
def test_full_reduction_scalar(scalar_data, axis):
res = np.sum(scalar_data, axis=axis)
ref = 8
assert res.shape == ()
assert_allclose(res, ref, err_msg="Unexpected behavior of full or default reduction")
assert isinstance(res, Number)
def test_binary_operator_vector_number(vector_data):
res = vector_data + 1
ref = np.ones((4, 4, 4, 3))
ref[1:-1, 1:-1, 1:-1] = 2
assert res.shape == ref.shape
assert_allclose(res, ref, err_msg="Unexpected behavior of addition")
assert isinstance(res, Field)
assert res.vector
def test_binary_operator_vector_scalar(vector_data, scalar_data):
res = vector_data + scalar_data
ref = np.zeros((4, 4, 4, 3))
ref[1:-1, 1:-1, 1:-1] = 2
assert res.shape == ref.shape
assert_allclose(res, ref, err_msg="Unexpected behavior of addition")
assert isinstance(res, Field)
assert res.vector
def test_binary_operator_vector_vector(vector_data):
res = vector_data + vector_data
ref = np.zeros((4, 4, 4, 3))
ref[1:-1, 1:-1, 1:-1] = 2
assert res.shape == ref.shape
assert_allclose(res, ref, err_msg="Unexpected behavior of addition")
assert isinstance(res, Field)
assert res.vector
@pytest.mark.xfail
def test_binary_operator_vector_broadcast(vector_data):
# Broadcasting between vector fields is currently not implemented
second = np.zeros((4, 4, 3))
second[1:-1, 1:-1] = 1
second = Field(second, 10.0, vector=True)
res = vector_data + second
ref = np.zeros((4, 4, 4, 3))
ref[1:-1, 1:-1, 1:-1] = 1
ref[:, 1:-1, 1:-1] += 1
assert res.shape == ref.shape
assert_allclose(res, ref, err_msg="Unexpected behavior of addition")
assert isinstance(res, Field)
assert res.vector
def test_mask(vector_data):
mask = vector_data.mask
reference = np.zeros((4, 4, 4))
reference[1:-1, 1:-1, 1:-1] = True
assert_allclose(mask, reference,
err_msg='Unexpected behavior in mask attribute!')
def test_get_vector(vector_data):
mask = vector_data.mask
vector = vector_data.get_vector(mask)
reference = np.ones(np.sum(mask) * 3)
assert_allclose(vector, reference,
err_msg='Unexpected behavior in get_vector()!')
def test_set_vector(vector_data):
mask = vector_data.mask
vector = 2 * np.ones(np.sum(mask) * 3)
vector_data.set_vector(vector, mask)
reference = np.zeros((4, 4, 4, 3))
reference[1:-1, 1:-1, 1:-1] = 2
assert_allclose(vector_data, reference,
err_msg='Unexpected behavior in set_vector()!')
def test_flip(vector_data_asymm):
field_flipz = vector_data_asymm.flip(0)
field_flipy = vector_data_asymm.flip(1)
field_flipx = vector_data_asymm.flip(2)
field_flipxy = vector_data_asymm.flip((1, 2))
field_flipdefault = vector_data_asymm.flip()
field_flipcomp = vector_data_asymm.flip(-1)
assert_allclose(np.flip(vector_data_asymm.data, axis=0) * [1, 1, -1], field_flipz.data,
err_msg='Unexpected behavior in flip()! (z)')
assert_allclose(np.flip(vector_data_asymm.data, axis=1) * [1, -1, 1], field_flipy.data,
err_msg='Unexpected behavior in flip()! (y)')
assert_allclose(np.flip(vector_data_asymm.data, axis=2) * [-1, 1, 1], field_flipx.data,
err_msg='Unexpected behavior in flip()! (x)')
assert_allclose(np.flip(vector_data_asymm.data, axis=(1, 2)) * [-1, -1, 1], field_flipxy.data,
err_msg='Unexpected behavior in flip()! (xy)')
assert_allclose(np.flip(vector_data_asymm.data, axis=(0, 1, 2)) * [-1, -1, -1], field_flipdefault.data,
err_msg='Unexpected behavior in flip()! (default)')
assert_allclose(np.flip(vector_data_asymm.data, axis=-1) * [1, 1, 1], field_flipcomp.data,
err_msg='Unexpected behavior in flip()! (components)')
def test_unknown_num_of_components():
shape = (5, 7, 7)
data = np.linspace(0, 1, np.prod(shape))
with pytest.raises(AssertionError):
Field(data.reshape(shape), 10.0, vector=True)
def test_repr(vector_data_asymm):
string_repr = repr(vector_data_asymm)
data_str = str(vector_data_asymm.data)
string_ref = f'Field(data={data_str}, scale=(10.0, 10.0, 10.0), vector=True)'
print(f'reference: {string_ref}')
print(f'repr output: {string_repr}')
assert string_repr == string_ref, 'Unexpected behavior in __repr__()!'
def test_str(vector_data_asymm):
string_str = str(vector_data_asymm)
string_ref = 'Field(dim=(5, 7, 11), scale=(10.0, 10.0, 10.0), vector=True, ncomp=3)'
print(f'reference: {string_str}')
print(f'str output: {string_str}')
assert string_str == string_ref, 'Unexpected behavior in __str__()!'
@pytest.mark.parametrize(
"index,t,scale", [
((0, 1, 2), tuple, None),
((0, ), Field, (2., 3.)),
(0, Field, (2., 3.)),
((0, 1, 2, 0), float, None),
((0, 1, 2, 0), float, None),
((..., 0), Field, (1., 2., 3.)),
((0, slice(1, 3), 2), Field, (2.,)),
]
)
def test_getitem(vector_data, index, t, scale):
vector_data.scale = (1., 2., 3.)
data_index = index
res = vector_data[index]
assert_allclose(res, vector_data.data[data_index])
assert isinstance(res, t)
if t is Field:
assert res.scale == scale
def test_from_scalar_field(scalar_data):
sca_x, sca_y, sca_z = [i * scalar_data for i in range(1, 4)]
field_comb = Field.from_scalar_fields([sca_x, sca_y, sca_z])
assert field_comb.vector
assert field_comb.scale == scalar_data.scale
assert_allclose(sca_x, field_comb.comp[0])
assert_allclose(sca_y, field_comb.comp[1])
assert_allclose(sca_z, field_comb.comp[2])
def test_squeeze():
magnitude = np.zeros((4, 1, 4, 3))
field = Field(magnitude, (1., 2., 3.), vector=True)
sq = field.squeeze()
assert sq.shape == (4, 4, 3)
assert sq.dim == (4, 4)
assert sq.scale == (1., 3.)
def test_gradient():
pass
def test_gradient_1d():
pass
def test_curl():
pass
def test_curl_2d():
pass
def test_clip_scalar_noop():
shape = (3, 3, 3)
data = np.linspace(-2, 1, np.prod(shape)).reshape(shape)
field = Field(data, (1., 2., 3.), vector=False)
assert_allclose(field, field.clip())
def test_clip_scalar_minmax():
shape = (3, 3, 3)
data = np.linspace(-2, 1, np.prod(shape)).reshape(shape)
field = Field(data, (1., 2., 3.), vector=False)
assert_allclose(np.clip(data, -1, 0.1), field.clip(vmin=-1, vmax=0.1))
def test_clip_scalar_sigma():
shape = (3, 3, 3)
data = np.linspace(-2, 1, np.prod(shape)).reshape(shape)
data[0, 0, 0] = 1e6
field = Field(data, (1., 2., 3.), vector=False)
# We clip off the one outlier
assert_allclose(np.clip(data, -2, 1), field.clip(sigma=5))
assert field.clip(sigma=5)[0, 0, 0] == 1
def test_clip_scalar_mask():
shape = (3, 3, 3)
data = np.linspace(-2, 1, np.prod(shape)).reshape(shape)
mask = np.zeros(shape, dtype=bool)
mask[0, 0, 0] = True
mask[0, 0, 1] = True
field = Field(data, (1., 2., 3.), vector=False)
assert_allclose(np.clip(data, data[0, 0, 0], data[0, 0, 1]), field.clip(mask=mask))
def test_clip_vector_noop():
shape = (3, 3, 3, 3)
data = np.linspace(-2, 1, np.prod(shape)).reshape(shape)
field = Field(data, (1., 2., 3.), vector=True)
assert_allclose(field, field.clip())
def test_clip_vector_max():
shape = (3, 3, 3, 3)
data = np.linspace(-2, 1, np.prod(shape)).reshape(shape)
field = Field(data, (1., 2., 3.), vector=True)
res = field.clip(vmax=0.1)
assert_allclose(np.max(res.amp), 0.1)
def test_clip_vector_sigma():
shape = (3, 3, 3, 3)
data = np.linspace(-2, 1, np.prod(shape)).reshape(shape)
data[0, 0, 0] = (1e6, 1e6, 1e6)
field = Field(data, (1., 2., 3.), vector=True)
# We clip off the one outlier
res = field.clip(sigma=5)
assert np.max(res.amp) < 1e3
# TODO: HyperSpy would need to be installed for the following tests (slow...):
# def test_from_signal()
# raise NotImplementedError()
#
# def test_to_signal()
# raise NotImplementedError()
import pytest
from utils import assert_allclose
from empyre.fields import Field
import numpy as np
@pytest.mark.parametrize(
'axis', ['x', 'y', 'z']
)
def test_rot90_360(vector_data_asymm, axis):
assert_allclose(vector_data_asymm.rot90(axis=axis).rot90(axis=axis).rot90(axis=axis).rot90(axis=axis),
vector_data_asymm,
err_msg=f'Unexpected behavior in rot90()! {axis}')
@pytest.mark.parametrize(
'rot_axis,flip_axes', [
('x', (0, 1)),
('y', (0, 2)),
('z', (1, 2))
]
)
def test_rot90_180(vector_data_asymm, rot_axis, flip_axes):
res = vector_data_asymm.rot90(axis=rot_axis).rot90(axis=rot_axis)
ref = vector_data_asymm.flip(axis=flip_axes)
assert_allclose(res, ref, err_msg=f'Unexpected behavior in rot90()! {rot_axis}')
@pytest.mark.parametrize(
'rot_axis', [
'x',
'y',
'z',
]
)
def test_rotate_compare_rot90_1(vector_data_asymmcube, rot_axis):
res = vector_data_asymmcube.rotate(angle=90, axis=rot_axis)
ref = vector_data_asymmcube.rot90(axis=rot_axis)
print("input", vector_data_asymmcube.data)
print("ref", res.data)
print("res", ref.data)
assert_allclose(res, ref, err_msg=f'Unexpected behavior in rotate()! {rot_axis}')
def test_rot90_manual():
data = np.zeros((3, 3, 3, 3))
diag = np.array((1, 1, 1))
diag_unity = diag / np.sqrt(np.sum(diag**2))
data[0, 0, 0] = diag_unity
data = Field(data, 10, vector=True)
print("data", data.data)
rot90_x = np.zeros((3, 3, 3, 3))
# Axis order z, y, x; vector components x, y, z
rot90_x[0, 2, 0] = diag_unity * (1, -1, 1)
rot90_x = Field(rot90_x, 10, vector=True)
print("rot90_x", rot90_x.data)
print("data rot90 x", data.rot90(axis='x').data)
rot90_y = np.zeros((3, 3, 3, 3))
# Axis order z, y, x; vector components x, y, z
rot90_y[2, 0, 0] = diag_unity * (1, 1, -1)
rot90_y = Field(rot90_y, 10, vector=True)
print("rot90_y", rot90_y.data)
print("data rot90 y", data.rot90(axis='y').data)
rot90_z = np.zeros((3, 3, 3, 3))
# Axis order z, y, x; vector components x, y, z
rot90_z[0, 0, 2] = diag_unity * (-1, 1, 1)
rot90_z = Field(rot90_z, 10, vector=True)
print("rot90_z", rot90_z.data)
print("data rot90 z", data.rot90(axis='z').data)
assert_allclose(rot90_x, data.rot90(axis='x'), err_msg='Unexpected behavior in rot90("x")!')
assert_allclose(rot90_y, data.rot90(axis='y'), err_msg='Unexpected behavior in rot90("y")!')
assert_allclose(rot90_z, data.rot90(axis='z'), err_msg='Unexpected behavior in rot90("z")!')
def test_rot45_manual():
data = np.zeros((3, 3, 3, 3))
data[0, 0, 0] = (1, 1, 1)
data = Field(data, 10, vector=True)
print("data", data.data)
rot45_x = np.zeros((3, 3, 3, 3))
# Axis order z, y, x; vector components x, y, z
rot45_x[0, 1, 0] = (1, 0, np.sqrt(2))
rot45_x = Field(rot45_x, 10, vector=True)
print("rot45_x", rot45_x.data)
# Disable spline interpolation, use nearest instead
res_rot45_x = data.rotate(45, axis='x', order=0)
print("data rot45 x", res_rot45_x.data)
rot45_y = np.zeros((3, 3, 3, 3))
# Axis order z, y, x; vector components x, y, z
rot45_y[1, 0, 0] = (np.sqrt(2), 1, 0)
rot45_y = Field(rot45_y, 10, vector=True)
print("rot45_y", rot45_y.data)
# Disable spline interpolation, use nearest instead
res_rot45_y = data.rotate(45, axis='y', order=0)
print("data rot45 y", res_rot45_y.data)
rot45_z = np.zeros((3, 3, 3, 3))
# Axis order z, y, x; vector components x, y, z
rot45_z[0, 0, 1] = (0, np.sqrt(2), 1)
rot45_z = Field(rot45_z, 10, vector=True)
print("rot45_z", rot45_z.data)
# Disable spline interpolation, use nearest instead
res_rot45_z = data.rotate(45, axis='z', order=0)
print("data rot45 z", res_rot45_z.data)
assert_allclose(rot45_x, res_rot45_x, err_msg='Unexpected behavior in rotate(45, "x")!')
assert_allclose(rot45_y, res_rot45_y, err_msg='Unexpected behavior in rotate(45, "y")!')
assert_allclose(rot45_z, res_rot45_z, err_msg='Unexpected behavior in rotate(45, "z")!')
def test_rot90_2d_360(vector_data_asymm_2d):
assert_allclose(vector_data_asymm_2d.rot90().rot90().rot90().rot90(), vector_data_asymm_2d,
err_msg='Unexpected behavior in 2D rot90()!')
def test_rot90_2d_180(vector_data_asymm_2d):
res = vector_data_asymm_2d.rot90().rot90()
ref = vector_data_asymm_2d.flip()
assert_allclose(res, ref, err_msg='Unexpected behavior in 2D rot90()!')
@pytest.mark.parametrize(
'k', [0, 1, 2, 3, 4]
)
def test_rot90_comp_2d_with_3d(vector_data_asymm_2d, k):
data_x, data_y = [comp.data[np.newaxis, :, :] for comp in vector_data_asymm_2d.comp]
data_z = np.zeros_like(data_x)
data_3d = np.stack([data_x, data_y, data_z], axis=-1)
vector_data_asymm_3d = Field(data_3d, scale=10, vector=True)
print(f'2D shape, scale: {vector_data_asymm_2d.shape, vector_data_asymm_2d.scale}')
print(f'3D shape, scale: {vector_data_asymm_3d.shape, vector_data_asymm_3d.scale}')
vector_data_rot_2d = vector_data_asymm_2d.rot90(k=k)
vector_data_rot_3d = vector_data_asymm_3d.rot90(k=k, axis='z')
print(f'2D shape after rot: {vector_data_rot_2d.shape}')
print(f'3D shape after rot: {vector_data_rot_3d.shape}')
assert_allclose(vector_data_rot_2d, vector_data_rot_3d[0, :, :, :2], err_msg='Unexpected behavior in 2D rot90()!')
@pytest.mark.parametrize(
'angle', [90, 45, 23, 11.5]
)
def test_rotate_comp_2d_with_3d(vector_data_asymm_2d, angle):
data_x, data_y = [comp.data[np.newaxis, :, :] for comp in vector_data_asymm_2d.comp]
data_z = np.zeros_like(data_x)
data_3d = np.stack([data_x, data_y, data_z], axis=-1)
vector_data_asymm_3d = Field(data_3d, scale=10, vector=True)
print(f'2D shape, scale: {vector_data_asymm_2d.shape, vector_data_asymm_2d.scale}')
print(f'3D shape, scale: {vector_data_asymm_3d.shape, vector_data_asymm_3d.scale}')
r2d = vector_data_asymm_2d.rotate(angle)
r3d = vector_data_asymm_3d.rotate(angle, axis='z')
print(f'2D shape after rot: {r2d.shape}')
print(f'3D shape after rot: {r3d.shape}')
assert_allclose(r2d, r3d[0, :, :, :2], err_msg='Unexpected behavior in 2D rotate()!')
@pytest.mark.parametrize(
'angle', [180, 360, 90, 45, 23, 11.5],
)
@pytest.mark.parametrize(
'axis', ['x', 'y', 'z'],
)
def test_rotate_scalar(vector_data_asymm, angle, axis):
data = np.zeros((1, 2, 2, 3))
data[0, 0, 0] = 1
field = Field(data, scale=10., vector=True)
print(field)
print(field.amp)
assert_allclose(
field.rotate(angle, axis=axis).amp,
field.amp.rotate(angle, axis=axis)
)
@pytest.mark.parametrize(
'angle,order', [(180, 3), (360, 3), (90, 3), (45, 0), (23, 0), (11.5, 0)],
)
@pytest.mark.parametrize(
'axis', ['x', 'y', 'z'],
)
@pytest.mark.parametrize(
'reshape', [True, False],
)
def test_rotate_scalar_asymm(vector_data_asymm, angle, axis, order, reshape):
assert_allclose(
vector_data_asymm.rotate(angle, axis=axis, reshape=reshape, order=order).amp,
vector_data_asymm.amp.rotate(angle, axis=axis, reshape=reshape, order=order)
)
@pytest.mark.parametrize(
'axis', ['x', 'y', 'z'],
)
@pytest.mark.parametrize(
'k', [0, 1, 2, 3, 4],
)
def test_rot90_scalar(vector_data_asymm, axis, k):
assert_allclose(
vector_data_asymm.amp.rot90(k=k, axis=axis),
vector_data_asymm.rot90(k=k, axis=axis).amp
)
import numpy
import numpy.testing
def assert_allclose(actual, desired, rtol=1e-07, atol=1e-08, equal_nan=True, err_msg='', verbose=True):
return numpy.testing.assert_allclose(
actual=actual,
desired=desired,
rtol=rtol,
atol=atol,
equal_nan=equal_nan,
err_msg=err_msg,
verbose=verbose,
)