Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • empyre/empyre
  • weber/empyre
  • wessels/empyre
  • bryan/empyre
4 results
Show changes
Showing
with 2314 additions and 0 deletions
# -*- coding: utf-8 -*-
# Copyright 2020 by Forschungszentrum Juelich GmbH
# Author: J. Caron
#
"""Subpackage containing EMPyRe IO functionality for the Field class."""
from . import llg, numpy, ovf, tec, text, vtk
plugin_list = [llg, numpy, ovf, tec, text, vtk]
__all__ = ['plugin_list']
# -*- coding: utf-8 -*-
# Copyright 2020 by Forschungszentrum Juelich GmbH
# Author: J. Caron
#
"""IO plugin for LLG format."""
import logging
import numpy as np
from ...fields.field import Field
_log = logging.getLogger(__name__)
file_extensions = ('.llg',) # Recognised file extensions
def reader(filename, scale=None, vector=None, **kwargs):
_log.debug('Call reader')
if vector is None:
vector = True
assert vector is True, 'Only vector fields can be read from the llg file format!'
SCALE = 1.0E-9 / 1.0E-2 # From cm to nm
data_columns = np.genfromtxt(filename, skip_header=2)
dim = tuple(np.genfromtxt(filename, dtype=int, skip_header=1, skip_footer=len(data_columns[:, 0])))
if scale is None: # Otherwise overwrite!
stride_x = 1 # x varies fastest
stride_y = dim[2] # one y step for dim[2] x steps
stride_z = dim[1] * dim[2] # one z step for one x-y layer (dim[1]*dim[2])
scale_x = (data_columns[stride_x, 0] - data_columns[0, 0]) / SCALE # first column varies in x
scale_y = (data_columns[stride_y, 1] - data_columns[0, 1]) / SCALE # second column varies in y
scale_z = (data_columns[stride_z, 2] - data_columns[0, 2]) / SCALE # third column varies in z
scale = (scale_z, scale_y, scale_x)
x_mag, y_mag, z_mag = data_columns[:, 3:6].T
data = np.stack((x_mag.reshape(dim), y_mag.reshape(dim), z_mag.reshape(dim)), axis=-1)
return Field(data, scale, vector=True)
def writer(filename, field, **kwargs):
_log.debug('Call writer')
assert field.vector and len(field.dim) == 3, 'Only 3D vector fields can be saved to the llg file format!'
SCALE = 1.0E-9 / 1.0E-2 # from nm to cm
# Create 3D meshgrid and reshape it and the field into a list where x varies first:
zzz, yyy, xxx = (np.indices(field.dim) + 0.5) * np.reshape(field.scale, (3, 1, 1, 1)) * SCALE # broadcast shape!
z_coord, y_coord, x_coord = np.ravel(zzz), np.ravel(yyy), np.ravel(xxx) # Turn into vectors!
x_comp, y_comp, z_comp = field.comp # Extract scalar field components!
x_vec, y_vec, z_vec = np.ravel(x_comp.data), np.ravel(y_comp.data), np.ravel(z_comp.data) # Turn into vectors!
data = np.array([x_coord, y_coord, z_coord, x_vec, y_vec, z_vec]).T
# Save data to file:
with open(filename, 'w') as mag_file:
mag_file.write('LLGFileCreator: EMPyRe vector Field\n')
mag_file.write(' {:d} {:d} {:d}\n'.format(*field.dim))
mag_file.writelines('\n'.join(' '.join('{:7.6e}'.format(cell) for cell in row) for row in data))
# -*- coding: utf-8 -*-
# Copyright 2020 by Forschungszentrum Juelich GmbH
# Author: J. Caron
#
"""IO plugin for the numpy format."""
import logging
import numpy as np
from ...fields.field import Field
_log = logging.getLogger(__name__)
file_extensions = ('.npy', '.npz') # Recognised file extensions
def reader(filename, scale=None, vector=None, **kwargs):
_log.debug('Call reader')
if vector is None:
vector = False
if scale is None:
scale = 1.0
return Field(np.load(filename, **kwargs), scale, vector)
def writer(filename, field, **kwargs):
_log.debug('Call writer')
np.save(filename, field.data, **kwargs)
# -*- coding: utf-8 -*-
# Copyright 2020 by Forschungszentrum Juelich GmbH
# Author: J. Caron
#
"""IO plugin for simple text format."""
import logging
import os
import numpy as np
from ...fields.field import Field
_log = logging.getLogger(__name__)
file_extensions = ('.ovf', '.omf', '.ohf', 'obf') # Recognised file extensions
def reader(filename, scale=None, vector=None, segment=None, **kwargs):
"""More info at:
http://math.nist.gov/oommf/doc/userguide11b2/userguide/vectorfieldformat.html
http://math.nist.gov/oommf/doc/userguide12a5/userguide/OVF_2.0_format.html
"""
_log.debug('Calling reader')
if vector is None:
vector = True
assert vector is True, 'Only vector fields can be loaded from ovf-files!'
with open(filename, 'rb') as load_file:
line = load_file.readline()
assert line.startswith(b'# OOMMF') # File has OVF format!
read_header, read_data = False, False
header = {'version': line.split()[-1].decode('utf-8')}
x_mag, y_mag, z_mag = [], [], []
data_mode = None
while True:
# --- READ START OF FILE OR IN BETWEEN SEGMENTS LINE BY LINE ---------------------------
if not read_header and not read_data: # Start of file or between segments!
line = load_file.readline()
if line == b'':
break # End of file is reached!
if line.startswith(b'# Segment count'):
seg_count = int(line.split()[-1]) # Total number of segments (often just 1)!
seg_curr = 0 # Current segment (0: not in first segment, yet!)
if seg_count > 1: # If multiple segments, check if "segment" was set correctly:
assert segment is not None, (f'Multiple ({seg_count}) segments were found! '
'Chose one via the segment parameter!')
elif segment is None: # Only one segment AND parameter not set:
segment = 1 # Default to the first/only segment!
assert 0 < segment <= seg_count, (f'parameter segment={segment} out of bounds, '
f'Use value between 1 and {seg_count}!')
header['segment_count'] = seg_count
elif line.startswith(b'# Begin: Segment'): # Segment start!
seg_curr += 1
if seg_curr > segment:
break # Stop reading the file!
elif line.startswith(b'# Begin: Header'): # Header start!
read_header = True
elif line.startswith(b'# Begin: Data'): # Data start!
read_data = True
data_mode = ' '.join(line.decode('utf-8').split()[3:])
assert data_mode in ['text', 'Text', 'Binary 4', 'Binary 8'], \
f'Data mode {data_mode} is currently not supported by this reader!'
assert header.get('meshtype') == 'rectangular', \
'Only rectangular grids can be currently read!'
# --- READ HEADER LINE BY LINE ---------------------------------------------------------
elif read_header:
line = load_file.readline()
if line.startswith(b'# End: Header'): # Header is done:
read_header = False
continue
line = line.decode('utf-8') # Decode to use strings here!
line_list = line.split()
if '##' in line_list: # Strip trailing comments:
del line_list[line_list.index('##'):]
if len(line_list) <= 1: # Just '#' or empty line:
continue
key, value = line_list[1].strip(':'), ' '.join(line_list[2:])
if key not in header: # Add new key, value pair if not existant:
header[key] = value
elif key == 'Desc': # Description can go over several lines:
header['Desc'] = ' '.join([header['Desc'], value])
# --- READ DATA LINE BY LINE -----------------------------------------------------------
elif read_data: # Currently in a data block:
if data_mode in ['text', 'Text']: # Read data as text, line by line:
line = load_file.readline()
if line.startswith(b'# End: Data'):
read_data = False # Stop reading data and search for new segments!
continue
elif seg_curr < segment: # Do nothing with the line if wrong segment!
continue
else:
x, y, z = [float(i) for i in line.split()]
x_mag.append(x)
y_mag.append(y)
z_mag.append(z)
elif 'Binary' in data_mode: # Read data as binary, all bytes at the same time:
# Currently every segment is read until the wanted one is processed. Only that one is returned!
count = int(data_mode.split()[-1]) # Either 4 or 8!
if header['version'] == '1.0': # Big endian float:
dtype = f'>f{count}'
elif header['version'] == '2.0': # Little endian float:
dtype = f'<f{count}'
test = np.fromfile(load_file, dtype=dtype, count=1) # Read test byte!
if count == 4: # Binary 4:
assert test == 1234567.0, 'Wrong test bytes!'
elif count == 8: # Binary 8:
assert test == 123456789012345.0, 'Wrong test bytes!'
dim = (int(header['znodes']), int(header['ynodes']), int(header['xnodes']))
data_raw = np.fromfile(load_file, dtype=dtype, count=3*np.prod(dim))
x_mag, y_mag, z_mag = data_raw[0::3], data_raw[1::3], data_raw[2::3]
read_data = False # Stop reading data and search for new segments (if any).
# --- READING DONE -------------------------------------------------------------------------
# Format after reading:
dim = (int(header['znodes']), int(header['ynodes']), int(header['xnodes']))
x_mag = np.asarray(x_mag).reshape(dim)
y_mag = np.asarray(y_mag).reshape(dim)
z_mag = np.asarray(z_mag).reshape(dim)
data = np.stack((x_mag, y_mag, z_mag), axis=-1) * float(header.get('valuemultiplier', 1))
if scale is None:
unit = header.get('meshunit', 'nm')
if unit == 'unspecified':
unit = 'nm'
_log.info(f'unit: {unit}')
unit_scale = {'m': 1e9, 'mm': 1e6, 'µm': 1e3, 'nm': 1}[unit]
xstep = float(header.get('xstepsize')) * unit_scale
ystep = float(header.get('ystepsize')) * unit_scale
zstep = float(header.get('zstepsize')) * unit_scale
scale = (zstep, ystep, xstep)
return Field(data, scale=scale, vector=True)
def writer(filename, field, **kwargs):
_log.debug('Call writer')
assert field.vector and len(field.dim) == 3, 'Only 3D vector fields can be saved to ovf files!'
with open(filename, 'w') as save_file:
save_file.write('# OOMMF OVF 2.0\n')
save_file.write('# Segment count: 1\n')
save_file.write('# Begin: Segment\n')
# Write Header:
save_file.write('# Begin: Header\n')
name = os.path.split(filename)[1]
save_file.write(f'# Title: PYRAMID-VECTORDATA {name}\n')
save_file.write('# meshtype: rectangular\n')
save_file.write('# meshunit: nm\n')
save_file.write('# valueunit: A/m\n')
save_file.write('# valuemultiplier: 1.\n')
save_file.write('# xmin: 0.\n')
save_file.write('# ymin: 0.\n')
save_file.write('# zmin: 0.\n')
save_file.write(f'# xmax: {field.scale[2] * field.dim[2]}\n')
save_file.write(f'# ymax: {field.scale[1] * field.dim[1]}\n')
save_file.write(f'# zmax: {field.scale[0] * field.dim[0]}\n')
save_file.write('# xbase: 0.\n')
save_file.write('# ybase: 0.\n')
save_file.write('# zbase: 0.\n')
save_file.write(f'# xstepsize: {field.scale[2]}\n')
save_file.write(f'# ystepsize: {field.scale[1]}\n')
save_file.write(f'# zstepsize: {field.scale[0]}\n')
save_file.write(f'# xnodes: {field.dim[2]}\n')
save_file.write(f'# ynodes: {field.dim[1]}\n')
save_file.write(f'# znodes: {field.dim[0]}\n')
save_file.write('# End: Header\n')
# Write data:
save_file.write('# Begin: Data Text\n')
x_mag, y_mag, z_mag = field.comp
x_mag = x_mag.data.ravel()
y_mag = y_mag.data.ravel()
z_mag = z_mag.data.ravel()
for i in range(np.prod(field.dim)):
save_file.write(f'{x_mag[i]:g} {y_mag[i]:g} {z_mag[i]:g}\n')
save_file.write('# End: Data Text\n')
save_file.write('# End: Segment\n')
# -*- coding: utf-8 -*-
# Copyright 2020 by Forschungszentrum Juelich GmbH
# Author: J. Caron
#
"""IO plugin for simple text format."""
import logging
import re
import numpy as np
from ...fields.field import Field
from ...utils.misc import interp_to_regular_grid
_log = logging.getLogger(__name__)
file_extensions = ('.tec',) # Recognised file extensions
def reader(filename, scale=None, vector=None, **kwargs):
assert isinstance(scale, tuple), 'The scale must be a tuple, each entry corresponding to one grid dimensions!'
_log.debug('Call reader')
if vector is None:
vector = True
assert vector is True, 'Only vector fields can be loaded from tec-files!'
with open(filename, 'r') as mag_file:
lines = mag_file.readlines() # Read in lines!
match = re.search(R'N=(\d+)', lines[2]) # Extract number of points from third line!
if match:
n_points = int(match.group(1))
else:
raise IOError('File does not seem to match .tec format!')
n_head, n_foot = 3, len(lines) - (3 + n_points)
# Read in data:
data_raw = np.genfromtxt(filename, skip_header=n_head, skip_footer=n_foot)
if scale is None:
raise ValueError('For the interpolation of unstructured grids, the `scale` parameter is required!')
data = interp_to_regular_grid(data_raw[:, :3], data_raw[:, 3:], scale, **kwargs)
return Field(data, scale, vector=vector)
def writer(filename, field, **kwargs):
_log.debug('Call writer')
raise NotImplementedError('A writer for this extension is not yet implemented!')
# -*- coding: utf-8 -*-
# Copyright 2020 by Forschungszentrum Juelich GmbH
# Author: J. Caron
#
"""IO plugin for simple text format."""
import logging
import ast
import numpy as np
from ...fields.field import Field
_log = logging.getLogger(__name__)
file_extensions = ('.txt',) # Recognised file extensions
def reader(filename, scale=None, vector=None, **kwargs):
_log.debug('Call reader')
if vector is None:
vector = False
assert vector is False, 'Only scalar 2D fields can currently be read with this file reader!'
with open(filename, 'r') as load_file: # Read data:
empyre_format = load_file.readline().startswith('EMPYRE-FORMAT')
if empyre_format: # File has EMPyRe structure:
scale = ast.literal_eval(load_file.readline()[8:-4]) # [8:-4] takes just the scale string!
data = np.loadtxt(filename, delimiter='\t', skiprows=2) # skips header!
else: # Try default with provided kwargs:
scale = 1.0 if scale is None else scale # Set default if not provided!
data = np.loadtxt(filename, **kwargs)
return Field(data, scale=scale, vector=False)
def writer(filename, field, with_header=True, **kwargs):
_log.debug('Call writer')
assert not field.vector, 'Vector fields can currently not be saved to text!'
assert len(field.dim) == 2, 'Only 2D fields can currenty be saved to text!'
if with_header: # write header:
with open(filename, 'w') as save_file:
save_file.write('EMPYRE-FORMAT\n')
save_file.write(f'scale = {field.scale} nm\n')
save_kwargs = {'fmt': '%7.6e', 'delimiter': '\t'}
else:
save_kwargs = kwargs
with open(filename, 'ba') as save_file: # the 'a' is for append!
np.savetxt(save_file, field.data, **save_kwargs)
# -*- coding: utf-8 -*-
# Copyright 2020 by Forschungszentrum Juelich GmbH
# Author: J. Caron
#
"""IO plugin for simple text format."""
import logging
from numbers import Number
import numpy as np
from ...fields.field import Field
from ...utils.misc import interp_to_regular_grid, restrict_points
from ...vis import colors
_log = logging.getLogger(__name__)
file_extensions = ('.vtk',) # Recognised file extensions
def reader(filename, scale=None, vector=None, bounds=None, **kwargs):
"""More infos at:
overview: https://docs.enthought.com/mayavi/mayavi/data.html
writing: https://vtk.org/Wiki/VTK/Writing_VTK_files_using_python
format: https://www.vtk.org/wp-content/uploads/2015/04/file-formats.pdf
"""
_log.debug('Calling reader')
try:
from tvtk.api import tvtk
except ImportError:
_log.error('This extension recquires the tvtk package!')
return
if vector is None:
vector = True
# Setting up reader:
reader = tvtk.DataSetReader(file_name=filename, read_all_scalars=True, read_all_vectors=vector)
reader.update()
# Getting output:
output = reader.output
assert output is not None, 'File reader could not find data or file "{}"!'.format(filename)
# Reading points and vectors:
if isinstance(output, tvtk.ImageData): # tvtk.StructuredPoints is a subclass of tvtk.ImageData!
# Connectivity: implicit; described by: 3D data array and spacing along each axis!
_log.info('geometry: ImageData')
# Load relevant information from output (reverse to get typical Python order z,y,x):
dim = output.dimensions[::-1]
origin = output.origin[::-1]
spacing = output.spacing[::-1]
_log.info(f'dim: {dim}, origin: {origin}, spacing: {spacing}')
assert len(dim) == 3, 'Data has to be three-dimensional!'
if scale is None:
scale = tuple(spacing)
if vector: # Extract vector compontents and create magnitude array:
vector_array = np.asarray(output.point_data.vectors, dtype=np.float)
x_mag, y_mag, z_mag = vector_array.T
data = np.stack((x_mag.reshape(dim), y_mag.reshape(dim), z_mag.reshape(dim)), axis=-1)
else: # Extract scalar data and create magnitude array:
scalar_array = np.asarray(output.point_data.scalars, dtype=np.float)
data = scalar_array.reshape(dim)
elif isinstance(output, tvtk.RectilinearGrid):
# Connectivity: implicit; described by: 3D data array and 1D array of spacing for each axis!
_log.info('geometry: RectilinearGrid')
raise NotImplementedError('RectilinearGrid is currently not supported!')
elif isinstance(output, tvtk.StructuredGrid):
# Connectivity: implicit; described by: 3D data array and 3D position arrays for each axis!
_log.info('geometry: StructuredGrid')
raise NotImplementedError('StructuredGrid is currently not supported!')
elif isinstance(output, tvtk.PolyData):
# Connectivity: explicit; described by: x, y, z positions of vertices and arrays of surface cells!
_log.info('geometry: PolyData')
raise NotImplementedError('PolyData is currently not supported!')
elif isinstance(output, tvtk.UnstructuredGrid):
# Connectivity: explicit; described by: x, y, z positions of vertices and arrays of volume cells!
_log.info('geometry: UnstructuredGrid')
# Load relevant information from output:
point_array = np.asarray(output.points, dtype=np.float)
if vector:
data_array = np.asarray(output.point_data.vectors, dtype=np.float)
else:
data_array = np.asarray(output.point_data.scalars, dtype=np.float)
if scale is None:
raise ValueError('For the interpolation of unstructured grids, the `scale` parameter is required!')
elif isinstance(scale, Number): # Scale is the same for each dimension x, y, z!
scale = (scale,) * 3
elif isinstance(scale, tuple):
assert len(scale) == 3, f'Each dimension (z, y, x) needs a scale, but {scale} was given!'
# Crop data to required range, if necessary
if bounds is not None:
_log.info('Restrict data')
point_array, data_array = restrict_points(point_array, data_array, bounds)
data = interp_to_regular_grid(point_array, data_array, scale, **kwargs)
else:
raise TypeError('Data type of {} not understood!'.format(output))
return Field(data, scale, vector=True)
def writer(filename, field, **kwargs):
_log.debug('Call writer')
assert len(field.dim) == 3, 'Currently only 3D fields can be saved to vtk!'
try:
from tvtk.api import tvtk, write_data
except ImportError:
_log.error('This extension recquires the tvtk package!')
return
# Create dataset:
origin = (0, 0, 0)
spacing = (field.scale[2], field.scale[1], field.scale[0])
dimensions = (field.dim[2], field.dim[1], field.dim[0])
sp = tvtk.StructuredPoints(origin=origin, spacing=spacing, dimensions=dimensions)
# Fill with data from field:
if field.vector: # Handle vector fields:
# Put vector components in corresponding array:
vectors = field.data.reshape(-1, 3)
sp.point_data.vectors = vectors
sp.point_data.vectors.name = 'vectors'
# Calculate colors:
x_mag, y_mag, z_mag = field.comp
magvec = np.asarray((x_mag.data.ravel(), y_mag.data.ravel(), z_mag.data.ravel()))
cmap = kwargs.pop('cmap', None)
if cmap is None:
cmap = colors.cmaps.cyclic_cubehelix
rgb = cmap.rgb_from_vector(magvec)
point_colors = tvtk.UnsignedIntArray()
point_colors.number_of_components = 3
point_colors.name = 'colors'
point_colors.from_array(rgb)
sp.point_data.scalars = point_colors
sp.point_data.scalars.name = 'colors'
else: # Handle scalar fields:
scalars = field.data.ravel()
sp.point_data.scalars = scalars
sp.point_data.scalars.name = 'scalars'
# Write the data to file:
write_data(sp, filename)
# -*- coding: utf-8 -*-
# Copyright 2020 by Forschungszentrum Juelich GmbH
# Author: J. Caron
#
"""IO functionality for Field objects."""
import logging
import os
from ..fields.field import Field
from .field_plugins import plugin_list
__all__ = ['load_field', 'save_field']
_log = logging.getLogger(__name__)
def load_field(filename, scale=None, vector=None, **kwargs):
"""Load supported file into a :class:`~..fields.Field` instance.
The function loads the file according to the extension:
SCALAR???
- hdf5 for HDF5. # TODO: You can use comp_pos here!!!
- EMD Electron Microscopy Dataset format (also HDF5).
- npy or npz for numpy formats.
PHASEMAP
- hdf5 for HDF5.
- rpl for Ripple (useful to export to Digital Micrograph).
- dm3 and dm4 for Digital Micrograph files.
- unf for SEMPER unf binary format.
- txt format.
- npy or npz for numpy formats.
- Many image formats such as png, tiff, jpeg...
VECTOR
- hdf5 for HDF5.
- EMD Electron Microscopy Dataset format (also HDF5).
- llg format.
- ovf format.
- npy or npz for numpy formats.
Any extra keyword is passed to the corresponsing reader. For available options see their individual documentation.
Parameters
----------
filename: str
The filename to be loaded.
scale: tuple of float, optional
Scaling along the dimensions of the underlying data. Default is 1.
vector: bool, optional
True if the field should be a vector field, False if it should be interpreted as a scalar field (default).
Returns
-------
field: `Field`
A `Field` object containing the loaded data.
Notes
-----
Falls back to HyperSpy routines for loading data, make sure it is installed if you need the full capabilities.
"""
_log.debug('Calling load_field')
extension = os.path.splitext(filename)[1]
for plugin in plugin_list: # Iterate over all plugins:
if extension in plugin.file_extensions: # Check if extension is recognised:
return plugin.reader(filename, scale=scale, vector=vector, **kwargs)
# If nothing was found, try HyperSpy
_log.debug('Using HyperSpy')
try:
import hyperspy.api as hs
except ImportError:
_log.error('This extension recquires the hyperspy package!')
return
comp_pos = kwargs.pop('comp_pos', -1)
return Field.from_signal(hs.load(filename, **kwargs), scale=scale, vector=vector, comp_pos=comp_pos)
def save_field(filename, field, **kwargs):
"""Saves the Field in the specified format.
The function gets the format from the extension:
- hdf5 for HDF5.
- EMD Electron Microscopy Dataset format (also HDF5).
- npy or npz for numpy formats.
If no extension is provided, 'hdf5' is used. Most formats are saved with the HyperSpy package (internally the field
is first converted to a HyperSpy Signal.
Each format accepts a different set of parameters. For details see the specific format documentation.
Parameters
----------
filename : str, optional
Name of the file which the Field is saved into. The extension determines the saving procedure.
"""
"""Saves the phasemap in the specified format.
The function gets the format from the extension:
- hdf5 for HDF5.
- rpl for Ripple (useful to export to Digital Micrograph).
- unf for SEMPER unf binary format.
- txt format.
- Many image formats such as png, tiff, jpeg...
If no extension is provided, 'hdf5' is used. Most formats are
saved with the HyperSpy package (internally the phasemap is first
converted to a HyperSpy Signal.
Each format accepts a different set of parameters. For details
see the specific format documentation.
Parameters
----------
filename: str, optional
Name of the file which the phasemap is saved into. The extension
determines the saving procedure.
save_mask: boolean, optional
If True, the `mask` is saved, too. For all formats, except HDF5, a separate file will
be created. HDF5 always saves the `mask` in the metadata, independent of this flag. The
default is False.
save_conf: boolean, optional
If True, the `confidence` is saved, too. For all formats, except HDF5, a separate file
will be created. HDF5 always saves the `confidence` in the metadata, independent of
this flag. The default is False
pyramid_format: boolean, optional
Only used for saving to '.txt' files. If this is True, the grid spacing is saved
in an appropriate header. Otherwise just the phase is written with the
corresponding `kwargs`.
"""
_log.debug('Calling save_field')
extension = os.path.splitext(filename)[1]
for plugin in plugin_list: # Iterate over all plugins:
if extension in plugin.file_extensions: # Check if extension is recognised:
plugin.writer(filename, field, **kwargs)
return
# If nothing was found, try HyperSpy:
_log.debug('Using HyperSpy')
field.to_signal().save(filename, **kwargs)
# -*- coding: utf-8 -*-
# Copyright 2020 by Forschungszentrum Juelich GmbH
# Author: J. Caron
#
"""Subpackage containing utility functionality."""
from .quaternion import *
__all__ = []
__all__.extend(quaternion.__all__)
del quaternion
# -*- coding: utf-8 -*-
# Copyright 2020 by Forschungszentrum Juelich GmbH
# Author: J. Caron
#
"""This module provides the miscellaneous helper functions."""
import logging
import itertools
from time import time
import numpy as np
from tqdm import tqdm
from scipy.spatial import cKDTree, qhull
from scipy.interpolate import LinearNDInterpolator
__all__ = ['levi_civita', 'interp_to_regular_grid', 'restrict_points']
_log = logging.getLogger(__name__)
def levi_civita(i, j, k):
_log.debug('Calling levi_civita')
return (j-i) * (k-i) * (k-j) / (np.abs(j-i) * np.abs(k-i) * np.abs(k-j))
def interp_to_regular_grid(points, values, scale, scale_factor=1, step=1, convex=True, distance_upper_bound=None):
"""Interpolate values on points to regular grid.
Parameters
----------
points : np.ndarray, (N, 3)
Array of points, describing the location of the values that should be interpolated. Three columns x, y, z!
values : np.ndarray, (N, c)
Array of values that should be interpolated to the new grid. `c` is the number of components (`1` for scalar
fields, `3` for normal 3D vector fields).
scale : tuple of 3 ints
Scale along each of the 3 spatial dimensions. Usually given in nm.
scale_factor : float, optional
Additional scaling factor that should be used if the original points are not described on a nm-scale. Use this
to convert from nm of `scale` to the unit of the points. By default 1.
step : int, optional
If this is bigger than 1 (the default), only every `step` point is taken into account. Can speed up calculation,
but you'll lose accuracy of your interpolation.
convex : bool, optional
By default True. If this is set to False, additional measures are taken to find holes in the point cloud.
WARNING: this is an experimental feature that should be used with caution!
distance_upper_bound: float, optional
Only used if `convex=False`. Set the upper bound, determining if a point of the new (interpolated) grid is too
far away from any original point. They are assumed to be in a "hole" and their values are set to zero. Set this
value in nm, it will be converted to the local unit of the original points internally. If not set and
`convex=True`, double of the the mean of `scale` is calculated and used (can be troublesome if the scales vary
drastically).
Returns
-------
interpolation: np.ndarray
Interpolated grid with shape `(zdim, ydim, xdim)` for scalar and `(zdim, ydim, xdim, ncomp)` for vector field.
"""
_log.debug('Calling interpolate_to_regular_grid')
z_uniq = np.unique(points[:, 2])
_log.info(f'unique positions along z: {len(z_uniq)}')
# Translate scale in local units (not necessarily nm), taken care of with `scale_factor`:
scale = tuple([s * scale_factor for s in scale])
# Determine the size of the point cloud of irregular coordinates:
x_min, x_max = points[:, 0].min(), points[:, 0].max()
y_min, y_max = points[:, 1].min(), points[:, 1].max()
z_min, z_max = points[:, 2].min(), points[:, 2].max()
x_diff, y_diff, z_diff = np.ptp(points[:, 0]), np.ptp(points[:, 1]), np.ptp(points[:, 2])
_log.info(f'x-range: {x_min:.2g} <-> {x_max:.2g} ({x_diff:.2g})')
_log.info(f'y-range: {y_min:.2g} <-> {y_max:.2g} ({y_diff:.2g})')
_log.info(f'z-range: {z_min:.2g} <-> {z_max:.2g} ({z_diff:.2g})')
# Determine dimensions from given grid spacing a:
dim = tuple(np.round(np.asarray((z_diff/scale[0], y_diff/scale[1], x_diff/scale[2]), dtype=int)))
assert all(dim) > 0, f'All dimensions of dim={dim} need to be > 0, please adjust the scale accordingly!'
z = z_min + scale[0] * (np.arange(dim[0]) + 0.5) # +0.5: shift to pixel center!
y = y_min + scale[1] * (np.arange(dim[1]) + 0.5) # +0.5: shift to pixel center!
x = x_min + scale[2] * (np.arange(dim[2]) + 0.5) # +0.5: shift to pixel center!
# Create points for new Euclidian grid; fliplr for (x, y, z) order:
points_euc = np.fliplr(np.asarray(list(itertools.product(z, y, x))))
# Make values 2D (if not already); double .T so that a new axis is added at the END (n, 1):
values = np.atleast_2d(values.T).T
# Prepare interpolated grid:
interpolation = np.empty(dim+(values.shape[-1],), dtype=np.float)
_log.info(f'Dimensions of new grid: {interpolation.shape}')
# Calculate the Delaunay triangulation (same for every component of multidim./vector fields):
_log.info('Start Delaunay triangulation...')
tick = time()
triangulation = qhull.Delaunay(points[::step])
tock = time()
_log.info(f'Delaunay triangulation complete (took {tock-tick:.2f} s)!')
# Perform the interpolation for each column of `values`:
for i in tqdm(range(values.shape[-1])):
# Create interpolator for the given triangulation and the values of the current column:
interpolator = LinearNDInterpolator(triangulation, values[::step, i], fill_value=0)
# Interpolate:
interpolation[..., i] = interpolator(points_euc).reshape(dim)
# If NOT convex, we have to check for additional holes in the structure (EXPERIMENTAL):
if not convex: # Only necessary if the user expects holes in the (-> nonconvex) distribution:
# Create k-dimensional tree for queries:
_log.info('Create cKDTree...')
tick = time()
tree = cKDTree(points)
tock = time()
_log.info(f'cKDTree creation complete (took {tock-tick:.2f} s)!')
# Query the tree for nearest neighbors, x: points to query, k: number of neighbors, p: norm
# to use (here: 2 - Euclidean), distance_upper_bound: maximum distance that is searched!
if distance_upper_bound is None: # Take the mean of the scale for the upper bound:
distance_upper_bound = 2 * np.mean(scale) # NOTE: could be problematic for wildly varying scale numbers.
else: # Convert to local scale:
distance_upper_bound *= scale_factor
_log.info('Start cKDTree neighbour query...')
tick = time()
data, leafsize = tree.query(x=points_euc, k=1, p=2, distance_upper_bound=distance_upper_bound)
tock = time()
_log.info(f'cKDTree neighbour query complete (took {tock-tick:.2f} s)!')
# Create boolean mask that determines which interpolation points have no neighbor near enough:
mask = np.isinf(data).reshape(dim) # Points further away than upper bound were marked 'inf'!
_log.info(f'{np.sum(mask)} of {points_euc.shape[0]} points were assumed to be in holes of the point cloud!')
# Set these points to zero (NOTE: This can take a looooong time):
interpolation[mask, :] = 0
return np.squeeze(interpolation)
def restrict_points(point_array, data_array, bounds):
"""Restrict range of point_array and data_array
Parameters
----------
points_array : np.ndarray, (N, 3)
Array of points, describing the location of the values that should be interpolated. Three columns x, y, z!
data_array : np.ndarray, (N, c)
Array of values that should be interpolated to the new grid. `c` is the number of components (`1` for scalar
fields, `3` for normal 3D vector fields).
bounds : tuple of 6 values
Restrict data range to given bounds, x0, x1, y0, y1, z0, z1.
Returns
-------
point_restricted: np.ndarray
Cut out of the array of points inside the bounds, describing the location of the values that should be
interpolated. Three columns x, y, z!
value_restricted: np.ndarray
Cut out of the array of values inside the bounds, describing the location of the values that should be
interpolated. Three columns x, y, z!
"""
point_restricted = []
data_restricted = []
for i, pos in enumerate(point_array):
if bounds[0] <= pos[0] <= bounds[1]:
if bounds[2] <= pos[1] <= bounds[3]:
if bounds[4] <= pos[2] <= bounds[5]:
point_restricted.append(pos)
data_restricted.append(data_array[i])
point_restricted = np.array(point_restricted)
data_restricted = np.array(data_restricted)
return point_restricted, data_restricted
# -*- coding: utf-8 -*-
# Copyright 2020 by Forschungszentrum Juelich GmbH
# Author: J. Caron
#
"""This module provides the :class:`~.Quaternion` class which can be used for rotations."""
import logging
import numpy as np
__all__ = ['Quaternion']
class Quaternion(object):
R"""Class representing a rotation expressed by a quaternion.
A quaternion is a four-dimensional description of a rotation which can also be described by
a rotation vector (`v1`, `v2`, `v3`) and a rotation angle :math:`\theta`. The four components
are calculated to:
.. math::
w = \cos(\theta/2)
x = v_1 \cdot \sin(\theta/2)
y = v_2 \cdot \sin(\theta/2)
z = v_3 \cdot \sin(\theta/2)
Use the :func:`~.from_axisangle` and :func:`~.to_axisangle` to convert to axis-angle
representation and vice versa. Quaternions can be multiplied by other quaternions, which
results in a new rotation or with a vector, which results in a rotated vector.
Attributes
----------
values : float
The four quaternion values `w`, `x`, `y`, `z`.
"""
NORM_TOLERANCE = 1E-6
_log = logging.getLogger(__name__ + '.Quaternion')
@property
def conj(self):
"""The conjugate of the quaternion, representing a tilt in opposite direction."""
w, x, y, z = self.values
return Quaternion((w, -x, -y, -z))
@property
def matrix(self):
"""The rotation matrix representation of the quaternion."""
w, x, y, z = self.values
return np.array([[1 - 2 * (y ** 2 + z ** 2), 2 * (x * y - w * z), 2 * (x * z + w * y)],
[2 * (x * y + w * z), 1 - 2 * (x ** 2 + z ** 2), 2 * (y * z - w * x)],
[2 * (x * z - w * y), 2 * (y * z + w * x), 1 - 2 * (x ** 2 + y ** 2)]])
def __init__(self, values):
self._log.debug('Calling __init__')
self.values = values
self._normalize()
self._log.debug('Created ' + str(self))
def __mul__(self, other): # self * other
self._log.debug('Calling __mul__')
if isinstance(other, Quaternion): # Quaternion multiplication
return self.dot_quat(self, other)
elif len(other) == 3: # vector multiplication (Caution: normalises!)
q_vec = Quaternion((0,) + tuple(other))
q = self.dot_quat(self.dot_quat(self, q_vec), self.conj)
return q.values[1:]
def _normalize(self):
self._log.debug('Calling _normalize')
mag2 = np.sum([n ** 2 for n in self.values])
if abs(mag2 - 1.0) > self.NORM_TOLERANCE:
mag = np.sqrt(mag2)
self.values = tuple(n / mag for n in self.values)
def dot_quat(self, q1, q2):
"""Multiply two :class:`~.Quaternion` objects to create a new one (always normalized).
Parameters
----------
q1, q2 : :class:`~.Quaternion`
The quaternion which should be multiplied.
Returns
-------
quaternion : :class:`~.Quaternion`
The resulting quaternion.
"""
self._log.debug('Calling dot_quat')
w1, x1, y1, z1 = q1.values
w2, x2, y2, z2 = q2.values
w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2
x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2
y = w1 * y2 + y1 * w2 + z1 * x2 - x1 * z2
z = w1 * z2 + z1 * w2 + x1 * y2 - y1 * x2
return Quaternion((w, x, y, z))
@classmethod
def from_axisangle(cls, vector, theta):
"""Create a quaternion from an axis-angle representation
Parameters
----------
vector : :class:`~numpy.ndarray` (N=3)
Vector around which the rotation is executed.
theta : float
Rotation angle.
Returns
-------
quaternion : :class:`~.Quaternion`
The resulting quaternion.
"""
cls._log.debug('Calling from_axisangle')
x, y, z = vector
theta /= 2.
w = np.cos(theta)
x *= np.sin(theta)
y *= np.sin(theta)
z *= np.sin(theta)
return cls((w, x, y, z))
def to_axisangle(self):
"""Convert the quaternion to axis-angle-representation.
Returns
-------
vector, theta : :class:`~numpy.ndarray` (N=3), float
Vector around which the rotation is executed and rotation angle.
"""
self._log.debug('Calling to_axisangle')
w, x, y, z = self.values
theta = 2.0 * np.arccos(w)
return np.array((x, y, z)), theta
# -*- coding: utf-8 -*-
# Copyright 2020 by Forschungszentrum Juelich GmbH
# Author: J. Caron
#
"""Subpackage containing functionality for visualisation of multidimensional fields."""
from . import colors
from .plot2d import *
from .decorators import *
from .tools import *
from .plot3d import *
__all__ = ['colors']
__all__.extend(plot2d.__all__)
__all__.extend(decorators.__all__)
__all__.extend(tools.__all__)
__all__.extend(plot3d.__all__)
del plot2d
del decorators
del tools
del plot3d
# -*- coding: utf-8 -*-
# Copyright 2020 by Forschungszentrum Juelich GmbH
# Author: J. Caron
#
"""This module provides a number of custom colormaps, which also have capabilities for 3D plotting.
If this is the case, the :class:`~.Colormap3D` colormap class is a parent class. In `cmaps`, a
number of specialised colormaps is available for convenience.
For general questions about colors see:
http://www.poynton.com/PDFs/GammaFAQ.pdf
http://www.poynton.com/PDFs/ColorFAQ.pdf
"""
import logging
import colorsys
import abc
import numpy as np
from PIL import Image
from matplotlib import colors
import matplotlib.pyplot as plt
from matplotlib.ticker import FixedLocator
from mpl_toolkits.mplot3d import Axes3D # noqa: F401
from matplotlib.patches import Circle
from .tools import use_style
__all__ = ['Colormap3D', 'ColormapCubehelix', 'ColormapPerception', 'ColormapHLS', 'ColormapClassic',
'ColormapTransparent', 'cmaps', 'interpolate_color']
_log = logging.getLogger(__name__)
class Colormap3D(colors.Colormap, metaclass=abc.ABCMeta):
"""Colormap subclass for encoding directions with colors.
This abstract class is used as a superclass/interface for 3D vector plotting capabilities.
In general, a circular colormap should be used to encode the in-plane angle (hue). The
perpendicular angle is encoded via luminance variation (up: white, down: black). Finally,
the length of a vector is encoded via saturation. Decreasing vector length causes a desaturated
color. Subclassing colormaps get access to routines to plot a colorwheel (which should
ideally be located in the 50% luminance plane, which depends strongly on the underlying map),
a convenience function to interpolate color tuples and a function to return rgb triples for a
given vector. The :class:`~.Colormap3D` class itself subclasses the matplotlib base colormap.
"""
_log = logging.getLogger(__name__ + '.Colormap3D')
def rgb_from_vector(self, vector, vmax=None):
"""Construct a hls tuple from three coordinates representing a 3D direction.
Parameters
----------
vector: tuple (N=3) or :class:`~numpy.ndarray`
Vector containing the x, y and z component, or a numpy array encompassing the
components as three lists.
Returns
-------
rgb: :class:`~numpy.ndarray`
Numpy array containing the calculated color tuples.
"""
self._log.debug('Calling rgb_from_vector')
x, y, z = vector
R = np.sqrt(x ** 2 + y ** 2 + z ** 2)
R_max = vmax if vmax is not None else R.max() + 1E-30
# FIRST color dimension: HUE (1D ring/angular direction)
phi = np.asarray(np.arctan2(y, x))
phi[phi < 0] += 2 * np.pi
hue = phi / (2 * np.pi)
rgba = np.asarray(self(hue))
r, g, b = rgba[..., 0], rgba[..., 1], rgba[..., 2]
# SECOND color dimension: SATURATION (2D, in-plane)
rho = np.sqrt(x ** 2 + y ** 2)
sat = rho / R_max
r, g, b = interpolate_color(sat, (0.5, 0.5, 0.5), np.stack((r, g, b), axis=-1))
# THIRD color dimension: LUMINANCE (3D, color sphere)
theta = np.arccos(z / R_max)
lum = 1 - theta / np.pi # goes from 0 (black) over 0.5 (grey) to 1 (white)!
lum_target = np.where(lum < 0.5, 0, 1) # Separate upper(white)/lower(black) hemispheres!
lum_target = np.stack([lum_target] * 3, axis=-1) # [0, 0, 0] -> black / [1, 1, 1] -> white!
fraction = 2 * np.abs(lum - 0.5) # 0.5: difference from grey, 2: scale to range (0, 1)!
r, g, b = interpolate_color(fraction, np.stack((r, g, b), axis=-1), lum_target)
# Return RGB:
return np.asarray(255 * np.stack((r, g, b), axis=-1), dtype=np.uint8)
def make_colorwheel(self, size=64):
"""Construct a color wheel as an :class:`~PIL.Image` object.
Parameters
----------
size : int, optional
Diameter of the color wheel along both axes in pixels, by default 64.
Returns
-------
img: :class:`~PIL.Image`
The resulting image.
"""
self._log.debug('Calling make_colorwheel')
# Construct the colorwheel:
yy, xx = (np.indices((size, size)) - size/2 + 0.5)
rr = np.hypot(xx, yy)
xx = np.where(rr <= size/2-3, xx, 0)
yy = np.where(rr <= size/2-3, yy, 0)
zz = np.zeros((size, size))
aa = np.where(rr >= size/2-3, 0, 255).astype(dtype=np.uint8)
rgba = np.dstack((self.rgb_from_vector(np.asarray((xx, yy, zz))), aa))
# Create color wheel:
return Image.fromarray(rgba)
def plot_colorwheel(self, axis=None, size=64, arrows=True, grayscale=False, **kwargs):
"""Display a color wheel to illustrate the color coding of vector gradient directions.
Parameters
----------
figsize : tuple of floats (N=2)
Size of the plot figure.
Returns
-------
img: :class:`matplotlib.image.AxesImage`
The resulting colorwheel.
"""
self._log.debug('Calling plot_colorwheel')
# Construct the colorwheel:
color_wheel = self.make_colorwheel(size=size)
if grayscale:
color_wheel = color_wheel.convert('LA')
# Plot the color wheel:
with use_style('empyre-image'): # Only works on axes created WITHIN context!
if axis is None: # If no axis is set, find the current or create a new one:
fig = plt.figure()
axis = fig.add_subplot(1, 1, 1, aspect='equal')
# Plot:
im = axis.imshow(color_wheel, **kwargs)
xy = size/2 - 0.5, size/2 - 0.5
axis.add_patch(Circle(xy=xy, radius=size/2-2.5, linewidth=2, edgecolor='k', facecolor='none'))
if arrows:
axis.arrow(size/2, size/2+5, 0, 0.1*size, fc='k', ec='k', lw=1, width=2, alpha=0.15)
axis.arrow(size/2, size/2-5, 0, -0.1*size, fc='k', ec='k', lw=1, width=2, alpha=0.15)
axis.arrow(size/2+5, size/2, 0.1*size, 0, fc='k', ec='k', lw=1, width=2, alpha=0.15)
axis.arrow(size/2-5, size/2, -0.1*size, 0, fc='k', ec='k', lw=1, width=2, alpha=0.15)
return im
class ColormapCubehelix(colors.LinearSegmentedColormap, Colormap3D):
"""A full implementation of Dave Green's "cubehelix" for Matplotlib.
Based on the FORTRAN 77 code provided in D.A. Green, 2011, BASI, 39, 289.
http://adsabs.harvard.edu/abs/2011arXiv1108.5083G
Also see:
http://www.mrao.cam.ac.uk/~dag/CUBEHELIX/
http://davidjohnstone.net/pages/cubehelix-gradient-picker
User can adjust all parameters of the cubehelix algorithm. This enables much greater
flexibility in choosing color maps. Default color map settings produce the standard cubehelix.
Create color map in only blues by setting rot=0 and start=0. Create reverse (white to black)
backwards through the rainbow once by setting rot=1 and reverse=True, etc. Furthermore, the
algorithm was tuned, so that constant luminance values can be used (e.g. to create a truly
isoluminant colorwheel). The `rot` parameter is also tuned to hold true for these cases.
Of the here presented colorwheels, only this one manages to solely navigate through the L*=50
plane, which can be seen here:
https://upload.wikimedia.org/wikipedia/commons/2/21/Lab_color_space.png
Parameters
----------
start : scalar, optional
Sets the starting position in the color space. 0=blue, 1=red,
2=green. Defaults to 0.5.
rot : scalar, optional
The number of rotations through the rainbow. Can be positive
or negative, indicating direction of rainbow. Negative values
correspond to Blue->Red direction. Defaults to -1.5.
gamma : scalar, optional
The gamma correction for intensity. Defaults to 1.0.
reverse : boolean, optional
Set to True to reverse the color map. Will go from black to
white. Good for density plots where shade~density. Defaults to False.
nlev : scalar, optional
Defines the number of discrete levels to render colors at.
Defaults to 256.
sat : scalar, optional
The saturation intensity factor. Defaults to 1.2
NOTE: this was formerly known as `hue` parameter
minSat : scalar, optional
Sets the minimum-level saturation. Defaults to 1.2.
maxSat : scalar, optional
Sets the maximum-level saturation. Defaults to 1.2.
startHue : scalar, optional
Sets the starting color, ranging from [0, 360], as in
D3 version by @mbostock.
NOTE: overrides values in start parameter.
endHue : scalar, optional
Sets the ending color, ranging from [0, 360], as in
D3 version by @mbostock
NOTE: overrides values in rot parameter.
minLight : scalar, optional
Sets the minimum lightness value. Defaults to 0.
maxLight : scalar, optional
Sets the maximum lightness value. Defaults to 1.
Returns
-------
matplotlib.colors.LinearSegmentedColormap object
Revisions
---------
2014-04 (@jradavenport) Ported from IDL version
2014-04 (@jradavenport) Added kwargs to enable similar to D3 version,
changed name of `hue` parameter to `sat`.
2016-11 (@jan.caron) Added support for isoluminant cubehelices while making sure
`rot` works as intended. Decoded the plane-vectors a bit.
"""
_log = logging.getLogger(__name__ + '.ColormapCubehelix')
def __init__(self, start=0.5, rot=-1.5, gamma=1.0, reverse=False, nlev=256,
minSat=1.2, maxSat=1.2, minLight=0., maxLight=1., **kwargs):
self._log.debug('Calling __init__')
# Override start and rot if startHue and endHue are set:
if kwargs is not None:
if 'startHue' in kwargs:
start = (kwargs.get('startHue') / 360. - 1.) * 3.
if 'endHue' in kwargs:
rot = kwargs.get('endHue') / 360. - start / 3. - 1.
if 'sat' in kwargs:
minSat = kwargs.get('sat')
maxSat = kwargs.get('sat')
self.nlev = nlev
# Set up the parameters:
self.fract = np.linspace(minLight, maxLight, nlev)
angle = 2.0 * np.pi * (start / 3.0 + rot * np.linspace(0, 1, nlev))
self.fract = self.fract**gamma
satar = np.linspace(minSat, maxSat, nlev)
amp = np.asarray(satar * self.fract * (1. - self.fract) / 2)
# Set RGB color coefficients (Luma is calculated in RGB Rec.601, so choose those), the original version of
# Dave Green used (0.30, 0.59, 0.11) and Rec.709 is c709 = (0.2126, 0.7152, 0.0722) but with eihter of those,
# this function would not produce the correct YPbPr Luma.
c601 = (0.299, 0.587, 0.114)
cr, cg, cb = c601
cw = -0.90649 # Chosen to comply with Dave Greens implementation.
k = -1.6158 / cr / cw # k has to balance out cw so nothing gets out of RGB gamut (> 1).
# Calculate the vectors v and w spanning the plane of constant perceived intensity. v and w have to solve
# v x w = k(cr, cg, cb) (normal vector of the described plane) and
# v * w = 0 (scalar product, v and w have to be perpendicular).
# 6 unknown and 4 equations --> Chose wb = 0 and wg = cw (constant).
v = np.array((k * cr ** 2 * cb / (cw * (cr ** 2 + cg ** 2)),
k * cr * cg * cb / (cw * (cr ** 2 + cg ** 2)), -k * cr / cw))
w = np.array((-cw * cg / cr, cw, 0))
# Calculate components:
self.red = self.fract + amp * (v[0] * np.cos(angle) + w[0] * np.sin(angle))
self.grn = self.fract + amp * (v[1] * np.cos(angle) + w[1] * np.sin(angle))
self.blu = self.fract + amp * (v[2] * np.cos(angle) + w[2] * np.sin(angle))
# Original formulas with original v and w:
# self.red = self.fract + amp * (-0.14861 * np.cos(angle) + 1.78277 * np.sin(angle))
# self.grn = self.fract + amp * (-0.29227 * np.cos(angle) - 0.90649 * np.sin(angle))
# self.blu = self.fract + amp * (1.97294 * np.cos(angle))
# Find where RBG are outside the range [0,1], clip:
self.red = np.clip(self.red, 0, 1)
self.grn = np.clip(self.grn, 0, 1)
self.blu = np.clip(self.blu, 0, 1)
# Optional color reverse:
if reverse is True:
self.red = self.red[::-1]
self.blu = self.blu[::-1]
self.grn = self.grn[::-1]
# Put in to tuple & dictionary structures needed:
rr, bb, gg = [], [], []
for k in range(0, int(nlev)):
rr.append((float(k) / (nlev - 1), self.red[k], self.red[k]))
bb.append((float(k) / (nlev - 1), self.blu[k], self.blu[k]))
gg.append((float(k) / (nlev - 1), self.grn[k], self.grn[k]))
cdict = {'red': rr, 'blue': bb, 'green': gg}
super().__init__('cubehelix', cdict, N=256)
self._log.debug('Created ' + str(self))
def plot_helix(self, figsize=None, **kwargs):
"""Display the RGB and luminance plots for the chosen cubehelix.
Parameters
----------
figsize : tuple of floats (N=2)
Size of the plot figure.
Returns
-------
None
"""
self._log.debug('Calling plot_helix')
with use_style('empyre-plot'):
fig = plt.figure(figsize=figsize, constrained_layout=True)
gs = fig.add_gridspec(2, 1, height_ratios=[8, 1])
# Main plot:
axis = plt.subplot(gs[0])
axis.plot(self.fract, 'k')
axis.plot(self.red, 'r')
axis.plot(self.grn, 'g')
axis.plot(self.blu, 'b')
axis.set_xlim(0, self.nlev)
axis.set_ylim(0, 1)
axis.set_title('Cubehelix')
axis.set_xlabel('Color index')
axis.set_ylabel('Brightness / RGB')
axis.xaxis.set_major_locator(FixedLocator(locs=np.linspace(0, self.nlev, 5)))
axis.yaxis.set_major_locator(FixedLocator(locs=[0, 0.5, 1]))
# Colorbar horizontal:
caxis = plt.subplot(gs[1])
rgb = self(np.linspace(0, 1, 256))[None, ...]
rgb = np.asarray(255.9999 * rgb, dtype=np.uint8)
rgb = np.repeat(rgb, 30, axis=0)
im = Image.fromarray(rgb)
caxis.imshow(im, aspect='auto')
caxis.tick_params(axis='both', which='both', labelleft=False, labelbottom=False,
left=False, right=False, top=False, bottom=False)
class ColormapPerception(colors.LinearSegmentedColormap, Colormap3D):
"""A perceptual colormap based on face-based luminance matching.
Based on a publication by Kindlmann et. al.
http://www.cs.utah.edu/~gk/papers/vis02/FaceLumin.pdf
This colormap tries to achieve an isoluminant perception by using a list of colors acquired
through face recognition studies. It is a lot better than the HLS colormap, but still not
completely isoluminant (despite its name). Also it appears a bit dark.
"""
_log = logging.getLogger(__name__ + '.ColormapPerception')
CDICT = {'red': [(0/6, 0.847, 0.847),
(1/6, 0.527, 0.527),
(2/6, 0.000, 0.000),
(3/6, 0.000, 0.000),
(4/6, 0.316, 0.316),
(5/6, 0.718, 0.718),
(6/6, 0.847, 0.847)],
'green': [(0/6, 0.057, 0.057),
(1/6, 0.527, 0.527),
(2/6, 0.592, 0.592),
(3/6, 0.559, 0.559),
(4/6, 0.316, 0.316),
(5/6, 0.000, 0.000),
(6/6, 0.057, 0.057)],
'blue': [(0/6, 0.057, 0.057),
(1/6, 0.000, 0.000),
(2/6, 0.000, 0.000),
(3/6, 0.559, 0.559),
(4/6, 0.991, 0.991),
(5/6, 0.718, 0.718),
(6/6, 0.057, 0.057)]}
def __init__(self):
self._log.debug('Calling __init__')
super().__init__('perception', self.CDICT, N=256)
self._log.debug('Created ' + str(self))
class ColormapHLS(colors.ListedColormap, Colormap3D):
"""Colormap subclass for encoding directions with colors.
This class is a subclass of the :class:`~matplotlib.pyplot.colors.ListedColormap`
class. The class follows the HSL ('hue', 'saturation', 'lightness') 'Double Hexcone' Model
with the saturation always set to 1 (moving on the surface of the color
cylinder) with a lightness of 0.5 (full color). The three prime colors (`rgb`) are spaced
equidistant with 120° space in between, according to a triadic arrangement.
Even though the lightness is constant in the plane, the luminance (which is a weighted sum
of the RGB components which encompasses human perception) is not, which can lead to
artifacts like reliefs. Converting the map to a grayscale show spokes at the secondary colors.
For more information see:
https://vis4.net/blog/posts/avoid-equidistant-hsv-colors/
http://www.workwithcolor.com/color-luminance-2233.htm
http://blog.asmartbear.com/color-wheels.html
"""
_log = logging.getLogger(__name__ + '.ColormapHLS')
def __init__(self):
self._log.debug('Calling __init__')
h = np.linspace(0, 1, 256)
l = 0.5 * np.ones_like(h)
s = np.ones_like(h)
r, g, b = np.vectorize(colorsys.hls_to_rgb)(h, l, s)
colors = [(r[i], g[i], b[i]) for i in range(len(r))]
super().__init__(colors, 'hls', N=256)
self._log.debug('Created ' + str(self))
class ColormapClassic(colors.LinearSegmentedColormap, Colormap3D):
"""Colormap subclass for encoding directions with colors.
This class is a subclass of the :class:`~matplotlib.pyplot.colors.LinearSegmentedColormap`
class. The class follows the HSL ('hue', 'saturation', 'lightness') 'Double
Hexcone' Model with the saturation always set to 1 (moving on the surface of the color
cylinder) with a luminance of 0.5 (full color). The colors follow a tetradic arrangement with
four colors (red, green, blue and yellow) arranged with 90° spacing in between.
"""
_log = logging.getLogger(__name__ + '.ColormapClassic')
CDICT = {'red': [(0/4, 1.0, 1.0),
(1/4, 0.0, 0.0),
(2/4, 0.0, 0.0),
(3/4, 1.0, 1.0),
(4/4, 1.0, 1.0)],
'green': [(0/4, 0.0, 0.0),
(1/4, 0.0, 0.0),
(2/4, 1.0, 1.0),
(3/4, 1.0, 1.0),
(4/4, 0.0, 0.0)],
'blue': [(0/4, 0.0, 0.0),
(1/4, 1.0, 1.0),
(2/4, 0.0, 0.0),
(3/4, 0.0, 0.0),
(4/4, 0.0, 0.0)]}
def __init__(self):
self._log.debug('Calling __init__')
super().__init__('classic', self.CDICT, N=256)
self._log.debug('Created ' + str(self))
class ColormapTransparent(colors.LinearSegmentedColormap):
"""Colormap subclass for including transparency.
This class is a subclass of the :class:`~matplotlib.pyplot.colors.LinearSegmentedColormap`
class with integrated support for transparency. The colormap is unicolor and varies only in
transparency.
Attributes
----------
r: float, optional
Intensity of red in the colormap. Has to be between 0. and 1.
g: float, optional
Intensity of green in the colormap. Has to be between 0. and 1.
b: float, optional
Intensity of blue in the colormap. Has to be between 0. and 1.
alpha_range : list (N=2) of float, optional
Start and end alpha value. Has to be between 0. and 1.
"""
_log = logging.getLogger(__name__ + '.ColormapTransparent')
def __init__(self, r=0., g=0., b=0., alpha_range=None):
self._log.debug('Calling __init__')
if alpha_range is None:
alpha_range = [0., 1.]
red = [(0., 0., r), (1., r, 1.)]
green = [(0., 0., g), (1., g, 1.)]
blue = [(0., 0., b), (1., b, 1.)]
alpha = [(0., 0., alpha_range[0]), (1., alpha_range[1], 1.)]
cdict = {'red': red, 'green': green, 'blue': blue, 'alpha': alpha}
super().__init__('transparent', cdict, N=256)
self._log.debug('Created ' + str(self))
def interpolate_color(fraction, start, end):
"""Interpolate linearly between two color tuples (e.g. RGB).
Parameters
----------
fraction: float or :class:`~numpy.ndarray`
Interpolation fraction between 0 and 1, which determines the position of the
interpolation between `start` and `end`.
start: tuple (N=3) or :class:`~numpy.ndarray`
Start of the interpolation as a tuple of three numbers or a numpy array, where the last
dimension should have length 3 and contain the color tuples.
end: tuple (N=3) or :class:`~numpy.ndarray`
End of the interpolation as a tuple of three numbers or a numpy array, where the last
dimension should have length 3 and contain the color tuples.
Returns
-------
result: tuple (N=3) or :class:`~numpy.ndarray`
Result of the interpolation as a tuple of three numbers or a numpy array, where the
last dimension should has length 3 and contains the color tuples.
"""
_log.debug('Calling interpolate_color')
start, end = np.asarray(start), np.asarray(end)
r1 = start[..., 0] + (end[..., 0] - start[..., 0]) * fraction
r2 = start[..., 1] + (end[..., 1] - start[..., 1]) * fraction
r3 = start[..., 2] + (end[..., 2] - start[..., 2]) * fraction
return r1, r2, r3
class CMapNamespace(object):
def __init__(self):
self.cubehelix = ColormapCubehelix()
self.cubehelix_r = ColormapCubehelix(reverse=True)
self.cyclic_cubehelix = ColormapCubehelix(start=1, rot=1, minLight=0.5, maxLight=0.5, sat=2)
self.cyclic_perception = ColormapPerception()
self.cyclic_hls = ColormapHLS()
self.cyclic_classic = ColormapClassic()
self.transparent_black = ColormapTransparent(0, 0, 0, [0, 1.])
self.transparent_white = ColormapTransparent(1, 1, 1, [0, 1.])
self.transparent_confidence = ColormapTransparent(0.2, 0.3, 0.2, [0.75, 0.])
def __getitem__(self, key):
return self.__dict__[key]
def add_cmap_dict(self, **kwargs):
self.__dict__.update(kwargs)
cmaps = CMapNamespace()
# -*- coding: utf-8 -*-
# Copyright 2020 by Forschungszentrum Juelich GmbH
# Author: J. Caron
#
"""This module provides functions that decorate exisiting `matplotlib` plots."""
import logging
from collections.abc import Iterable
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import patheffects
from matplotlib.patches import Circle
from matplotlib.offsetbox import TextArea, AnchoredOffsetbox
from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from . import colors
from .tools import use_style
from ..fields.field import Field
__all__ = ['scalebar', 'colorwheel', 'annotate', 'quiverkey', 'coords', 'colorbar']
_log = logging.getLogger(__name__)
def scalebar(axis=None, unit='nm', loc='lower left', **kwargs):
"""Add a scalebar to the axis.
Parameters
----------
axis : :class:`~matplotlib.axes.AxesSubplot`, optional
Axis to which the scalebar is added, by default None, which will pick the last used axis via `gca`.
unit: str, optional
String that determines the unit of the scalebar, defaults to 'nm'.
loc : str or pair of floats, optional
The location of the scalebar, defaults to 'lower left'. See `matplotlib.legend` for possible settings.
Returns
-------
aoffbox : :class:`~matplotlib.offsetbox.AnchoredOffsetbox`
The box containing the scalebar.
Notes
-----
Additional kwargs are passed to `mpl_toolkits.axes_grid1.anchored_artists.AnchoredSizeBar`.
"""
_log.debug('Calling scalebar')
if axis is None: # If no axis is set, find the current or create a new one:
axis = plt.gca()
# Transform axis borders (1, 1) to data borders to get number of pixels in y and x:
transform = axis.transData
bb0 = axis.transLimits.inverted().transform((0, 0))
bb1 = axis.transLimits.inverted().transform((1, 1))
data_extent = (int(abs(bb1[1] - bb0[1])), int(abs(bb1[0] - bb0[0])))
# Calculate bar length:
bar_length = data_extent[1] / 4 # 25% of the data width!
thresholds = [1, 5, 10, 50, 100, 500, 1000]
for t in thresholds: # For larger grids (real images), multiples of threshold look better!
if bar_length > t: # Round down to the next lowest multiple of t:
bar_length = (bar_length // t) * t
# Set parameters for scale bar:
label = f'{bar_length:g} {unit}'
# Set defaults:
kwargs.setdefault('borderpad', 0.2)
kwargs.setdefault('pad', 0.2)
kwargs.setdefault('sep', 5)
kwargs.setdefault('color', 'w')
kwargs.setdefault('size_vertical', data_extent[0]*0.01)
kwargs.setdefault('frameon', False)
kwargs.setdefault('label_top', True)
kwargs.setdefault('fill_bar', True)
# Create scale bar:
scalebar = AnchoredSizeBar(transform, bar_length, label, loc, **kwargs)
scalebar.txt_label._text._color = 'w' # Overwrite AnchoredSizeBar color!
# Set stroke patheffect:
effect_txt = [patheffects.withStroke(linewidth=2, foreground='k')]
scalebar.txt_label._text.set_path_effects(effect_txt)
effect_bar = [patheffects.withStroke(linewidth=3, foreground='k')]
scalebar.size_bar._children[0].set_path_effects(effect_bar)
# Add scale bar to axis and return:
axis.add_artist(scalebar)
return scalebar
def colorwheel(axis=None, cmap=None, ax_size='20%', loc='upper right', **kwargs):
"""Add a colorwheel to the axis on the upper right corner.
Parameters
----------
axis : :class:`~matplotlib.axes.Axes`, optional
Axis to which the colorwheel is added, by default None, which will pick the last used axis via `gca`.
cmap : str or `matplotlib.colors.Colormap`, optional
The Colormap that should be used for the colorwheel, defaults to `None`, which chooses the
`.colors.cmaps.cyclic_cubehelix` colormap. Needs to be a :class:`~.colors.Colormap3D` to work correctly.
ax_size : str or float, optional
String or float determining the size of the inset axis used, defaults to `20%`.
loc : str or pair of floats, optional
The location of the colorwheel, defaults to 'upper right'. See `matplotlib.legend` for possible settings.
Returns
-------
axis : :class:`~matplotlib.image.AxesImage`
The colorwheel image that was created.
Notes
-----
Additional kwargs are passed to :class:`~.colors.Colormap3D.plot_colorwheel` of the :class:`~.colors.Colormap3D`.
"""
_log.debug('Calling colorwheel')
if axis is None: # If no axis is set, find the current or create a new one:
axis = plt.gca()
ins_axes = inset_axes(axis, width=ax_size, height=ax_size, loc=loc)
ins_axes.axis('off')
if cmap is None:
cmap = colors.cmaps.cyclic_cubehelix
plt.sca(axis) # Set focus back to parent axis!
return cmap.plot_colorwheel(axis=ins_axes, **kwargs)
def annotate(label, axis=None, loc='upper left'):
"""Add an annotation to the axis on the upper left corner.
Parameters
----------
label : string
The text of the annotation.
axis : :class:`~matplotlib.axes.AxesSubplot`, optional
Axis to which the annotation is added, by default None, which will pick the last used axis via `gca`.
loc : str or pair of floats, optional
The location of the annotation, defaults to 'upper left'. See `matplotlib.legend` for possible settings.
Returns
-------
aoffbox : :class:`~matplotlib.offsetbox.AnchoredOffsetbox`
The box containing the annotation.
"""
_log.debug('Calling annotate')
if axis is None: # If no axis is set, find the current or create a new one:
axis = plt.gca()
# Create text:
txt = TextArea(label, textprops={'color': 'w'})
txt.set_clip_box(axis.bbox)
txt._text.set_path_effects([patheffects.withStroke(linewidth=2, foreground='k')])
# Pack into and add AnchoredOffsetBox:
aoffbox = AnchoredOffsetbox(loc=loc, pad=0.5, borderpad=0.1, child=txt, frameon=False)
axis.add_artist(aoffbox)
return aoffbox
def quiverkey(quiv, field, axis=None, unit='', loc='lower right', **kwargs):
"""Add a quiver key to an axis.
Parameters
----------
quiv : Quiver instance
The quiver instance returned by a call to quiver.
field : `Field` or ndarray
The vector data as a `Field` or a numpy array (in the latter case, `vector=True` and `scale=1.0` are assumed).
axis : :class:`~matplotlib.axes.AxesSubplot`, optional
Axis to which the quiverkey is added, by default None, which will pick the last used axis via `gca`.
unit: str, optional
String that determines the unit of the quiverkey, defaults to ''.
loc : str or pair of floats, optional
The location of the quiverkey, defaults to 'lower right'. See `matplotlib.legend` for possible settings.
Returns
-------
qk: Quiverkey
The generated quiverkey.
Notes
-----
Additional kwargs are passed to `matplotlib.pyplot.quiverkey`.
"""
_log.debug('Calling quiverkey')
if axis is None: # If no axis is set, find the current or create a new one:
axis = plt.gca()
if not isinstance(field, Field): # Try to convert input to Field if it is not already one:
field = Field(data=np.asarray(field), scale=1.0, vector=True)
length = field.amp.data.max()
shift = 1 / field.squeeze().dim[1] # equivalent to one pixel distance in axis coords!
label = f'{length:.3g} {unit}'
if loc in ('upper right', 1):
X, Y, labelpos = 0.95-shift, 0.95-shift/4, 'W'
elif loc in ('upper left', 2):
X, Y, labelpos = 0.05+shift, 0.95-shift/4, 'E'
elif loc in ('lower left', 3):
X, Y, labelpos = 0.05+shift, 0.05+shift/4, 'E'
elif loc in ('lower right', 4):
X, Y, labelpos = 0.95-shift, 0.05+shift/4, 'W'
else:
raise ValueError('Quiverkey can only be placed in one of the corners (number 1 - 4 or associated strings)!')
# Set defaults:
kwargs.setdefault('coordinates', 'axes')
kwargs.setdefault('facecolor', 'w')
kwargs.setdefault('edgecolor', 'k')
kwargs.setdefault('labelcolor', 'w')
kwargs.setdefault('linewidth', 1)
kwargs.setdefault('clip_box', axis.bbox)
kwargs.setdefault('clip_on', True)
# Plot:
qk = axis.quiverkey(quiv, X, Y, U=1, label=label, labelpos=labelpos, **kwargs)
qk.text.set_path_effects([patheffects.withStroke(linewidth=2, foreground='k')])
return qk
def coords(axis=None, coords=('x', 'y'), loc='lower left', **kwargs):
"""Add coordinate arrows to an axis.
Parameters
----------
axis : :class:`~matplotlib.axes.AxesSubplot`, optional
Axis to which the coordinates are added, by default None, which will pick the last used axis via `gca`.
coords : tuple or int, optional
Tuple of strings determining the labels, by default ('x', 'y'). Can also be `2` or `3` which expands to
('x', 'y') or ('x', 'y', 'z'). The length of `coords` determines the number of arrows (2 or 3).
loc : str, optional
[description], by default 'lower left'
Returns
-------
ins_axes : :class:`~matplotlib.axes.Axes`
The created inset axes containing the coordinates.
"""
_log.debug('Calling coords')
if axis is None: # If no axis is set, find the current or create a new one:
axis = plt.gca()
ins_ax = inset_axes(axis, width='5%', height='5%', loc=loc, borderpad=2.2)
if coords == 3:
coords = ('x', 'y', 'z')
elif coords == 2:
coords = ('x', 'y')
effects = [patheffects.withStroke(linewidth=2, foreground='k')]
kwargs.setdefault('fc', 'w')
kwargs.setdefault('ec', 'k')
kwargs.setdefault('head_width', 0.6)
kwargs.setdefault('head_length', 0.7)
kwargs.setdefault('linewidth', 1)
kwargs.setdefault('width', 0.2)
if len(coords) == 3:
ins_ax.arrow(x=0.5, y=0.5, dx=-1.05, dy=-0.75, clip_on=False, **kwargs)
ins_ax.arrow(x=0.5, y=0.5, dx=0.96, dy=-0.75, clip_on=False, **kwargs)
ins_ax.arrow(x=0.5, y=0.5, dx=0, dy=1.35, clip_on=False, **kwargs)
ins_ax.annotate(coords[0], xy=(0, 0), xytext=(-1.0, 0.3), path_effects=effects, color='w')
ins_ax.annotate(coords[1], xy=(0, 0), xytext=(1.7, 0.3), path_effects=effects, color='w')
ins_ax.annotate(coords[2], xy=(0, 0), xytext=(0.8, 1.5), path_effects=effects, color='w')
ins_ax.add_artist(Circle((0.5, 0.5), 0.12, fc='w', ec='k', linewidth=1, clip_on=False))
elif len(coords) == 2:
ins_ax.arrow(x=-0.5, y=-0.5, dx=1.5, dy=0, clip_on=False, **kwargs)
ins_ax.arrow(x=-0.5, y=-0.5, dx=0, dy=1.5, clip_on=False, **kwargs)
ins_ax.annotate(coords[0], xy=(0, 0), xytext=(1.3, -0.05), path_effects=effects, color='w')
ins_ax.annotate(coords[1], xy=(0, 0), xytext=(-0.1, 1.1), path_effects=effects, color='w')
ins_ax.add_artist(Circle((-0.5, -0.5), 0.12, fc='w', ec='k', linewidth=1, clip_on=False))
ins_ax.axis('off')
plt.sca(axis)
return coords
def colorbar(im, fig=None, cbar_axis=None, axes=None, position='right', pad=0.02, thickness=0.03, label=None,
constrain_ticklabels=True, ticks=None, ticklabels=None):
"""Creates a colorbar, aligned with figure axes.
Parameters
----------
im : matplotlib object, mappable
Mappable matplotlib object.
fig : matplotlib.figure object, optional
The figure object that contains the matplotlib axes and artists, by default None, which will pick the last used
figure via `gcf`.
axes : matplotlib.axes or list of matplotlib.axes
The axes object(s), where the colorbar is drawn, by default None, which will pick the last used axis via `gca`.
Only provide those axes, which the colorbar should span over.
position : str, optional
The position defines the location of the colorbar. One of 'top', 'bottom', 'left' or 'right' (default).
pad : float, optional
Defines the spacing between the axes and colorbar axis. Is given in figure fraction.
thickness : float, optional
Thickness of the colorbar given in figure fraction.
label : string, optional
Colorbar label, defaults to None.
constrain_ticklabels : bool, optional
Allows to slightly shift the outermost ticklabels, such that they do not exceed the cbar axis, defaults to True.
ticks : list, np.ndarray, optional
List of cbar ticks, defaults to None.
ticklabels : list, np.ndarray, optional
List of cbar ticklabels, defaults to None.
Returns
-------
cbar : :class:`~matplotlib.Colorbar`
The created colorbar.
Notes
-----
Based on a modified snippet by Florian Winkler. Note that this function TURNS OFF constrained layout, therefore it
should be the final command before finishing or saving a figure. The colorbar will be outside the original bounds
of your constructed figure. If you set the size, e.g. with `~empyre.vis.tools.new`, make sure to account for the
additional space by setting the `width_scale` to something smaller than 1 (e.g. 0.9).
"""
_log.debug('Calling colorbar')
assert position in ('left', 'right', 'top', 'bottom'), "position has to be 'left', 'right', 'top' or 'bottom'!"
if fig is None: # If no figure is set, find the current or create a new one:
fig = plt.gcf()
fig.canvas.draw() # Trigger a draw so that a potential constrained_layout is executed once!
fig.set_constrained_layout(False) # we don't want the layout to change after this point!
if axes is None: # If no axis is set, find the current or create a new one:
axes = plt.gca()
if not isinstance(axes, Iterable):
axes = (axes,) # Make sure it is an iterable (e.g. tuple)!
# Save previously active axis for later:
previous_axis = plt.gca()
if cbar_axis is None: # Construct a new cbar_axis:
x_coords, y_coords = [], []
# Find bounds of all individual axes:
for ax in np.ravel(axes): # ravel needed for arrays of axes:
points = ax.get_position().get_points()
x_coords.extend(points[:, 0])
y_coords.extend(points[:, 1])
# Find outer bounds of plotting area:
left, right = min(x_coords), max(x_coords)
bottom, top = min(y_coords), max(y_coords)
# Determine where the colorbar will be placed:
if position == 'right':
bounds = [right+pad, bottom, thickness, top-bottom]
elif position == 'left':
bounds = [left-pad-thickness, bottom, thickness, top-bottom]
if position == 'top':
bounds = [left, top+pad, right-left, thickness]
elif position == 'bottom':
bounds = [left, bottom-pad-thickness, right-left, thickness]
cbar_axis = fig.add_axes(bounds)
# Create the colorbar:
with use_style('empyre-image'):
if position in ('left', 'right'):
cb = plt.colorbar(im, cax=cbar_axis, orientation='vertical')
cb.ax.yaxis.set_ticks_position(position)
cb.ax.yaxis.set_label_position(position)
elif position in ('top', 'bottom'):
cb = plt.colorbar(im, cax=cbar_axis, orientation='horizontal')
cb.ax.xaxis.set_ticks_position(position)
cb.ax.xaxis.set_label_position(position)
# Colorbar label
if label is not None:
cb.set_label(f'{label}')
# Set ticks and ticklabels (if specified):
if ticks:
cb.set_ticks(ticks)
if ticklabels:
cb.set_ticklabels(ticklabels)
# Constrain tick labels (if wanted):
if constrain_ticklabels:
if position == 'top' or position == 'bottom':
t = cb.ax.get_xticklabels()
t[0].set_horizontalalignment('left')
t[-1].set_horizontalalignment('right')
elif position == 'left' or position == 'right':
t = cb.ax.get_yticklabels()
t[0].set_verticalalignment('bottom')
t[-1].set_verticalalignment('top')
# Set focus back from colorbar to previous axis and return colorbar:
plt.sca(previous_axis)
return cb
### MATPLOTLIB STYLESHEET FOR EMPYRE IMAGES
font.family : serif ## default font family (use serifs)
font.serif : cm ## Computer Modern (LaTeX font)
xtick.top : False ## draw ticks on the top side
xtick.bottom : False ## draw ticks on the bottom side
xtick.labeltop : False ## draw label on the top
xtick.labelbottom : False ## draw label on the bottom
ytick.left : False ## draw ticks on the left side
ytick.right : False ## draw ticks on the right side
ytick.labelleft : False ## draw tick labels on the left side
ytick.labelright : False ## draw tick labels on the right side
figure.figsize : 3.6, 3.6 ## figure size in inches
figure.dpi : 200 ## figure dots per inch
figure.constrained_layout.use : True ## use constrained layout
image.origin : lower ## lower | upper
### MATPLOTLIB STYLESHEET FOR EMPYRE PLOTS
font.family : serif ## default font family (use serifs)
font.serif : cm ## Computer Modern (LaTeX font)
figure.figsize : 3.6, 2.2 ## figure size in inches
figure.dpi : 200 ## figure dots per inch
figure.constrained_layout.use : True ## use constrained layout
### MATPLOTLIB STYLESHEET FOR SAVING EMPYRE IMAGES AND PLOTS
font.family : serif ## default font family (use serifs)
font.serif : cm ## Computer Modern (LaTeX font)
savefig.dpi : 200 ## figure dots per inch or 'figure'
savefig.facecolor : white ## figure facecolor when saving
savefig.edgecolor : white ## figure edgecolor when saving
savefig.format : pdf ## png, ps, pdf, svg
savefig.bbox : tight ## 'tight' or 'standard'.
savefig.pad_inches : 0.01 ## Padding to be used when bbox is set to 'tight'
savefig.jpeg_quality : 95 ## when a jpeg is saved, the default quality parameter.
savefig.directory : ~ ## default directory in savefig dialog box, leave empty to always use cwd
savefig.transparent : False ## controls whether figures are saved with transp. background by default
savefig.orientation : portrait ## Orientation of saved figure
# -*- coding: utf-8 -*-
# Copyright 2019 by Forschungszentrum Juelich GmbH
# Author: J. Caron
#
"""This module provides functions for 2D plots that often wrap functions from `maptlotlib.pyplot`."""
import logging
import warnings
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import TwoSlopeNorm
from PIL import Image
from . import colors
from .tools import use_style
from ..fields.field import Field
__all__ = ['imshow', 'contour', 'colorvec', 'cosine_contours', 'quiver']
_log = logging.getLogger(__name__)
DIVERGING_CMAPS = ['PiYG', 'PRGn', 'BrBG', 'PuOr', 'RdGy', 'RdBu', 'RdYlBu', 'RdYlGn',
'Spectral', 'coolwarm', 'bwr', 'seismic', # all divergent maps from matplotlib!
'balance', 'delta', 'curl', 'diff', 'tarn'] # all divergent maps from cmocean!
# TODO: add seaborn and more?
def imshow(field, axis=None, cmap=None, **kwargs):
"""Display an image on a 2D regular raster. Wrapper for `matplotlib.pyplot.imshow`.
Parameters
----------
field : `Field` or ndarray
The image data as a `Field` or a numpy array (in the latter case, `vector=False` and `scale=1.0` are assumed).
axis : `matplotlib.axes.Axes` object, optional
The axis to which the image should be added, by default None, which will pick the last use axis via `gca`.
cmap : str or `matplotlib.colors.Colormap`, optional
The Colormap that should be used for the display, either as a string or object, by default None, which will pick
`cmocean.cm.balance` if available. `imshow` will automatically detect if a divergent colormap is used and will
make sure that zero is pinned to the symmetry point of the colormap (this is done by creating a new colormap
with custom range under the hood).
Returns
-------
axis : `matplotlib.axes.Axes`
The plotting axis.
Notes
-----
Additional kwargs are passed to :meth:`~matplotlib.pyplot.imshow`.
Note that the y-axis of the plot is flipped in comparison to :meth:`~matplotlib.pyplot.imshow`, i.e. that the
origin is `'lower'` in this case instead of `'upper'`.
Uses the `empyre-image` stylesheet settings for plotting (and axis creation if none exists, yet).
Fields are squeezed before plotting, so non-2D fields work as long as their superfluous dimensions have length 1.
"""
_log.debug('Calling imshow')
if not isinstance(field, Field): # Try to convert input to Field if it is not already one:
field = Field(data=np.asarray(field), scale=1.0, vector=False)
assert not field.vector, 'Can only plot scalar fields!'
# Get squeezed data and make sure it's 2D scalar:
squeezed_field = field.squeeze()
assert len(squeezed_field.dim) == 2, 'Cannot plot more than 2 dimensions!'
# Determine colormap and related important properties and flags:
if cmap is None:
try:
import cmocean
cmap = cmocean.cm.balance
except ImportError:
_log.info('cmocean.balance not found, fallback to rRdBu!')
cmap = plt.get_cmap('RdBu_r') # '_r' for reverse!
elif isinstance(cmap, str): # make sure we have a Colormap object (and not a string):
cmap = plt.get_cmap(cmap)
if cmap.name.replace('_r', '') in DIVERGING_CMAPS: # 'replace' also matches reverted cmaps!
kwargs.setdefault('norm', TwoSlopeNorm(0)) # Diverging colormap should have zero at the symmetry point!
# Set extent in data coordinates (left, right, bottom, top) to kwargs (if not set explicitely):
dim_v, dim_u = squeezed_field.dim
s_v, s_u = squeezed_field.scale
kwargs.setdefault('extent', (0, dim_u * s_u, 0, dim_v * s_v))
# Plot with the empyre style context:
with use_style('empyre-image'): # Only works on axes created WITHIN context!
if axis is None: # If no axis is set, find the current or create a new one:
axis = plt.gca()
return axis.imshow(squeezed_field, cmap=cmap, **kwargs)
def contour(field, axis=None, **kwargs):
"""Plot contours. Wrapper for `matplotlib.pyplot.contour`.
Parameters
----------
field : `Field` or ndarray
The contour data as a `Field` or a numpy array (in the latter case, `vector=False` and `scale=1.0` are assumed).
axis : `matplotlib.axes.Axes` object, optional
The axis to which the contour should be added, by default None, which will pick the last use axis via `gca`.
Returns
-------
axis : `matplotlib.axes.Axes`
The plotting axis.
Notes
-----
Additional kwargs are passed to `matplotlib.pyplot.contour`.
Note that the y-axis of the plot is flipped in comparison to :meth:`~matplotlib.pyplot.imshow`, i.e. that the
origin is `'lower'` in this case instead of `'upper'`.
Uses the `empyre-image` stylesheet settings for plotting (and axis creation if none exists, yet).
Fields are squeezed before plotting, so non-2D fields work as long as their superfluous dimensions have length 1.
"""
_log.debug('Calling contour')
if not isinstance(field, Field): # Try to convert input to Field if it is not already one:
field = Field(data=np.asarray(field), scale=1.0, vector=False)
assert not field.vector, 'Can only plot scalar fields!'
# Get squeezed data and make sure it's 2D scalar:
squeezed_field = field.squeeze()
assert len(squeezed_field.dim) == 2, 'Cannot plot more than 2 dimensions!'
# Create coordinates (respecting the field scale, +0.5: pixel center!):
vv, uu = (np.indices(squeezed_field.dim) + 0.5) * np.asarray(squeezed_field.scale)[:, None, None]
# Set kwargs defaults without overriding possible user input:
kwargs.setdefault('levels', [0.5])
kwargs.setdefault('colors', 'k')
kwargs.setdefault('linestyles', 'dotted')
kwargs.setdefault('linewidths', 2)
# Plot with the empyre style context:
with use_style('empyre-image'): # Only works on axes created WITHIN context!
if axis is None: # If no axis is set, find the current or create a new one:
axis = plt.gca()
axis.set_aspect('equal')
return axis.contour(uu, vv, squeezed_field.data, **kwargs)
def colorvec(field, axis=None, **kwargs):
"""Plot an image of a 2D vector field with up to 3 components by color encoding the vector direction.
In-plane directions are encoded via hue ("color wheel"), making sure that all in-plane directions are isoluminant
(i.e. a greyscale image would result a homogeneously medium grey image). Out-of-plane directions are encoded via
brightness with upwards pointing vectors being white and downward pointing vectors being black. The length of the
vectors are encoded via saturation, with full saturation being fully chromatic (in-plane) or fully white/black
(up/down). The center of the "color sphere" desaturated in a medium gray and encodes vectors with length zero.
Parameters
----------
field : `Field` or ndarray
The image data as a `Field` or a numpy array (in the latter case, `vector=True` and `scale=1.0` are assumed).
axis : `matplotlib.axes.Axes` object, optional
The axis to which the image should be added, by default None, which will pick the last use axis via `gca`.
Returns
-------
axis : `matplotlib.axes.Axes`
The plotting axis.
Notes
-----
Additional kwargs are passed to `matplotlib.pyplot.imshow`.
Note that the y-axis of the plot is flipped in comparison to :meth:`~matplotlib.pyplot.imshow`, i.e. that the
origin is `'lower'` in this case instead of `'upper'`.
Uses the `empyre-image` stylesheet settings for plotting (and axis creation if none exists, yet).
Fields are squeezed before plotting, so non-2D fields work as long as their superfluous dimensions have length 1.
Even though squeezing takes place, `colorvec` "remembers" the original orientation of the slice! This is important
if you want to plot a slice that should not represent the xy-plane. The colors chosen will respect the original
orientation of your slice, e.g. a vortex in the xz-plane will include black and white colors (up/down) if the
`Field` object given as the `field` parameter has `dim=(128, 1, 128)`. If you want to plot a slice of a 3D vector
with 3 components and make use of this functionality, make sure to not use an integer as an index, as that will
drop the dimension BEFORE it is passed to `colorvec`, which will have no way of knowing which dimension was dropped.
Instead, make sure to use a slice of length one (example with `dim=(128, 128, 128)`):
>>> colorvec(field[:, 15, :]) # Wrong! Shape: (128, 128), interpreted as xy-plane!
>>> colorvec(field[:, 15:16, :]) # Right! Shape: (128, 1, 128), passed as 3D to `colorvec`, squeezed internally!
"""
_log.debug('Calling colorvec')
if not isinstance(field, Field): # Try to convert input to Field if it is not already one:
field = Field(data=np.asarray(field), scale=1.0, vector=True)
assert field.vector, 'Can only plot vector fields!'
assert len(field.dim) <= 3, 'Unusable for vector fields with dimension higher than 3!'
# Get squeezed data and make sure it's 2D scalar:
squeezed_field = field.squeeze()
assert len(squeezed_field.dim) == 2, 'Cannot plot more than 2 dimensions!'
# Extract vector components (fill 3rd component with zeros if field.comp is only 2):
comp = squeezed_field.comp
x_comp = comp[0]
y_comp = comp[1]
z_comp = comp[2] if (squeezed_field.ncomp == 3) else np.zeros(squeezed_field.dim)
# Calculate image with color encoded directions:
cmap = kwargs.pop('cmap', None)
if cmap is None:
cmap = colors.cmaps.cyclic_cubehelix
rgb = cmap.rgb_from_vector(np.stack((x_comp, y_comp, z_comp), axis=0))
# Set extent in data coordinates (left, right, bottom, top) to kwargs (if not set explicitely):
dim_v, dim_u = squeezed_field.dim
s_v, s_u = squeezed_field.scale
kwargs.setdefault('extent', (0, dim_u * s_u, 0, dim_v * s_v))
# Plot with the empyre style context:
with use_style('empyre-image'): # Only works on axes created WITHIN context!
if axis is None: # If no axis is set, find the current or create a new one:
axis = plt.gca()
return axis.imshow(Image.fromarray(rgb), **kwargs)
def cosine_contours(field, axis=None, gain='auto', cmap=None, **kwargs):
"""Plots the cosine of the (amplified) field. Wrapper for `matplotlib.pyplot.imshow`.
Parameters
----------
field : `Field` or ndarray
The contour data as a `Field` or a numpy array (in the latter case, `vector=False` and `scale=1.0` are assumed).
axis : `matplotlib.axes.Axes` object, optional
The axis to which the contour should be added, by default None, which will pick the last use axis via `gca`.
gain : float or 'auto', optional
Gain factor with which the `Field` is amplified before taking the cosine, by default 'auto', which calculates a
gain factor that would produce roughly 4 cosine contours.
cmap : str or `matplotlib.colors.Colormap`, optional
The Colormap that should be used for the display, either as a string or object, by default None, which will pick
`colors.cmaps['transparent_black']` that will alternate between regions with alpha=0, showing layers below and
black contours.
Returns
-------
axis : `matplotlib.axes.Axes`
The plotting axis.
Notes
-----
Additional kwargs are passed to `matplotlib.pyplot.imshow`.
Uses the `empyre-image` stylesheet settings for plotting (and axis creation if none exists, yet).
Fields are squeezed before plotting, so non-2D fields work as long as their superfluous dimensions have length 1.
"""
_log.debug('Calling cosine_contours')
if not isinstance(field, Field): # Try to convert input to Field if it is not already one:
field = Field(data=np.asarray(field), scale=1.0, vector=False)
assert not field.vector, 'Can only plot scalar fields!'
# Get squeezed data and make sure it's 2D scalar:
squeezed_field = field.squeeze()
assert len(squeezed_field.dim) == 2, 'Cannot plot more than 2 dimensions (Squeezing did not help)!'
# Determine colormap and related important properties and flags:
if cmap is None:
cmap = colors.cmaps['transparent_black']
# Calculate gain if 'auto' is selected:
if gain == 'auto':
gain = 4 * 2*np.pi / (squeezed_field.amp.data.max() + 1E-30) # 4: roughly 4 contours!
gain = round(gain, -int(np.floor(np.log10(abs(gain))))) # Round to last significant digit!
_log.info(f'Automatically calculated a gain of: {gain}')
# Calculate the contours:
contours = np.cos(gain * squeezed_field) # Range: [-1, 1]
contours += 1 # Shift to positive values
contours /= 2 # Rescale to [0, 1]
# Set extent in data coordinates (left, right, bottom, top) to kwargs (if not set explicitely):
dim_v, dim_u = squeezed_field.dim
s_v, s_u = squeezed_field.scale
kwargs.setdefault('extent', (0, dim_u * s_u, 0, dim_v * s_v))
# Plot with the empyre style context:
with use_style('empyre-image'): # Only works on axes created WITHIN context!
if axis is None: # If no axis is set, find the current or create a new one:
axis = plt.gca()
return axis.imshow(contours, cmap=cmap, **kwargs)
def quiver(field, axis=None, color_angles=False, cmap=None, n_bin='auto', bin_with_mask=True, **kwargs):
"""Plot a 2D field of arrows. Wrapper for `matplotlib.pyplot.imshow`.
Parameters
----------
field : `Field` or ndarray
The vector data as a `Field` or a numpy array (in the latter case, `vector=True` and `scale=1.0` are assumed).
axis : `matplotlib.axes.Axes` object, optional
The axis to which the image should be added, by default None, which will pick the last use axis via `gca`.
color_angles : bool, optional
Switch that turns on color encoding of the arrows, by default False. Encoding works the same as for the
`colorvec` function (see for details). If False, arrows are uniformly colored white with black border. In both
cases, the amplitude is encoded via the transparency of the arrow.
cmap : str or `matplotlib.colors.Colormap`, optional
The Colormap that should be used for the arrows, either as a string or object, by default None. Will only be
used if `color_angles=True`.
n_bin : float or 'auto', optional
Number of entries along each axis over which the average is taken, by default 'auto', which automatically
determines a bin size resulting in roughly 16 arrows along the largest dimension. Usually sensible to leave
this on to not clutter the image with too many arrows (also due to performance). Can be turned off by setting
`n_bin=1`. Uses the `..fields.field.Field.bin` method.
bin_with_mask : bool, optional
If True (default) and if `n_bin>1`, entries of the constructed binned `Field` that averaged over regions that
were outside the `..fields.field.Field.mask` will not be assigned an arrow and stay empty instead. This prevents
errouneous "fade-out" effects of the arrows that would occur even for homogeneous objects.
Returns
-------
quiv : Quiver instance
The quiver instance that was created.
Notes
-----
Additional kwargs are passed to `matplotlib.pyplot.quiver`.
Uses the `empyre-image` stylesheet settings for plotting (and axis creation if none exists, yet).
Fields are squeezed before plotting, so non-2D fields work as long as their superfluous dimensions have length 1.
Even though squeezing takes place, `quiver` "remembers" the original orientation of the slice and which dimensions
were squeezed! See `colorvec` for more information and an example (the same principles apply here, too).
The transparency of the arrows denotes the 3D(!) amplitude, if you see dots in the plot, that means the amplitude
is not zero, but simply out of the current plane!
"""
_log.debug('Calling quiver')
if not isinstance(field, Field): # Try to convert input to Field if it is not already one:
field = Field(data=np.asarray(field), scale=1.0, vector=True)
assert field.vector, 'Can only plot vector fields!'
assert len(field.dim) <= 3, 'Unusable for vector fields with dimension higher than 3!'
if len(field.dim) < field.ncomp:
warnings.warn('Assignment of vector components to dimensions is ambiguous!'
f'`ncomp` ({field.ncomp}) should match `len(dim)` ({len(field.dim)})!'
'If you want to plot a slice of a 3D volume, make sure to use `from:to` notation!')
# Get squeezed data and make sure it's 2D scalar:
squeezed_field = field.squeeze()
assert len(squeezed_field.dim) == 2, 'Cannot plot more than 2 dimensions (Squeezing did not help)!'
# Determine binning size if necessary:
if n_bin == 'auto':
n_bin = int(np.max((1, np.max(squeezed_field.dim) / 16)))
# Save old limits in case binning has to use padding:
u_lim = squeezed_field.dim[1] * squeezed_field.scale[1]
v_lim = squeezed_field.dim[0] * squeezed_field.scale[0]
# Bin if necessary:
if n_bin > 1:
field_mask = squeezed_field.mask # Get mask BEFORE binning!
squeezed_field = squeezed_field.bin(n_bin)
if bin_with_mask: # Excludes regions where in and outside are binned together!
mask = (field_mask.bin(n_bin) == 1)
squeezed_field *= mask
# Extract normalized vector components (fill 3rd component with zeros if field.comp is only 2):
normalised_comp = (squeezed_field / squeezed_field.amp.data.max()).comp
amplitude = squeezed_field.amp.data / squeezed_field.amp.data.max()
x_comp = normalised_comp[0].data
y_comp = normalised_comp[1].data
z_comp = normalised_comp[2].data if (field.ncomp == 3) else np.zeros(squeezed_field.dim)
# Create coordinates (respecting the field scale, +0.5: pixel center!):
vv, uu = (np.indices(squeezed_field.dim) + 0.5) * np.asarray(squeezed_field.scale)[:, None, None]
# Calculate the arrow colors:
if color_angles: # Color angles according to calculated RGB values (only with circular colormaps):
_log.debug('Encoding angles')
if cmap is None:
cmap = colors.cmaps.cyclic_cubehelix
rgb = cmap.rgb_from_vector(np.asarray((x_comp, y_comp, z_comp))) / 255
rgba = np.concatenate((rgb, amplitude[..., None]), axis=-1)
kwargs.setdefault('color', rgba.reshape(-1, 4))
else: # Color amplitude with numeric values, according to cmap, overrides 'color':
_log.debug('Encoding amplitudes')
if cmap is None:
cmap = colors.cmaps['transparent_white']
C = amplitude # Numeric values, used with cmap!
# Check which (if any) indices were squeezed to find out which components are passed to quiver: # TODO: TEST!!!
squeezed_indices = np.flatnonzero(np.asarray(field.dim) == 1)
if not squeezed_indices: # Separate check, because in this case squeezed_indices == []:
u_comp = x_comp
v_comp = y_comp
elif squeezed_indices[0] == 0: # Slice of the xy-plane with z squeezed:
u_comp = x_comp
v_comp = y_comp
elif squeezed_indices[0] == 1: # Slice of the xz-plane with y squeezed:
u_comp = x_comp
v_comp = z_comp
elif squeezed_indices[0] == 2: # Slice of the zy-plane with x squeezed:
u_comp = y_comp
v_comp = z_comp
# Set specific defaults for quiver kwargs:
kwargs.setdefault('edgecolor', colors.cmaps['transparent_black'](amplitude).reshape(-1, 4))
kwargs.setdefault('scale', 1/np.max(squeezed_field.scale))
kwargs.setdefault('width', np.max(squeezed_field.scale))
kwargs.setdefault('clim', (0, 1))
kwargs.setdefault('pivot', 'middle')
kwargs.setdefault('units', 'xy')
kwargs.setdefault('scale_units', 'xy')
kwargs.setdefault('minlength', 0.05)
kwargs.setdefault('headlength', 2)
kwargs.setdefault('headaxislength', 2)
kwargs.setdefault('headwidth', 2)
kwargs.setdefault('minshaft', 2)
kwargs.setdefault('linewidths', 1)
# Plot with the empyre style context:
with use_style('empyre-image'): # Only works on axes created WITHIN context!
if axis is None: # If no axis is set, find the current or create a new one:
axis = plt.gca()
axis.set_xlim(0, u_lim)
axis.set_ylim(0, v_lim)
axis.set_aspect('equal')
if color_angles:
return axis.quiver(uu, vv, np.asarray(u_comp), np.asarray(v_comp), cmap=cmap, **kwargs)
else:
return axis.quiver(uu, vv, np.asarray(u_comp), np.asarray(v_comp), C, cmap=cmap, **kwargs)