Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • empyre/empyre
  • weber/empyre
  • wessels/empyre
  • bryan/empyre
4 results
Show changes
### MATPLOTLIB STYLESHEET FOR SAVING EMPYRE IMAGES AND PLOTS ### MATPLOTLIB STYLESHEET FOR SAVING EMPYRE IMAGES AND PLOTS
text.usetex : True ## use TeX to render text
font.family : serif ## default font family (use serifs) font.family : serif ## default font family (use serifs)
font.serif : cm ## Computer Modern (LaTeX font) font.serif : cm ## Computer Modern (LaTeX font)
......
...@@ -6,10 +6,11 @@ ...@@ -6,10 +6,11 @@
import logging import logging
import warnings
import numpy as np import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap from matplotlib.colors import TwoSlopeNorm
from PIL import Image from PIL import Image
from . import colors from . import colors
...@@ -34,7 +35,7 @@ def imshow(field, axis=None, cmap=None, **kwargs): ...@@ -34,7 +35,7 @@ def imshow(field, axis=None, cmap=None, **kwargs):
Parameters Parameters
---------- ----------
field : `Field` or ndarray field : `Field` or ndarray
The image data as a `Field` or a numpy array (in the latter case, `vector=False` and `scale=1` are assumed). The image data as a `Field` or a numpy array (in the latter case, `vector=False` and `scale=1.0` are assumed).
axis : `matplotlib.axes.Axes` object, optional axis : `matplotlib.axes.Axes` object, optional
The axis to which the image should be added, by default None, which will pick the last use axis via `gca`. The axis to which the image should be added, by default None, which will pick the last use axis via `gca`.
cmap : str or `matplotlib.colors.Colormap`, optional cmap : str or `matplotlib.colors.Colormap`, optional
...@@ -50,14 +51,16 @@ def imshow(field, axis=None, cmap=None, **kwargs): ...@@ -50,14 +51,16 @@ def imshow(field, axis=None, cmap=None, **kwargs):
Notes Notes
----- -----
Additional kwargs are passed to `matplotlib.pyplot.imshow`. Additional kwargs are passed to :meth:`~matplotlib.pyplot.imshow`.
Note that the y-axis of the plot is flipped in comparison to :meth:`~matplotlib.pyplot.imshow`, i.e. that the
origin is `'lower'` in this case instead of `'upper'`.
Uses the `empyre-image` stylesheet settings for plotting (and axis creation if none exists, yet). Uses the `empyre-image` stylesheet settings for plotting (and axis creation if none exists, yet).
Fields are squeezed before plotting, so non-2D fields work as long as their superfluous dimensions have length 1. Fields are squeezed before plotting, so non-2D fields work as long as their superfluous dimensions have length 1.
""" """
_log.debug('Calling imshow') _log.debug('Calling imshow')
if not isinstance(field, Field): # Try to convert input to Field if it is not already one: if not isinstance(field, Field): # Try to convert input to Field if it is not already one:
field = Field(data=np.asarray(field), scale=1, vector=False) field = Field(data=np.asarray(field), scale=1.0, vector=False)
assert not field.vector, 'Can only plot scalar fields!' assert not field.vector, 'Can only plot scalar fields!'
# Get squeezed data and make sure it's 2D scalar: # Get squeezed data and make sure it's 2D scalar:
squeezed_field = field.squeeze() squeezed_field = field.squeeze()
...@@ -73,21 +76,10 @@ def imshow(field, axis=None, cmap=None, **kwargs): ...@@ -73,21 +76,10 @@ def imshow(field, axis=None, cmap=None, **kwargs):
elif isinstance(cmap, str): # make sure we have a Colormap object (and not a string): elif isinstance(cmap, str): # make sure we have a Colormap object (and not a string):
cmap = plt.get_cmap(cmap) cmap = plt.get_cmap(cmap)
if cmap.name.replace('_r', '') in DIVERGING_CMAPS: # 'replace' also matches reverted cmaps! if cmap.name.replace('_r', '') in DIVERGING_CMAPS: # 'replace' also matches reverted cmaps!
# Symmetric colormap only has zero at symmetry point if mappable has symmetric bounds (from - to + limit)! kwargs.setdefault('norm', TwoSlopeNorm(0)) # Diverging colormap should have zero at the symmetry point!
vmin = kwargs.get('vmin', squeezed_field.data.min()-1E-30)
vmax = kwargs.get('vmax', squeezed_field.data.max()+1E-30)
vmin, vmax = np.min([vmin, 0]), np.max([0, vmax]) # Ensure zero is present!
kwargs.setdefault('vmin', vmin)
kwargs.setdefault('vmax', vmax)
limit = np.max(np.abs([vmin, vmax]))
# Calculate the subset of colors for the range vmin to vmax (often not full range of 2*limit):
start = 0.5 + vmin/(2*limit) # 0 for symmetric bounds, >0: unused colors at lower end!
end = 0.5 + vmax/(2*limit) # 1 for symmetric bounds, <1: unused colors at upper end!
cmap_colors = cmap(np.linspace(start, end, 256)) # New color indices with symmetry color at 0.5!
# Use calculated colors to create custom (asymmetric) colormap with symmetry color (often white) at zero:
cmap = LinearSegmentedColormap.from_list('custom', cmap_colors)
# Set extent in data coordinates (left, right, bottom, top) to kwargs (if not set explicitely): # Set extent in data coordinates (left, right, bottom, top) to kwargs (if not set explicitely):
dim_v, dim_u, s_v, s_u = *squeezed_field.dim, *squeezed_field.scale dim_v, dim_u = squeezed_field.dim
s_v, s_u = squeezed_field.scale
kwargs.setdefault('extent', (0, dim_u * s_u, 0, dim_v * s_v)) kwargs.setdefault('extent', (0, dim_u * s_u, 0, dim_v * s_v))
# Plot with the empyre style context: # Plot with the empyre style context:
with use_style('empyre-image'): # Only works on axes created WITHIN context! with use_style('empyre-image'): # Only works on axes created WITHIN context!
...@@ -102,7 +94,7 @@ def contour(field, axis=None, **kwargs): ...@@ -102,7 +94,7 @@ def contour(field, axis=None, **kwargs):
Parameters Parameters
---------- ----------
field : `Field` or ndarray field : `Field` or ndarray
The contour data as a `Field` or a numpy array (in the latter case, `vector=False` and `scale=1` are assumed). The contour data as a `Field` or a numpy array (in the latter case, `vector=False` and `scale=1.0` are assumed).
axis : `matplotlib.axes.Axes` object, optional axis : `matplotlib.axes.Axes` object, optional
The axis to which the contour should be added, by default None, which will pick the last use axis via `gca`. The axis to which the contour should be added, by default None, which will pick the last use axis via `gca`.
...@@ -114,13 +106,15 @@ def contour(field, axis=None, **kwargs): ...@@ -114,13 +106,15 @@ def contour(field, axis=None, **kwargs):
Notes Notes
----- -----
Additional kwargs are passed to `matplotlib.pyplot.contour`. Additional kwargs are passed to `matplotlib.pyplot.contour`.
Note that the y-axis of the plot is flipped in comparison to :meth:`~matplotlib.pyplot.imshow`, i.e. that the
origin is `'lower'` in this case instead of `'upper'`.
Uses the `empyre-image` stylesheet settings for plotting (and axis creation if none exists, yet). Uses the `empyre-image` stylesheet settings for plotting (and axis creation if none exists, yet).
Fields are squeezed before plotting, so non-2D fields work as long as their superfluous dimensions have length 1. Fields are squeezed before plotting, so non-2D fields work as long as their superfluous dimensions have length 1.
""" """
_log.debug('Calling contour') _log.debug('Calling contour')
if not isinstance(field, Field): # Try to convert input to Field if it is not already one: if not isinstance(field, Field): # Try to convert input to Field if it is not already one:
field = Field(data=np.asarray(field), scale=1, vector=False) field = Field(data=np.asarray(field), scale=1.0, vector=False)
assert not field.vector, 'Can only plot scalar fields!' assert not field.vector, 'Can only plot scalar fields!'
# Get squeezed data and make sure it's 2D scalar: # Get squeezed data and make sure it's 2D scalar:
squeezed_field = field.squeeze() squeezed_field = field.squeeze()
...@@ -152,7 +146,7 @@ def colorvec(field, axis=None, **kwargs): ...@@ -152,7 +146,7 @@ def colorvec(field, axis=None, **kwargs):
Parameters Parameters
---------- ----------
field : `Field` or ndarray field : `Field` or ndarray
The image data as a `Field` or a numpy array (in the latter case, `vector=True` and `scale=1` are assumed). The image data as a `Field` or a numpy array (in the latter case, `vector=True` and `scale=1.0` are assumed).
axis : `matplotlib.axes.Axes` object, optional axis : `matplotlib.axes.Axes` object, optional
The axis to which the image should be added, by default None, which will pick the last use axis via `gca`. The axis to which the image should be added, by default None, which will pick the last use axis via `gca`.
...@@ -164,6 +158,8 @@ def colorvec(field, axis=None, **kwargs): ...@@ -164,6 +158,8 @@ def colorvec(field, axis=None, **kwargs):
Notes Notes
----- -----
Additional kwargs are passed to `matplotlib.pyplot.imshow`. Additional kwargs are passed to `matplotlib.pyplot.imshow`.
Note that the y-axis of the plot is flipped in comparison to :meth:`~matplotlib.pyplot.imshow`, i.e. that the
origin is `'lower'` in this case instead of `'upper'`.
Uses the `empyre-image` stylesheet settings for plotting (and axis creation if none exists, yet). Uses the `empyre-image` stylesheet settings for plotting (and axis creation if none exists, yet).
Fields are squeezed before plotting, so non-2D fields work as long as their superfluous dimensions have length 1. Fields are squeezed before plotting, so non-2D fields work as long as their superfluous dimensions have length 1.
Even though squeezing takes place, `colorvec` "remembers" the original orientation of the slice! This is important Even though squeezing takes place, `colorvec` "remembers" the original orientation of the slice! This is important
...@@ -179,11 +175,9 @@ def colorvec(field, axis=None, **kwargs): ...@@ -179,11 +175,9 @@ def colorvec(field, axis=None, **kwargs):
""" """
_log.debug('Calling colorvec') _log.debug('Calling colorvec')
if not isinstance(field, Field): # Try to convert input to Field if it is not already one: if not isinstance(field, Field): # Try to convert input to Field if it is not already one:
field = Field(data=np.asarray(field), scale=1, vector=True) field = Field(data=np.asarray(field), scale=1.0, vector=True)
assert field.vector, 'Can only plot vector fields!' assert field.vector, 'Can only plot vector fields!'
assert len(field.dim) <= 3, 'Unusable for vector fields with dimension higher than 3!' assert len(field.dim) <= 3, 'Unusable for vector fields with dimension higher than 3!'
assert len(field.dim) == field.ncomp, ('Assignment of vector components to dimensions is ambiguous!'
f'`ncomp` ({field.ncomp}) must match `len(dim)` ({len(field.dim)})')
# Get squeezed data and make sure it's 2D scalar: # Get squeezed data and make sure it's 2D scalar:
squeezed_field = field.squeeze() squeezed_field = field.squeeze()
assert len(squeezed_field.dim) == 2, 'Cannot plot more than 2 dimensions!' assert len(squeezed_field.dim) == 2, 'Cannot plot more than 2 dimensions!'
...@@ -193,9 +187,13 @@ def colorvec(field, axis=None, **kwargs): ...@@ -193,9 +187,13 @@ def colorvec(field, axis=None, **kwargs):
y_comp = comp[1] y_comp = comp[1]
z_comp = comp[2] if (squeezed_field.ncomp == 3) else np.zeros(squeezed_field.dim) z_comp = comp[2] if (squeezed_field.ncomp == 3) else np.zeros(squeezed_field.dim)
# Calculate image with color encoded directions: # Calculate image with color encoded directions:
rgb = colors.CMAP_CIRCULAR_DEFAULT.rgb_from_vector(np.stack((x_comp, y_comp, z_comp), axis=0)) cmap = kwargs.pop('cmap', None)
if cmap is None:
cmap = colors.cmaps.cyclic_cubehelix
rgb = cmap.rgb_from_vector(np.stack((x_comp, y_comp, z_comp), axis=0))
# Set extent in data coordinates (left, right, bottom, top) to kwargs (if not set explicitely): # Set extent in data coordinates (left, right, bottom, top) to kwargs (if not set explicitely):
dim_v, dim_u, s_v, s_u = *squeezed_field.dim, *squeezed_field.scale dim_v, dim_u = squeezed_field.dim
s_v, s_u = squeezed_field.scale
kwargs.setdefault('extent', (0, dim_u * s_u, 0, dim_v * s_v)) kwargs.setdefault('extent', (0, dim_u * s_u, 0, dim_v * s_v))
# Plot with the empyre style context: # Plot with the empyre style context:
with use_style('empyre-image'): # Only works on axes created WITHIN context! with use_style('empyre-image'): # Only works on axes created WITHIN context!
...@@ -210,7 +208,7 @@ def cosine_contours(field, axis=None, gain='auto', cmap=None, **kwargs): ...@@ -210,7 +208,7 @@ def cosine_contours(field, axis=None, gain='auto', cmap=None, **kwargs):
Parameters Parameters
---------- ----------
field : `Field` or ndarray field : `Field` or ndarray
The contour data as a `Field` or a numpy array (in the latter case, `vector=False` and `scale=1` are assumed). The contour data as a `Field` or a numpy array (in the latter case, `vector=False` and `scale=1.0` are assumed).
axis : `matplotlib.axes.Axes` object, optional axis : `matplotlib.axes.Axes` object, optional
The axis to which the contour should be added, by default None, which will pick the last use axis via `gca`. The axis to which the contour should be added, by default None, which will pick the last use axis via `gca`.
gain : float or 'auto', optional gain : float or 'auto', optional
...@@ -235,7 +233,7 @@ def cosine_contours(field, axis=None, gain='auto', cmap=None, **kwargs): ...@@ -235,7 +233,7 @@ def cosine_contours(field, axis=None, gain='auto', cmap=None, **kwargs):
""" """
_log.debug('Calling cosine_contours') _log.debug('Calling cosine_contours')
if not isinstance(field, Field): # Try to convert input to Field if it is not already one: if not isinstance(field, Field): # Try to convert input to Field if it is not already one:
field = Field(data=np.asarray(field), scale=1, vector=False) field = Field(data=np.asarray(field), scale=1.0, vector=False)
assert not field.vector, 'Can only plot scalar fields!' assert not field.vector, 'Can only plot scalar fields!'
# Get squeezed data and make sure it's 2D scalar: # Get squeezed data and make sure it's 2D scalar:
squeezed_field = field.squeeze() squeezed_field = field.squeeze()
...@@ -253,7 +251,8 @@ def cosine_contours(field, axis=None, gain='auto', cmap=None, **kwargs): ...@@ -253,7 +251,8 @@ def cosine_contours(field, axis=None, gain='auto', cmap=None, **kwargs):
contours += 1 # Shift to positive values contours += 1 # Shift to positive values
contours /= 2 # Rescale to [0, 1] contours /= 2 # Rescale to [0, 1]
# Set extent in data coordinates (left, right, bottom, top) to kwargs (if not set explicitely): # Set extent in data coordinates (left, right, bottom, top) to kwargs (if not set explicitely):
dim_v, dim_u, s_v, s_u = *squeezed_field.dim, *squeezed_field.scale dim_v, dim_u = squeezed_field.dim
s_v, s_u = squeezed_field.scale
kwargs.setdefault('extent', (0, dim_u * s_u, 0, dim_v * s_v)) kwargs.setdefault('extent', (0, dim_u * s_u, 0, dim_v * s_v))
# Plot with the empyre style context: # Plot with the empyre style context:
with use_style('empyre-image'): # Only works on axes created WITHIN context! with use_style('empyre-image'): # Only works on axes created WITHIN context!
...@@ -268,7 +267,7 @@ def quiver(field, axis=None, color_angles=False, cmap=None, n_bin='auto', bin_wi ...@@ -268,7 +267,7 @@ def quiver(field, axis=None, color_angles=False, cmap=None, n_bin='auto', bin_wi
Parameters Parameters
---------- ----------
field : `Field` or ndarray field : `Field` or ndarray
The vector data as a `Field` or a numpy array (in the latter case, `vector=True` and `scale=1` are assumed). The vector data as a `Field` or a numpy array (in the latter case, `vector=True` and `scale=1.0` are assumed).
axis : `matplotlib.axes.Axes` object, optional axis : `matplotlib.axes.Axes` object, optional
The axis to which the image should be added, by default None, which will pick the last use axis via `gca`. The axis to which the image should be added, by default None, which will pick the last use axis via `gca`.
color_angles : bool, optional color_angles : bool, optional
...@@ -306,11 +305,13 @@ def quiver(field, axis=None, color_angles=False, cmap=None, n_bin='auto', bin_wi ...@@ -306,11 +305,13 @@ def quiver(field, axis=None, color_angles=False, cmap=None, n_bin='auto', bin_wi
""" """
_log.debug('Calling quiver') _log.debug('Calling quiver')
if not isinstance(field, Field): # Try to convert input to Field if it is not already one: if not isinstance(field, Field): # Try to convert input to Field if it is not already one:
field = Field(data=np.asarray(field), scale=1, vector=True) field = Field(data=np.asarray(field), scale=1.0, vector=True)
assert field.vector, 'Can only plot vector fields!' assert field.vector, 'Can only plot vector fields!'
assert len(field.dim) <= 3, 'Unusable for vector fields with dimension higher than 3!' assert len(field.dim) <= 3, 'Unusable for vector fields with dimension higher than 3!'
assert len(field.dim) == field.ncomp, ('Assignment of vector components to dimensions is ambiguous!' if len(field.dim) < field.ncomp:
f'`ncomp` ({field.ncomp}) must match `len(dim)` ({len(field.dim)})') warnings.warn('Assignment of vector components to dimensions is ambiguous!'
f'`ncomp` ({field.ncomp}) should match `len(dim)` ({len(field.dim)})!'
'If you want to plot a slice of a 3D volume, make sure to use `from:to` notation!')
# Get squeezed data and make sure it's 2D scalar: # Get squeezed data and make sure it's 2D scalar:
squeezed_field = field.squeeze() squeezed_field = field.squeeze()
assert len(squeezed_field.dim) == 2, 'Cannot plot more than 2 dimensions (Squeezing did not help)!' assert len(squeezed_field.dim) == 2, 'Cannot plot more than 2 dimensions (Squeezing did not help)!'
...@@ -339,7 +340,7 @@ def quiver(field, axis=None, color_angles=False, cmap=None, n_bin='auto', bin_wi ...@@ -339,7 +340,7 @@ def quiver(field, axis=None, color_angles=False, cmap=None, n_bin='auto', bin_wi
if color_angles: # Color angles according to calculated RGB values (only with circular colormaps): if color_angles: # Color angles according to calculated RGB values (only with circular colormaps):
_log.debug('Encoding angles') _log.debug('Encoding angles')
if cmap is None: if cmap is None:
cmap = colors.CMAP_CIRCULAR_DEFAULT cmap = colors.cmaps.cyclic_cubehelix
rgb = cmap.rgb_from_vector(np.asarray((x_comp, y_comp, z_comp))) / 255 rgb = cmap.rgb_from_vector(np.asarray((x_comp, y_comp, z_comp))) / 255
rgba = np.concatenate((rgb, amplitude[..., None]), axis=-1) rgba = np.concatenate((rgb, amplitude[..., None]), axis=-1)
kwargs.setdefault('color', rgba.reshape(-1, 4)) kwargs.setdefault('color', rgba.reshape(-1, 4))
......
...@@ -109,7 +109,7 @@ def mask3d(field, title='Mask', threshold=0, grid=True, labels=True, ...@@ -109,7 +109,7 @@ def mask3d(field, title='Mask', threshold=0, grid=True, labels=True,
return cont return cont
def quiver3d(field, title='Vector Field', limit=None, cmap='jet', mode='2darrow', def quiver3d(field, title='Vector Field', limit=None, cmap=None, mode='2darrow',
coloring='angle', ar_dens=1, opacity=1.0, grid=True, labels=True, coloring='angle', ar_dens=1, opacity=1.0, grid=True, labels=True,
orientation=True, size=(700, 750), new_fig=True, view='isometric', orientation=True, size=(700, 750), new_fig=True, view='isometric',
position=None, bgcolor=(0.5, 0.5, 0.5)): position=None, bgcolor=(0.5, 0.5, 0.5)):
...@@ -122,7 +122,8 @@ def quiver3d(field, title='Vector Field', limit=None, cmap='jet', mode='2darrow' ...@@ -122,7 +122,8 @@ def quiver3d(field, title='Vector Field', limit=None, cmap='jet', mode='2darrow'
limit : float, optional limit : float, optional
Plotlimit for the vector field arrow length used to scale the colormap. Plotlimit for the vector field arrow length used to scale the colormap.
cmap : string, optional cmap : string, optional
String describing the colormap which is used for amplitude encoding (default is 'jet'). String describing the colormap which is used for color encoding (uses `~.colors.cmaps.cyclic_cubehelix` if
left on the `None` default) or amplitude encoding (uses 'jet' if left on the `None` default).
ar_dens: int, optional ar_dens: int, optional
Number defining the arrow density which is plotted. A higher ar_dens number skips more Number defining the arrow density which is plotted. A higher ar_dens number skips more
arrows (a number of 2 plots every second arrow). Default is 1. arrows (a number of 2 plots every second arrow). Default is 1.
...@@ -165,13 +166,17 @@ def quiver3d(field, title='Vector Field', limit=None, cmap='jet', mode='2darrow' ...@@ -165,13 +166,17 @@ def quiver3d(field, title='Vector Field', limit=None, cmap='jet', mode='2darrow'
vecs = mlab.quiver3d(xxx, yyy, zzz, x_mag, y_mag, z_mag, mode=mode, opacity=opacity, vecs = mlab.quiver3d(xxx, yyy, zzz, x_mag, y_mag, z_mag, mode=mode, opacity=opacity,
scalars=np.arange(len(xxx)), line_width=2) scalars=np.arange(len(xxx)), line_width=2)
vector = np.asarray((x_mag.ravel(), y_mag.ravel(), z_mag.ravel())) vector = np.asarray((x_mag.ravel(), y_mag.ravel(), z_mag.ravel()))
rgb = colors.CMAP_CIRCULAR_DEFAULT.rgb_from_vector(vector) if cmap is None:
cmap = colors.cmaps.cyclic_cubehelix
rgb = cmap.rgb_from_vector(vector)
rgba = np.hstack((rgb, 255 * np.ones((len(xxx), 1), dtype=np.uint8))) rgba = np.hstack((rgb, 255 * np.ones((len(xxx), 1), dtype=np.uint8)))
vecs.glyph.color_mode = 'color_by_scalar' vecs.glyph.color_mode = 'color_by_scalar'
vecs.module_manager.scalar_lut_manager.lut.table = rgba vecs.module_manager.scalar_lut_manager.lut.table = rgba
mlab.draw() mlab.draw()
elif coloring == 'amplitude': # Encodes the amplitude of the arrows with the jet colormap: elif coloring == 'amplitude': # Encodes the amplitude of the arrows with the jet colormap:
_log.debug('Encoding amplitude') _log.debug('Encoding amplitude')
if cmap is None:
cmap = 'jet'
vecs = mlab.quiver3d(xxx, yyy, zzz, x_mag, y_mag, z_mag, vecs = mlab.quiver3d(xxx, yyy, zzz, x_mag, y_mag, z_mag,
mode=mode, colormap=cmap, opacity=opacity, line_width=2) mode=mode, colormap=cmap, opacity=opacity, line_width=2)
mlab.colorbar(label_fmt='%.2f') mlab.colorbar(label_fmt='%.2f')
......
...@@ -23,7 +23,7 @@ __all__ = ['new', 'savefig', 'calc_figsize', 'use_style', 'copy_mpl_stylesheets' ...@@ -23,7 +23,7 @@ __all__ = ['new', 'savefig', 'calc_figsize', 'use_style', 'copy_mpl_stylesheets'
_log = logging.getLogger(__name__) _log = logging.getLogger(__name__)
def new(nrows=1, ncols=1, mode='image', figsize=None, textwidth=None, width_scale=1, aspect=None, **kwargs): def new(nrows=1, ncols=1, mode='image', figsize=None, textwidth=None, width_scale=1.0, aspect=None, **kwargs):
R"""Convenience function for the creation of a new subplot grid (wraps `~matplotlib.pyplot.subplots`). R"""Convenience function for the creation of a new subplot grid (wraps `~matplotlib.pyplot.subplots`).
If you use the `textwidth` parameter, plot sizes are fitting into publications with LaTeX. Requires two stylesheets If you use the `textwidth` parameter, plot sizes are fitting into publications with LaTeX. Requires two stylesheets
...@@ -48,8 +48,8 @@ def new(nrows=1, ncols=1, mode='image', figsize=None, textwidth=None, width_scal ...@@ -48,8 +48,8 @@ def new(nrows=1, ncols=1, mode='image', figsize=None, textwidth=None, width_scal
Width and height of the figure in inches, defaults to rcParams["figure.figsize"], which depends on the chosen Width and height of the figure in inches, defaults to rcParams["figure.figsize"], which depends on the chosen
stylesheet. If set, this will overwrite all other following parameters. stylesheet. If set, this will overwrite all other following parameters.
textwidth : float, optional textwidth : float, optional
The textwidth of your LaTeX document in points, which you can get py using "\the\textwidth". If this is not None The textwidth of your LaTeX document in points, which you can get by using :math:`\the\textwidth`. If this is
(the default), this will be used to define the figure size if it is not set explicitely. not None (the default), this will be used to define the figure size if it is not set explicitely.
width_scale : float, optional width_scale : float, optional
Only meaningful if `textwidth` is set. If it is, `width_scale` will be a scaling factor for the figure width. Only meaningful if `textwidth` is set. If it is, `width_scale` will be a scaling factor for the figure width.
Example: if you set this to 0.5, your figure will span half of the textwidth. Default is 1. Example: if you set this to 0.5, your figure will span half of the textwidth. Default is 1.
...@@ -73,22 +73,18 @@ def new(nrows=1, ncols=1, mode='image', figsize=None, textwidth=None, width_scal ...@@ -73,22 +73,18 @@ def new(nrows=1, ncols=1, mode='image', figsize=None, textwidth=None, width_scal
""" """
_log.debug('Calling new') _log.debug('Calling new')
assert mode in ('image', 'plot'), "mode has to be 'image', or 'plot'!" assert mode in ('image', 'plot'), "mode has to be 'image', or 'plot'!"
if figsize is None and textwidth is not None: # Only then is all this necessary: with use_style(f'empyre-{mode}'):
if aspect is None: if figsize is None:
aspect = 'golden' if mode == 'plot' else 1 # Both image modes have 'same' as default'! if aspect is None:
elif isinstance(aspect, Field): aspect = 'golden' if mode == 'plot' else 1 # Both image modes have 'same' as default'!
dim_uv = [d for d in aspect.dim if d != 1] elif isinstance(aspect, Field):
assert len(dim_uv) == 2, f"Couldn't find fields aspect ({len(dim_uv)} squeezed dimensions, has to be of 2)!" dim_uv = [d for d in aspect.dim if d != 1]
aspect = dim_uv[0]/dim_uv[1] # height/width assert len(dim_uv) == 2, f"Couldn't find field aspect ({len(dim_uv)} squeezed dimensions, has to be 2)!"
else: aspect = dim_uv[0]/dim_uv[1] # height/width
assert isinstance(aspect, Number), 'aspect has to be None, a number or a field instance squeezable to 2D!' else:
figsize = calc_figsize(width_scale=width_scale, aspect=aspect, textwidth=textwidth) assert isinstance(aspect, Number), 'aspect has to be None, a number or field instance squeezable to 2D!'
if mode == 'image': figsize = calc_figsize(textwidth=textwidth, width_scale=width_scale, aspect=aspect)
with use_style('empyre-image'): return plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize, **kwargs)
return plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize, **kwargs)
else: # mode == 'plot':
with use_style('empyre-plot'):
return plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize, **kwargs)
def savefig(fname, **kwargs): def savefig(fname, **kwargs):
...@@ -110,13 +106,14 @@ def savefig(fname, **kwargs): ...@@ -110,13 +106,14 @@ def savefig(fname, **kwargs):
plt.savefig(fname, **kwargs) plt.savefig(fname, **kwargs)
def calc_figsize(textwidth, width_scale=1, aspect=1): def calc_figsize(textwidth=None, width_scale=1.0, aspect=1):
R"""Helper function to calculate the figure size from various parameters. Useful for publications via LaTeX. R"""Helper function to calculate the figure size from various parameters. Useful for publications via LaTeX.
Parameters Parameters
---------- ----------
textwidth : float textwidth : float, optional
The textwidth of your LaTeX document in points, which you can get py using "\the\textwidth". The textwidth of your LaTeX document in points, which you can get by using :math:`\the\textwidth`. If this is
None (default), the standard width in inches from the current stylesheet is used.
width_scale : float, optional width_scale : float, optional
Scaling factor for the figure width. Example: if you set this to 0.5, your figure will span half of the Scaling factor for the figure width. Example: if you set this to 0.5, your figure will span half of the
textwidth. Default is 1. textwidth. Default is 1.
...@@ -132,20 +129,23 @@ def calc_figsize(textwidth, width_scale=1, aspect=1): ...@@ -132,20 +129,23 @@ def calc_figsize(textwidth, width_scale=1, aspect=1):
Notes Notes
----- -----
Based on snippet from Florian Winkler. Based on snippet by Florian Winkler.
""" """
_log.debug('Calling calc_figsize') _log.debug('Calling calc_figsize')
GOLDEN_RATIO = (1 + np.sqrt(5)) / 2 # Aesthetic ratio! GOLDEN_RATIO = (1 + np.sqrt(5)) / 2 # Aesthetic ratio!
INCHES_PER_POINT = 1.0 / 72.27 # Convert points to inch, LaTeX constant, apparently... INCHES_PER_POINT = 1.0 / 72.27 # Convert points to inch, LaTeX constant, apparently...
textwidth_in = textwidth * INCHES_PER_POINT # Width of the text in inches if textwidth is not None:
textwidth_in = textwidth * INCHES_PER_POINT # Width of the text in inches
else: # If textwidth is not given, use the default from rcParams:
textwidth_in = mpl.rcParams["figure.figsize"][0]
fig_width = textwidth_in * width_scale # Width in inches fig_width = textwidth_in * width_scale # Width in inches
if aspect == 'golden': if aspect == 'golden':
fig_height = fig_width / GOLDEN_RATIO fig_height = fig_width / GOLDEN_RATIO
elif isinstance(aspect, Number): elif isinstance(aspect, Number):
fig_height = textwidth_in * aspect fig_height = textwidth_in * aspect
else: else:
raise ValueError(f"aspect has to be either a number, 'same' or 'golden'! Was {aspect}!") raise ValueError(f"aspect has to be either a number, or 'golden'! Was {aspect}!")
fig_size = [fig_width, fig_height] # Both in inches fig_size = [fig_width, fig_height] # Both in inches
return fig_size return fig_size
......
import os
import pytest
import numpy as np
from empyre.fields import Field
@pytest.fixture
def fielddata_path():
return os.path.join(os.path.dirname(os.path.realpath(__file__)), 'test_fielddata')
@pytest.fixture
def vector_data():
magnitude = np.zeros((4, 4, 4, 3))
magnitude[1:-1, 1:-1, 1:-1] = 1
return Field(magnitude, 10.0, vector=True)
@pytest.fixture
def vector_data_asymm():
shape = (5, 7, 11, 3)
data = np.linspace(0, 1, np.prod(shape))
return Field(data.reshape(shape), 10.0, vector=True)
@pytest.fixture
def vector_data_asymm_2d():
shape = (5, 7, 2)
data = np.linspace(0, 1, np.prod(shape))
return Field(data.reshape(shape), 10.0, vector=True)
@pytest.fixture
def vector_data_asymmcube():
shape = (3, 3, 3, 3)
data = np.linspace(0, 1, np.prod(shape))
return Field(data.reshape(shape), 10.0, vector=True)
@pytest.fixture
def scalar_data():
magnitude = np.zeros((4, 4, 4))
magnitude[1:-1, 1:-1, 1:-1] = 1
return Field(magnitude, 10.0, vector=False)
@pytest.fixture
def scalar_data_asymm():
shape = (5, 7, 2)
data = np.linspace(0, 1, np.prod(shape))
return Field(data.reshape(shape), 10.0, vector=False)
# -*- coding: utf-8 -*-
"""Testcase for the magdata module."""
import pytest
from numbers import Number
import numpy as np
import numpy.testing
from empyre.fields import Field
from utils import assert_allclose
def test_copy(vector_data):
vector_data = vector_data.copy()
# Make sure it is a new object
assert vector_data != vector_data, 'Unexpected behaviour in copy()!'
assert np.allclose(vector_data, vector_data)
def test_bin(vector_data):
binned_data = vector_data.bin(2)
reference = 1 / 8. * np.ones((2, 2, 2, 3))
assert_allclose(binned_data, reference,
err_msg='Unexpected behavior in scale_down()!')
assert_allclose(binned_data.scale, (20, 20, 20),
err_msg='Unexpected behavior in scale_down()!')
def test_zoom(vector_data):
zoomed_test = vector_data.zoom(2, order=0)
reference = np.zeros((8, 8, 8, 3))
reference[2:6, 2:6, 2:6] = 1
assert_allclose(zoomed_test, reference,
err_msg='Unexpected behavior in zoom()!')
assert_allclose(zoomed_test.scale, (5, 5, 5),
err_msg='Unexpected behavior in zoom()!')
@pytest.mark.parametrize(
'mode', [
'constant',
'edge',
'wrap'
]
)
@pytest.mark.parametrize(
'pad_width,np_pad', [
(1, ((1, 1), (1, 1), (1, 1), (0, 0))),
((1, 2, 3), ((1, 1), (2, 2), (3, 3), (0, 0))),
(((1, 2), (3, 4), (5, 6)), ((1, 2), (3, 4), (5, 6), (0, 0)))
]
)
def test_pad(vector_data, mode, pad_width, np_pad):
magdata_test = vector_data.pad(pad_width, mode=mode)
reference = np.pad(vector_data, np_pad, mode=mode)
assert_allclose(magdata_test, reference,
err_msg='Unexpected behavior in pad()!')
@pytest.mark.parametrize(
'axis', [-1, 3]
)
def test_component_reduction(vector_data, axis):
# axis=-1 is supposed to reduce over the component dimension, if it exists. axis=3 should do the same here!
res = np.sum(vector_data, axis=axis)
ref = np.zeros((4, 4, 4))
ref[1:-1, 1:-1, 1:-1] = 3
assert res.shape == ref.shape, 'Shape mismatch!'
assert_allclose(res, ref, err_msg="Unexpected behavior of axis keyword")
assert isinstance(res, Field), 'Result is not a Field object!'
assert not res.vector, 'Result is a vector field, but should be reduced to a scalar!'
@pytest.mark.parametrize(
'axis', [(0, 1, 2), (2, 1, 0), None, (-4, -3, -2)]
)
def test_full_reduction(vector_data, axis):
res = np.sum(vector_data, axis=axis)
ref = np.zeros((3,))
ref[:] = 8
assert res.shape == ref.shape
assert_allclose(res, ref, err_msg="Unexpected behavior of full or default reduction")
assert isinstance(res, np.ndarray)
@pytest.mark.parametrize(
'axis', [-1, 2]
)
def test_last_reduction_scalar(scalar_data, axis):
# axis=-1 is supposed to reduce over the component dimension if it exists.
# In this case it doesn't!
res = np.sum(scalar_data, axis=axis)
ref = np.zeros((4, 4))
ref[1:-1, 1:-1] = 2
assert res.shape == ref.shape
assert_allclose(res, ref, err_msg="Unexpected behavior of axis keyword")
assert isinstance(res, Field)
assert not res.vector
@pytest.mark.parametrize(
'axis', [(0, 1, 2), (2, 1, 0), None, (-1, -2, -3)]
)
def test_full_reduction_scalar(scalar_data, axis):
res = np.sum(scalar_data, axis=axis)
ref = 8
assert res.shape == ()
assert_allclose(res, ref, err_msg="Unexpected behavior of full or default reduction")
assert isinstance(res, Number)
def test_binary_operator_vector_number(vector_data):
res = vector_data + 1
ref = np.ones((4, 4, 4, 3))
ref[1:-1, 1:-1, 1:-1] = 2
assert res.shape == ref.shape
assert_allclose(res, ref, err_msg="Unexpected behavior of addition")
assert isinstance(res, Field)
assert res.vector
def test_binary_operator_vector_scalar(vector_data, scalar_data):
res = vector_data + scalar_data
ref = np.zeros((4, 4, 4, 3))
ref[1:-1, 1:-1, 1:-1] = 2
assert res.shape == ref.shape
assert_allclose(res, ref, err_msg="Unexpected behavior of addition")
assert isinstance(res, Field)
assert res.vector
def test_binary_operator_vector_vector(vector_data):
res = vector_data + vector_data
ref = np.zeros((4, 4, 4, 3))
ref[1:-1, 1:-1, 1:-1] = 2
assert res.shape == ref.shape
assert_allclose(res, ref, err_msg="Unexpected behavior of addition")
assert isinstance(res, Field)
assert res.vector
@pytest.mark.xfail
def test_binary_operator_vector_broadcast(vector_data):
# Broadcasting between vector fields is currently not implemented
second = np.zeros((4, 4, 3))
second[1:-1, 1:-1] = 1
second = Field(second, 10.0, vector=True)
res = vector_data + second
ref = np.zeros((4, 4, 4, 3))
ref[1:-1, 1:-1, 1:-1] = 1
ref[:, 1:-1, 1:-1] += 1
assert res.shape == ref.shape
assert_allclose(res, ref, err_msg="Unexpected behavior of addition")
assert isinstance(res, Field)
assert res.vector
def test_mask(vector_data):
mask = vector_data.mask
reference = np.zeros((4, 4, 4))
reference[1:-1, 1:-1, 1:-1] = True
assert_allclose(mask, reference,
err_msg='Unexpected behavior in mask attribute!')
def test_get_vector(vector_data):
mask = vector_data.mask
vector = vector_data.get_vector(mask)
reference = np.ones(np.sum(mask) * 3)
assert_allclose(vector, reference,
err_msg='Unexpected behavior in get_vector()!')
def test_set_vector(vector_data):
mask = vector_data.mask
vector = 2 * np.ones(np.sum(mask) * 3)
vector_data.set_vector(vector, mask)
reference = np.zeros((4, 4, 4, 3))
reference[1:-1, 1:-1, 1:-1] = 2
assert_allclose(vector_data, reference,
err_msg='Unexpected behavior in set_vector()!')
def test_flip(vector_data_asymm):
field_flipz = vector_data_asymm.flip(0)
field_flipy = vector_data_asymm.flip(1)
field_flipx = vector_data_asymm.flip(2)
field_flipxy = vector_data_asymm.flip((1, 2))
field_flipdefault = vector_data_asymm.flip()
field_flipcomp = vector_data_asymm.flip(-1)
assert_allclose(np.flip(vector_data_asymm.data, axis=0) * [1, 1, -1], field_flipz.data,
err_msg='Unexpected behavior in flip()! (z)')
assert_allclose(np.flip(vector_data_asymm.data, axis=1) * [1, -1, 1], field_flipy.data,
err_msg='Unexpected behavior in flip()! (y)')
assert_allclose(np.flip(vector_data_asymm.data, axis=2) * [-1, 1, 1], field_flipx.data,
err_msg='Unexpected behavior in flip()! (x)')
assert_allclose(np.flip(vector_data_asymm.data, axis=(1, 2)) * [-1, -1, 1], field_flipxy.data,
err_msg='Unexpected behavior in flip()! (xy)')
assert_allclose(np.flip(vector_data_asymm.data, axis=(0, 1, 2)) * [-1, -1, -1], field_flipdefault.data,
err_msg='Unexpected behavior in flip()! (default)')
assert_allclose(np.flip(vector_data_asymm.data, axis=-1) * [1, 1, 1], field_flipcomp.data,
err_msg='Unexpected behavior in flip()! (components)')
def test_unknown_num_of_components():
shape = (5, 7, 7)
data = np.linspace(0, 1, np.prod(shape))
with pytest.raises(AssertionError):
Field(data.reshape(shape), 10.0, vector=True)
def test_repr(vector_data_asymm):
string_repr = repr(vector_data_asymm)
data_str = str(vector_data_asymm.data)
string_ref = f'Field(data={data_str}, scale=(10.0, 10.0, 10.0), vector=True)'
print(f'reference: {string_ref}')
print(f'repr output: {string_repr}')
assert string_repr == string_ref, 'Unexpected behavior in __repr__()!'
def test_str(vector_data_asymm):
string_str = str(vector_data_asymm)
string_ref = 'Field(dim=(5, 7, 11), scale=(10.0, 10.0, 10.0), vector=True, ncomp=3)'
print(f'reference: {string_str}')
print(f'str output: {string_str}')
assert string_str == string_ref, 'Unexpected behavior in __str__()!'
@pytest.mark.parametrize(
"index,t,scale", [
((0, 1, 2), tuple, None),
((0, ), Field, (2., 3.)),
(0, Field, (2., 3.)),
((0, 1, 2, 0), float, None),
((0, 1, 2, 0), float, None),
((..., 0), Field, (1., 2., 3.)),
((0, slice(1, 3), 2), Field, (2.,)),
]
)
def test_getitem(vector_data, index, t, scale):
vector_data.scale = (1., 2., 3.)
data_index = index
res = vector_data[index]
assert_allclose(res, vector_data.data[data_index])
assert isinstance(res, t)
if t is Field:
assert res.scale == scale
def test_from_scalar_field(scalar_data):
sca_x, sca_y, sca_z = [i * scalar_data for i in range(1, 4)]
field_comb = Field.from_scalar_fields([sca_x, sca_y, sca_z])
assert field_comb.vector
assert field_comb.scale == scalar_data.scale
assert_allclose(sca_x, field_comb.comp[0])
assert_allclose(sca_y, field_comb.comp[1])
assert_allclose(sca_z, field_comb.comp[2])
def test_squeeze():
magnitude = np.zeros((4, 1, 4, 3))
field = Field(magnitude, (1., 2., 3.), vector=True)
sq = field.squeeze()
assert sq.shape == (4, 4, 3)
assert sq.dim == (4, 4)
assert sq.scale == (1., 3.)
def test_gradient():
pass
def test_gradient_1d():
pass
def test_curl():
pass
def test_curl_2d():
pass
def test_clip_scalar_noop():
shape = (3, 3, 3)
data = np.linspace(-2, 1, np.prod(shape)).reshape(shape)
field = Field(data, (1., 2., 3.), vector=False)
assert_allclose(field, field.clip())
def test_clip_scalar_minmax():
shape = (3, 3, 3)
data = np.linspace(-2, 1, np.prod(shape)).reshape(shape)
field = Field(data, (1., 2., 3.), vector=False)
assert_allclose(np.clip(data, -1, 0.1), field.clip(vmin=-1, vmax=0.1))
def test_clip_scalar_sigma():
shape = (3, 3, 3)
data = np.linspace(-2, 1, np.prod(shape)).reshape(shape)
data[0, 0, 0] = 1e6
field = Field(data, (1., 2., 3.), vector=False)
# We clip off the one outlier
assert_allclose(np.clip(data, -2, 1), field.clip(sigma=5))
assert field.clip(sigma=5)[0, 0, 0] == 1
def test_clip_scalar_mask():
shape = (3, 3, 3)
data = np.linspace(-2, 1, np.prod(shape)).reshape(shape)
mask = np.zeros(shape, dtype=bool)
mask[0, 0, 0] = True
mask[0, 0, 1] = True
field = Field(data, (1., 2., 3.), vector=False)
assert_allclose(np.clip(data, data[0, 0, 0], data[0, 0, 1]), field.clip(mask=mask))
def test_clip_vector_noop():
shape = (3, 3, 3, 3)
data = np.linspace(-2, 1, np.prod(shape)).reshape(shape)
field = Field(data, (1., 2., 3.), vector=True)
assert_allclose(field, field.clip())
def test_clip_vector_max():
shape = (3, 3, 3, 3)
data = np.linspace(-2, 1, np.prod(shape)).reshape(shape)
field = Field(data, (1., 2., 3.), vector=True)
res = field.clip(vmax=0.1)
assert_allclose(np.max(res.amp), 0.1)
def test_clip_vector_sigma():
shape = (3, 3, 3, 3)
data = np.linspace(-2, 1, np.prod(shape)).reshape(shape)
data[0, 0, 0] = (1e6, 1e6, 1e6)
field = Field(data, (1., 2., 3.), vector=True)
# We clip off the one outlier
res = field.clip(sigma=5)
assert np.max(res.amp) < 1e3
# TODO: HyperSpy would need to be installed for the following tests (slow...):
# def test_from_signal()
# raise NotImplementedError()
#
# def test_to_signal()
# raise NotImplementedError()
import pytest
from utils import assert_allclose
from empyre.fields import Field
import numpy as np
@pytest.mark.parametrize(
'axis', ['x', 'y', 'z']
)
def test_rot90_360(vector_data_asymm, axis):
assert_allclose(vector_data_asymm.rot90(axis=axis).rot90(axis=axis).rot90(axis=axis).rot90(axis=axis),
vector_data_asymm,
err_msg=f'Unexpected behavior in rot90()! {axis}')
@pytest.mark.parametrize(
'rot_axis,flip_axes', [
('x', (0, 1)),
('y', (0, 2)),
('z', (1, 2))
]
)
def test_rot90_180(vector_data_asymm, rot_axis, flip_axes):
res = vector_data_asymm.rot90(axis=rot_axis).rot90(axis=rot_axis)
ref = vector_data_asymm.flip(axis=flip_axes)
assert_allclose(res, ref, err_msg=f'Unexpected behavior in rot90()! {rot_axis}')
@pytest.mark.parametrize(
'rot_axis', [
'x',
'y',
'z',
]
)
def test_rotate_compare_rot90_1(vector_data_asymmcube, rot_axis):
res = vector_data_asymmcube.rotate(angle=90, axis=rot_axis)
ref = vector_data_asymmcube.rot90(axis=rot_axis)
print("input", vector_data_asymmcube.data)
print("ref", res.data)
print("res", ref.data)
assert_allclose(res, ref, err_msg=f'Unexpected behavior in rotate()! {rot_axis}')
def test_rot90_manual():
data = np.zeros((3, 3, 3, 3))
diag = np.array((1, 1, 1))
diag_unity = diag / np.sqrt(np.sum(diag**2))
data[0, 0, 0] = diag_unity
data = Field(data, 10, vector=True)
print("data", data.data)
rot90_x = np.zeros((3, 3, 3, 3))
# Axis order z, y, x; vector components x, y, z
rot90_x[0, 2, 0] = diag_unity * (1, -1, 1)
rot90_x = Field(rot90_x, 10, vector=True)
print("rot90_x", rot90_x.data)
print("data rot90 x", data.rot90(axis='x').data)
rot90_y = np.zeros((3, 3, 3, 3))
# Axis order z, y, x; vector components x, y, z
rot90_y[2, 0, 0] = diag_unity * (1, 1, -1)
rot90_y = Field(rot90_y, 10, vector=True)
print("rot90_y", rot90_y.data)
print("data rot90 y", data.rot90(axis='y').data)
rot90_z = np.zeros((3, 3, 3, 3))
# Axis order z, y, x; vector components x, y, z
rot90_z[0, 0, 2] = diag_unity * (-1, 1, 1)
rot90_z = Field(rot90_z, 10, vector=True)
print("rot90_z", rot90_z.data)
print("data rot90 z", data.rot90(axis='z').data)
assert_allclose(rot90_x, data.rot90(axis='x'), err_msg='Unexpected behavior in rot90("x")!')
assert_allclose(rot90_y, data.rot90(axis='y'), err_msg='Unexpected behavior in rot90("y")!')
assert_allclose(rot90_z, data.rot90(axis='z'), err_msg='Unexpected behavior in rot90("z")!')
def test_rot45_manual():
data = np.zeros((3, 3, 3, 3))
data[0, 0, 0] = (1, 1, 1)
data = Field(data, 10, vector=True)
print("data", data.data)
rot45_x = np.zeros((3, 3, 3, 3))
# Axis order z, y, x; vector components x, y, z
rot45_x[0, 1, 0] = (1, 0, np.sqrt(2))
rot45_x = Field(rot45_x, 10, vector=True)
print("rot45_x", rot45_x.data)
# Disable spline interpolation, use nearest instead
res_rot45_x = data.rotate(45, axis='x', order=0)
print("data rot45 x", res_rot45_x.data)
rot45_y = np.zeros((3, 3, 3, 3))
# Axis order z, y, x; vector components x, y, z
rot45_y[1, 0, 0] = (np.sqrt(2), 1, 0)
rot45_y = Field(rot45_y, 10, vector=True)
print("rot45_y", rot45_y.data)
# Disable spline interpolation, use nearest instead
res_rot45_y = data.rotate(45, axis='y', order=0)
print("data rot45 y", res_rot45_y.data)
rot45_z = np.zeros((3, 3, 3, 3))
# Axis order z, y, x; vector components x, y, z
rot45_z[0, 0, 1] = (0, np.sqrt(2), 1)
rot45_z = Field(rot45_z, 10, vector=True)
print("rot45_z", rot45_z.data)
# Disable spline interpolation, use nearest instead
res_rot45_z = data.rotate(45, axis='z', order=0)
print("data rot45 z", res_rot45_z.data)
assert_allclose(rot45_x, res_rot45_x, err_msg='Unexpected behavior in rotate(45, "x")!')
assert_allclose(rot45_y, res_rot45_y, err_msg='Unexpected behavior in rotate(45, "y")!')
assert_allclose(rot45_z, res_rot45_z, err_msg='Unexpected behavior in rotate(45, "z")!')
def test_rot90_2d_360(vector_data_asymm_2d):
assert_allclose(vector_data_asymm_2d.rot90().rot90().rot90().rot90(), vector_data_asymm_2d,
err_msg='Unexpected behavior in 2D rot90()!')
def test_rot90_2d_180(vector_data_asymm_2d):
res = vector_data_asymm_2d.rot90().rot90()
ref = vector_data_asymm_2d.flip()
assert_allclose(res, ref, err_msg='Unexpected behavior in 2D rot90()!')
@pytest.mark.parametrize(
'k', [0, 1, 2, 3, 4]
)
def test_rot90_comp_2d_with_3d(vector_data_asymm_2d, k):
data_x, data_y = [comp.data[np.newaxis, :, :] for comp in vector_data_asymm_2d.comp]
data_z = np.zeros_like(data_x)
data_3d = np.stack([data_x, data_y, data_z], axis=-1)
vector_data_asymm_3d = Field(data_3d, scale=10, vector=True)
print(f'2D shape, scale: {vector_data_asymm_2d.shape, vector_data_asymm_2d.scale}')
print(f'3D shape, scale: {vector_data_asymm_3d.shape, vector_data_asymm_3d.scale}')
vector_data_rot_2d = vector_data_asymm_2d.rot90(k=k)
vector_data_rot_3d = vector_data_asymm_3d.rot90(k=k, axis='z')
print(f'2D shape after rot: {vector_data_rot_2d.shape}')
print(f'3D shape after rot: {vector_data_rot_3d.shape}')
assert_allclose(vector_data_rot_2d, vector_data_rot_3d[0, :, :, :2], err_msg='Unexpected behavior in 2D rot90()!')
@pytest.mark.parametrize(
'angle', [90, 45, 23, 11.5]
)
def test_rotate_comp_2d_with_3d(vector_data_asymm_2d, angle):
data_x, data_y = [comp.data[np.newaxis, :, :] for comp in vector_data_asymm_2d.comp]
data_z = np.zeros_like(data_x)
data_3d = np.stack([data_x, data_y, data_z], axis=-1)
vector_data_asymm_3d = Field(data_3d, scale=10, vector=True)
print(f'2D shape, scale: {vector_data_asymm_2d.shape, vector_data_asymm_2d.scale}')
print(f'3D shape, scale: {vector_data_asymm_3d.shape, vector_data_asymm_3d.scale}')
r2d = vector_data_asymm_2d.rotate(angle)
r3d = vector_data_asymm_3d.rotate(angle, axis='z')
print(f'2D shape after rot: {r2d.shape}')
print(f'3D shape after rot: {r3d.shape}')
assert_allclose(r2d, r3d[0, :, :, :2], err_msg='Unexpected behavior in 2D rotate()!')
@pytest.mark.parametrize(
'angle', [180, 360, 90, 45, 23, 11.5],
)
@pytest.mark.parametrize(
'axis', ['x', 'y', 'z'],
)
def test_rotate_scalar(vector_data_asymm, angle, axis):
data = np.zeros((1, 2, 2, 3))
data[0, 0, 0] = 1
field = Field(data, scale=10., vector=True)
print(field)
print(field.amp)
assert_allclose(
field.rotate(angle, axis=axis).amp,
field.amp.rotate(angle, axis=axis)
)
@pytest.mark.parametrize(
'angle,order', [(180, 3), (360, 3), (90, 3), (45, 0), (23, 0), (11.5, 0)],
)
@pytest.mark.parametrize(
'axis', ['x', 'y', 'z'],
)
@pytest.mark.parametrize(
'reshape', [True, False],
)
def test_rotate_scalar_asymm(vector_data_asymm, angle, axis, order, reshape):
assert_allclose(
vector_data_asymm.rotate(angle, axis=axis, reshape=reshape, order=order).amp,
vector_data_asymm.amp.rotate(angle, axis=axis, reshape=reshape, order=order)
)
@pytest.mark.parametrize(
'axis', ['x', 'y', 'z'],
)
@pytest.mark.parametrize(
'k', [0, 1, 2, 3, 4],
)
def test_rot90_scalar(vector_data_asymm, axis, k):
assert_allclose(
vector_data_asymm.amp.rot90(k=k, axis=axis),
vector_data_asymm.rot90(k=k, axis=axis).amp
)
# -*- coding: utf-8 -*-
"""Testcase for the magdata module."""
# import os
import unittest
# import numpy as np
# from numpy.testing import assert_allclose
# from empyre.fields.field import Field
class TestCaseField(unittest.TestCase):
def test_stupid(self):
print('FOUND')
assert True
def test_simple():
print('METHOD')
assert True
import numpy
import numpy.testing
def assert_allclose(actual, desired, rtol=1e-07, atol=1e-08, equal_nan=True, err_msg='', verbose=True):
return numpy.testing.assert_allclose(
actual=actual,
desired=desired,
rtol=rtol,
atol=atol,
equal_nan=equal_nan,
err_msg=err_msg,
verbose=verbose,
)