diff --git a/.gitignore b/.gitignore index 9f11b755a17d8192c60f61cb17b8902dffbd9f23..8b070234fadec9752fbf9f28f9224568d0bac20f 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ .idea/ +/desktop.ini diff --git a/desktop.ini b/desktop.ini deleted file mode 100644 index de5592493da8142401e946bd4e579631abbedf9d..0000000000000000000000000000000000000000 --- a/desktop.ini +++ /dev/null @@ -1,6 +0,0 @@ -[.ShellClassInfo] -IconResource=C:\Users\Jan\Home\PhD Thesis\Projects\Pyramid\docs\icon.ico,0 -[ViewState] -Mode= -Vid= -FolderType=Generic diff --git a/doc/.static/sphinxdoc.css b/docs/.static/sphinxdoc.css similarity index 100% rename from doc/.static/sphinxdoc.css rename to docs/.static/sphinxdoc.css diff --git a/doc/Makefile b/docs/Makefile similarity index 100% rename from doc/Makefile rename to docs/Makefile diff --git a/doc/Pyramid Logo.png b/docs/Pyramid Logo.png similarity index 100% rename from doc/Pyramid Logo.png rename to docs/Pyramid Logo.png diff --git a/doc/conf.py b/docs/conf.py similarity index 100% rename from doc/conf.py rename to docs/conf.py diff --git a/doc/icon.ico b/docs/icon.ico similarity index 100% rename from doc/icon.ico rename to docs/icon.ico diff --git a/doc/index.rst b/docs/index.rst similarity index 98% rename from doc/index.rst rename to docs/index.rst index 2c6a07f14060d1c6d3c4f5ea7e1907329b5d3e06..ef178cbe4bac41e4018ec8097529a3f2ab290e89 100644 --- a/doc/index.rst +++ b/docs/index.rst @@ -12,8 +12,6 @@ Contents: :maxdepth: 4 pyramid.rst - - Indices and tables diff --git a/doc/make.bat b/docs/make.bat similarity index 100% rename from doc/make.bat rename to docs/make.bat diff --git a/doc/pyramid.rst b/docs/pyramid.rst similarity index 100% rename from doc/pyramid.rst rename to docs/pyramid.rst diff --git a/pyramid/__init__.py b/pyramid/__init__.py index f2e77a770689c4c708a298e634ab8835bc33615f..e655bb04f334db0f59d8feb6c4fe3df79347e1b1 100644 --- a/pyramid/__init__.py +++ b/pyramid/__init__.py @@ -1,87 +1,89 @@ -# -*- coding: utf-8 -*- -# Copyright 2016 by Forschungszentrum Juelich GmbH -# Author: J. Caron -# -"""Package for the creation and reconstruction of magnetic distributions and resulting phase maps. - -Modules -------- -magcreator - Create simple magnetic distributions. -magdata - Class for the storage of magnetization data. -projector - Class for projecting given magnetization distributions. -kernel - Class for the kernel matrix representing one magnetized pixel. -phasemapper - Create magnetic and electric phase maps from magnetization data. -phasemap - Class for the storage of phase data. -analytic - Create phase maps for magnetic distributions with analytic solutions. -dataset - Class for collecting pairs of phase maps and corresponding projectors. -forwardmodel - Class which represents a phase mapping strategy. -costfunction - Class for the evaluation of the cost of a function. -reconstruction - Reconstruct magnetic distributions from given phasemaps. -regularisator - Class to instantiate different regularisation strategies. -ramp - Class which is used to add polynomial ramps to phasemaps. -diagnostics - Class to calculate diagnostics -quaternion - Class which is used for easy rotations in the Projector classes. -colormap - Class which implements a custom direction encoding colormap. - -""" - -from . import analytic -from . import reconstruction -from . import fieldconverter -from . import magcreator -from . import colors -from . import plottools -from . import utils -from .costfunction import * -from .dataset import * -from .diagnostics import * -from .fielddata import * -from .forwardmodel import * -from .kernel import * -from .phasemap import * -from .phasemapper import * -from .projector import * -from .regularisator import * -from .ramp import * -from .quaternion import * -from .file_io import * -from .version import version as __version__ -from .version import hg_revision as __hg_revision__ - -import logging -_log = logging.getLogger(__name__) -_log.info("Starting Pyramid V{} HG{}".format(__version__, __hg_revision__)) -del logging - -__all__ = ['analytic', 'magcreator', 'reconstruction', 'fieldconverter', - 'load_phasemap', 'load_vectordata', 'load_scalardata', 'load_projector', 'load_dataset', - 'colors', 'utils'] -__all__.extend(costfunction.__all__) -__all__.extend(dataset.__all__) -__all__.extend(diagnostics.__all__) -__all__.extend(forwardmodel.__all__) -__all__.extend(kernel.__all__) -__all__.extend(fielddata.__all__) -__all__.extend(phasemap.__all__) -__all__.extend(phasemapper.__all__) -__all__.extend(projector.__all__) -__all__.extend(regularisator.__all__) -__all__.extend(ramp.__all__) -__all__.extend(quaternion.__all__) -__all__.extend(file_io.__all__) +# -*- coding: utf-8 -*- +# Copyright 2016 by Forschungszentrum Juelich GmbH +# Author: J. Caron +# +"""Package for the creation and reconstruction of magnetic distributions and resulting phase maps. + +Modules +------- +magcreator + Create simple magnetic distributions. +magdata + Class for the storage of magnetization data. +projector + Class for projecting given magnetization distributions. +kernel + Class for the kernel matrix representing one magnetized pixel. +phasemapper + Create magnetic and electric phase maps from magnetization data. +phasemap + Class for the storage of phase data. +analytic + Create phase maps for magnetic distributions with analytic solutions. +dataset + Class for collecting pairs of phase maps and corresponding projectors. +forwardmodel + Class which represents a phase mapping strategy. +costfunction + Class for the evaluation of the cost of a function. +reconstruction + Reconstruct magnetic distributions from given phasemaps. +regularisator + Class to instantiate different regularisation strategies. +ramp + Class which is used to add polynomial ramps to phasemaps. +diagnostics + Class to calculate diagnostics +quaternion + Class which is used for easy rotations in the Projector classes. +colormap + Class which implements a custom direction encoding colormap. + +""" + +from . import analytic +from . import reconstruction +from . import fieldconverter +from . import magcreator +from . import colors +from . import plottools +from . import utils +from .costfunction import * +from .dataset import * +from .diagnostics import * +from .fielddata import * +from .forwardmodel import * +from .kernel import * +from .phasemap import * +from .phasemapper import * +from .projector import * +from .regularisator import * +from .ramp import * +from .quaternion import * +from .file_io import * +from .version import version as __version__ +from .version import hg_revision as __hg_revision__ + +import logging +_log = logging.getLogger(__name__) +_log.info("Starting Pyramid V{} HG{}".format(__version__, __hg_revision__)) +del logging + +__all__ = ['analytic', 'magcreator', 'reconstruction', 'fieldconverter', + 'load_phasemap', 'load_vectordata', 'load_scalardata', 'load_projector', 'load_dataset', + 'colors', 'utils'] +__all__.extend(costfunction.__all__) +__all__.extend(dataset.__all__) +__all__.extend(diagnostics.__all__) +__all__.extend(forwardmodel.__all__) +__all__.extend(kernel.__all__) +__all__.extend(fielddata.__all__) +__all__.extend(phasemap.__all__) +__all__.extend(phasemapper.__all__) +__all__.extend(projector.__all__) +__all__.extend(regularisator.__all__) +__all__.extend(ramp.__all__) +__all__.extend(quaternion.__all__) +__all__.extend(file_io.__all__) + +# TODO: Test for different systems! diff --git a/pyramid/analytic.py b/pyramid/analytic.py index 6339f58d26f11ae29e1e40bc72fca7b9bc27a3f3..6e3131353b3887f8c5390f92e39fe8f1f2fe3eac 100644 --- a/pyramid/analytic.py +++ b/pyramid/analytic.py @@ -1,229 +1,229 @@ -# -*- coding: utf-8 -*- -# Copyright 2016 by Forschungszentrum Juelich GmbH -# Author: J. Caron -# -"""Create phase maps for magnetic distributions with analytic solutions. - -This module provides methods for the calculation of the magnetic phase for simple geometries for -which the analytic solutions are known. These can be used for comparison with the phase -calculated by the functions from the :mod:`~pyramid.phasemapper` module. - -""" - -import logging - -import numpy as np -from numpy import pi - -from pyramid.phasemap import PhaseMap - -__all__ = ['phase_mag_slab', 'phase_mag_slab', 'phase_mag_sphere', 'phase_mag_vortex'] -_log = logging.getLogger(__name__) - - -PHI_0 = 2067.83 # magnetic flux in T*nm² - - -def phase_mag_slab(dim, a, phi, center, width, b_0=1): - """Calculate the analytic magnetic phase for a homogeneously magnetized slab. - - Parameters - ---------- - dim : tuple (N=3) - The dimensions of the grid `(z, y, x)`. - a : float - The grid spacing in nm. - phi : float - The azimuthal angle, describing the direction of the magnetization. - center : tuple (N=3) - The center of the slab in pixel coordinates `(z, y, x)`. - width : tuple (N=3) - The width of the slab in pixel coordinates `(z, y, x)`. - b_0 : float, optional - The magnetic induction corresponding to a magnetization `M`\ :sub:`0` in T. - The default is 1. - - Returns - ------- - phasemap : :class:`~numpy.ndarray` (N=2) - The phase as a 2-dimensional array. - - """ - _log.debug('Calling phase_mag_slab') - - # Function for the phase: - def _phi_mag(x, y): - def _F_0(x, y): - A = np.log(x ** 2 + y ** 2 + 1E-30) - B = np.arctan(x / (y + 1E-30)) - return x * A - 2 * x + 2 * y * B - - return coeff * Lz * (- np.cos(phi) * (_F_0(x - x0 - Lx / 2, y - y0 - Ly / 2) - - _F_0(x - x0 + Lx / 2, y - y0 - Ly / 2) - - _F_0(x - x0 - Lx / 2, y - y0 + Ly / 2) + - _F_0(x - x0 + Lx / 2, y - y0 + Ly / 2)) - + np.sin(phi) * (_F_0(y - y0 - Ly / 2, x - x0 - Lx / 2) - - _F_0(y - y0 + Ly / 2, x - x0 - Lx / 2) - - _F_0(y - y0 - Ly / 2, x - x0 + Lx / 2) + - _F_0(y - y0 + Ly / 2, x - x0 + Lx / 2))) - - # Process input parameters: - z_dim, y_dim, x_dim = dim - y0 = a * center[1] # y0, x0 define the center of a pixel, - x0 = a * center[2] # hence: (cellindex + 0.5) * grid spacing - Lz, Ly, Lx = a * width[0], a * width[1], a * width[2] - coeff = - b_0 / (4 * PHI_0) # Minus because of negative z-direction - # Create grid: - x = np.linspace(a / 2, x_dim * a - a / 2, num=x_dim) - y = np.linspace(a / 2, y_dim * a - a / 2, num=y_dim) - xx, yy = np.meshgrid(x, y) - # Return phase: - return PhaseMap(a, _phi_mag(xx, yy)) - - -def phase_mag_disc(dim, a, phi, center, radius, height, b_0=1): - """Calculate the analytic magnetic phase for a homogeneously magnetized disc. - - Parameters - ---------- - dim : tuple (N=3) - The dimensions of the grid `(z, y, x)`. - a : float - The grid spacing in nm. - phi : float - The azimuthal angle, describing the direction of the magnetization. - center : tuple (N=3) - The center of the disc in pixel coordinates `(z, y, x)`. - radius : float - The radius of the disc in pixel coordinates. - height : float - The height of the disc in pixel coordinates. - b_0 : float, optional - The magnetic induction corresponding to a magnetization `M`\ :sub:`0` in T. - The default is 1. - - Returns - ------- - phasemap : :class:`~numpy.ndarray` (N=2) - The phase as a 2-dimensional array. - - """ - _log.debug('Calling phase_mag_disc') - - # Function for the phase: - def _phi_mag(x, y): - r = np.hypot(x - x0, y - y0) - result = coeff * Lz * ((y - y0) * np.cos(phi) - (x - x0) * np.sin(phi)) - result *= np.where(r <= R, 1, (R / (r + 1E-30)) ** 2) - return result - - # Process input parameters: - z_dim, y_dim, x_dim = dim - y0 = a * center[1] - x0 = a * center[2] - Lz = a * height - R = a * radius - coeff = pi * b_0 / (2 * PHI_0) # Minus is gone because of negative z-direction - # Create grid: - x = np.linspace(a / 2, x_dim * a - a / 2, num=x_dim) - y = np.linspace(a / 2, y_dim * a - a / 2, num=y_dim) - xx, yy = np.meshgrid(x, y) - # Return phase: - return PhaseMap(a, _phi_mag(xx, yy)) - - -def phase_mag_sphere(dim, a, phi, center, radius, b_0=1): - """Calculate the analytic magnetic phase for a homogeneously magnetized sphere. - - Parameters - ---------- - dim : tuple (N=3) - The dimensions of the grid `(z, y, x)`. - a : float - The grid spacing in nm. - phi : float - The azimuthal angle, describing the direction of the magnetization. - center : tuple (N=3) - The center of the sphere in pixel coordinates `(z, y, x)`. - radius : float - The radius of the sphere in pixel coordinates. - b_0 : float, optional - The magnetic induction corresponding to a magnetization `M`\ :sub:`0` in T. - The default is 1. - - Returns - ------- - phasemap : :class:`~numpy.ndarray` (N=2) - The phase as a 2-dimensional array. - - """ - _log.debug('Calling phase_mag_sphere') - - # Function for the phase: - def _phi_mag(x, y): - r = np.hypot(x - x0, y - y0) - result = coeff * R ** 3 / (r + 1E-30) ** 2 * ( - (y - y0) * np.cos(phi) - (x - x0) * np.sin(phi)) - result *= np.where(r > R, 1, (1 - (1 - (r / R) ** 2) ** (3. / 2.))) - return result - - # Process input parameters: - z_dim, y_dim, x_dim = dim - y0 = a * center[1] - x0 = a * center[2] - R = a * radius - coeff = 2. / 3. * pi * b_0 / PHI_0 # Minus is gone because of negative z-direction - # Create grid: - x = np.linspace(a / 2, x_dim * a - a / 2, num=x_dim) - y = np.linspace(a / 2, y_dim * a - a / 2, num=y_dim) - xx, yy = np.meshgrid(x, y) - # Return phase: - return PhaseMap(a, _phi_mag(xx, yy)) - - -def phase_mag_vortex(dim, a, center, radius, height, b_0=1): - """Calculate the analytic magnetic phase for a vortex state disc. - - Parameters - ---------- - dim : tuple (N=3) - The dimensions of the grid `(z, y, x)`. - a : float - The grid spacing in nm. - center : tuple (N=3) - The center of the disc in pixel coordinates `(z, y, x)`, which is also the vortex center. - radius : float - The radius of the disc in pixel coordinates. - height : float - The height of the disc in pixel coordinates. - b_0 : float, optional - The magnetic induction corresponding to a magnetization `M`\ :sub:`0` in T. - The default is 1. - - Returns - ------- - phasemap : :class:`~numpy.ndarray` (N=2) - The phase as a 2-dimensional array. - - """ - _log.debug('Calling phase_mag_vortex') - - # Function for the phase: - def _phi_mag(x, y): - r = np.hypot(x - x0, y - y0) - result = coeff * np.where(r <= R, r - R, 0) - return result - - # Process input parameters: - z_dim, y_dim, x_dim = dim - y0 = a * center[1] - x0 = a * center[2] - Lz = a * height - R = a * radius - coeff = - pi * b_0 * Lz / PHI_0 # Minus because of negative z-direction - # Create grid: - x = np.linspace(a / 2, x_dim * a - a / 2, num=x_dim) - y = np.linspace(a / 2, y_dim * a - a / 2, num=y_dim) - xx, yy = np.meshgrid(x, y) - # Return phase: - return PhaseMap(a, _phi_mag(xx, yy)) +# -*- coding: utf-8 -*- +# Copyright 2016 by Forschungszentrum Juelich GmbH +# Author: J. Caron +# +"""Create phase maps for magnetic distributions with analytic solutions. + +This module provides methods for the calculation of the magnetic phase for simple geometries for +which the analytic solutions are known. These can be used for comparison with the phase +calculated by the functions from the :mod:`~pyramid.phasemapper` module. + +""" + +import logging + +import numpy as np +from numpy import pi + +from pyramid.phasemap import PhaseMap + +__all__ = ['phase_mag_slab', 'phase_mag_slab', 'phase_mag_sphere', 'phase_mag_vortex'] +_log = logging.getLogger(__name__) + + +PHI_0 = 2067.83 # magnetic flux in T*nm² + + +def phase_mag_slab(dim, a, phi, center, width, b_0=1): + """Calculate the analytic magnetic phase for a homogeneously magnetized slab. + + Parameters + ---------- + dim : tuple (N=3) + The dimensions of the grid `(z, y, x)`. + a : float + The grid spacing in nm. + phi : float + The azimuthal angle, describing the direction of the magnetization. + center : tuple (N=3) + The center of the slab in pixel coordinates `(z, y, x)`. + width : tuple (N=3) + The width of the slab in pixel coordinates `(z, y, x)`. + b_0 : float, optional + The magnetic induction corresponding to a magnetization `M`\ :sub:`0` in T. + The default is 1. + + Returns + ------- + phasemap : :class:`~numpy.ndarray` (N=2) + The phase as a 2-dimensional array. + + """ + _log.debug('Calling phase_mag_slab') + + # Function for the phase: + def _phi_mag(x, y): + def _F_0(x, y): + A = np.log(x ** 2 + y ** 2 + 1E-30) + B = np.arctan(x / (y + 1E-30)) + return x * A - 2 * x + 2 * y * B + + return coeff * Lz * (- np.cos(phi) * (_F_0(x - x0 - Lx / 2, y - y0 - Ly / 2) - + _F_0(x - x0 + Lx / 2, y - y0 - Ly / 2) - + _F_0(x - x0 - Lx / 2, y - y0 + Ly / 2) + + _F_0(x - x0 + Lx / 2, y - y0 + Ly / 2)) + + np.sin(phi) * (_F_0(y - y0 - Ly / 2, x - x0 - Lx / 2) - + _F_0(y - y0 + Ly / 2, x - x0 - Lx / 2) - + _F_0(y - y0 - Ly / 2, x - x0 + Lx / 2) + + _F_0(y - y0 + Ly / 2, x - x0 + Lx / 2))) + + # Process input parameters: + z_dim, y_dim, x_dim = dim + y0 = a * center[1] # y0, x0 define the center of a pixel, + x0 = a * center[2] # hence: (cellindex + 0.5) * grid spacing + Lz, Ly, Lx = a * width[0], a * width[1], a * width[2] + coeff = - b_0 / (4 * PHI_0) # Minus because of negative z-direction + # Create grid: + x = np.linspace(a / 2, x_dim * a - a / 2, num=x_dim) + y = np.linspace(a / 2, y_dim * a - a / 2, num=y_dim) + xx, yy = np.meshgrid(x, y) + # Return phase: + return PhaseMap(a, _phi_mag(xx, yy)) + + +def phase_mag_disc(dim, a, phi, center, radius, height, b_0=1): + """Calculate the analytic magnetic phase for a homogeneously magnetized disc. + + Parameters + ---------- + dim : tuple (N=3) + The dimensions of the grid `(z, y, x)`. + a : float + The grid spacing in nm. + phi : float + The azimuthal angle, describing the direction of the magnetization. + center : tuple (N=3) + The center of the disc in pixel coordinates `(z, y, x)`. + radius : float + The radius of the disc in pixel coordinates. + height : float + The height of the disc in pixel coordinates. + b_0 : float, optional + The magnetic induction corresponding to a magnetization `M`\ :sub:`0` in T. + The default is 1. + + Returns + ------- + phasemap : :class:`~numpy.ndarray` (N=2) + The phase as a 2-dimensional array. + + """ + _log.debug('Calling phase_mag_disc') + + # Function for the phase: + def _phi_mag(x, y): + r = np.hypot(x - x0, y - y0) + result = coeff * Lz * ((y - y0) * np.cos(phi) - (x - x0) * np.sin(phi)) + result *= np.where(r <= R, 1, (R / (r + 1E-30)) ** 2) + return result + + # Process input parameters: + z_dim, y_dim, x_dim = dim + y0 = a * center[1] + x0 = a * center[2] + Lz = a * height + R = a * radius + coeff = pi * b_0 / (2 * PHI_0) # Minus is gone because of negative z-direction + # Create grid: + x = np.linspace(a / 2, x_dim * a - a / 2, num=x_dim) + y = np.linspace(a / 2, y_dim * a - a / 2, num=y_dim) + xx, yy = np.meshgrid(x, y) + # Return phase: + return PhaseMap(a, _phi_mag(xx, yy)) + + +def phase_mag_sphere(dim, a, phi, center, radius, b_0=1): + """Calculate the analytic magnetic phase for a homogeneously magnetized sphere. + + Parameters + ---------- + dim : tuple (N=3) + The dimensions of the grid `(z, y, x)`. + a : float + The grid spacing in nm. + phi : float + The azimuthal angle, describing the direction of the magnetization. + center : tuple (N=3) + The center of the sphere in pixel coordinates `(z, y, x)`. + radius : float + The radius of the sphere in pixel coordinates. + b_0 : float, optional + The magnetic induction corresponding to a magnetization `M`\ :sub:`0` in T. + The default is 1. + + Returns + ------- + phasemap : :class:`~numpy.ndarray` (N=2) + The phase as a 2-dimensional array. + + """ + _log.debug('Calling phase_mag_sphere') + + # Function for the phase: + def _phi_mag(x, y): + r = np.hypot(x - x0, y - y0) + result = coeff * R ** 3 / (r + 1E-30) ** 2 * ( + (y - y0) * np.cos(phi) - (x - x0) * np.sin(phi)) + result *= np.where(r > R, 1, (1 - (1 - (r / R) ** 2) ** (3. / 2.))) + return result + + # Process input parameters: + z_dim, y_dim, x_dim = dim + y0 = a * center[1] + x0 = a * center[2] + R = a * radius + coeff = 2. / 3. * pi * b_0 / PHI_0 # Minus is gone because of negative z-direction + # Create grid: + x = np.linspace(a / 2, x_dim * a - a / 2, num=x_dim) + y = np.linspace(a / 2, y_dim * a - a / 2, num=y_dim) + xx, yy = np.meshgrid(x, y) + # Return phase: + return PhaseMap(a, _phi_mag(xx, yy)) + + +def phase_mag_vortex(dim, a, center, radius, height, b_0=1): + """Calculate the analytic magnetic phase for a vortex state disc. + + Parameters + ---------- + dim : tuple (N=3) + The dimensions of the grid `(z, y, x)`. + a : float + The grid spacing in nm. + center : tuple (N=3) + The center of the disc in pixel coordinates `(z, y, x)`, which is also the vortex center. + radius : float + The radius of the disc in pixel coordinates. + height : float + The height of the disc in pixel coordinates. + b_0 : float, optional + The magnetic induction corresponding to a magnetization `M`\ :sub:`0` in T. + The default is 1. + + Returns + ------- + phasemap : :class:`~numpy.ndarray` (N=2) + The phase as a 2-dimensional array. + + """ + _log.debug('Calling phase_mag_vortex') + + # Function for the phase: + def _phi_mag(x, y): + r = np.hypot(x - x0, y - y0) + result = coeff * np.where(r <= R, r - R, 0) + return result + + # Process input parameters: + z_dim, y_dim, x_dim = dim + y0 = a * center[1] + x0 = a * center[2] + Lz = a * height + R = a * radius + coeff = - pi * b_0 * Lz / PHI_0 # Minus because of negative z-direction + # Create grid: + x = np.linspace(a / 2, x_dim * a - a / 2, num=x_dim) + y = np.linspace(a / 2, y_dim * a - a / 2, num=y_dim) + xx, yy = np.meshgrid(x, y) + # Return phase: + return PhaseMap(a, _phi_mag(xx, yy)) diff --git a/pyramid/colors.py b/pyramid/colors.py index 92afbfcd9d217992a8164785cddd19f13f859e93..6ff3918c5c70c03d2d14248bb6407a0ab992fa0f 100644 --- a/pyramid/colors.py +++ b/pyramid/colors.py @@ -1,1191 +1,1191 @@ -# -*- coding: utf-8 -*- -# Copyright 2014 by Forschungszentrum Juelich GmbH -# Author: J. Caron -# -"""This module provides a number of custom colormaps, which also have capabilities for 3D plotting. -If this is the case, the :class:`~.Colormap3D` colormap class is a parent class. In `cmaps`, a -number of specialised colormaps is available for convenience. If the default for circular colormaps -(used for 3D plotting) should be changed, set it via `CMAP_CMAP_ANGULAR_DEFAULT`. -For general questions about colors see: -http://www.poynton.com/PDFs/GammaFAQ.pdf -http://www.poynton.com/PDFs/ColorFAQ.pdf -""" - -import logging - -import matplotlib.pyplot as plt -from matplotlib.ticker import FuncFormatter as FuncForm -from matplotlib.ticker import MaxNLocator - -from mpl_toolkits.mplot3d import Axes3D -from mpl_toolkits.axes_grid1 import ImageGrid -from matplotlib import gridspec -from matplotlib.patches import Circle - -import numpy as np -from PIL import Image -from matplotlib import colors - -from skimage import color as skcolor - -import colorsys - -import abc - -__all__ = ['Colormap3D', 'ColormapCubehelix', 'ColormapPerception', 'ColormapHLS', - 'ColormapClassic', 'ColormapTransparent', 'cmaps', 'CMAP_CIRCULAR_DEFAULT', - 'ColorspaceCIELab', 'ColorspaceCIELuv', 'ColorspaceCIExyY', 'ColorspaceYPbPr', - 'interpolate_color', 'rgb_to_brightness', 'colormap_brightness_comparison'] -_log = logging.getLogger(__name__) - - -# TODO: DOCSTRINGS!!! - - -class Colormap3D(colors.Colormap, metaclass=abc.ABCMeta): - """Colormap subclass for encoding directions with colors. - - This abstract class is used as a superclass/interface for 3D vector plotting capabilities. - In general, a circular colormap should be used to encode the in-plane angle (hue). The - perpendicular angle is encoded via luminance variation (up: white, down: black). Finally, - the length of a vector is encoded via saturation. Decreasing vector length causes a desaturated - color. Subclassing colormaps get access to routines to plot a colorwheel (which should - ideally be located in the 50% luminance plane, which depends strongly on the underlying map), - a convenience function to interpolate color tuples and a function to return rgb triples for a - given vector. The :class:`~.Colormap3D` class itself subclasses the matplotlib base colormap. - - """ - - _log = logging.getLogger(__name__ + '.Colormap3D') - - def rgb_from_vector(self, vector): - """Construct a hls tuple from three coordinates representing a 3D direction. - - Parameters - ---------- - vector: tuple (N=3) or :class:`~numpy.ndarray` - Vector containing the x, y and z component, or a numpy array encompassing the - components as three lists.z-coordinate of the desired direction to encode. - - Returns - ------- - rgb: :class:`~numpy.ndarray` - Numpy array containing the calculated color tuples. - - """ - self._log.debug('Calling rgb_from_vector') - x, y, z = np.asarray(vector) - # Calculate spherical coordinates: - r = np.sqrt(x ** 2 + y ** 2 + z ** 2) - phi = np.asarray(np.arctan2(y, x)) - phi[phi < 0] += 2 * np.pi - theta = np.arccos(z / (r + 1E-30)) - # Calculate color deterministics: - hue = phi / (2 * np.pi) - lum = 1 - theta / np.pi - sat = r / r.max() - # Calculate RGB from hue with colormap: - rgba = np.asarray(self(hue)) - r, g, b = rgba[..., 0], rgba[..., 1], rgba[..., 2] - # Interpolate saturation: - r, g, b = interpolate_color(sat, (0.5, 0.5, 0.5), np.stack((r, g, b), axis=-1)) - # Interpolate luminance: - lum_target = np.where(lum < 0.5, 0, 1) - lum_target = np.stack([lum_target] * 3, axis=-1) - fraction = np.where(lum < 0.5, 1 - 2 * lum, 2 * (lum - 0.5)) - r, g, b = interpolate_color(fraction, np.stack((r, g, b), axis=-1), lum_target) - # Return RGB: - return np.asarray(255 * np.stack((r, g, b), axis=-1), dtype=np.uint8) - - def make_colorwheel(self, size=256, alpha=1): - self._log.debug('Calling make_colorwheel') - # Construct the colorwheel: - yy, xx = (np.indices((size, size)) - size/2 + 0.5) - rr = np.hypot(xx, yy) - xx = np.where(rr <= size/2-2, xx, 0) - yy = np.where(rr <= size/2-2, yy, 0) - zz = np.where(rr <= size/2-2, 0, -1) # color inside, black outside - aa = np.where(rr >= size/2-2, 255*alpha, 255).astype(dtype=np.uint8) - rgba = np.dstack((self.rgb_from_vector(np.asarray((xx, yy, zz))), aa)) - # Create color wheel: - return Image.fromarray(rgba) - - def plot_colorwheel(self, axis=None, size=512, alpha=1, arrows=False, figsize=(4, 4), - **kwargs): - """Display a color wheel to illustrate the color coding of vector gradient directions. - - Parameters - ---------- - figsize : tuple of floats (N=2) - Size of the plot figure. - - Returns - ------- - None - - """ - self._log.debug('Calling plot_colorwheel') - # Construct the colorwheel: - color_wheel = self.make_colorwheel(size=size, alpha=alpha) - # Plot the color wheel: - if axis is None: - fig = plt.figure(figsize=figsize) - axis = fig.add_subplot(1, 1, 1, aspect='equal') - axis.imshow(color_wheel, origin='lower', **kwargs) - axis.add_artist(Circle(xy=(size/2-0.5, size/2-0.5), radius=size/2-2, linewidth=2, - edgecolor='k', facecolor='none')) - if arrows: - plt.tick_params(axis='both', which='both', labelleft='off', labelbottom='off', - left='off', right='off', top='off', bottom='off') - axis.arrow(size/2, size/2, 0, 0.15*size, head_width=9, head_length=20, - fc='k', ec='k', lw=1, width=2) - axis.arrow(size/2, size/2, 0, -0.15*size, head_width=9, head_length=20, - fc='k', ec='k', lw=1, width=2) - axis.arrow(size/2, size/2, 0.15*size, 0, head_width=9, head_length=20, - fc='k', ec='k', lw=1, width=2) - axis.arrow(size/2, size/2, -0.15*size, 0, head_width=9, head_length=20, - fc='k', ec='k', lw=1, width=2) - # Return axis: - axis.xaxis.set_visible(False) - axis.yaxis.set_visible(False) - return axis - - -class ColormapCubehelix(colors.LinearSegmentedColormap, Colormap3D): - """A full implementation of Dave Green's "cubehelix" for Matplotlib. - - Based on the FORTRAN 77 code provided in D.A. Green, 2011, BASI, 39, 289. - http://adsabs.harvard.edu/abs/2011arXiv1108.5083G - Also see: - http://www.mrao.cam.ac.uk/~dag/CUBEHELIX/ - http://davidjohnstone.net/pages/cubehelix-gradient-picker - User can adjust all parameters of the cubehelix algorithm. This enables much greater - flexibility in choosing color maps. Default color map settings produce the standard cubehelix. - Create color map in only blues by setting rot=0 and start=0. Create reverse (white to black) - backwards through the rainbow once by setting rot=1 and reverse=True, etc. Furthermore, the - algorithm was tuned, so that constant luminance values can be used (e.g. to create a truly - isoluminant colorwheel). The `rot` parameter is also tuned to hold true for these cases. - Of the here presented colorwheels, only this one manages to solely navigate through the L*=50 - plane, which can be seen here: - https://upload.wikimedia.org/wikipedia/commons/2/21/Lab_color_space.png - - Parameters - ---------- - start : scalar, optional - Sets the starting position in the color space. 0=blue, 1=red, - 2=green. Defaults to 0.5. - rot : scalar, optional - The number of rotations through the rainbow. Can be positive - or negative, indicating direction of rainbow. Negative values - correspond to Blue->Red direction. Defaults to -1.5. - gamma : scalar, optional - The gamma correction for intensity. Defaults to 1.0. - reverse : boolean, optional - Set to True to reverse the color map. Will go from black to - white. Good for density plots where shade~density. Defaults to False. - nlev : scalar, optional - Defines the number of discrete levels to render colors at. - Defaults to 256. - sat : scalar, optional - The saturation intensity factor. Defaults to 1.2 - NOTE: this was formerly known as `hue` parameter - minSat : scalar, optional - Sets the minimum-level saturation. Defaults to 1.2. - maxSat : scalar, optional - Sets the maximum-level saturation. Defaults to 1.2. - startHue : scalar, optional - Sets the starting color, ranging from [0, 360], as in - D3 version by @mbostock. - NOTE: overrides values in start parameter. - endHue : scalar, optional - Sets the ending color, ranging from [0, 360], as in - D3 version by @mbostock - NOTE: overrides values in rot parameter. - minLight : scalar, optional - Sets the minimum lightness value. Defaults to 0. - maxLight : scalar, optional - Sets the maximum lightness value. Defaults to 1. - - Returns - ------- - matplotlib.colors.LinearSegmentedColormap object - - Revisions - --------- - 2014-04 (@jradavenport) Ported from IDL version - 2014-04 (@jradavenport) Added kwargs to enable similar to D3 version, - changed name of `hue` parameter to `sat`. - 2016-11 (@jan.caron) Added support for isoluminant cubehelices while making sure - `rot` works as intended. Decoded the plane-vectors a bit. - """ - - _log = logging.getLogger(__name__ + '.ColormapCubehelix') - - def __init__(self, start=0.5, rot=-1.5, gamma=1.0, reverse=False, nlev=256, - minSat=1.2, maxSat=1.2, minLight=0., maxLight=1., **kwargs): - self._log.debug('Calling __init__') - # Override start and rot if startHue and endHue are set: - if kwargs is not None: - if 'startHue' in kwargs: - start = (kwargs.get('startHue') / 360. - 1.) * 3. - if 'endHue' in kwargs: - rot = kwargs.get('endHue') / 360. - start / 3. - 1. - if 'sat' in kwargs: - minSat = kwargs.get('sat') - maxSat = kwargs.get('sat') - self.nlev = nlev - # Set up the parameters: - self.fract = np.linspace(minLight, maxLight, nlev) - angle = 2.0 * np.pi * (start / 3.0 + rot * np.linspace(0, 1, nlev)) - self.fract = self.fract**gamma - satar = np.linspace(minSat, maxSat, nlev) - amp = np.asarray(satar * self.fract * (1. - self.fract) / 2) - # Set RGB color coefficients (Luma is calculated in RGB Rec.601, so choose those), - # the original version of Dave green used (0.30, 0.59, 0.11) and REc.709 is - # c709 = (0.2126, 0.7152, 0.0722) but would not produce correct YPbPr Luma. - c601 = (0.299, 0.587, 0.114) - cr, cg, cb = c601 - cw = -0.90649 # Chosen to comply with Dave Greens implementation. - k = -1.6158 / cr / cw # k has to balance out cw so nothing gets out of RGB gamut (> 1). - # Calculate the vectors v and w spanning the plane of constant perceived intensity. - # v and w have to solve v x w = k(cr, cg, cb) (normal vector of the described plane) and - # v * w = 0 (scalar product, v and w have to be perpendicular). - # 6 unknown and 4 equations --> Chose wb = 0 and wg = cw (constant). - v = np.array((k * cr ** 2 * cb / (cw * (cr ** 2 + cg ** 2)), - k * cr * cg * cb / (cw * (cr ** 2 + cg ** 2)), -k * cr / cw)) - w = np.array((-cw * cg / cr, cw, 0)) - # Calculate components: - self.red = self.fract + amp * (v[0] * np.cos(angle) + w[0] * np.sin(angle)) - self.grn = self.fract + amp * (v[1] * np.cos(angle) + w[1] * np.sin(angle)) - self.blu = self.fract + amp * (v[2] * np.cos(angle) + w[2] * np.sin(angle)) - # Original formulas with original v and w: - # self.red = self.fract + amp * (-0.14861 * np.cos(angle) + 1.78277 * np.sin(angle)) - # self.grn = self.fract + amp * (-0.29227 * np.cos(angle) - 0.90649 * np.sin(angle)) - # self.blu = self.fract + amp * (1.97294 * np.cos(angle)) - # Find where RBG are outside the range [0,1], clip: - self.red = np.clip(self.red, 0, 1) - self.grn = np.clip(self.grn, 0, 1) - self.blu = np.clip(self.blu, 0, 1) - # Optional color reverse: - if reverse is True: - self.red = self.red[::-1] - self.blu = self.blu[::-1] - self.grn = self.grn[::-1] - # Put in to tuple & dictionary structures needed: - rr, bb, gg = [], [], [] - for k in range(0, int(nlev)): - rr.append((float(k) / (nlev - 1), self.red[k], self.red[k])) - bb.append((float(k) / (nlev - 1), self.blu[k], self.blu[k])) - gg.append((float(k) / (nlev - 1), self.grn[k], self.grn[k])) - cdict = {'red': rr, 'blue': bb, 'green': gg} - super().__init__('cubehelix', cdict, N=256) - self._log.debug('Created ' + str(self)) - - def plot_helix(self, figsize=(8, 8)): - """Display the RGB and luminance plots for the chosen cubehelix. - - Parameters - ---------- - figsize : tuple of floats (N=2) - Size of the plot figure. - - Returns - ------- - None - - """ - self._log.debug('Calling plot_helix') - plt.figure(figsize=figsize) - gs = gridspec.GridSpec(2, 1, height_ratios=[8, 1]) - # Main plot: - axis = plt.subplot(gs[0]) - axis.plot(self.fract, 'k', linewidth=2) - axis.plot(self.red, 'r', linewidth=2) - axis.plot(self.grn, 'g', linewidth=2) - axis.plot(self.blu, 'b', linewidth=2) - axis.set_xlim(0, self.nlev) - axis.set_ylim(0, 1) - axis.set_title('Cubehelix', fontsize=18) - axis.set_xlabel('color index', fontsize=15) - axis.set_ylabel('lightness / rgb', fontsize=15) - # Colorbar horizontal: - caxis = plt.subplot(gs[1], sharex=axis) - rgb = self(np.linspace(0, 1, 256))[None, ...] - rgb = np.asarray(255.9999 * rgb, dtype=np.uint8) - rgb = np.repeat(rgb, 30, axis=0) - im = Image.fromarray(rgb) - caxis.imshow(im) - plt.tick_params(axis='both', which='both', labelleft='off', labelbottom='off', - left='off', right='off', top='on', bottom='on') - - -class ColormapPerception(colors.LinearSegmentedColormap, Colormap3D): - """A perceptual colormap based on face-based luminance matching. - - Based on a publication by Kindlmann et. al. - http://www.cs.utah.edu/~gk/papers/vis02/FaceLumin.pdf - This colormap tries to achieve an isoluminant perception by using a list of colors acquired - through face recognition studies. It is a lot better than the HLS colormap, but still not - completely isoluminant (despite its name). Also it appears a bit dark. - - """ - - _log = logging.getLogger(__name__ + '.ColormapPerception') - - CDICT = {'red': [(0/6, 0.847, 0.847), - (1/6, 0.527, 0.527), - (2/6, 0.000, 0.000), - (3/6, 0.000, 0.000), - (4/6, 0.316, 0.316), - (5/6, 0.718, 0.718), - (6/6, 0.847, 0.847)], - - 'green': [(0/6, 0.057, 0.057), - (1/6, 0.527, 0.527), - (2/6, 0.592, 0.592), - (3/6, 0.559, 0.559), - (4/6, 0.316, 0.316), - (5/6, 0.000, 0.000), - (6/6, 0.057, 0.057)], - - 'blue': [(0/6, 0.057, 0.057), - (1/6, 0.000, 0.000), - (2/6, 0.000, 0.000), - (3/6, 0.559, 0.559), - (4/6, 0.991, 0.991), - (5/6, 0.718, 0.718), - (6/6, 0.057, 0.057)]} - - def __init__(self): - self._log.debug('Calling __init__') - super().__init__('perception', self.CDICT, N=256) - self._log.debug('Created ' + str(self)) - - -class ColormapHLS(colors.ListedColormap, Colormap3D): - """Colormap subclass for encoding directions with colors. - - This class is a subclass of the :class:`~matplotlib.pyplot.colors.ListedColormap` - class. The class follows the HSL ('hue', 'saturation', 'lightness') 'Double Hexcone' Model - with the saturation always set to 1 (moving on the surface of the color - cylinder) with a lightness of 0.5 (full color). The three prime colors (`rgb`) are spaced - equidistant with 120° space in between, according to a triadic arrangement. - Even though the lightness is constant in the plane, the luminance (which is a weighted sum - of the RGB components which encompasses human perception) is not, which can lead to - artifacts like reliefs. Converting the map to a grayscale show spokes at the secondary colors. - For more information see: - https://vis4.net/blog/posts/avoid-equidistant-hsv-colors/ - http://www.workwithcolor.com/color-luminance-2233.htm - http://blog.asmartbear.com/color-wheels.html - - """ - - _log = logging.getLogger(__name__ + '.ColormapHLS') - - def __init__(self): - self._log.debug('Calling __init__') - h = np.linspace(0, 1, 256) - l = 0.5 * np.ones_like(h) - s = np.ones_like(h) - r, g, b = np.vectorize(colorsys.hls_to_rgb)(h, l, s) - colors = [(r[i], g[i], b[i]) for i in range(len(r))] - super().__init__(colors, 'hls', N=256) - self._log.debug('Created ' + str(self)) - - -class ColormapClassic(colors.LinearSegmentedColormap, Colormap3D): - """Colormap subclass for encoding directions with colors. - - This class is a subclass of the :class:`~matplotlib.pyplot.colors.LinearSegmentedColormap` - class. The class follows the HSL ('hue', 'saturation', 'lightness') 'Double - Hexcone' Model with the saturation always set to 1 (moving on the surface of the color - cylinder) with a luminance of 0.5 (full color). The colors follow a tetradic arrangement with - four colors (red, green, blue and yellow) arranged with 90° spacing in between. - - """ - - _log = logging.getLogger(__name__ + '.ColormapClassic') - - CDICT = {'red': [(0.00, 1.0, 1.0), - (0.25, 0.0, 0.0), - (0.50, 0.0, 0.0), - (0.75, 1.0, 1.0), - (1.00, 1.0, 1.0)], - - 'green': [(0.00, 0.0, 0.0), - (0.25, 0.0, 0.0), - (0.50, 1.0, 1.0), - (0.75, 1.0, 1.0), - (1.00, 0.0, 0.0)], - - 'blue': [(0.00, 0.0, 0.0), - (0.25, 1.0, 1.0), - (0.50, 0.0, 0.0), - (0.75, 0.0, 0.0), - (1.00, 0.0, 0.0)]} - - def __init__(self): - self._log.debug('Calling __init__') - super().__init__('classic', self.CDICT, N=256) - self._log.debug('Created ' + str(self)) - - -class ColormapTransparent(colors.LinearSegmentedColormap): - """Colormap subclass for including transparency. - - This class is a subclass of the :class:`~matplotlib.pyplot.colors.LinearSegmentedColormap` - class with integrated support for transparency. The colormap is unicolor and varies only in - transparency. - - Attributes - ---------- - r: float, optional - Intensity of red in the colormap. Has to be between 0. and 1. - g: float, optional - Intensity of green in the colormap. Has to be between 0. and 1. - b: float, optional - Intensity of blue in the colormap. Has to be between 0. and 1. - alpha_range : list (N=2) of float, optional - Start and end alpha value. Has to be between 0. and 1. - - """ - - _log = logging.getLogger(__name__ + '.ColormapTransparent') - - def __init__(self, r=0., g=0., b=0., alpha_range=None): - self._log.debug('Calling __init__') - if alpha_range is None: - alpha_range = [0., 1.] - red = [(0., 0., r), (1., r, 1.)] - green = [(0., 0., g), (1., g, 1.)] - blue = [(0., 0., b), (1., b, 1.)] - alpha = [(0., 0., alpha_range[0]), (1., alpha_range[1], 1.)] - cdict = {'red': red, 'green': green, 'blue': blue, 'alpha': alpha} - super().__init__('transparent', cdict, N=256) - self._log.debug('Created ' + str(self)) - - -class ColorspaceCIELab(object): # TODO: Superclass? - """Class representing the CIELab colorspace.""" - - _log = logging.getLogger(__name__ + '.ColorspaceCIELab') - - def __init__(self, dim=(500, 500), extent=(-100, 100, -100, 100), cut_gamut=False, clip=True): - self._log.debug('Calling __init__') - self.dim = dim - self.extent = extent - self.cut_out_gamut = cut_gamut - self.clip = clip - self._log.debug('Created ' + str(self)) - - def plot(self, L=53.4, axis=None, figsize=(8, 8)): - self._log.debug('Calling plot') - dim, ext = self.dim, self.extent - # Create Lab colorspace: - a = np.linspace(ext[0], ext[1], dim[1]) - b = np.linspace(ext[2], ext[3], dim[0]) - aa, bb = np.meshgrid(a, b) - LL = np.full(dim, L, dtype=int) - Lab = np.stack((LL, aa, bb), axis=-1) - # Convert to XYZ colorspace: - XYZ = skcolor.lab2xyz(Lab) - # Convert to RGB colorspace following algorithm from http://www.easyrgb.com/index.php: - rgb = skcolor.colorconv._convert(skcolor.colorconv.rgb_from_xyz, XYZ) - # Gamma correction (gamma encoding) rgb are now nonlinear (R'G'B')!: - mask = rgb > 0.0031308 - rgb[mask] = 1.055 * np.power(rgb[mask], 1 / 2.4) - 0.055 - rgb[~mask] *= 12.92 - # Determine gamut: - gamut = np.logical_or(rgb < 0, rgb > 1) - gamut = np.sum(gamut, axis=-1, dtype=bool) - gamut_mask = np.stack((gamut, gamut, gamut), axis=-1) - # Cut out gamut (set out of bound colors to gray) if necessary: - if self.cut_out_gamut: - rgb[gamut_mask] = 0.5 - # Clip out of gamut colors: - if self.clip: - rgb[rgb < 0] = 0 - rgb[rgb > 1] = 1 - # Plot colorspace: - if axis is None: - fig = plt.figure(figsize=figsize) - axis = fig.add_subplot(1, 1, 1, aspect='equal') - axis.imshow(rgb, origin='lower', interpolation='none', extent=(0, dim[0], 0, dim[1])) - axis.contour(gamut, levels=[0], colors='k', linewidths=1.5) - axis.set_xlabel('a', fontsize=15) - axis.set_ylabel('b', fontsize=15) - axis.set_title('CIELab (L = {:g})'.format(L), fontsize=18) - fx = FuncForm(lambda x, pos: '{:.3g}'.format(x / dim[1] * (ext[1] - ext[0]) + ext[0])) - axis.xaxis.set_major_formatter(fx) - fy = FuncForm(lambda y, pos: '{:.3g}'.format(y / dim[0] * (ext[3] - ext[2]) + ext[2])) - axis.yaxis.set_major_formatter(fy) - axis.xaxis.set_major_locator(MaxNLocator(nbins=12, integer=True)) - axis.yaxis.set_major_locator(MaxNLocator(nbins=12, integer=True)) - - def plot_colormap(self, cmap, N=256, L='auto', figsize=None, cbar_lim=None, brightness=True, - input_rec=None): - self._log.debug('Calling plot_colormap') - dim, ext = self.dim, self.extent - # Calculate rgb values: - rgb = cmap(np.linspace(0, 1, N))[None, :, :3] # These are R'G'B' values! - if input_rec == 601: - rgb = RGBConverter('Rec601', 'Rec709')(rgb) - # Convert to Lab space: - Lab = np.squeeze(skcolor.rgb2lab(rgb)) - LL, aa, bb = Lab.T - aa = (aa - ext[0]) / (ext[1] - ext[0]) * dim[1] - bb = (bb - ext[2]) / (ext[3] - ext[2]) * dim[0] - # Determine number of images / luma levels: - LL_min, LL_max = np.round(np.min(LL), 1), np.round(np.max(LL), 1) - if L == 'auto': - if LL_max - LL_min < 0.1: # Just one image: - L = LL_min - else: # Two images: - L = np.asarray((LL_max, np.mean(LL), LL_min)) - L_list = np.atleast_1d(L) - # Determine colorbar limits: - if cbar_lim is not None: # Overwrite limits! - LL_min, LL_max = cbar_lim - elif not brightness or LL_max - LL_min < 0.1: # Just one value, full range for colormap: - LL_min, LL_max = 0, 1 - # Creat grid: - if figsize is None: - figsize = (len(L_list) * 5 + 2, 7) - fig = plt.figure(figsize=figsize) - grid = ImageGrid(fig, 111, nrows_ncols=(1, len(L_list)), axes_pad=0.4, share_all=False, - cbar_location="right", cbar_mode="single", cbar_size="5%", cbar_pad=0.25) - # Plot: - if brightness: - c = LL - cmap = 'gray' - else: - c = np.linspace(0, 1, N) - for i, axis in enumerate(grid): - self.plot(L=L_list[i], axis=axis) - im = axis.scatter(aa, bb, c=c, cmap=cmap, edgecolors='none', - vmin=LL_min, vmax=LL_max) - axis.set_xlim(0, self.dim[1]) - axis.set_ylim(0, self.dim[0]) - axis.cax.colorbar(im, ticks=np.linspace(LL_min, LL_max, 9)) - - def plot3d(self, N=9): - self._log.debug('Calling plot3d') - dim, ext = self.dim, self.extent - # Create Lab colorspace: - a = np.linspace(ext[0], ext[1], dim[1]) - b = np.linspace(ext[2], ext[3], dim[0]) - aa, bb = np.meshgrid(a, b) - import visvis # TODO: If VisPy is ever ready, switch every plot to that! - for i in range(1, N): - LL = np.full(dim, i * 100 / N, dtype=int) - Lab = np.stack((LL, aa, bb), axis=-1) - # Convert to XYZ colorspace: - XYZ = skcolor.lab2xyz(Lab) - # Convert to RGB colorspace following algorithm from http://www.easyrgb.com/index.php: - rgb = skcolor.colorconv._convert(skcolor.colorconv.rgb_from_xyz, XYZ) - # Gamma correction (gamma encoding) rgb are now nonlinear (R'G'B')!: - mask = rgb > 0.0031308 - rgb[mask] = 1.055 * np.power(rgb[mask], 1 / 2.4) - 0.055 - rgb[~mask] *= 12.92 - # Determine gamut: - gamut = np.logical_or(rgb < 0, rgb > 1) - gamut = np.sum(gamut, axis=-1, dtype=bool) - # Alpha: - alpha = 1. - a = np.full(dim + (1,), alpha) - a *= np.logical_not(gamut[..., None]) - rgba = np.asarray(255 * np.dstack((rgb, a)), dtype=np.uint8) - # Visvis plot: - obj = visvis.functions.surf(aa, bb, i * 100. / N * np.ones_like(aa), rgba, aa=0) - obj.parent.light0.ambient = 1. - obj.parent.light0.diffuse = 0. - - -class ColorspaceCIELuv(object): - """Class representing the CIELuv colorspace.""" - - _log = logging.getLogger(__name__ + '.ColorspaceCIELuv') - - def __init__(self, dim=(500, 500), extent=(-100, 100, -100, 100), cut_gamut=False, clip=True): - self._log.debug('Calling __init__') - self.dim = dim - self.extent = extent - self.cut_out_gamut = cut_gamut - self.clip = clip - self._log.debug('Created ' + str(self)) - - def plot(self, L=53.4, axis=None, figsize=(8, 8)): - self._log.debug('Calling plot') - dim, ext = self.dim, self.extent - # Create Lab colorspace: - u = np.linspace(ext[0], ext[1], dim[1]) - v = np.linspace(ext[2], ext[3], dim[0]) - uu, vv = np.meshgrid(u, v) - LL = np.full(dim, L, dtype=int) - Luv = np.stack((LL, uu, vv), axis=-1) - # Convert to XYZ colorspace: - XYZ = skcolor.luv2xyz(Luv) - # Convert to RGB colorspace following algorithm from http://www.easyrgb.com/index.php: - rgb = skcolor.colorconv._convert(skcolor.colorconv.rgb_from_xyz, XYZ) - # Gamma correction (gamma encoding) rgb are now nonlinear (R'G'B')!: - mask = rgb > 0.0031308 - rgb[mask] = 1.055 * np.power(rgb[mask], 1 / 2.4) - 0.055 - rgb[~mask] *= 12.92 - # Determine gamut: - gamut = np.logical_or(rgb < 0, rgb > 1) - gamut = np.sum(gamut, axis=-1, dtype=bool) - gamut_mask = np.stack((gamut, gamut, gamut), axis=-1) - # Cut out gamut (set out of bound colors to gray) if necessary: - if self.cut_out_gamut: - rgb[gamut_mask] = 0.5 - # Clip out of gamut colors: - if self.clip: - rgb[rgb < 0] = 0 - rgb[rgb > 1] = 1 - # Plot colorspace: - if axis is None: - fig = plt.figure(figsize=figsize) - axis = fig.add_subplot(1, 1, 1, aspect='equal') - axis.imshow(rgb, origin='lower', interpolation='none', extent=(0, dim[0], 0, dim[1])) - axis.contour(gamut, levels=[0], colors='k', linewidths=1.5) - axis.set_xlabel('u', fontsize=15) - axis.set_ylabel('v', fontsize=15) - axis.set_title('CIELuv (L = {:g})'.format(L), fontsize=18) - fx = FuncForm(lambda x, pos: '{:.3g}'.format(x / dim[1] * (ext[1] - ext[0]) + ext[0])) - axis.xaxis.set_major_formatter(fx) - fy = FuncForm(lambda y, pos: '{:.3g}'.format(y / dim[0] * (ext[3] - ext[2]) + ext[2])) - axis.yaxis.set_major_formatter(fy) - axis.xaxis.set_major_locator(MaxNLocator(nbins=12, integer=True)) - axis.yaxis.set_major_locator(MaxNLocator(nbins=12, integer=True)) - - def plot_colormap(self, cmap, N=256, L='auto', figsize=None, cbar_lim=None, brightness=True, - input_rec=None): - self._log.debug('Calling plot_colormap') - dim, ext = self.dim, self.extent - # Calculate rgb values: - rgb = cmap(np.linspace(0, 1, N))[None, :, :3] - if input_rec == 601: - rgb = RGBConverter('Rec601', 'Rec709')(rgb) - # Convert to Lab space: - Luv = np.squeeze(skcolor.rgb2luv(rgb)) - LL, uu, vv = Luv.T - uu = (uu - ext[0]) / (ext[1] - ext[0]) * dim[1] - vv = (vv - ext[2]) / (ext[3] - ext[2]) * dim[0] - # Determine number of images / luma levels: - LL_min, LL_max = np.round(np.min(LL), 1), np.round(np.max(LL), 1) - if L == 'auto': - if LL_max - LL_min < 0.1: # Just one image: - L = LL_min - else: # Two images: - L = np.asarray((LL_max, np.mean(LL), LL_min)) - L_list = np.atleast_1d(L) - # Determine colorbar limits: - if cbar_lim is not None: # Overwrite limits! - LL_min, LL_max = cbar_lim - elif not brightness or LL_max - LL_min < 0.1: # Just one value, full range for colormap: - LL_min, LL_max = 0, 1 - # Creat grid: - if figsize is None: - figsize = (len(L_list) * 5 + 2, 7) - fig = plt.figure(figsize=figsize) - grid = ImageGrid(fig, 111, nrows_ncols=(1, len(L_list)), axes_pad=0.4, share_all=False, - cbar_location="right", cbar_mode="single", cbar_size="5%", cbar_pad=0.25) - # Plot: - if brightness: - c = LL - cmap = 'gray' - else: - c = np.linspace(0, 1, N) - for i, axis in enumerate(grid): - self.plot(L=L_list[i], axis=axis) - im = axis.scatter(uu, vv, c=c, cmap=cmap, edgecolors='none', - vmin=LL_min, vmax=LL_max) - axis.set_xlim(0, self.dim[1]) - axis.set_ylim(0, self.dim[0]) - axis.cax.colorbar(im, ticks=np.linspace(LL_min, LL_max, 9)) - - def plot3d(self, N=9): - self._log.debug('Calling plot3d') - dim, ext = self.dim, self.extent - # Create Lab colorspace: - u = np.linspace(ext[0], ext[1], dim[1]) - v = np.linspace(ext[2], ext[3], dim[0]) - uu, vv = np.meshgrid(u, v) - import visvis - for i in range(1, N): - LL = np.full(dim, i * 100 / N, dtype=int) - Luv = np.stack((LL, uu, vv), axis=-1) - # Convert to XYZ colorspace: - XYZ = skcolor.luv2xyz(Luv) - # Convert to RGB colorspace following algorithm from http://www.easyrgb.com/index.php: - rgb = skcolor.colorconv._convert(skcolor.colorconv.rgb_from_xyz, XYZ) - # Gamma correction (gamma encoding) rgb are now nonlinear (R'G'B')!: - mask = rgb > 0.0031308 - rgb[mask] = 1.055 * np.power(rgb[mask], 1 / 2.4) - 0.055 - rgb[~mask] *= 12.92 - # Determine gamut: - gamut = np.logical_or(rgb < 0, rgb > 1) - gamut = np.sum(gamut, axis=-1, dtype=bool) - # Alpha: - alpha = 1. - a = np.full(dim + (1,), alpha) - a *= np.logical_not(gamut[..., None]) - rgba = np.asarray(255 * np.dstack((rgb, a)), dtype=np.uint8) - # Visvis plot: - obj = visvis.functions.surf(uu, vv, i * 100. / N * np.ones_like(uu), rgba, aa=0) - obj.parent.light0.ambient = 1. - obj.parent.light0.diffuse = 0. - - -class ColorspaceCIExyY(object): - """Class representing the CIExyY colorspace.""" - - _log = logging.getLogger(__name__ + '.ColorspaceCIExyY') - - def __init__(self, dim=(500, 500), extent=(0, 0.8, 0, 0.8), cut_gamut=False, clip=True): - self._log.debug('Calling __init__') - self.dim = dim - self.extent = extent - self.cut_out_gamut = cut_gamut - self.clip = clip - self._log.debug('Created ' + str(self)) - - def plot(self, Y=0.214, axis=None, figsize=(8, 8)): - self._log.debug('Calling plot') - dim, ext = self.dim, self.extent - # Create Lab colorspace: - x = np.linspace(ext[0], ext[1], dim[1]) - y = np.linspace(ext[2], ext[3], dim[0]) - xx, yy = np.meshgrid(x, y) - YY = np.full(dim, Y) - # Convert to XYZ: - XX = YY / (yy + 1e-30) * xx - ZZ = YY / (yy + 1e-30) * (1 - xx - yy) - XYZ = np.stack((XX, YY, ZZ), axis=-1) - # Convert to RGB colorspace following algorithm from http://www.easyrgb.com/index.php: - rgb = skcolor.colorconv._convert(skcolor.colorconv.rgb_from_xyz, XYZ) - # Gamma correction (gamma encoding) rgb are now nonlinear (R'G'B')!: - mask = rgb > 0.0031308 - rgb[mask] = 1.055 * np.power(rgb[mask], 1 / 2.4) - 0.055 - rgb[~mask] *= 12.92 - # Determine gamut: - gamut = np.logical_or(rgb < 0, rgb > 1) - gamut = np.sum(gamut, axis=-1, dtype=bool) - gamut_mask = np.stack((gamut, gamut, gamut), axis=-1) - # Cut out gamut (set out of bound colors to gray) if necessary: - if self.cut_out_gamut: - rgb[gamut_mask] = 0.5 - # Clip out of gamut colors: - if self.clip: - rgb[rgb < 0] = 0 - rgb[rgb > 1] = 1 - # Plot colorspace: - if axis is None: - fig = plt.figure(figsize=figsize) - axis = fig.add_subplot(1, 1, 1, aspect='equal') - axis.imshow(rgb, origin='lower', interpolation='none', extent=(0, dim[0], 0, dim[1])) - axis.contour(gamut, levels=[0], colors='k', linewidths=1.5) - axis.set_xlabel('x', fontsize=15) - axis.set_ylabel('y', fontsize=15) - axis.set_title('CIExyY (Y = {:g})'.format(Y), fontsize=18) - fx = FuncForm(lambda x, pos: '{:.3g}'.format(x / dim[1] * (ext[1] - ext[0]) + ext[0])) - axis.xaxis.set_major_formatter(fx) - fy = FuncForm(lambda y, pos: '{:.3g}'.format(y / dim[0] * (ext[3] - ext[2]) + ext[2])) - axis.yaxis.set_major_formatter(fy) - axis.xaxis.set_major_locator(MaxNLocator(nbins=12, integer=True)) - axis.yaxis.set_major_locator(MaxNLocator(nbins=12, integer=True)) - - def plot_colormap(self, cmap, N=256, Y='auto', figsize=None, cbar_lim=None, brightness=True, - input_rec=None): - self._log.debug('Calling plot_colormap') - dim, ext = self.dim, self.extent - # Calculate rgb values: - rgb = cmap(np.linspace(0, 1, N))[None, :, :3] - if input_rec == 601: - rgb = RGBConverter('Rec601', 'Rec709')(rgb) - # Convert to XYZ space: - XYZ = np.squeeze(skcolor.rgb2xyz(rgb)) - XX, YY, ZZ = XYZ.T - # Convert to xyY space: - xx = XX / (XX + YY + ZZ) - yy = YY / (XX + YY + ZZ) - xx = (xx - ext[0]) / (ext[1] - ext[0]) * dim[1] - yy = (yy - ext[2]) / (ext[3] - ext[2]) * dim[0] - # Determine number of images / luma levels: - YY_min, YY_max = np.round(np.min(YY), 2), np.round(np.max(YY), 2) - if Y == 'auto': - if YY_max - YY_min < 0.01: # Just one image: - Y = YY_min - else: # Two images: - Y = np.asarray((YY_max, np.mean(YY), YY_min)) - Y_list = np.atleast_1d(Y) - # Determine colorbar limits: - if cbar_lim is not None: # Overwrite limits! - YY_min, YY_max = cbar_lim - elif not brightness or YY_max - YY_min < 0.01: # Just one value, full range for colormap: - YY_min, YY_max = 0, 1 - # Creat grid: - if figsize is None: - figsize = (len(Y_list) * 5 + 2, 7) - fig = plt.figure(figsize=figsize) - grid = ImageGrid(fig, 111, nrows_ncols=(1, len(Y_list)), axes_pad=0.4, share_all=False, - cbar_location="right", cbar_mode="single", cbar_size="5%", cbar_pad=0.25) - # Plot: - if brightness: - c = YY - cmap = 'gray' - else: - c = np.linspace(0, 1, N) - for i, axis in enumerate(grid): - self.plot(Y=Y_list[i], axis=axis) - im = axis.scatter(xx, yy, c=c, cmap=cmap, edgecolors='none', - vmin=YY_min, vmax=YY_max) - axis.set_xlim(0, self.dim[1]) - axis.set_ylim(0, self.dim[0]) - axis.cax.colorbar(im, ticks=np.linspace(YY_min, YY_max, 9)) - - def plot3d(self, N=9): - self._log.debug('Calling plot3d') - dim, ext = self.dim, self.extent - # Create Lab colorspace: - x = np.linspace(ext[0], ext[1], dim[1]) - y = np.linspace(ext[2], ext[3], dim[0]) - xx, yy = np.meshgrid(x, y) - import visvis - for i in range(1, N): - YY = np.full(dim, i * 1. / N) - # Convert to XYZ: - XX = YY / (yy + 1e-30) * xx - ZZ = YY / (yy + 1e-30) * (1 - xx - yy) - XYZ = np.stack((XX, YY, ZZ), axis=-1) - # Convert to RGB colorspace following algorithm from http://www.easyrgb.com/index.php: - rgb = skcolor.colorconv._convert(skcolor.colorconv.rgb_from_xyz, XYZ) - # Gamma correction (gamma encoding) rgb are now nonlinear (R'G'B')!: - mask = rgb > 0.0031308 - rgb[mask] = 1.055 * np.power(rgb[mask], 1 / 2.4) - 0.055 - rgb[~mask] *= 12.92 - # Determine gamut: - gamut = np.logical_or(rgb < 0, rgb > 1) - gamut = np.sum(gamut, axis=-1, dtype=bool) - # Alpha: - alpha = 1. - a = np.full(dim + (1,), alpha) - a *= np.logical_not(gamut[..., None]) - rgba = np.asarray(255 * np.dstack((rgb, a)), dtype=np.uint8) - # Visvis plot: - obj = visvis.functions.surf(xx, yy, i / N * np.ones_like(xx), rgba, aa=0) - obj.parent.light0.ambient = 1. - obj.parent.light0.diffuse = 0. - - -class ColorspaceYPbPr(object): - """Class representing the YPbPr colorspace.""" - - _log = logging.getLogger(__name__ + '.ColorspaceYPbPr') - - def __init__(self, dim=(500, 500), extent=(-0.8, 0.8, -0.8, 0.8), cut_gamut=False, clip=True): - self._log.debug('Calling __init__') - self.dim = dim - self.extent = extent - self.cut_out_gamut = cut_gamut - self.clip = clip - self._log.debug('Created ' + str(self)) - - def plot(self, Y=0.5, axis=None, figsize=(8, 8)): - self._log.debug('Calling plot') - dim, ext = self.dim, self.extent - # Create YPbPr colorspace: - pb = np.linspace(ext[0], ext[1], dim[1]) - pr = np.linspace(ext[2], ext[3], dim[0]) - ppb, ppr = np.meshgrid(pb, pr) - YY = np.full(dim, Y) # This is luma, not relative luminance (Y', not Y)! - # Convert to RGB colorspace (this is the nonlinear R'G'B' space!): - rr = YY + 1.402 * ppr - gg = YY - 0.344136 * ppb - 0.7141136 * ppr - bb = YY + 1.772 * ppb - rgb = np.stack((rr, gg, bb), axis=-1) - # Determine gamut: - gamut = np.logical_or(rgb < 0, rgb > 1) - gamut = np.sum(gamut, axis=-1, dtype=bool) - gamut_mask = np.stack((gamut, gamut, gamut), axis=-1) - # Cut out gamut (set out of bound colors to gray) if necessary: - if self.cut_out_gamut: - rgb[gamut_mask] = 0.5 - # Clip out of gamut colors: - if self.clip: - rgb[rgb < 0] = 0 - rgb[rgb > 1] = 1 - # Plot colorspace: - if axis is None: - fig = plt.figure(figsize=figsize) - axis = fig.add_subplot(1, 1, 1, aspect='equal') - axis.imshow(rgb, origin='lower', interpolation='none', - extent=(0, dim[0], 0, dim[1])) - axis.contour(gamut, levels=[0], colors='k', linewidths=1.5) - axis.set_xlabel('Pb', fontsize=15) - axis.set_ylabel('Pr', fontsize=15) - axis.set_title("Y'PbPr (Y' = {:g})".format(Y), fontsize=18) - fx = FuncForm( - lambda x, pos: '{:.3g}'.format(x / dim[1] * (ext[1] - ext[0]) + ext[0])) - axis.xaxis.set_major_formatter(fx) - fy = FuncForm( - lambda y, pos: '{:.3g}'.format(y / dim[0] * (ext[3] - ext[2]) + ext[2])) - axis.yaxis.set_major_formatter(fy) - axis.xaxis.set_major_locator(MaxNLocator(nbins=12, integer=True)) - axis.yaxis.set_major_locator(MaxNLocator(nbins=12, integer=True)) - - def plot_colormap(self, cmap, N=256, Y='auto', figsize=None, cbar_lim=None, brightness=True, - input_rec=None): - self._log.debug('Calling plot_colormap') - dim, ext = self.dim, self.extent - # Calculate rgb values: - rgb = cmap(np.linspace(0, 1, N))[None, :, :3] - if input_rec == 709: - rgb = RGBConverter('Rec709', 'Rec601')(rgb) - rr, gg, bb = rgb.T - # Convert to YPbPr space: - k_r, k_g, k_b = 0.299, 0.587, 0.114 # Constants Rec.601! - YY = k_r * rr + k_g * gg + k_b * bb - ppb = (bb - YY) / (2 * (1 - k_b)) - ppr = (rr - YY) / (2 * (1 - k_r)) - ppb = (ppb - ext[0]) / (ext[1] - ext[0]) * dim[1] - ppr = (ppr - ext[2]) / (ext[3] - ext[2]) * dim[0] - # Determine number of images / luma levels: - YY_min, YY_max = np.round(np.min(YY), 2), np.round(np.max(YY), 2) - if Y == 'auto': - if YY_max - YY_min < 0.01: # Just one image: - Y = YY_min - else: # Two images: - Y = np.asarray((YY_max, np.mean(YY), YY_min)) - Y_list = np.atleast_1d(Y) - # Determine colorbar limits: - if cbar_lim is not None: # Overwrite limits! - YY_min, YY_max = cbar_lim - elif not brightness or YY_max - YY_min < 0.01: # Just one value, full range for colormap: - YY_min, YY_max = 0, 1 - # Creat grid: - if figsize is None: - figsize = (len(Y_list) * 5 + 2, 7) - fig = plt.figure(figsize=figsize) - grid = ImageGrid(fig, 111, nrows_ncols=(1, len(Y_list)), axes_pad=0.4, share_all=False, - cbar_location="right", cbar_mode="single", cbar_size="5%", cbar_pad=0.25) - # Plot: - if brightness: - c = YY - cmap = 'gray' - else: - c = np.linspace(0, 1, N) - for i, axis in enumerate(grid): - self.plot(Y=Y_list[i], axis=axis) - im = axis.scatter(ppb, ppr, c=c, cmap=cmap, edgecolors='none', - vmin=YY_min, vmax=YY_max) - axis.set_xlim(0, self.dim[1]) - axis.set_ylim(0, self.dim[0]) - axis.cax.colorbar(im, ticks=np.linspace(YY_min, YY_max, 9)) - - def plot3d(self, N=9): - self._log.debug('Calling plot3d') - dim, ext = self.dim, self.extent - # Create YPbPr colorspace: - pb = np.linspace(ext[0], ext[1], dim[1]) - pr = np.linspace(ext[2], ext[3], dim[0]) - ppb, ppr = np.meshgrid(pb, pr) - import visvis - for i in range(1, N): - YY = np.full(dim, i * 1. / N) # This is luma, not relative luminance (Y', not Y)! - # Convert to RGB colorspace (this is the nonlinear R'G'B' space!): - rr = YY + 1.402 * ppr - gg = YY - 0.344136 * ppb - 0.7141136 * ppr - bb = YY + 1.772 * ppb - rgb = np.stack((rr, gg, bb), axis=-1) - # Determine gamut: - gamut = np.logical_or(rgb < 0, rgb > 1) - gamut = np.sum(gamut, axis=-1, dtype=bool) - # Alpha: - alpha = 1. - a = np.full(dim + (1,), alpha) - a *= np.logical_not(gamut[..., None]) - rgba = np.asarray(255 * np.dstack((rgb, a)), dtype=np.uint8) - # Visvis plot: - obj = visvis.functions.surf(ppb, ppr, i / N * np.ones_like(ppb), rgba, aa=0) - obj.parent.light0.ambient = 1. - obj.parent.light0.diffuse = 0. - - -class RGBConverter(object): - """Class for the conversion of RGB values from one RGB space to another. - - Notes - ----- - This operates only on NONLINEAR R'G'B' values, normalised to a range of [0, 1]! - Convert from linear RGB values beforehand, if necessary! - - """ - - rgb601_to_ypbpr = np.array([[+0.299000, +0.587000, +0.114000], - [-0.168736, -0.331264, +0.500000], - [+0.500000, -0.418688, -0.081312]]) - ypbr_to_rgb709 = np.array([[1, +0.0000, +1.5701], - [1, -0.1870, -0.4664], - [1, +1.8556, +0.0000]]) - rgb601_to_rgb709 = ypbr_to_rgb709.dot(rgb601_to_ypbpr) - rgb709_to_rgb601 = np.linalg.inv(rgb601_to_rgb709) - - def __init__(self, source='Rec601', target='Rec709'): - if source == 'Rec601' and target == 'Rec709': - self.convert_matrix = self.rgb601_to_rgb709 - elif source == 'Rec709' and target == 'Rec601': - self.convert_matrix = self.rgb709_to_rgb601 - else: - raise KeyError('Conversion from {} to {} not found!'.format(source, target)) - - def __call__(self, rgb): - """Convert from one RGB space to another. - - Parameters - ---------- - rgb: :class:`~numpy.ndarray` - Numpy array containing the RGB source values (last dimension: 3). - - Returns - ------- - rgb_result: :class:`~numpy.ndarray` - The resulting RGB values in the target space. - - """ - rgb_out = rgb.reshape((-1, 3)).T - rgb_out = self.convert_matrix.dot(rgb_out) - return rgb_out.T.reshape(rgb.shape) - - -def interpolate_color(fraction, start, end): - """Interpolate linearly between two color tuples (e.g. RGB). - - Parameters - ---------- - fraction: float or :class:`~numpy.ndarray` - Interpolation fraction between 0 and 1, which determines the position of the - interpolation between `start` and `end`. - start: tuple (N=3) or :class:`~numpy.ndarray` - Start of the interpolation as a tuple of three numbers or a numpy array, where the last - dimension should have length 3 and contain the color tuples. - end: tuple (N=3) or :class:`~numpy.ndarray` - End of the interpolation as a tuple of three numbers or a numpy array, where the last - dimension should have length 3 and contain the color tuples. - - Returns - ------- - result: tuple (N=3) or :class:`~numpy.ndarray` - Result of the interpolation as a tuple of three numbers or a numpy array, where the - last dimension should has length 3 and contains the color tuples. - - """ - _log.debug('Calling interpolate_color') - start, end = np.asarray(start), np.asarray(end) - r1 = start[..., 0] + (end[..., 0] - start[..., 0]) * fraction - r2 = start[..., 1] + (end[..., 1] - start[..., 1]) * fraction - r3 = start[..., 2] + (end[..., 2] - start[..., 2]) * fraction - return r1, r2, r3 - - -def rgb_to_brightness(rgb, mode="Y'", input_rec=None): - - import colorspacious # TODO: Use for everything! - c = {601: [0.299, 0.587, 0.114], 709: [0.2125, 0.7154, 0.0721]} # Image.convert('L') uses 601! - if input_rec is None: # Not specified, use in all cases: - rgbp601 = rgb - rgbp709 = rgb - elif input_rec == 601: - rgbp601 = rgb - rgbp709 = RGBConverter('Rec601', 'Rec709')(rgb) - elif input_rec == 709: - rgbp601 = RGBConverter('Rec601', 'Rec709')(rgb) - rgbp709 = rgb - else: - raise KeyError('Input RGB type {} not understood!'.format(input_rec)) - if mode in ("Y'", 'Luma'): - rp601, gp601, bp601 = rgbp601.T - brightness = c[601][0] * rp601 + c[601][1] * gp601 + c[601][2] * bp601 - elif mode in ('Y', 'Luminance'): - rgb709 = colorspacious.cspace_converter('sRGB1', 'sRGB1-linear')(rgbp709) - r709, g709, b709 = rgb709.T - brightness = c[709][0] * r709 + c[709][1] * g709 + c[709][2] * b709 - elif mode in ('L*', 'LightnessLab'): - lab = colorspacious.cspace_converter('sRGB1', 'CIELab')(rgbp709) - brightness = lab[0, :, 0] - elif mode in ('I', 'Intensity', 'Average'): - brightness = np.mean(rgb, axis=-1) - elif mode in ('V', 'Value', 'Maximum'): - brightness = np.max(rgb, axis=-1) - elif mode in ('L', 'LightnessHSL'): - brightness = (np.max(rgb, axis=-1) + np.min(rgb, axis=-1)) / 2 - else: - raise KeyError('Brightness request {} not understood!'.format(mode)) - return brightness - - -def colormap_brightness_comparison(cmap, input_rec=None, figsize=(18, 8)): - - # Create R'G'B' values from colormap: - x = np.linspace(0, 1, 1000) - rgbp = cmap(x)[None, :, :3] - # Calculate different brightness measures: - luma = rgb_to_brightness(rgbp, mode="Y'", input_rec=input_rec) - luminance = rgb_to_brightness(rgbp, mode='Y', input_rec=input_rec) - lightness_lab = rgb_to_brightness(rgbp, mode='L*', input_rec=input_rec) - intensity = rgb_to_brightness(rgbp, mode='I', input_rec=input_rec) - value = rgb_to_brightness(rgbp, mode='V', input_rec=input_rec) - lightness_hls = rgb_to_brightness(rgbp, mode='L', input_rec=input_rec) - # Plot: - fig, grid = plt.subplots(2, 3, figsize=figsize) - plt.title(cmap.name) - axis = grid[0, 0] - axis.scatter(x, luma, c=x, cmap=cmap, s=200, linewidths=0.) - axis.axhline(y=0.5, color='k', ls='--') - axis.set_xlim(0, 1) - axis.set_ylim(0, 1) - axis.set_title("Luma $Y$ '") - axis = grid[0, 1] - axis.scatter(x, luminance, c=x, cmap=cmap, s=200, linewidths=0.) - axis.axhline(y=0.214, color='k', ls='--') - axis.set_xlim(0, 1) - axis.set_ylim(0, 1) - axis.set_title('Relative Luminance $Y$') - axis = grid[0, 2] - axis.scatter(x, lightness_lab, c=x, cmap=cmap, s=200, linewidths=0.) - axis.axhline(y=53.39, color='k', ls='--') - axis.set_xlim(0, 1) - axis.set_ylim(0, 100) - axis.set_title('Lightness $L^*$ (CIELab)') - axis = grid[1, 0] - axis.scatter(x, intensity, c=x, cmap=cmap, s=200, linewidths=0.) - axis.axhline(y=53.39, color='k', ls='--') - axis.set_xlim(0, 1) - axis.set_ylim(0, 1) - axis.set_title('Intensity $I$ (HSI Component Average)') - axis = grid[1, 1] - axis.scatter(x, value, c=x, cmap=cmap, s=200, linewidths=0.) - axis.axhline(y=53.39, color='k', ls='--') - axis.set_xlim(0, 1) - axis.set_ylim(0, 1) - axis.set_title('Value $V$ (HSV Component Maximum)') - axis = grid[1, 2] - axis.scatter(x, lightness_hls, c=x, cmap=cmap, s=200, linewidths=0.) - axis.axhline(y=53.39, color='k', ls='--') - axis.set_xlim(0, 1) - axis.set_ylim(0, 1) - axis.set_title('Lightness $L$ (HSL Min-Max-Average)') - - -cmaps = {'cubehelix_standard': ColormapCubehelix(), - 'cubehelix_reverse': ColormapCubehelix(reverse=True), - 'cubehelix_circular': ColormapCubehelix(start=1, rot=1, - minLight=0.5, maxLight=0.5, sat=2), - 'perception_circular': ColormapPerception(), - 'hls_circular': ColormapHLS(), - 'classic_circular': ColormapClassic(), - 'transparent_black': ColormapTransparent(0, 0, 0, [0, 1.]), - 'transparent_white': ColormapTransparent(1, 1, 1, [0, 1.]), - 'transparent_confidence': ColormapTransparent(0.2, 0.3, 0.2, [0.75, 0.])} - -CMAP_CIRCULAR_DEFAULT = cmaps['cubehelix_circular'] +# -*- coding: utf-8 -*- +# Copyright 2014 by Forschungszentrum Juelich GmbH +# Author: J. Caron +# +"""This module provides a number of custom colormaps, which also have capabilities for 3D plotting. +If this is the case, the :class:`~.Colormap3D` colormap class is a parent class. In `cmaps`, a +number of specialised colormaps is available for convenience. If the default for circular colormaps +(used for 3D plotting) should be changed, set it via `CMAP_CMAP_ANGULAR_DEFAULT`. +For general questions about colors see: +http://www.poynton.com/PDFs/GammaFAQ.pdf +http://www.poynton.com/PDFs/ColorFAQ.pdf +""" + +import logging + +import matplotlib.pyplot as plt +from matplotlib.ticker import FuncFormatter as FuncForm +from matplotlib.ticker import MaxNLocator + +from mpl_toolkits.mplot3d import Axes3D +from mpl_toolkits.axes_grid1 import ImageGrid +from matplotlib import gridspec +from matplotlib.patches import Circle + +import numpy as np +from PIL import Image +from matplotlib import colors + +from skimage import color as skcolor + +import colorsys + +import abc + +__all__ = ['Colormap3D', 'ColormapCubehelix', 'ColormapPerception', 'ColormapHLS', + 'ColormapClassic', 'ColormapTransparent', 'cmaps', 'CMAP_CIRCULAR_DEFAULT', + 'ColorspaceCIELab', 'ColorspaceCIELuv', 'ColorspaceCIExyY', 'ColorspaceYPbPr', + 'interpolate_color', 'rgb_to_brightness', 'colormap_brightness_comparison'] +_log = logging.getLogger(__name__) + + +# TODO: DOCSTRINGS!!! + + +class Colormap3D(colors.Colormap, metaclass=abc.ABCMeta): + """Colormap subclass for encoding directions with colors. + + This abstract class is used as a superclass/interface for 3D vector plotting capabilities. + In general, a circular colormap should be used to encode the in-plane angle (hue). The + perpendicular angle is encoded via luminance variation (up: white, down: black). Finally, + the length of a vector is encoded via saturation. Decreasing vector length causes a desaturated + color. Subclassing colormaps get access to routines to plot a colorwheel (which should + ideally be located in the 50% luminance plane, which depends strongly on the underlying map), + a convenience function to interpolate color tuples and a function to return rgb triples for a + given vector. The :class:`~.Colormap3D` class itself subclasses the matplotlib base colormap. + + """ + + _log = logging.getLogger(__name__ + '.Colormap3D') + + def rgb_from_vector(self, vector): + """Construct a hls tuple from three coordinates representing a 3D direction. + + Parameters + ---------- + vector: tuple (N=3) or :class:`~numpy.ndarray` + Vector containing the x, y and z component, or a numpy array encompassing the + components as three lists.z-coordinate of the desired direction to encode. + + Returns + ------- + rgb: :class:`~numpy.ndarray` + Numpy array containing the calculated color tuples. + + """ + self._log.debug('Calling rgb_from_vector') + x, y, z = np.asarray(vector) + # Calculate spherical coordinates: + r = np.sqrt(x ** 2 + y ** 2 + z ** 2) + phi = np.asarray(np.arctan2(y, x)) + phi[phi < 0] += 2 * np.pi + theta = np.arccos(z / (r + 1E-30)) + # Calculate color deterministics: + hue = phi / (2 * np.pi) + lum = 1 - theta / np.pi + sat = r / r.max() + # Calculate RGB from hue with colormap: + rgba = np.asarray(self(hue)) + r, g, b = rgba[..., 0], rgba[..., 1], rgba[..., 2] + # Interpolate saturation: + r, g, b = interpolate_color(sat, (0.5, 0.5, 0.5), np.stack((r, g, b), axis=-1)) + # Interpolate luminance: + lum_target = np.where(lum < 0.5, 0, 1) + lum_target = np.stack([lum_target] * 3, axis=-1) + fraction = np.where(lum < 0.5, 1 - 2 * lum, 2 * (lum - 0.5)) + r, g, b = interpolate_color(fraction, np.stack((r, g, b), axis=-1), lum_target) + # Return RGB: + return np.asarray(255 * np.stack((r, g, b), axis=-1), dtype=np.uint8) + + def make_colorwheel(self, size=256, alpha=1): + self._log.debug('Calling make_colorwheel') + # Construct the colorwheel: + yy, xx = (np.indices((size, size)) - size/2 + 0.5) + rr = np.hypot(xx, yy) + xx = np.where(rr <= size/2-2, xx, 0) + yy = np.where(rr <= size/2-2, yy, 0) + zz = np.where(rr <= size/2-2, 0, -1) # color inside, black outside + aa = np.where(rr >= size/2-2, 255*alpha, 255).astype(dtype=np.uint8) + rgba = np.dstack((self.rgb_from_vector(np.asarray((xx, yy, zz))), aa)) + # Create color wheel: + return Image.fromarray(rgba) + + def plot_colorwheel(self, axis=None, size=512, alpha=1, arrows=False, figsize=(4, 4), + **kwargs): + """Display a color wheel to illustrate the color coding of vector gradient directions. + + Parameters + ---------- + figsize : tuple of floats (N=2) + Size of the plot figure. + + Returns + ------- + None + + """ + self._log.debug('Calling plot_colorwheel') + # Construct the colorwheel: + color_wheel = self.make_colorwheel(size=size, alpha=alpha) + # Plot the color wheel: + if axis is None: + fig = plt.figure(figsize=figsize) + axis = fig.add_subplot(1, 1, 1, aspect='equal') + axis.imshow(color_wheel, origin='lower', **kwargs) + axis.add_artist(Circle(xy=(size/2-0.5, size/2-0.5), radius=size/2-2, linewidth=2, + edgecolor='k', facecolor='none')) + if arrows: + plt.tick_params(axis='both', which='both', labelleft='off', labelbottom='off', + left='off', right='off', top='off', bottom='off') + axis.arrow(size/2, size/2, 0, 0.15*size, head_width=9, head_length=20, + fc='k', ec='k', lw=1, width=2) + axis.arrow(size/2, size/2, 0, -0.15*size, head_width=9, head_length=20, + fc='k', ec='k', lw=1, width=2) + axis.arrow(size/2, size/2, 0.15*size, 0, head_width=9, head_length=20, + fc='k', ec='k', lw=1, width=2) + axis.arrow(size/2, size/2, -0.15*size, 0, head_width=9, head_length=20, + fc='k', ec='k', lw=1, width=2) + # Return axis: + axis.xaxis.set_visible(False) + axis.yaxis.set_visible(False) + return axis + + +class ColormapCubehelix(colors.LinearSegmentedColormap, Colormap3D): + """A full implementation of Dave Green's "cubehelix" for Matplotlib. + + Based on the FORTRAN 77 code provided in D.A. Green, 2011, BASI, 39, 289. + http://adsabs.harvard.edu/abs/2011arXiv1108.5083G + Also see: + http://www.mrao.cam.ac.uk/~dag/CUBEHELIX/ + http://davidjohnstone.net/pages/cubehelix-gradient-picker + User can adjust all parameters of the cubehelix algorithm. This enables much greater + flexibility in choosing color maps. Default color map settings produce the standard cubehelix. + Create color map in only blues by setting rot=0 and start=0. Create reverse (white to black) + backwards through the rainbow once by setting rot=1 and reverse=True, etc. Furthermore, the + algorithm was tuned, so that constant luminance values can be used (e.g. to create a truly + isoluminant colorwheel). The `rot` parameter is also tuned to hold true for these cases. + Of the here presented colorwheels, only this one manages to solely navigate through the L*=50 + plane, which can be seen here: + https://upload.wikimedia.org/wikipedia/commons/2/21/Lab_color_space.png + + Parameters + ---------- + start : scalar, optional + Sets the starting position in the color space. 0=blue, 1=red, + 2=green. Defaults to 0.5. + rot : scalar, optional + The number of rotations through the rainbow. Can be positive + or negative, indicating direction of rainbow. Negative values + correspond to Blue->Red direction. Defaults to -1.5. + gamma : scalar, optional + The gamma correction for intensity. Defaults to 1.0. + reverse : boolean, optional + Set to True to reverse the color map. Will go from black to + white. Good for density plots where shade~density. Defaults to False. + nlev : scalar, optional + Defines the number of discrete levels to render colors at. + Defaults to 256. + sat : scalar, optional + The saturation intensity factor. Defaults to 1.2 + NOTE: this was formerly known as `hue` parameter + minSat : scalar, optional + Sets the minimum-level saturation. Defaults to 1.2. + maxSat : scalar, optional + Sets the maximum-level saturation. Defaults to 1.2. + startHue : scalar, optional + Sets the starting color, ranging from [0, 360], as in + D3 version by @mbostock. + NOTE: overrides values in start parameter. + endHue : scalar, optional + Sets the ending color, ranging from [0, 360], as in + D3 version by @mbostock + NOTE: overrides values in rot parameter. + minLight : scalar, optional + Sets the minimum lightness value. Defaults to 0. + maxLight : scalar, optional + Sets the maximum lightness value. Defaults to 1. + + Returns + ------- + matplotlib.colors.LinearSegmentedColormap object + + Revisions + --------- + 2014-04 (@jradavenport) Ported from IDL version + 2014-04 (@jradavenport) Added kwargs to enable similar to D3 version, + changed name of `hue` parameter to `sat`. + 2016-11 (@jan.caron) Added support for isoluminant cubehelices while making sure + `rot` works as intended. Decoded the plane-vectors a bit. + """ + + _log = logging.getLogger(__name__ + '.ColormapCubehelix') + + def __init__(self, start=0.5, rot=-1.5, gamma=1.0, reverse=False, nlev=256, + minSat=1.2, maxSat=1.2, minLight=0., maxLight=1., **kwargs): + self._log.debug('Calling __init__') + # Override start and rot if startHue and endHue are set: + if kwargs is not None: + if 'startHue' in kwargs: + start = (kwargs.get('startHue') / 360. - 1.) * 3. + if 'endHue' in kwargs: + rot = kwargs.get('endHue') / 360. - start / 3. - 1. + if 'sat' in kwargs: + minSat = kwargs.get('sat') + maxSat = kwargs.get('sat') + self.nlev = nlev + # Set up the parameters: + self.fract = np.linspace(minLight, maxLight, nlev) + angle = 2.0 * np.pi * (start / 3.0 + rot * np.linspace(0, 1, nlev)) + self.fract = self.fract**gamma + satar = np.linspace(minSat, maxSat, nlev) + amp = np.asarray(satar * self.fract * (1. - self.fract) / 2) + # Set RGB color coefficients (Luma is calculated in RGB Rec.601, so choose those), + # the original version of Dave green used (0.30, 0.59, 0.11) and REc.709 is + # c709 = (0.2126, 0.7152, 0.0722) but would not produce correct YPbPr Luma. + c601 = (0.299, 0.587, 0.114) + cr, cg, cb = c601 + cw = -0.90649 # Chosen to comply with Dave Greens implementation. + k = -1.6158 / cr / cw # k has to balance out cw so nothing gets out of RGB gamut (> 1). + # Calculate the vectors v and w spanning the plane of constant perceived intensity. + # v and w have to solve v x w = k(cr, cg, cb) (normal vector of the described plane) and + # v * w = 0 (scalar product, v and w have to be perpendicular). + # 6 unknown and 4 equations --> Chose wb = 0 and wg = cw (constant). + v = np.array((k * cr ** 2 * cb / (cw * (cr ** 2 + cg ** 2)), + k * cr * cg * cb / (cw * (cr ** 2 + cg ** 2)), -k * cr / cw)) + w = np.array((-cw * cg / cr, cw, 0)) + # Calculate components: + self.red = self.fract + amp * (v[0] * np.cos(angle) + w[0] * np.sin(angle)) + self.grn = self.fract + amp * (v[1] * np.cos(angle) + w[1] * np.sin(angle)) + self.blu = self.fract + amp * (v[2] * np.cos(angle) + w[2] * np.sin(angle)) + # Original formulas with original v and w: + # self.red = self.fract + amp * (-0.14861 * np.cos(angle) + 1.78277 * np.sin(angle)) + # self.grn = self.fract + amp * (-0.29227 * np.cos(angle) - 0.90649 * np.sin(angle)) + # self.blu = self.fract + amp * (1.97294 * np.cos(angle)) + # Find where RBG are outside the range [0,1], clip: + self.red = np.clip(self.red, 0, 1) + self.grn = np.clip(self.grn, 0, 1) + self.blu = np.clip(self.blu, 0, 1) + # Optional color reverse: + if reverse is True: + self.red = self.red[::-1] + self.blu = self.blu[::-1] + self.grn = self.grn[::-1] + # Put in to tuple & dictionary structures needed: + rr, bb, gg = [], [], [] + for k in range(0, int(nlev)): + rr.append((float(k) / (nlev - 1), self.red[k], self.red[k])) + bb.append((float(k) / (nlev - 1), self.blu[k], self.blu[k])) + gg.append((float(k) / (nlev - 1), self.grn[k], self.grn[k])) + cdict = {'red': rr, 'blue': bb, 'green': gg} + super().__init__('cubehelix', cdict, N=256) + self._log.debug('Created ' + str(self)) + + def plot_helix(self, figsize=(8, 8)): + """Display the RGB and luminance plots for the chosen cubehelix. + + Parameters + ---------- + figsize : tuple of floats (N=2) + Size of the plot figure. + + Returns + ------- + None + + """ + self._log.debug('Calling plot_helix') + plt.figure(figsize=figsize) + gs = gridspec.GridSpec(2, 1, height_ratios=[8, 1]) + # Main plot: + axis = plt.subplot(gs[0]) + axis.plot(self.fract, 'k', linewidth=2) + axis.plot(self.red, 'r', linewidth=2) + axis.plot(self.grn, 'g', linewidth=2) + axis.plot(self.blu, 'b', linewidth=2) + axis.set_xlim(0, self.nlev) + axis.set_ylim(0, 1) + axis.set_title('Cubehelix', fontsize=18) + axis.set_xlabel('color index', fontsize=15) + axis.set_ylabel('lightness / rgb', fontsize=15) + # Colorbar horizontal: + caxis = plt.subplot(gs[1], sharex=axis) + rgb = self(np.linspace(0, 1, 256))[None, ...] + rgb = np.asarray(255.9999 * rgb, dtype=np.uint8) + rgb = np.repeat(rgb, 30, axis=0) + im = Image.fromarray(rgb) + caxis.imshow(im) + plt.tick_params(axis='both', which='both', labelleft='off', labelbottom='off', + left='off', right='off', top='on', bottom='on') + + +class ColormapPerception(colors.LinearSegmentedColormap, Colormap3D): + """A perceptual colormap based on face-based luminance matching. + + Based on a publication by Kindlmann et. al. + http://www.cs.utah.edu/~gk/papers/vis02/FaceLumin.pdf + This colormap tries to achieve an isoluminant perception by using a list of colors acquired + through face recognition studies. It is a lot better than the HLS colormap, but still not + completely isoluminant (despite its name). Also it appears a bit dark. + + """ + + _log = logging.getLogger(__name__ + '.ColormapPerception') + + CDICT = {'red': [(0/6, 0.847, 0.847), + (1/6, 0.527, 0.527), + (2/6, 0.000, 0.000), + (3/6, 0.000, 0.000), + (4/6, 0.316, 0.316), + (5/6, 0.718, 0.718), + (6/6, 0.847, 0.847)], + + 'green': [(0/6, 0.057, 0.057), + (1/6, 0.527, 0.527), + (2/6, 0.592, 0.592), + (3/6, 0.559, 0.559), + (4/6, 0.316, 0.316), + (5/6, 0.000, 0.000), + (6/6, 0.057, 0.057)], + + 'blue': [(0/6, 0.057, 0.057), + (1/6, 0.000, 0.000), + (2/6, 0.000, 0.000), + (3/6, 0.559, 0.559), + (4/6, 0.991, 0.991), + (5/6, 0.718, 0.718), + (6/6, 0.057, 0.057)]} + + def __init__(self): + self._log.debug('Calling __init__') + super().__init__('perception', self.CDICT, N=256) + self._log.debug('Created ' + str(self)) + + +class ColormapHLS(colors.ListedColormap, Colormap3D): + """Colormap subclass for encoding directions with colors. + + This class is a subclass of the :class:`~matplotlib.pyplot.colors.ListedColormap` + class. The class follows the HSL ('hue', 'saturation', 'lightness') 'Double Hexcone' Model + with the saturation always set to 1 (moving on the surface of the color + cylinder) with a lightness of 0.5 (full color). The three prime colors (`rgb`) are spaced + equidistant with 120° space in between, according to a triadic arrangement. + Even though the lightness is constant in the plane, the luminance (which is a weighted sum + of the RGB components which encompasses human perception) is not, which can lead to + artifacts like reliefs. Converting the map to a grayscale show spokes at the secondary colors. + For more information see: + https://vis4.net/blog/posts/avoid-equidistant-hsv-colors/ + http://www.workwithcolor.com/color-luminance-2233.htm + http://blog.asmartbear.com/color-wheels.html + + """ + + _log = logging.getLogger(__name__ + '.ColormapHLS') + + def __init__(self): + self._log.debug('Calling __init__') + h = np.linspace(0, 1, 256) + l = 0.5 * np.ones_like(h) + s = np.ones_like(h) + r, g, b = np.vectorize(colorsys.hls_to_rgb)(h, l, s) + colors = [(r[i], g[i], b[i]) for i in range(len(r))] + super().__init__(colors, 'hls', N=256) + self._log.debug('Created ' + str(self)) + + +class ColormapClassic(colors.LinearSegmentedColormap, Colormap3D): + """Colormap subclass for encoding directions with colors. + + This class is a subclass of the :class:`~matplotlib.pyplot.colors.LinearSegmentedColormap` + class. The class follows the HSL ('hue', 'saturation', 'lightness') 'Double + Hexcone' Model with the saturation always set to 1 (moving on the surface of the color + cylinder) with a luminance of 0.5 (full color). The colors follow a tetradic arrangement with + four colors (red, green, blue and yellow) arranged with 90° spacing in between. + + """ + + _log = logging.getLogger(__name__ + '.ColormapClassic') + + CDICT = {'red': [(0.00, 1.0, 1.0), + (0.25, 0.0, 0.0), + (0.50, 0.0, 0.0), + (0.75, 1.0, 1.0), + (1.00, 1.0, 1.0)], + + 'green': [(0.00, 0.0, 0.0), + (0.25, 0.0, 0.0), + (0.50, 1.0, 1.0), + (0.75, 1.0, 1.0), + (1.00, 0.0, 0.0)], + + 'blue': [(0.00, 0.0, 0.0), + (0.25, 1.0, 1.0), + (0.50, 0.0, 0.0), + (0.75, 0.0, 0.0), + (1.00, 0.0, 0.0)]} + + def __init__(self): + self._log.debug('Calling __init__') + super().__init__('classic', self.CDICT, N=256) + self._log.debug('Created ' + str(self)) + + +class ColormapTransparent(colors.LinearSegmentedColormap): + """Colormap subclass for including transparency. + + This class is a subclass of the :class:`~matplotlib.pyplot.colors.LinearSegmentedColormap` + class with integrated support for transparency. The colormap is unicolor and varies only in + transparency. + + Attributes + ---------- + r: float, optional + Intensity of red in the colormap. Has to be between 0. and 1. + g: float, optional + Intensity of green in the colormap. Has to be between 0. and 1. + b: float, optional + Intensity of blue in the colormap. Has to be between 0. and 1. + alpha_range : list (N=2) of float, optional + Start and end alpha value. Has to be between 0. and 1. + + """ + + _log = logging.getLogger(__name__ + '.ColormapTransparent') + + def __init__(self, r=0., g=0., b=0., alpha_range=None): + self._log.debug('Calling __init__') + if alpha_range is None: + alpha_range = [0., 1.] + red = [(0., 0., r), (1., r, 1.)] + green = [(0., 0., g), (1., g, 1.)] + blue = [(0., 0., b), (1., b, 1.)] + alpha = [(0., 0., alpha_range[0]), (1., alpha_range[1], 1.)] + cdict = {'red': red, 'green': green, 'blue': blue, 'alpha': alpha} + super().__init__('transparent', cdict, N=256) + self._log.debug('Created ' + str(self)) + + +class ColorspaceCIELab(object): # TODO: Superclass? + """Class representing the CIELab colorspace.""" + + _log = logging.getLogger(__name__ + '.ColorspaceCIELab') + + def __init__(self, dim=(500, 500), extent=(-100, 100, -100, 100), cut_gamut=False, clip=True): + self._log.debug('Calling __init__') + self.dim = dim + self.extent = extent + self.cut_out_gamut = cut_gamut + self.clip = clip + self._log.debug('Created ' + str(self)) + + def plot(self, L=53.4, axis=None, figsize=(8, 8)): + self._log.debug('Calling plot') + dim, ext = self.dim, self.extent + # Create Lab colorspace: + a = np.linspace(ext[0], ext[1], dim[1]) + b = np.linspace(ext[2], ext[3], dim[0]) + aa, bb = np.meshgrid(a, b) + LL = np.full(dim, L, dtype=int) + Lab = np.stack((LL, aa, bb), axis=-1) + # Convert to XYZ colorspace: + XYZ = skcolor.lab2xyz(Lab) + # Convert to RGB colorspace following algorithm from http://www.easyrgb.com/index.php: + rgb = skcolor.colorconv._convert(skcolor.colorconv.rgb_from_xyz, XYZ) + # Gamma correction (gamma encoding) rgb are now nonlinear (R'G'B')!: + mask = rgb > 0.0031308 + rgb[mask] = 1.055 * np.power(rgb[mask], 1 / 2.4) - 0.055 + rgb[~mask] *= 12.92 + # Determine gamut: + gamut = np.logical_or(rgb < 0, rgb > 1) + gamut = np.sum(gamut, axis=-1, dtype=bool) + gamut_mask = np.stack((gamut, gamut, gamut), axis=-1) + # Cut out gamut (set out of bound colors to gray) if necessary: + if self.cut_out_gamut: + rgb[gamut_mask] = 0.5 + # Clip out of gamut colors: + if self.clip: + rgb[rgb < 0] = 0 + rgb[rgb > 1] = 1 + # Plot colorspace: + if axis is None: + fig = plt.figure(figsize=figsize) + axis = fig.add_subplot(1, 1, 1, aspect='equal') + axis.imshow(rgb, origin='lower', interpolation='none', extent=(0, dim[0], 0, dim[1])) + axis.contour(gamut, levels=[0], colors='k', linewidths=1.5) + axis.set_xlabel('a', fontsize=15) + axis.set_ylabel('b', fontsize=15) + axis.set_title('CIELab (L = {:g})'.format(L), fontsize=18) + fx = FuncForm(lambda x, pos: '{:.3g}'.format(x / dim[1] * (ext[1] - ext[0]) + ext[0])) + axis.xaxis.set_major_formatter(fx) + fy = FuncForm(lambda y, pos: '{:.3g}'.format(y / dim[0] * (ext[3] - ext[2]) + ext[2])) + axis.yaxis.set_major_formatter(fy) + axis.xaxis.set_major_locator(MaxNLocator(nbins=12, integer=True)) + axis.yaxis.set_major_locator(MaxNLocator(nbins=12, integer=True)) + + def plot_colormap(self, cmap, N=256, L='auto', figsize=None, cbar_lim=None, brightness=True, + input_rec=None): + self._log.debug('Calling plot_colormap') + dim, ext = self.dim, self.extent + # Calculate rgb values: + rgb = cmap(np.linspace(0, 1, N))[None, :, :3] # These are R'G'B' values! + if input_rec == 601: + rgb = RGBConverter('Rec601', 'Rec709')(rgb) + # Convert to Lab space: + Lab = np.squeeze(skcolor.rgb2lab(rgb)) + LL, aa, bb = Lab.T + aa = (aa - ext[0]) / (ext[1] - ext[0]) * dim[1] + bb = (bb - ext[2]) / (ext[3] - ext[2]) * dim[0] + # Determine number of images / luma levels: + LL_min, LL_max = np.round(np.min(LL), 1), np.round(np.max(LL), 1) + if L == 'auto': + if LL_max - LL_min < 0.1: # Just one image: + L = LL_min + else: # Two images: + L = np.asarray((LL_max, np.mean(LL), LL_min)) + L_list = np.atleast_1d(L) + # Determine colorbar limits: + if cbar_lim is not None: # Overwrite limits! + LL_min, LL_max = cbar_lim + elif not brightness or LL_max - LL_min < 0.1: # Just one value, full range for colormap: + LL_min, LL_max = 0, 1 + # Creat grid: + if figsize is None: + figsize = (len(L_list) * 5 + 2, 7) + fig = plt.figure(figsize=figsize) + grid = ImageGrid(fig, 111, nrows_ncols=(1, len(L_list)), axes_pad=0.4, share_all=False, + cbar_location="right", cbar_mode="single", cbar_size="5%", cbar_pad=0.25) + # Plot: + if brightness: + c = LL + cmap = 'gray' + else: + c = np.linspace(0, 1, N) + for i, axis in enumerate(grid): + self.plot(L=L_list[i], axis=axis) + im = axis.scatter(aa, bb, c=c, cmap=cmap, edgecolors='none', + vmin=LL_min, vmax=LL_max) + axis.set_xlim(0, self.dim[1]) + axis.set_ylim(0, self.dim[0]) + axis.cax.colorbar(im, ticks=np.linspace(LL_min, LL_max, 9)) + + def plot3d(self, N=9): + self._log.debug('Calling plot3d') + dim, ext = self.dim, self.extent + # Create Lab colorspace: + a = np.linspace(ext[0], ext[1], dim[1]) + b = np.linspace(ext[2], ext[3], dim[0]) + aa, bb = np.meshgrid(a, b) + import visvis # TODO: If VisPy is ever ready, switch every plot to that! + for i in range(1, N): + LL = np.full(dim, i * 100 / N, dtype=int) + Lab = np.stack((LL, aa, bb), axis=-1) + # Convert to XYZ colorspace: + XYZ = skcolor.lab2xyz(Lab) + # Convert to RGB colorspace following algorithm from http://www.easyrgb.com/index.php: + rgb = skcolor.colorconv._convert(skcolor.colorconv.rgb_from_xyz, XYZ) + # Gamma correction (gamma encoding) rgb are now nonlinear (R'G'B')!: + mask = rgb > 0.0031308 + rgb[mask] = 1.055 * np.power(rgb[mask], 1 / 2.4) - 0.055 + rgb[~mask] *= 12.92 + # Determine gamut: + gamut = np.logical_or(rgb < 0, rgb > 1) + gamut = np.sum(gamut, axis=-1, dtype=bool) + # Alpha: + alpha = 1. + a = np.full(dim + (1,), alpha) + a *= np.logical_not(gamut[..., None]) + rgba = np.asarray(255 * np.dstack((rgb, a)), dtype=np.uint8) + # Visvis plot: + obj = visvis.functions.surf(aa, bb, i * 100. / N * np.ones_like(aa), rgba, aa=0) + obj.parent.light0.ambient = 1. + obj.parent.light0.diffuse = 0. + + +class ColorspaceCIELuv(object): + """Class representing the CIELuv colorspace.""" + + _log = logging.getLogger(__name__ + '.ColorspaceCIELuv') + + def __init__(self, dim=(500, 500), extent=(-100, 100, -100, 100), cut_gamut=False, clip=True): + self._log.debug('Calling __init__') + self.dim = dim + self.extent = extent + self.cut_out_gamut = cut_gamut + self.clip = clip + self._log.debug('Created ' + str(self)) + + def plot(self, L=53.4, axis=None, figsize=(8, 8)): + self._log.debug('Calling plot') + dim, ext = self.dim, self.extent + # Create Lab colorspace: + u = np.linspace(ext[0], ext[1], dim[1]) + v = np.linspace(ext[2], ext[3], dim[0]) + uu, vv = np.meshgrid(u, v) + LL = np.full(dim, L, dtype=int) + Luv = np.stack((LL, uu, vv), axis=-1) + # Convert to XYZ colorspace: + XYZ = skcolor.luv2xyz(Luv) + # Convert to RGB colorspace following algorithm from http://www.easyrgb.com/index.php: + rgb = skcolor.colorconv._convert(skcolor.colorconv.rgb_from_xyz, XYZ) + # Gamma correction (gamma encoding) rgb are now nonlinear (R'G'B')!: + mask = rgb > 0.0031308 + rgb[mask] = 1.055 * np.power(rgb[mask], 1 / 2.4) - 0.055 + rgb[~mask] *= 12.92 + # Determine gamut: + gamut = np.logical_or(rgb < 0, rgb > 1) + gamut = np.sum(gamut, axis=-1, dtype=bool) + gamut_mask = np.stack((gamut, gamut, gamut), axis=-1) + # Cut out gamut (set out of bound colors to gray) if necessary: + if self.cut_out_gamut: + rgb[gamut_mask] = 0.5 + # Clip out of gamut colors: + if self.clip: + rgb[rgb < 0] = 0 + rgb[rgb > 1] = 1 + # Plot colorspace: + if axis is None: + fig = plt.figure(figsize=figsize) + axis = fig.add_subplot(1, 1, 1, aspect='equal') + axis.imshow(rgb, origin='lower', interpolation='none', extent=(0, dim[0], 0, dim[1])) + axis.contour(gamut, levels=[0], colors='k', linewidths=1.5) + axis.set_xlabel('u', fontsize=15) + axis.set_ylabel('v', fontsize=15) + axis.set_title('CIELuv (L = {:g})'.format(L), fontsize=18) + fx = FuncForm(lambda x, pos: '{:.3g}'.format(x / dim[1] * (ext[1] - ext[0]) + ext[0])) + axis.xaxis.set_major_formatter(fx) + fy = FuncForm(lambda y, pos: '{:.3g}'.format(y / dim[0] * (ext[3] - ext[2]) + ext[2])) + axis.yaxis.set_major_formatter(fy) + axis.xaxis.set_major_locator(MaxNLocator(nbins=12, integer=True)) + axis.yaxis.set_major_locator(MaxNLocator(nbins=12, integer=True)) + + def plot_colormap(self, cmap, N=256, L='auto', figsize=None, cbar_lim=None, brightness=True, + input_rec=None): + self._log.debug('Calling plot_colormap') + dim, ext = self.dim, self.extent + # Calculate rgb values: + rgb = cmap(np.linspace(0, 1, N))[None, :, :3] + if input_rec == 601: + rgb = RGBConverter('Rec601', 'Rec709')(rgb) + # Convert to Lab space: + Luv = np.squeeze(skcolor.rgb2luv(rgb)) + LL, uu, vv = Luv.T + uu = (uu - ext[0]) / (ext[1] - ext[0]) * dim[1] + vv = (vv - ext[2]) / (ext[3] - ext[2]) * dim[0] + # Determine number of images / luma levels: + LL_min, LL_max = np.round(np.min(LL), 1), np.round(np.max(LL), 1) + if L == 'auto': + if LL_max - LL_min < 0.1: # Just one image: + L = LL_min + else: # Two images: + L = np.asarray((LL_max, np.mean(LL), LL_min)) + L_list = np.atleast_1d(L) + # Determine colorbar limits: + if cbar_lim is not None: # Overwrite limits! + LL_min, LL_max = cbar_lim + elif not brightness or LL_max - LL_min < 0.1: # Just one value, full range for colormap: + LL_min, LL_max = 0, 1 + # Creat grid: + if figsize is None: + figsize = (len(L_list) * 5 + 2, 7) + fig = plt.figure(figsize=figsize) + grid = ImageGrid(fig, 111, nrows_ncols=(1, len(L_list)), axes_pad=0.4, share_all=False, + cbar_location="right", cbar_mode="single", cbar_size="5%", cbar_pad=0.25) + # Plot: + if brightness: + c = LL + cmap = 'gray' + else: + c = np.linspace(0, 1, N) + for i, axis in enumerate(grid): + self.plot(L=L_list[i], axis=axis) + im = axis.scatter(uu, vv, c=c, cmap=cmap, edgecolors='none', + vmin=LL_min, vmax=LL_max) + axis.set_xlim(0, self.dim[1]) + axis.set_ylim(0, self.dim[0]) + axis.cax.colorbar(im, ticks=np.linspace(LL_min, LL_max, 9)) + + def plot3d(self, N=9): + self._log.debug('Calling plot3d') + dim, ext = self.dim, self.extent + # Create Lab colorspace: + u = np.linspace(ext[0], ext[1], dim[1]) + v = np.linspace(ext[2], ext[3], dim[0]) + uu, vv = np.meshgrid(u, v) + import visvis + for i in range(1, N): + LL = np.full(dim, i * 100 / N, dtype=int) + Luv = np.stack((LL, uu, vv), axis=-1) + # Convert to XYZ colorspace: + XYZ = skcolor.luv2xyz(Luv) + # Convert to RGB colorspace following algorithm from http://www.easyrgb.com/index.php: + rgb = skcolor.colorconv._convert(skcolor.colorconv.rgb_from_xyz, XYZ) + # Gamma correction (gamma encoding) rgb are now nonlinear (R'G'B')!: + mask = rgb > 0.0031308 + rgb[mask] = 1.055 * np.power(rgb[mask], 1 / 2.4) - 0.055 + rgb[~mask] *= 12.92 + # Determine gamut: + gamut = np.logical_or(rgb < 0, rgb > 1) + gamut = np.sum(gamut, axis=-1, dtype=bool) + # Alpha: + alpha = 1. + a = np.full(dim + (1,), alpha) + a *= np.logical_not(gamut[..., None]) + rgba = np.asarray(255 * np.dstack((rgb, a)), dtype=np.uint8) + # Visvis plot: + obj = visvis.functions.surf(uu, vv, i * 100. / N * np.ones_like(uu), rgba, aa=0) + obj.parent.light0.ambient = 1. + obj.parent.light0.diffuse = 0. + + +class ColorspaceCIExyY(object): + """Class representing the CIExyY colorspace.""" + + _log = logging.getLogger(__name__ + '.ColorspaceCIExyY') + + def __init__(self, dim=(500, 500), extent=(0, 0.8, 0, 0.8), cut_gamut=False, clip=True): + self._log.debug('Calling __init__') + self.dim = dim + self.extent = extent + self.cut_out_gamut = cut_gamut + self.clip = clip + self._log.debug('Created ' + str(self)) + + def plot(self, Y=0.214, axis=None, figsize=(8, 8)): + self._log.debug('Calling plot') + dim, ext = self.dim, self.extent + # Create Lab colorspace: + x = np.linspace(ext[0], ext[1], dim[1]) + y = np.linspace(ext[2], ext[3], dim[0]) + xx, yy = np.meshgrid(x, y) + YY = np.full(dim, Y) + # Convert to XYZ: + XX = YY / (yy + 1e-30) * xx + ZZ = YY / (yy + 1e-30) * (1 - xx - yy) + XYZ = np.stack((XX, YY, ZZ), axis=-1) + # Convert to RGB colorspace following algorithm from http://www.easyrgb.com/index.php: + rgb = skcolor.colorconv._convert(skcolor.colorconv.rgb_from_xyz, XYZ) + # Gamma correction (gamma encoding) rgb are now nonlinear (R'G'B')!: + mask = rgb > 0.0031308 + rgb[mask] = 1.055 * np.power(rgb[mask], 1 / 2.4) - 0.055 + rgb[~mask] *= 12.92 + # Determine gamut: + gamut = np.logical_or(rgb < 0, rgb > 1) + gamut = np.sum(gamut, axis=-1, dtype=bool) + gamut_mask = np.stack((gamut, gamut, gamut), axis=-1) + # Cut out gamut (set out of bound colors to gray) if necessary: + if self.cut_out_gamut: + rgb[gamut_mask] = 0.5 + # Clip out of gamut colors: + if self.clip: + rgb[rgb < 0] = 0 + rgb[rgb > 1] = 1 + # Plot colorspace: + if axis is None: + fig = plt.figure(figsize=figsize) + axis = fig.add_subplot(1, 1, 1, aspect='equal') + axis.imshow(rgb, origin='lower', interpolation='none', extent=(0, dim[0], 0, dim[1])) + axis.contour(gamut, levels=[0], colors='k', linewidths=1.5) + axis.set_xlabel('x', fontsize=15) + axis.set_ylabel('y', fontsize=15) + axis.set_title('CIExyY (Y = {:g})'.format(Y), fontsize=18) + fx = FuncForm(lambda x, pos: '{:.3g}'.format(x / dim[1] * (ext[1] - ext[0]) + ext[0])) + axis.xaxis.set_major_formatter(fx) + fy = FuncForm(lambda y, pos: '{:.3g}'.format(y / dim[0] * (ext[3] - ext[2]) + ext[2])) + axis.yaxis.set_major_formatter(fy) + axis.xaxis.set_major_locator(MaxNLocator(nbins=12, integer=True)) + axis.yaxis.set_major_locator(MaxNLocator(nbins=12, integer=True)) + + def plot_colormap(self, cmap, N=256, Y='auto', figsize=None, cbar_lim=None, brightness=True, + input_rec=None): + self._log.debug('Calling plot_colormap') + dim, ext = self.dim, self.extent + # Calculate rgb values: + rgb = cmap(np.linspace(0, 1, N))[None, :, :3] + if input_rec == 601: + rgb = RGBConverter('Rec601', 'Rec709')(rgb) + # Convert to XYZ space: + XYZ = np.squeeze(skcolor.rgb2xyz(rgb)) + XX, YY, ZZ = XYZ.T + # Convert to xyY space: + xx = XX / (XX + YY + ZZ) + yy = YY / (XX + YY + ZZ) + xx = (xx - ext[0]) / (ext[1] - ext[0]) * dim[1] + yy = (yy - ext[2]) / (ext[3] - ext[2]) * dim[0] + # Determine number of images / luma levels: + YY_min, YY_max = np.round(np.min(YY), 2), np.round(np.max(YY), 2) + if Y == 'auto': + if YY_max - YY_min < 0.01: # Just one image: + Y = YY_min + else: # Two images: + Y = np.asarray((YY_max, np.mean(YY), YY_min)) + Y_list = np.atleast_1d(Y) + # Determine colorbar limits: + if cbar_lim is not None: # Overwrite limits! + YY_min, YY_max = cbar_lim + elif not brightness or YY_max - YY_min < 0.01: # Just one value, full range for colormap: + YY_min, YY_max = 0, 1 + # Creat grid: + if figsize is None: + figsize = (len(Y_list) * 5 + 2, 7) + fig = plt.figure(figsize=figsize) + grid = ImageGrid(fig, 111, nrows_ncols=(1, len(Y_list)), axes_pad=0.4, share_all=False, + cbar_location="right", cbar_mode="single", cbar_size="5%", cbar_pad=0.25) + # Plot: + if brightness: + c = YY + cmap = 'gray' + else: + c = np.linspace(0, 1, N) + for i, axis in enumerate(grid): + self.plot(Y=Y_list[i], axis=axis) + im = axis.scatter(xx, yy, c=c, cmap=cmap, edgecolors='none', + vmin=YY_min, vmax=YY_max) + axis.set_xlim(0, self.dim[1]) + axis.set_ylim(0, self.dim[0]) + axis.cax.colorbar(im, ticks=np.linspace(YY_min, YY_max, 9)) + + def plot3d(self, N=9): + self._log.debug('Calling plot3d') + dim, ext = self.dim, self.extent + # Create Lab colorspace: + x = np.linspace(ext[0], ext[1], dim[1]) + y = np.linspace(ext[2], ext[3], dim[0]) + xx, yy = np.meshgrid(x, y) + import visvis + for i in range(1, N): + YY = np.full(dim, i * 1. / N) + # Convert to XYZ: + XX = YY / (yy + 1e-30) * xx + ZZ = YY / (yy + 1e-30) * (1 - xx - yy) + XYZ = np.stack((XX, YY, ZZ), axis=-1) + # Convert to RGB colorspace following algorithm from http://www.easyrgb.com/index.php: + rgb = skcolor.colorconv._convert(skcolor.colorconv.rgb_from_xyz, XYZ) + # Gamma correction (gamma encoding) rgb are now nonlinear (R'G'B')!: + mask = rgb > 0.0031308 + rgb[mask] = 1.055 * np.power(rgb[mask], 1 / 2.4) - 0.055 + rgb[~mask] *= 12.92 + # Determine gamut: + gamut = np.logical_or(rgb < 0, rgb > 1) + gamut = np.sum(gamut, axis=-1, dtype=bool) + # Alpha: + alpha = 1. + a = np.full(dim + (1,), alpha) + a *= np.logical_not(gamut[..., None]) + rgba = np.asarray(255 * np.dstack((rgb, a)), dtype=np.uint8) + # Visvis plot: + obj = visvis.functions.surf(xx, yy, i / N * np.ones_like(xx), rgba, aa=0) + obj.parent.light0.ambient = 1. + obj.parent.light0.diffuse = 0. + + +class ColorspaceYPbPr(object): + """Class representing the YPbPr colorspace.""" + + _log = logging.getLogger(__name__ + '.ColorspaceYPbPr') + + def __init__(self, dim=(500, 500), extent=(-0.8, 0.8, -0.8, 0.8), cut_gamut=False, clip=True): + self._log.debug('Calling __init__') + self.dim = dim + self.extent = extent + self.cut_out_gamut = cut_gamut + self.clip = clip + self._log.debug('Created ' + str(self)) + + def plot(self, Y=0.5, axis=None, figsize=(8, 8)): + self._log.debug('Calling plot') + dim, ext = self.dim, self.extent + # Create YPbPr colorspace: + pb = np.linspace(ext[0], ext[1], dim[1]) + pr = np.linspace(ext[2], ext[3], dim[0]) + ppb, ppr = np.meshgrid(pb, pr) + YY = np.full(dim, Y) # This is luma, not relative luminance (Y', not Y)! + # Convert to RGB colorspace (this is the nonlinear R'G'B' space!): + rr = YY + 1.402 * ppr + gg = YY - 0.344136 * ppb - 0.7141136 * ppr + bb = YY + 1.772 * ppb + rgb = np.stack((rr, gg, bb), axis=-1) + # Determine gamut: + gamut = np.logical_or(rgb < 0, rgb > 1) + gamut = np.sum(gamut, axis=-1, dtype=bool) + gamut_mask = np.stack((gamut, gamut, gamut), axis=-1) + # Cut out gamut (set out of bound colors to gray) if necessary: + if self.cut_out_gamut: + rgb[gamut_mask] = 0.5 + # Clip out of gamut colors: + if self.clip: + rgb[rgb < 0] = 0 + rgb[rgb > 1] = 1 + # Plot colorspace: + if axis is None: + fig = plt.figure(figsize=figsize) + axis = fig.add_subplot(1, 1, 1, aspect='equal') + axis.imshow(rgb, origin='lower', interpolation='none', + extent=(0, dim[0], 0, dim[1])) + axis.contour(gamut, levels=[0], colors='k', linewidths=1.5) + axis.set_xlabel('Pb', fontsize=15) + axis.set_ylabel('Pr', fontsize=15) + axis.set_title("Y'PbPr (Y' = {:g})".format(Y), fontsize=18) + fx = FuncForm( + lambda x, pos: '{:.3g}'.format(x / dim[1] * (ext[1] - ext[0]) + ext[0])) + axis.xaxis.set_major_formatter(fx) + fy = FuncForm( + lambda y, pos: '{:.3g}'.format(y / dim[0] * (ext[3] - ext[2]) + ext[2])) + axis.yaxis.set_major_formatter(fy) + axis.xaxis.set_major_locator(MaxNLocator(nbins=12, integer=True)) + axis.yaxis.set_major_locator(MaxNLocator(nbins=12, integer=True)) + + def plot_colormap(self, cmap, N=256, Y='auto', figsize=None, cbar_lim=None, brightness=True, + input_rec=None): + self._log.debug('Calling plot_colormap') + dim, ext = self.dim, self.extent + # Calculate rgb values: + rgb = cmap(np.linspace(0, 1, N))[None, :, :3] + if input_rec == 709: + rgb = RGBConverter('Rec709', 'Rec601')(rgb) + rr, gg, bb = rgb.T + # Convert to YPbPr space: + k_r, k_g, k_b = 0.299, 0.587, 0.114 # Constants Rec.601! + YY = k_r * rr + k_g * gg + k_b * bb + ppb = (bb - YY) / (2 * (1 - k_b)) + ppr = (rr - YY) / (2 * (1 - k_r)) + ppb = (ppb - ext[0]) / (ext[1] - ext[0]) * dim[1] + ppr = (ppr - ext[2]) / (ext[3] - ext[2]) * dim[0] + # Determine number of images / luma levels: + YY_min, YY_max = np.round(np.min(YY), 2), np.round(np.max(YY), 2) + if Y == 'auto': + if YY_max - YY_min < 0.01: # Just one image: + Y = YY_min + else: # Two images: + Y = np.asarray((YY_max, np.mean(YY), YY_min)) + Y_list = np.atleast_1d(Y) + # Determine colorbar limits: + if cbar_lim is not None: # Overwrite limits! + YY_min, YY_max = cbar_lim + elif not brightness or YY_max - YY_min < 0.01: # Just one value, full range for colormap: + YY_min, YY_max = 0, 1 + # Creat grid: + if figsize is None: + figsize = (len(Y_list) * 5 + 2, 7) + fig = plt.figure(figsize=figsize) + grid = ImageGrid(fig, 111, nrows_ncols=(1, len(Y_list)), axes_pad=0.4, share_all=False, + cbar_location="right", cbar_mode="single", cbar_size="5%", cbar_pad=0.25) + # Plot: + if brightness: + c = YY + cmap = 'gray' + else: + c = np.linspace(0, 1, N) + for i, axis in enumerate(grid): + self.plot(Y=Y_list[i], axis=axis) + im = axis.scatter(ppb, ppr, c=c, cmap=cmap, edgecolors='none', + vmin=YY_min, vmax=YY_max) + axis.set_xlim(0, self.dim[1]) + axis.set_ylim(0, self.dim[0]) + axis.cax.colorbar(im, ticks=np.linspace(YY_min, YY_max, 9)) + + def plot3d(self, N=9): + self._log.debug('Calling plot3d') + dim, ext = self.dim, self.extent + # Create YPbPr colorspace: + pb = np.linspace(ext[0], ext[1], dim[1]) + pr = np.linspace(ext[2], ext[3], dim[0]) + ppb, ppr = np.meshgrid(pb, pr) + import visvis + for i in range(1, N): + YY = np.full(dim, i * 1. / N) # This is luma, not relative luminance (Y', not Y)! + # Convert to RGB colorspace (this is the nonlinear R'G'B' space!): + rr = YY + 1.402 * ppr + gg = YY - 0.344136 * ppb - 0.7141136 * ppr + bb = YY + 1.772 * ppb + rgb = np.stack((rr, gg, bb), axis=-1) + # Determine gamut: + gamut = np.logical_or(rgb < 0, rgb > 1) + gamut = np.sum(gamut, axis=-1, dtype=bool) + # Alpha: + alpha = 1. + a = np.full(dim + (1,), alpha) + a *= np.logical_not(gamut[..., None]) + rgba = np.asarray(255 * np.dstack((rgb, a)), dtype=np.uint8) + # Visvis plot: + obj = visvis.functions.surf(ppb, ppr, i / N * np.ones_like(ppb), rgba, aa=0) + obj.parent.light0.ambient = 1. + obj.parent.light0.diffuse = 0. + + +class RGBConverter(object): + """Class for the conversion of RGB values from one RGB space to another. + + Notes + ----- + This operates only on NONLINEAR R'G'B' values, normalised to a range of [0, 1]! + Convert from linear RGB values beforehand, if necessary! + + """ + + rgb601_to_ypbpr = np.array([[+0.299000, +0.587000, +0.114000], + [-0.168736, -0.331264, +0.500000], + [+0.500000, -0.418688, -0.081312]]) + ypbr_to_rgb709 = np.array([[1, +0.0000, +1.5701], + [1, -0.1870, -0.4664], + [1, +1.8556, +0.0000]]) + rgb601_to_rgb709 = ypbr_to_rgb709.dot(rgb601_to_ypbpr) + rgb709_to_rgb601 = np.linalg.inv(rgb601_to_rgb709) + + def __init__(self, source='Rec601', target='Rec709'): + if source == 'Rec601' and target == 'Rec709': + self.convert_matrix = self.rgb601_to_rgb709 + elif source == 'Rec709' and target == 'Rec601': + self.convert_matrix = self.rgb709_to_rgb601 + else: + raise KeyError('Conversion from {} to {} not found!'.format(source, target)) + + def __call__(self, rgb): + """Convert from one RGB space to another. + + Parameters + ---------- + rgb: :class:`~numpy.ndarray` + Numpy array containing the RGB source values (last dimension: 3). + + Returns + ------- + rgb_result: :class:`~numpy.ndarray` + The resulting RGB values in the target space. + + """ + rgb_out = rgb.reshape((-1, 3)).T + rgb_out = self.convert_matrix.dot(rgb_out) + return rgb_out.T.reshape(rgb.shape) + + +def interpolate_color(fraction, start, end): + """Interpolate linearly between two color tuples (e.g. RGB). + + Parameters + ---------- + fraction: float or :class:`~numpy.ndarray` + Interpolation fraction between 0 and 1, which determines the position of the + interpolation between `start` and `end`. + start: tuple (N=3) or :class:`~numpy.ndarray` + Start of the interpolation as a tuple of three numbers or a numpy array, where the last + dimension should have length 3 and contain the color tuples. + end: tuple (N=3) or :class:`~numpy.ndarray` + End of the interpolation as a tuple of three numbers or a numpy array, where the last + dimension should have length 3 and contain the color tuples. + + Returns + ------- + result: tuple (N=3) or :class:`~numpy.ndarray` + Result of the interpolation as a tuple of three numbers or a numpy array, where the + last dimension should has length 3 and contains the color tuples. + + """ + _log.debug('Calling interpolate_color') + start, end = np.asarray(start), np.asarray(end) + r1 = start[..., 0] + (end[..., 0] - start[..., 0]) * fraction + r2 = start[..., 1] + (end[..., 1] - start[..., 1]) * fraction + r3 = start[..., 2] + (end[..., 2] - start[..., 2]) * fraction + return r1, r2, r3 + + +def rgb_to_brightness(rgb, mode="Y'", input_rec=None): + + import colorspacious # TODO: Use for everything! + c = {601: [0.299, 0.587, 0.114], 709: [0.2125, 0.7154, 0.0721]} # Image.convert('L') uses 601! + if input_rec is None: # Not specified, use in all cases: + rgbp601 = rgb + rgbp709 = rgb + elif input_rec == 601: + rgbp601 = rgb + rgbp709 = RGBConverter('Rec601', 'Rec709')(rgb) + elif input_rec == 709: + rgbp601 = RGBConverter('Rec601', 'Rec709')(rgb) + rgbp709 = rgb + else: + raise KeyError('Input RGB type {} not understood!'.format(input_rec)) + if mode in ("Y'", 'Luma'): + rp601, gp601, bp601 = rgbp601.T + brightness = c[601][0] * rp601 + c[601][1] * gp601 + c[601][2] * bp601 + elif mode in ('Y', 'Luminance'): + rgb709 = colorspacious.cspace_converter('sRGB1', 'sRGB1-linear')(rgbp709) + r709, g709, b709 = rgb709.T + brightness = c[709][0] * r709 + c[709][1] * g709 + c[709][2] * b709 + elif mode in ('L*', 'LightnessLab'): + lab = colorspacious.cspace_converter('sRGB1', 'CIELab')(rgbp709) + brightness = lab[0, :, 0] + elif mode in ('I', 'Intensity', 'Average'): + brightness = np.mean(rgb, axis=-1) + elif mode in ('V', 'Value', 'Maximum'): + brightness = np.max(rgb, axis=-1) + elif mode in ('L', 'LightnessHSL'): + brightness = (np.max(rgb, axis=-1) + np.min(rgb, axis=-1)) / 2 + else: + raise KeyError('Brightness request {} not understood!'.format(mode)) + return brightness + + +def colormap_brightness_comparison(cmap, input_rec=None, figsize=(18, 8)): + + # Create R'G'B' values from colormap: + x = np.linspace(0, 1, 1000) + rgbp = cmap(x)[None, :, :3] + # Calculate different brightness measures: + luma = rgb_to_brightness(rgbp, mode="Y'", input_rec=input_rec) + luminance = rgb_to_brightness(rgbp, mode='Y', input_rec=input_rec) + lightness_lab = rgb_to_brightness(rgbp, mode='L*', input_rec=input_rec) + intensity = rgb_to_brightness(rgbp, mode='I', input_rec=input_rec) + value = rgb_to_brightness(rgbp, mode='V', input_rec=input_rec) + lightness_hls = rgb_to_brightness(rgbp, mode='L', input_rec=input_rec) + # Plot: + fig, grid = plt.subplots(2, 3, figsize=figsize) + plt.title(cmap.name) + axis = grid[0, 0] + axis.scatter(x, luma, c=x, cmap=cmap, s=200, linewidths=0.) + axis.axhline(y=0.5, color='k', ls='--') + axis.set_xlim(0, 1) + axis.set_ylim(0, 1) + axis.set_title("Luma $Y$ '") + axis = grid[0, 1] + axis.scatter(x, luminance, c=x, cmap=cmap, s=200, linewidths=0.) + axis.axhline(y=0.214, color='k', ls='--') + axis.set_xlim(0, 1) + axis.set_ylim(0, 1) + axis.set_title('Relative Luminance $Y$') + axis = grid[0, 2] + axis.scatter(x, lightness_lab, c=x, cmap=cmap, s=200, linewidths=0.) + axis.axhline(y=53.39, color='k', ls='--') + axis.set_xlim(0, 1) + axis.set_ylim(0, 100) + axis.set_title('Lightness $L^*$ (CIELab)') + axis = grid[1, 0] + axis.scatter(x, intensity, c=x, cmap=cmap, s=200, linewidths=0.) + axis.axhline(y=53.39, color='k', ls='--') + axis.set_xlim(0, 1) + axis.set_ylim(0, 1) + axis.set_title('Intensity $I$ (HSI Component Average)') + axis = grid[1, 1] + axis.scatter(x, value, c=x, cmap=cmap, s=200, linewidths=0.) + axis.axhline(y=53.39, color='k', ls='--') + axis.set_xlim(0, 1) + axis.set_ylim(0, 1) + axis.set_title('Value $V$ (HSV Component Maximum)') + axis = grid[1, 2] + axis.scatter(x, lightness_hls, c=x, cmap=cmap, s=200, linewidths=0.) + axis.axhline(y=53.39, color='k', ls='--') + axis.set_xlim(0, 1) + axis.set_ylim(0, 1) + axis.set_title('Lightness $L$ (HSL Min-Max-Average)') + + +cmaps = {'cubehelix_standard': ColormapCubehelix(), + 'cubehelix_reverse': ColormapCubehelix(reverse=True), + 'cubehelix_circular': ColormapCubehelix(start=1, rot=1, + minLight=0.5, maxLight=0.5, sat=2), + 'perception_circular': ColormapPerception(), + 'hls_circular': ColormapHLS(), + 'classic_circular': ColormapClassic(), + 'transparent_black': ColormapTransparent(0, 0, 0, [0, 1.]), + 'transparent_white': ColormapTransparent(1, 1, 1, [0, 1.]), + 'transparent_confidence': ColormapTransparent(0.2, 0.3, 0.2, [0.75, 0.])} + +CMAP_CIRCULAR_DEFAULT = cmaps['cubehelix_circular'] diff --git a/pyramid/diagnostics.py b/pyramid/diagnostics.py index 5039c314f529b87de984c289e639f81b55cb77e4..c714042635905a4380202548865382242e3c9bfa 100644 --- a/pyramid/diagnostics.py +++ b/pyramid/diagnostics.py @@ -1,398 +1,398 @@ -# -*- coding: utf-8 -*- -# Copyright 2014 by Forschungszentrum Juelich GmbH -# Author: J. Caron -# -"""This module provides the :class:`~.Diagnostics` class for the calculation of diagnostics of a -specified costfunction for a fixed magnetization distribution.""" - -import logging - -from pyramid.fielddata import VectorData -from pyramid.phasemap import PhaseMap - -import matplotlib.pyplot as plt -from matplotlib import patches -from matplotlib import patheffects -from matplotlib.ticker import FuncFormatter -import numpy as np - -import jutil - -__all__ = ['Diagnostics', 'get_vector_field_errors'] - - -class Diagnostics(object): - """Class for calculating diagnostic properties of a specified costfunction. - - For the calculation of diagnostic properties, a costfunction and a magnetization distribution - are specified at construction. With the :func:`~.set_position`, a position in 3D space can be - set at which all properties will be calculated. Properties are saved via boolean flags and - thus, calculation is only done if the position has changed in between. The standard deviation - and the measurement contribution require the execution of a conjugate gradient solver and can - take a while for larger problems. - - Attributes - ---------- - x_rec: :class:`~numpy.ndarray` - Vectorized magnetization distribution at which the costfunction is evaluated. - cost: :class:`~.pyramid.costfunction.Costfunction` - Costfunction for which the diagnostics are calculated. - max_iter: int, optional - Maximum number of iterations. Default is 1000. - fwd_model: :class:`~pyramid.forwardmodel.ForwardModel` - Forward model used in the costfunction. - Se_inv : :class:`~numpy.ndarray` (N=2), optional - Inverted covariance matrix of the measurement errors. The matrix has size `NxN` with N - being the length of the targetvector y (vectorized phase map information). - dim: tuple (N=3) - Dimensions of the 3D magnetic distribution. - row_idx: int - Row index of the system matrix corresponding to the current position in 3D space. - - Notes - ----- - Some properties depend on others, which may require recalculations of these prior - properties if necessary. The dependencies are ('-->' == 'requires'): - avrg_kern_row --> gain_row --> std --> m_inv_row - measure_contribution is independant - - """ - - _log = logging.getLogger(__name__ + '.Diagnostics') - - @property - def cov_row(self): - """Row of the covariance matrix (``S_a^-1+F'(x_f)^T S_e^-1 F'(x_f)``) which is needed for - the calculation of the gain and averaging kernel matrizes and which ideally contains the - variance at position `row_idx` for the current component and position in 3D. - Note that the covariance matrix of the solution is symmetric (like all covariance - matrices) and thusly this property could also be called cov_col for column.""" - if not self._updated_cov_row: - e_i = np.zeros(self.cost.n, dtype=self.x_rec.dtype) - e_i[self.row_idx] = 1 - row = 2 * jutil.cg.conj_grad_solve(self._A, e_i, P=self._P, max_iter=self.max_iter, - verbose=self.verbose) - self._std_row = np.asarray(row) - self._updated_cov_row = True - return self._std_row - - @property - def std(self): - """Standard deviation of the chosen component at the current position (calculated when - needed).""" - return np.sqrt(self.cov_row[self.row_idx]) - - @property - def gain_row(self): - """Row of the gain matrix, which maps differences of phase measurements onto differences in - the retrieval result of the magnetization distribution(calculated when needed).""" - if not self._updated_gain_row: - self._gain_row = self.Se_inv.dot(self.fwd_model.jac_dot(self.x_rec, self.cov_row)) - self._updated_gain_row = True - return self._gain_row - - @property - def avrg_kern_row(self): - """Row of the averaging kernel matrix (which is ideally the identity matrix), which - describes the smoothing introduced by the regularization (calculated when needed).""" - if not self._updated_avrg_kern_row: - self._avrg_kern_row = self.fwd_model.jac_T_dot(self.x_rec, self.gain_row) - self._updated_avrg_kern_row = True - return self._avrg_kern_row - - @property - def measure_contribution(self): - """The sum over an averaging kernel matrix row, which is an indicator for wheter a point of - the solution is determined by the measurement (close to `1`) or by a priori information - (close to `0`).""" - if not self._updated_measure_contribution: - cache = self.fwd_model.jac_dot(self.x_rec, np.ones(self.cost.n, self.x_rec.dtype)) - cache = self.fwd_model.jac_T_dot(self.x_rec, self.Se_inv.dot(cache)) - mc = 2 * jutil.cg.conj_grad_solve(self._A, cache, P=self._P, max_iter=self.max_iter) - self._measure_contribution = mc - self._updated_measure_contribution = True - return self._measure_contribution - - @property - def pos(self): - """The current solution position, which specifies the 3D-point (and the component) of the - magnetization, for which diagnostics are calculated.""" - return self._pos - - @pos.setter - def pos(self, pos): - c, z, y, x = pos - assert self.mask[z, y, x], 'Position is outside of the provided mask!' - mask_vec = self.mask.ravel() - idx_3d = z * self.dim[1] * self.dim[2] + y * self.dim[2] + x - row_idx = c * np.prod(mask_vec.sum()) + mask_vec[:idx_3d].sum() - if row_idx != self.row_idx: - self._pos = pos - self.row_idx = row_idx - self._updated_cov_row = False - self._updated_gain_row = False - self._updated_avrg_kern_row = False - self._updated_measure_contribution = False - - def __init__(self, magdata, cost, max_iter=1000, verbose=False): - self._log.debug('Calling __init__') - self.magdata = magdata - self.cost = cost - self.max_iter = max_iter - self.verbose = verbose - self.fwd_model = cost.fwd_model - self.Se_inv = cost.Se_inv - self.dim = cost.fwd_model.data_set.dim - self.mask = cost.fwd_model.data_set.mask - self.x_rec = np.empty(cost.n) - self.x_rec[:self.fwd_model.data_set.n] = self.magdata.get_vector(mask=self.mask) - self.x_rec[self.fwd_model.data_set.n:] = self.fwd_model.ramp.param_cache.ravel() - self.row_idx = None - self.pos = (0,) + tuple(np.array(np.where(self.mask))[:, 0]) # first True mask entry - self._updated_cov_row = False - self._updated_gain_row = False - self._updated_avrg_kern_row = False - self._updated_measure_contribution = False - self._A = jutil.operator.CostFunctionOperator(self.cost, self.x_rec) - self._P = jutil.preconditioner.CostFunctionPreconditioner(self.cost, self.x_rec) - self._log.debug('Creating ' + str(self)) - - def get_avrg_kern_field(self, pos=None): - """Get the averaging kernel matrix row represented as a 3D magnetization distribution. - - Parameters - ---------- - pos: tuple (N=4) - Position in 3D plus component `(c, z, y, x)` - - Returns - ------- - magdata_avrg_kern: :class:`~pyramid.fielddata.VectorData` - Averaging kernel matrix row represented as a 3D magnetization distribution - - """ - self._log.debug('Calling get_avrg_kern_field') - if pos is not None: - self.pos = pos - magdata_avrg_kern = VectorData(self.cost.fwd_model.data_set.a, np.zeros((3,) + self.dim)) - vector = self.avrg_kern_row[:-self.fwd_model.ramp.n] # Only take vector field, not ramp! - magdata_avrg_kern.set_vector(vector, mask=self.mask) - return magdata_avrg_kern - - def calculate_fwhm(self, pos=None, plot=False): - """Calculate and plot the averaging pixel number at a specified position for x, y or z. - - Parameters - ---------- - pos: tuple (N=4) - Position in 3D plus component `(c, z, y, x)` - plot : bool, optional - If True, a FWHM linescan plot is shown. Default is False. - - Returns - ------- - fwhm : float - The FWHM in x, y and z direction. The inverse corresponds to the number of pixels over - which is approximately averaged. - lr : 3 tuples of 2 floats - The left and right borders in x, y and z direction from which the FWHM is calculated. - Given in pixel coordinates and relative to the current position! - cxyz_dat : 4 lists of floats - The slices through the current position in the 4D volume (including the component), - which were used for FWHM calculations. Denotes information content in %! - - Notes - ----- - Uses the :func:`~.get_avrg_kern_field` function - - """ - self._log.debug('Calling calculate_fwhm') - a = self.magdata.a - magdata_avrg_kern = self.get_avrg_kern_field(pos) - x = np.arange(0, self.dim[2]) - self.pos[3] - y = np.arange(0, self.dim[1]) - self.pos[2] - z = np.arange(0, self.dim[0]) - self.pos[1] - c_dat = magdata_avrg_kern.field[:, self.pos[1], self.pos[2], self.pos[3]] - x_dat = magdata_avrg_kern.field[self.pos[0], self.pos[1], self.pos[2], :] - y_dat = magdata_avrg_kern.field[self.pos[0], self.pos[1], :, self.pos[3]] - z_dat = magdata_avrg_kern.field[self.pos[0], :, self.pos[2], self.pos[3]] - c_dat = np.asarray(c_dat * 100) # in % - x_dat = np.asarray(x_dat * 100) # in % - y_dat = np.asarray(y_dat * 100) # in % - z_dat = np.asarray(z_dat * 100) # in % - - def _calc_lr(c): - data = [x_dat, y_dat, z_dat][c] - i_m = np.argmax(data) # Index of the maximum - # Left side: - l = i_m - for i in np.arange(i_m - 1, -1, -1): - if data[i] < data[i_m] / 2: - # Linear interpolation between i and i + 1 to find left fractional index pos: - l = (data[i_m] / 2 - data[i]) / (data[i + 1] - data[i]) + i - break - # Right side: - r = i_m - for i in np.arange(i_m + 1, data.size): - if data[i] < data[i_m] / 2: - # Linear interpolation between i and i - 1 to find right fractional index pos: - r = (data[i_m] / 2 - data[i - 1]) / (data[i] - data[i - 1]) + i - 1 - break - # Transform from index to coordinates: - l = (l - self.pos[3-c]) - r = (r - self.pos[3-c]) - return l, r - - # Calculate FWHM: - lx, rx = _calc_lr(0) - ly, ry = _calc_lr(1) - lz, rz = _calc_lr(2) - fwhm_x = (rx - lx) * a - fwhm_y = (ry - ly) * a - fwhm_z = (rz - lz) * a - # Plot helpful stuff: - if plot: - fig, axis = plt.subplots(1, 1) - axis.axvline(x=0, ls='-', color='k', linewidth=2) - axis.axhline(y=0, ls='-', color='k', linewidth=2) - axis.axhline(y=x_dat.max(), ls='-', color='k', linewidth=2) - axis.axhline(y=x_dat.max() / 2, ls='--', color='k', linewidth=2) - axis.vlines(x=[lx, rx], ymin=0, ymax=x_dat.max() / 2, linestyles='--', - color='r', linewidth=2, alpha=0.5) - axis.vlines(x=[ly, ry], ymin=0, ymax=y_dat.max() / 2, linestyles='--', - color='g', linewidth=2, alpha=0.5) - axis.vlines(x=[lz, rz], ymin=0, ymax=z_dat.max() / 2, linestyles='--', - color='b', linewidth=2, alpha=0.5) - l = [] - l.extend(axis.plot(x, x_dat, label='x-dim.', color='r', marker='o', linewidth=2)) - l.extend(axis.plot(y, y_dat, label='y-dim.', color='g', marker='o', linewidth=2)) - l.extend(axis.plot(z, z_dat, label='z-dim.', color='b', marker='o', linewidth=2)) - cx = axis.scatter(0, c_dat[0], marker='o', s=200, edgecolor='r', label='x-comp.', - facecolor='r', alpha=0.75) - cy = axis.scatter(0, c_dat[1], marker='d', s=200, edgecolor='g', label='y-comp.', - facecolor='g', alpha=0.75) - cz = axis.scatter(0, c_dat[2], marker='*', s=200, edgecolor='b', label='z-comp.', - facecolor='b', alpha=0.75) - lim_min = np.min(np.concatenate((x, y, z))) - 0.5 - lim_max = np.max(np.concatenate((x, y, z))) + 0.5 - axis.set_xlim(lim_min, lim_max) - axis.set_title('Avrg. kern. FWHM', fontsize=18) - axis.set_xlabel('x/y/z-slice [nm]', fontsize=15) - axis.set_ylabel('information content [%]', fontsize=15) - axis.tick_params(axis='both', which='major', labelsize=14) - axis.xaxis.set_major_formatter(FuncFormatter(lambda x, pos: '{:.3g}'.format(x * a))) - comp_legend = axis.legend([cx, cy, cz], [c.get_label() for c in [cx, cy, cz]], loc=2, - scatterpoints=1, prop={'size': 14}) - axis.legend(l, [i.get_label() for i in l], loc=1, numpoints=1, prop={'size': 14}) - axis.add_artist(comp_legend) - fwhm = fwhm_x, fwhm_y, fwhm_z - lr = (lx, rx), (ly, ry), (lz, rz) - cxyz_dat = c_dat, x_dat, y_dat, z_dat - return fwhm, lr, cxyz_dat - - def get_gain_maps(self, pos=None): - """Get the gain matrix row represented as a list of 2D (inverse) phase maps. - - Parameters - ---------- - pos: tuple (N=4) - Position in 3D plus component `(c, z, y, x)` - - Returns - ------- - gain_map_list: list of :class:`~pyramid.phasemap.PhaseMap` - Gain matrix row represented as a list of 2D phase maps - - Notes - ----- - Note that the produced gain maps define the magnetization change at the current position - in 3d per phase change at the position of the . Take this into account when plotting the - maps (1/rad instead of rad). - - """ - self._log.debug('Calling get_gain_maps') - if pos is not None: - self.pos = pos - hp = self.cost.fwd_model.data_set.hook_points - gain_map_list = [] - for i, projector in enumerate(self.cost.fwd_model.data_set.projectors): - gain = self.gain_row[hp[i]:hp[i + 1]].reshape(projector.dim_uv) - gain_map_list.append(PhaseMap(self.cost.fwd_model.data_set.a, gain)) - return gain_map_list - - def plot_position(self, **kwargs): - proj_axis = kwargs.get('proj_axis', 'z') - if proj_axis == 'z': # Slice of the xy-plane with z = ax_slice - pos_2d = (self.pos[2], self.pos[3]) - ax_slice = self.pos[1] - elif proj_axis == 'y': # Slice of the xz-plane with y = ax_slice - pos_2d = (self.pos[1], self.pos[3]) - ax_slice = self.pos[2] - elif proj_axis == 'x': # Slice of the zy-plane with x = ax_slice - pos_2d = (self.pos[2], self.pos[1]) - ax_slice = self.pos[3] - else: - raise ValueError('{} is not a valid argument (use x, y or z)'.format(proj_axis)) - note = kwargs.pop('note', None) - if note is None: - comp = {0: 'x', 1: 'y', 2: 'z'}[self.pos[0]] - note = '{}-comp., pos.: {}'.format(comp, self.pos[1:]) - # Plots: - axis = self.magdata.plot_quiver_field(note=note, ax_slice=ax_slice, **kwargs) - rect = axis.add_patch(patches.Rectangle((pos_2d[1], pos_2d[0]), 1, 1, fill=False, - edgecolor='w', linewidth=2, alpha=0.5)) - rect.set_path_effects([patheffects.withStroke(linewidth=4, foreground='k', alpha=0.5)]) - - def plot_position3d(self, **kwargs): - pass - - def plot_avrg_kern_field(self, pos=None, **kwargs): - a = self.magdata.a - avrg_kern_field = self.get_avrg_kern_field(pos) - fwhms, lr = self.calculate_fwhm(pos)[:2] - proj_axis = kwargs.get('proj_axis', 'z') - if proj_axis == 'z': # Slice of the xy-plane with z = ax_slice - pos_2d = (self.pos[2], self.pos[3]) - ax_slice = self.pos[1] - width, height = fwhms[0] / a, fwhms[1] / a - elif proj_axis == 'y': # Slice of the xz-plane with y = ax_slice - pos_2d = (self.pos[1], self.pos[3]) - ax_slice = self.pos[2] - width, height = fwhms[0] / a, fwhms[2] / a - elif proj_axis == 'x': # Slice of the zy-plane with x = ax_slice - pos_2d = (self.pos[2], self.pos[1]) - ax_slice = self.pos[3] - width, height = fwhms[2] / a, fwhms[1] / a - else: - raise ValueError('{} is not a valid argument (use x, y or z)'.format(proj_axis)) - note = kwargs.pop('note', None) - if note is None: - comp = {0: 'x', 1: 'y', 2: 'z'}[self.pos[0]] - note = '{}-comp., pos.: {}'.format(comp, self.pos[1:]) - # Plots: - axis = avrg_kern_field.plot_quiver_field(note=note, ax_slice=ax_slice, **kwargs) - xy = (pos_2d[1], pos_2d[0]) - rect = axis.add_patch(patches.Rectangle(xy, 1, 1, fill=False, edgecolor='w', - linewidth=2, alpha=0.5)) - rect.set_path_effects([patheffects.withStroke(linewidth=4, foreground='k', alpha=0.5)]) - xy = (xy[0] + 0.5, xy[1] + 0.5) - artist = axis.add_patch(patches.Ellipse(xy, width, height, fill=False, edgecolor='w', - linewidth=2, alpha=0.5)) - artist.set_path_effects([patheffects.withStroke(linewidth=4, foreground='k', alpha=0.5)]) - - -def get_vector_field_errors(vector_data, vector_data_ref): - """After Kemp et. al.: Analysis of noise-induced errors in vector-field electron tomography""" - v, vr = vector_data.field, vector_data_ref.field - va, vra = vector_data.field_amp, vector_data_ref.field_amp - volume = np.prod(vector_data.dim) - # Total error: - amp_sum_sqr = np.nansum((v - vr)**2) - rms_tot = np.sqrt(amp_sum_sqr / np.nansum(vra**2)) - # Directional error: - scal_prod = np.clip(np.nansum(vr * v, axis=0) / (vra * va), -1, 1) # arccos float pt. inacc.! - rms_dir = np.sqrt(np.nansum(np.arccos(scal_prod)**2) / volume) / np.pi - # Magnitude error: - rms_mag = np.sqrt(np.nansum((va - vra)**2) / np.nansum(vra**2)) - # Return results as tuple: - return rms_tot, rms_dir, rms_mag +# -*- coding: utf-8 -*- +# Copyright 2014 by Forschungszentrum Juelich GmbH +# Author: J. Caron +# +"""This module provides the :class:`~.Diagnostics` class for the calculation of diagnostics of a +specified costfunction for a fixed magnetization distribution.""" + +import logging + +from pyramid.fielddata import VectorData +from pyramid.phasemap import PhaseMap + +import matplotlib.pyplot as plt +from matplotlib import patches +from matplotlib import patheffects +from matplotlib.ticker import FuncFormatter +import numpy as np + +import jutil + +__all__ = ['Diagnostics', 'get_vector_field_errors'] + + +class Diagnostics(object): + """Class for calculating diagnostic properties of a specified costfunction. + + For the calculation of diagnostic properties, a costfunction and a magnetization distribution + are specified at construction. With the :func:`~.set_position`, a position in 3D space can be + set at which all properties will be calculated. Properties are saved via boolean flags and + thus, calculation is only done if the position has changed in between. The standard deviation + and the measurement contribution require the execution of a conjugate gradient solver and can + take a while for larger problems. + + Attributes + ---------- + x_rec: :class:`~numpy.ndarray` + Vectorized magnetization distribution at which the costfunction is evaluated. + cost: :class:`~.pyramid.costfunction.Costfunction` + Costfunction for which the diagnostics are calculated. + max_iter: int, optional + Maximum number of iterations. Default is 1000. + fwd_model: :class:`~pyramid.forwardmodel.ForwardModel` + Forward model used in the costfunction. + Se_inv : :class:`~numpy.ndarray` (N=2), optional + Inverted covariance matrix of the measurement errors. The matrix has size `NxN` with N + being the length of the targetvector y (vectorized phase map information). + dim: tuple (N=3) + Dimensions of the 3D magnetic distribution. + row_idx: int + Row index of the system matrix corresponding to the current position in 3D space. + + Notes + ----- + Some properties depend on others, which may require recalculations of these prior + properties if necessary. The dependencies are ('-->' == 'requires'): + avrg_kern_row --> gain_row --> std --> m_inv_row + measure_contribution is independant + + """ + + _log = logging.getLogger(__name__ + '.Diagnostics') + + @property + def cov_row(self): + """Row of the covariance matrix (``S_a^-1+F'(x_f)^T S_e^-1 F'(x_f)``) which is needed for + the calculation of the gain and averaging kernel matrizes and which ideally contains the + variance at position `row_idx` for the current component and position in 3D. + Note that the covariance matrix of the solution is symmetric (like all covariance + matrices) and thusly this property could also be called cov_col for column.""" + if not self._updated_cov_row: + e_i = np.zeros(self.cost.n, dtype=self.x_rec.dtype) + e_i[self.row_idx] = 1 + row = 2 * jutil.cg.conj_grad_solve(self._A, e_i, P=self._P, max_iter=self.max_iter, + verbose=self.verbose) + self._std_row = np.asarray(row) + self._updated_cov_row = True + return self._std_row + + @property + def std(self): + """Standard deviation of the chosen component at the current position (calculated when + needed).""" + return np.sqrt(self.cov_row[self.row_idx]) + + @property + def gain_row(self): + """Row of the gain matrix, which maps differences of phase measurements onto differences in + the retrieval result of the magnetization distribution(calculated when needed).""" + if not self._updated_gain_row: + self._gain_row = self.Se_inv.dot(self.fwd_model.jac_dot(self.x_rec, self.cov_row)) + self._updated_gain_row = True + return self._gain_row + + @property + def avrg_kern_row(self): + """Row of the averaging kernel matrix (which is ideally the identity matrix), which + describes the smoothing introduced by the regularization (calculated when needed).""" + if not self._updated_avrg_kern_row: + self._avrg_kern_row = self.fwd_model.jac_T_dot(self.x_rec, self.gain_row) + self._updated_avrg_kern_row = True + return self._avrg_kern_row + + @property + def measure_contribution(self): + """The sum over an averaging kernel matrix row, which is an indicator for wheter a point of + the solution is determined by the measurement (close to `1`) or by a priori information + (close to `0`).""" + if not self._updated_measure_contribution: + cache = self.fwd_model.jac_dot(self.x_rec, np.ones(self.cost.n, self.x_rec.dtype)) + cache = self.fwd_model.jac_T_dot(self.x_rec, self.Se_inv.dot(cache)) + mc = 2 * jutil.cg.conj_grad_solve(self._A, cache, P=self._P, max_iter=self.max_iter) + self._measure_contribution = mc + self._updated_measure_contribution = True + return self._measure_contribution + + @property + def pos(self): + """The current solution position, which specifies the 3D-point (and the component) of the + magnetization, for which diagnostics are calculated.""" + return self._pos + + @pos.setter + def pos(self, pos): + c, z, y, x = pos + assert self.mask[z, y, x], 'Position is outside of the provided mask!' + mask_vec = self.mask.ravel() + idx_3d = z * self.dim[1] * self.dim[2] + y * self.dim[2] + x + row_idx = c * np.prod(mask_vec.sum()) + mask_vec[:idx_3d].sum() + if row_idx != self.row_idx: + self._pos = pos + self.row_idx = row_idx + self._updated_cov_row = False + self._updated_gain_row = False + self._updated_avrg_kern_row = False + self._updated_measure_contribution = False + + def __init__(self, magdata, cost, max_iter=1000, verbose=False): + self._log.debug('Calling __init__') + self.magdata = magdata + self.cost = cost + self.max_iter = max_iter + self.verbose = verbose + self.fwd_model = cost.fwd_model + self.Se_inv = cost.Se_inv + self.dim = cost.fwd_model.data_set.dim + self.mask = cost.fwd_model.data_set.mask + self.x_rec = np.empty(cost.n) + self.x_rec[:self.fwd_model.data_set.n] = self.magdata.get_vector(mask=self.mask) + self.x_rec[self.fwd_model.data_set.n:] = self.fwd_model.ramp.param_cache.ravel() + self.row_idx = None + self.pos = (0,) + tuple(np.array(np.where(self.mask))[:, 0]) # first True mask entry + self._updated_cov_row = False + self._updated_gain_row = False + self._updated_avrg_kern_row = False + self._updated_measure_contribution = False + self._A = jutil.operator.CostFunctionOperator(self.cost, self.x_rec) + self._P = jutil.preconditioner.CostFunctionPreconditioner(self.cost, self.x_rec) + self._log.debug('Creating ' + str(self)) + + def get_avrg_kern_field(self, pos=None): + """Get the averaging kernel matrix row represented as a 3D magnetization distribution. + + Parameters + ---------- + pos: tuple (N=4) + Position in 3D plus component `(c, z, y, x)` + + Returns + ------- + magdata_avrg_kern: :class:`~pyramid.fielddata.VectorData` + Averaging kernel matrix row represented as a 3D magnetization distribution + + """ + self._log.debug('Calling get_avrg_kern_field') + if pos is not None: + self.pos = pos + magdata_avrg_kern = VectorData(self.cost.fwd_model.data_set.a, np.zeros((3,) + self.dim)) + vector = self.avrg_kern_row[:-self.fwd_model.ramp.n] # Only take vector field, not ramp! + magdata_avrg_kern.set_vector(vector, mask=self.mask) + return magdata_avrg_kern + + def calculate_fwhm(self, pos=None, plot=False): + """Calculate and plot the averaging pixel number at a specified position for x, y or z. + + Parameters + ---------- + pos: tuple (N=4) + Position in 3D plus component `(c, z, y, x)` + plot : bool, optional + If True, a FWHM linescan plot is shown. Default is False. + + Returns + ------- + fwhm : float + The FWHM in x, y and z direction. The inverse corresponds to the number of pixels over + which is approximately averaged. + lr : 3 tuples of 2 floats + The left and right borders in x, y and z direction from which the FWHM is calculated. + Given in pixel coordinates and relative to the current position! + cxyz_dat : 4 lists of floats + The slices through the current position in the 4D volume (including the component), + which were used for FWHM calculations. Denotes information content in %! + + Notes + ----- + Uses the :func:`~.get_avrg_kern_field` function + + """ + self._log.debug('Calling calculate_fwhm') + a = self.magdata.a + magdata_avrg_kern = self.get_avrg_kern_field(pos) + x = np.arange(0, self.dim[2]) - self.pos[3] + y = np.arange(0, self.dim[1]) - self.pos[2] + z = np.arange(0, self.dim[0]) - self.pos[1] + c_dat = magdata_avrg_kern.field[:, self.pos[1], self.pos[2], self.pos[3]] + x_dat = magdata_avrg_kern.field[self.pos[0], self.pos[1], self.pos[2], :] + y_dat = magdata_avrg_kern.field[self.pos[0], self.pos[1], :, self.pos[3]] + z_dat = magdata_avrg_kern.field[self.pos[0], :, self.pos[2], self.pos[3]] + c_dat = np.asarray(c_dat * 100) # in % + x_dat = np.asarray(x_dat * 100) # in % + y_dat = np.asarray(y_dat * 100) # in % + z_dat = np.asarray(z_dat * 100) # in % + + def _calc_lr(c): + data = [x_dat, y_dat, z_dat][c] + i_m = np.argmax(data) # Index of the maximum + # Left side: + l = i_m + for i in np.arange(i_m - 1, -1, -1): + if data[i] < data[i_m] / 2: + # Linear interpolation between i and i + 1 to find left fractional index pos: + l = (data[i_m] / 2 - data[i]) / (data[i + 1] - data[i]) + i + break + # Right side: + r = i_m + for i in np.arange(i_m + 1, data.size): + if data[i] < data[i_m] / 2: + # Linear interpolation between i and i - 1 to find right fractional index pos: + r = (data[i_m] / 2 - data[i - 1]) / (data[i] - data[i - 1]) + i - 1 + break + # Transform from index to coordinates: + l = (l - self.pos[3-c]) + r = (r - self.pos[3-c]) + return l, r + + # Calculate FWHM: + lx, rx = _calc_lr(0) + ly, ry = _calc_lr(1) + lz, rz = _calc_lr(2) + fwhm_x = (rx - lx) * a + fwhm_y = (ry - ly) * a + fwhm_z = (rz - lz) * a + # Plot helpful stuff: + if plot: + fig, axis = plt.subplots(1, 1) + axis.axvline(x=0, ls='-', color='k', linewidth=2) + axis.axhline(y=0, ls='-', color='k', linewidth=2) + axis.axhline(y=x_dat.max(), ls='-', color='k', linewidth=2) + axis.axhline(y=x_dat.max() / 2, ls='--', color='k', linewidth=2) + axis.vlines(x=[lx, rx], ymin=0, ymax=x_dat.max() / 2, linestyles='--', + color='r', linewidth=2, alpha=0.5) + axis.vlines(x=[ly, ry], ymin=0, ymax=y_dat.max() / 2, linestyles='--', + color='g', linewidth=2, alpha=0.5) + axis.vlines(x=[lz, rz], ymin=0, ymax=z_dat.max() / 2, linestyles='--', + color='b', linewidth=2, alpha=0.5) + l = [] + l.extend(axis.plot(x, x_dat, label='x-dim.', color='r', marker='o', linewidth=2)) + l.extend(axis.plot(y, y_dat, label='y-dim.', color='g', marker='o', linewidth=2)) + l.extend(axis.plot(z, z_dat, label='z-dim.', color='b', marker='o', linewidth=2)) + cx = axis.scatter(0, c_dat[0], marker='o', s=200, edgecolor='r', label='x-comp.', + facecolor='r', alpha=0.75) + cy = axis.scatter(0, c_dat[1], marker='d', s=200, edgecolor='g', label='y-comp.', + facecolor='g', alpha=0.75) + cz = axis.scatter(0, c_dat[2], marker='*', s=200, edgecolor='b', label='z-comp.', + facecolor='b', alpha=0.75) + lim_min = np.min(np.concatenate((x, y, z))) - 0.5 + lim_max = np.max(np.concatenate((x, y, z))) + 0.5 + axis.set_xlim(lim_min, lim_max) + axis.set_title('Avrg. kern. FWHM', fontsize=18) + axis.set_xlabel('x/y/z-slice [nm]', fontsize=15) + axis.set_ylabel('information content [%]', fontsize=15) + axis.tick_params(axis='both', which='major', labelsize=14) + axis.xaxis.set_major_formatter(FuncFormatter(lambda x, pos: '{:.3g}'.format(x * a))) + comp_legend = axis.legend([cx, cy, cz], [c.get_label() for c in [cx, cy, cz]], loc=2, + scatterpoints=1, prop={'size': 14}) + axis.legend(l, [i.get_label() for i in l], loc=1, numpoints=1, prop={'size': 14}) + axis.add_artist(comp_legend) + fwhm = fwhm_x, fwhm_y, fwhm_z + lr = (lx, rx), (ly, ry), (lz, rz) + cxyz_dat = c_dat, x_dat, y_dat, z_dat + return fwhm, lr, cxyz_dat + + def get_gain_maps(self, pos=None): + """Get the gain matrix row represented as a list of 2D (inverse) phase maps. + + Parameters + ---------- + pos: tuple (N=4) + Position in 3D plus component `(c, z, y, x)` + + Returns + ------- + gain_map_list: list of :class:`~pyramid.phasemap.PhaseMap` + Gain matrix row represented as a list of 2D phase maps + + Notes + ----- + Note that the produced gain maps define the magnetization change at the current position + in 3d per phase change at the position of the . Take this into account when plotting the + maps (1/rad instead of rad). + + """ + self._log.debug('Calling get_gain_maps') + if pos is not None: + self.pos = pos + hp = self.cost.fwd_model.data_set.hook_points + gain_map_list = [] + for i, projector in enumerate(self.cost.fwd_model.data_set.projectors): + gain = self.gain_row[hp[i]:hp[i + 1]].reshape(projector.dim_uv) + gain_map_list.append(PhaseMap(self.cost.fwd_model.data_set.a, gain)) + return gain_map_list + + def plot_position(self, **kwargs): + proj_axis = kwargs.get('proj_axis', 'z') + if proj_axis == 'z': # Slice of the xy-plane with z = ax_slice + pos_2d = (self.pos[2], self.pos[3]) + ax_slice = self.pos[1] + elif proj_axis == 'y': # Slice of the xz-plane with y = ax_slice + pos_2d = (self.pos[1], self.pos[3]) + ax_slice = self.pos[2] + elif proj_axis == 'x': # Slice of the zy-plane with x = ax_slice + pos_2d = (self.pos[2], self.pos[1]) + ax_slice = self.pos[3] + else: + raise ValueError('{} is not a valid argument (use x, y or z)'.format(proj_axis)) + note = kwargs.pop('note', None) + if note is None: + comp = {0: 'x', 1: 'y', 2: 'z'}[self.pos[0]] + note = '{}-comp., pos.: {}'.format(comp, self.pos[1:]) + # Plots: + axis = self.magdata.plot_quiver_field(note=note, ax_slice=ax_slice, **kwargs) + rect = axis.add_patch(patches.Rectangle((pos_2d[1], pos_2d[0]), 1, 1, fill=False, + edgecolor='w', linewidth=2, alpha=0.5)) + rect.set_path_effects([patheffects.withStroke(linewidth=4, foreground='k', alpha=0.5)]) + + def plot_position3d(self, **kwargs): + pass + + def plot_avrg_kern_field(self, pos=None, **kwargs): + a = self.magdata.a + avrg_kern_field = self.get_avrg_kern_field(pos) + fwhms, lr = self.calculate_fwhm(pos)[:2] + proj_axis = kwargs.get('proj_axis', 'z') + if proj_axis == 'z': # Slice of the xy-plane with z = ax_slice + pos_2d = (self.pos[2], self.pos[3]) + ax_slice = self.pos[1] + width, height = fwhms[0] / a, fwhms[1] / a + elif proj_axis == 'y': # Slice of the xz-plane with y = ax_slice + pos_2d = (self.pos[1], self.pos[3]) + ax_slice = self.pos[2] + width, height = fwhms[0] / a, fwhms[2] / a + elif proj_axis == 'x': # Slice of the zy-plane with x = ax_slice + pos_2d = (self.pos[2], self.pos[1]) + ax_slice = self.pos[3] + width, height = fwhms[2] / a, fwhms[1] / a + else: + raise ValueError('{} is not a valid argument (use x, y or z)'.format(proj_axis)) + note = kwargs.pop('note', None) + if note is None: + comp = {0: 'x', 1: 'y', 2: 'z'}[self.pos[0]] + note = '{}-comp., pos.: {}'.format(comp, self.pos[1:]) + # Plots: + axis = avrg_kern_field.plot_quiver_field(note=note, ax_slice=ax_slice, **kwargs) + xy = (pos_2d[1], pos_2d[0]) + rect = axis.add_patch(patches.Rectangle(xy, 1, 1, fill=False, edgecolor='w', + linewidth=2, alpha=0.5)) + rect.set_path_effects([patheffects.withStroke(linewidth=4, foreground='k', alpha=0.5)]) + xy = (xy[0] + 0.5, xy[1] + 0.5) + artist = axis.add_patch(patches.Ellipse(xy, width, height, fill=False, edgecolor='w', + linewidth=2, alpha=0.5)) + artist.set_path_effects([patheffects.withStroke(linewidth=4, foreground='k', alpha=0.5)]) + + +def get_vector_field_errors(vector_data, vector_data_ref): + """After Kemp et. al.: Analysis of noise-induced errors in vector-field electron tomography""" + v, vr = vector_data.field, vector_data_ref.field + va, vra = vector_data.field_amp, vector_data_ref.field_amp + volume = np.prod(vector_data.dim) + # Total error: + amp_sum_sqr = np.nansum((v - vr)**2) + rms_tot = np.sqrt(amp_sum_sqr / np.nansum(vra**2)) + # Directional error: + scal_prod = np.clip(np.nansum(vr * v, axis=0) / (vra * va), -1, 1) # arccos float pt. inacc.! + rms_dir = np.sqrt(np.nansum(np.arccos(scal_prod)**2) / volume) / np.pi + # Magnitude error: + rms_mag = np.sqrt(np.nansum((va - vra)**2) / np.nansum(vra**2)) + # Return results as tuple: + return rms_tot, rms_dir, rms_mag diff --git a/pyramid/fft.py b/pyramid/fft.py index a9f0864ca46ab8a4e706added26bd9a86e64086a..f87fa4403e935c332299e9332573a08093fcf190 100644 --- a/pyramid/fft.py +++ b/pyramid/fft.py @@ -1,384 +1,384 @@ -# -*- coding: utf-8 -*- -# Copyright 2014 by Forschungszentrum Juelich GmbH -# Author: J. Caron -# -"""Custom FFT module with numpy and FFTW support. - -This module provides custom methods for FFTs including inverse, adjoint and real variants. The -FFTW library is supported and is used as a default if the import succeeds. Otherwise the numpy.fft -pack will be used. FFTW objects are saved in a cache after creation which speeds up further similar -FFT operations. - -""" - - -import pickle -import logging -import os - -import numpy as np - -_log = logging.getLogger(__name__) - -try: - import pyfftw - BACKEND = 'fftw' -except ImportError: - pyfftw = None - BACKEND = 'numpy' - _log.info('pyFFTW module not found. Using numpy implementation.') - -try: - import multiprocessing - NTHREADS = multiprocessing.cpu_count() - del multiprocessing -except ImportError: - NTHREADS = 1 - _log.info('multiprocessing module not found. Using single core.') - - -__all__ = ['plans', 'FLOAT', 'COMPLEX', 'dump_wisdom', 'load_wisdom', - 'zeros', 'empty', 'ones', 'configure_backend', - 'fftn', 'ifftn', 'rfftn', 'irfftn', 'rfftn_adj', 'irfftn_adj'] - - -class FFTWCache(object): - """Class for adding FFTW Plans and on-demand lookups. - - This class is instantiated in this module to store FFTW plans and for the lookup of the former. - - Attributes - ---------- - cache: dict - Cache for storing the FFTW plans. - - Notes - ----- - This class is used internally and is not normally not intended to be used directly by the user. - - """ - - _log = logging.getLogger(__name__ + '.FFTWCache') - - def __init__(self): - self._log.debug('Calling __init__') - self.cache = dict() - self._log.debug('Created ' + str(self)) - - def add_fftw(self, fft_type, fftw_obj, s, axes, nthreads): - """Add an FFTW object to the cache. - - Parameters - ---------- - fft_type: basestring - Identifier sting for the FFT type ('fftn', 'ifftn', 'rfftn', 'irfftn'). - fftw_obj: :class:`~pyfftw.FFTW` object - The FFTW object which should be added to the cache. - s: tuple of ints - Shape of the output array. - axes: tuple of ints - The axes along which the FFTW should be executed. - nthreads: int - Number of threads which should be used. - - """ - self._log.debug('Calling add_fftw') - in_arr = fftw_obj.get_input_array() - key = (fft_type, in_arr.shape, in_arr.dtype, s, axes, nthreads) - self.cache[key] = fftw_obj - - def lookup_fftw(self, fft_type, in_arr, s, axes, nthreads): - """ - - Parameters - ---------- - fft_type: basestring - Identifier sting for the FFT type ('fftn', 'ifftn', 'rfftn', 'irfftn'). - in_arr: - Input array, internally, just the `dtype` and the `shape` are used to identify the FFT. - s: tuple of ints - Shape of the output array. - axes: tuple of ints - The axes along which the FFTW should be executed. - nthreads: int - Number of threads which should be used. - - Returns - ------- - fftw_obj: :class:`~pyfftw.FFTW` object - The requested FFTW object. - - """ - self._log.debug('Calling lookup_fftw') - key = (fft_type, in_arr.shape, in_arr.dtype, s, axes, nthreads) - return self.cache.get(key, None) - - def clear_cache(self): - """Clear the cache.""" - self._log.debug('Calling clear_cache') - self.cache = dict() - - -plans = FFTWCache() -FLOAT = np.float32 # One convenient place to -COMPLEX = np.complex64 # change from 32 to 64 bit - - -# Numpy functions: - -def _fftn_numpy(a, s=None, axes=None): - return np.fft.fftn(a, s, axes) - - -def _ifftn_numpy(a, s=None, axes=None): - return np.fft.ifftn(a, s, axes) - - -def _rfftn_numpy(a, s=None, axes=None): - return np.fft.rfftn(a, s, axes) - - -def _irfftn_numpy(a, s=None, axes=None): - return np.fft.irfftn(a, s, axes) - - -def _rfftn_adj_numpy(a): - n = 2 * (a.shape[-1] - 1) - out_shape = a.shape[:-1] + (n,) - out_arr = zeros(out_shape, dtype=a.dtype) - out_arr[:, :a.shape[1]] = a - return _ifftn_numpy(out_arr).real * np.prod(out_shape) - - -def _irfftn_adj_numpy(a): - n = a.shape[-1] // 2 + 1 - out_arr = _fftn_numpy(a, axes=(-1,)) / a.shape[-1] - if a.shape[-1] % 2 == 0: # even - out_arr[:, 1:n - 1] += np.conj(out_arr[:, :n - 1:-1]) - else: # odd - out_arr[:, 1:n] += np.conj(out_arr[:, :n - 1:-1]) - axes = tuple(range(len(out_arr.shape[:-1]))) - return _fftn_numpy(out_arr[:, :n], axes=axes) / np.prod(out_arr.shape[:-1]) - - -# FFTW functions: - -def _fftn_fftw(a, s=None, axes=None): - fftw = plans.lookup_fftw('fftn', a, s, axes, NTHREADS) - if fftw is None: - fftw = pyfftw.builders.fftn(a, s, axes, threads=NTHREADS) - plans.add_fftw('fftn', fftw, s, axes, NTHREADS) - return fftw(a).copy() - - -def _ifftn_fftw(a, s=None, axes=None): - fftw = plans.lookup_fftw('ifftn', a, s, axes, NTHREADS) - if fftw is None: - fftw = pyfftw.builders.ifftn(a, s, axes, threads=NTHREADS) - plans.add_fftw('ifftn', fftw, s, axes, NTHREADS) - return fftw(a).copy() - - -def _rfftn_fftw(a, s=None, axes=None): - fftw = plans.lookup_fftw('rfftn', a, s, axes, NTHREADS) - if fftw is None: - fftw = pyfftw.builders.rfftn(a, s, axes, threads=NTHREADS) - plans.add_fftw('rfftn', fftw, s, axes, NTHREADS) - return fftw(a).copy() - - -def _irfftn_fftw(a, s=None, axes=None): - fftw = plans.lookup_fftw('irfftn', a, s, axes, NTHREADS) - if fftw is None: - fftw = pyfftw.builders.irfftn(a, s, axes, threads=NTHREADS) - plans.add_fftw('irfftn', fftw, s, axes, NTHREADS) - return fftw(a).copy() - - -def _rfftn_adj_fftw(a): - # Careful: just works for even a (which is guaranteed by the kernel!) - n = 2 * (a.shape[-1] - 1) - out_shape = a.shape[:-1] + (n,) - out_arr = zeros(out_shape, dtype=a.dtype) - out_arr[:, :a.shape[-1]] = a - return _ifftn_fftw(out_arr).real * np.prod(out_shape) - - -def _irfftn_adj_fftw(a): - out_arr = _fftn_fftw(a, axes=(-1,)) / a.shape[-1] # FFT of last axis - n = a.shape[-1] // 2 + 1 - if a.shape[-1] % 2 == 0: # even - out_arr[:, 1:n - 1] += np.conj(out_arr[:, :n - 1:-1]) - else: # odd - out_arr[:, 1:n] += np.conj(out_arr[:, :n - 1:-1]) - axes = tuple(range(len(out_arr.shape[:-1]))) - return _fftn_fftw(out_arr[:, :n], axes=axes) / np.prod(out_arr.shape[:-1]) - - -# These wisdom functions do nothing if pyFFTW is not available: - -def dump_wisdom(fname): - """Wrapper function for the pyfftw.export_wisdom(), which uses a pickle dump. - - Parameters - ---------- - fname: string - Name of the file in which the wisdom is saved. - - Returns - ------- - None - - """ - _log.debug('Calling dump_wisdom') - if pyfftw is not None: - with open(fname, 'wb') as fp: - pickle.dump(pyfftw.export_wisdom(), fp, pickle.HIGHEST_PROTOCOL) - - -def load_wisdom(fname): - """Wrapper function for the pyfftw.import_wisdom(), which uses a pickle to load a file. - - Parameters - ---------- - fname: string - Name of the file from which the wisdom is loaded. - - Returns - ------- - None - - """ - _log.debug('Calling load_wisdom') - if pyfftw is not None: - if not os.path.exists(fname): - print("Warning: Wisdom file does not exist. First time use?") - else: - with open(fname, 'rb') as fp: - pyfftw.import_wisdom(pickle.load(fp)) - - -# Array setups: -def empty(shape, dtype=FLOAT): - """Return a new array of given shape and type without initializing entries. - - Parameters - ---------- - shape: int or tuple of int - Shape of the array. - dtype: data-type, optional - Desired output data-type. - - Returns - ------- - out: :class:`~numpy.ndarray` - The created array. - - """ - _log.debug('Calling empty') - result = np.empty(shape, dtype) - if pyfftw is not None: - result = pyfftw.n_byte_align(result, pyfftw.simd_alignment) - return result - - -def zeros(shape, dtype=FLOAT): - """Return a new array of given shape and type, filled with zeros. - - Parameters - ---------- - shape: int or tuple of int - Shape of the array. - dtype: data-type, optional - Desired output data-type. - - Returns - ------- - out: :class:`~numpy.ndarray` - The created array. - - """ - _log.debug('Calling zeros') - result = np.zeros(shape, dtype) - if pyfftw is not None: - result = pyfftw.n_byte_align(result, pyfftw.simd_alignment) - return result - - -def ones(shape, dtype=FLOAT): - """Return a new array of given shape and type, filled with ones. - - Parameters - ---------- - shape: int or tuple of int - Shape of the array. - dtype: data-type, optional - Desired output data-type. - - Returns - ------- - out: :class:`~numpy.ndarray` - The created array. - - """ - _log.debug('Calling ones') - result = np.ones(shape, dtype) - if pyfftw is not None: - result = pyfftw.n_byte_align(result, pyfftw.simd_alignment) - return result - - -# Configure backend: -def configure_backend(backend): - """Change FFT backend. - - Parameters - ---------- - backend: string - Backend to use. Supported values are "numpy" and "fftw". - - Returns - ------- - None - - """ - _log.debug('Calling configure_backend') - global fftn - global ifftn - global rfftn - global irfftn - global rfftn_adj - global irfftn_adj - global BACKEND - if backend == 'numpy': - fftn = _fftn_numpy - ifftn = _ifftn_numpy - rfftn = _rfftn_numpy - irfftn = _irfftn_numpy - rfftn_adj = _rfftn_adj_numpy - irfftn_adj = _irfftn_adj_numpy - BACKEND = 'numpy' - elif backend == 'fftw': - if pyfftw is not None: - fftn = _fftn_fftw - ifftn = _ifftn_fftw - rfftn = _rfftn_fftw - irfftn = _irfftn_fftw - rfftn_adj = _rfftn_adj_fftw - irfftn_adj = _irfftn_adj_fftw - BACKEND = 'pyfftw' - else: - print('Error: FFTW requested but not available') - - -# On import: -ifftn = None -fftn = None -rfftn = None -irfftn = None -rfftn_adj = None -irfftn_adj = None -if pyfftw is not None: - configure_backend('fftw') -else: - configure_backend('numpy') +# -*- coding: utf-8 -*- +# Copyright 2014 by Forschungszentrum Juelich GmbH +# Author: J. Caron +# +"""Custom FFT module with numpy and FFTW support. + +This module provides custom methods for FFTs including inverse, adjoint and real variants. The +FFTW library is supported and is used as a default if the import succeeds. Otherwise the numpy.fft +pack will be used. FFTW objects are saved in a cache after creation which speeds up further similar +FFT operations. + +""" + + +import pickle +import logging +import os + +import numpy as np + +_log = logging.getLogger(__name__) + +try: + import pyfftw + BACKEND = 'fftw' +except ImportError: + pyfftw = None + BACKEND = 'numpy' + _log.info('pyFFTW module not found. Using numpy implementation.') + +try: + import multiprocessing + NTHREADS = multiprocessing.cpu_count() + del multiprocessing +except ImportError: + NTHREADS = 1 + _log.info('multiprocessing module not found. Using single core.') + + +__all__ = ['plans', 'FLOAT', 'COMPLEX', 'dump_wisdom', 'load_wisdom', + 'zeros', 'empty', 'ones', 'configure_backend', + 'fftn', 'ifftn', 'rfftn', 'irfftn', 'rfftn_adj', 'irfftn_adj'] + + +class FFTWCache(object): + """Class for adding FFTW Plans and on-demand lookups. + + This class is instantiated in this module to store FFTW plans and for the lookup of the former. + + Attributes + ---------- + cache: dict + Cache for storing the FFTW plans. + + Notes + ----- + This class is used internally and is not normally not intended to be used directly by the user. + + """ + + _log = logging.getLogger(__name__ + '.FFTWCache') + + def __init__(self): + self._log.debug('Calling __init__') + self.cache = dict() + self._log.debug('Created ' + str(self)) + + def add_fftw(self, fft_type, fftw_obj, s, axes, nthreads): + """Add an FFTW object to the cache. + + Parameters + ---------- + fft_type: basestring + Identifier sting for the FFT type ('fftn', 'ifftn', 'rfftn', 'irfftn'). + fftw_obj: :class:`~pyfftw.FFTW` object + The FFTW object which should be added to the cache. + s: tuple of ints + Shape of the output array. + axes: tuple of ints + The axes along which the FFTW should be executed. + nthreads: int + Number of threads which should be used. + + """ + self._log.debug('Calling add_fftw') + in_arr = fftw_obj.get_input_array() + key = (fft_type, in_arr.shape, in_arr.dtype, s, axes, nthreads) + self.cache[key] = fftw_obj + + def lookup_fftw(self, fft_type, in_arr, s, axes, nthreads): + """ + + Parameters + ---------- + fft_type: basestring + Identifier sting for the FFT type ('fftn', 'ifftn', 'rfftn', 'irfftn'). + in_arr: + Input array, internally, just the `dtype` and the `shape` are used to identify the FFT. + s: tuple of ints + Shape of the output array. + axes: tuple of ints + The axes along which the FFTW should be executed. + nthreads: int + Number of threads which should be used. + + Returns + ------- + fftw_obj: :class:`~pyfftw.FFTW` object + The requested FFTW object. + + """ + self._log.debug('Calling lookup_fftw') + key = (fft_type, in_arr.shape, in_arr.dtype, s, axes, nthreads) + return self.cache.get(key, None) + + def clear_cache(self): + """Clear the cache.""" + self._log.debug('Calling clear_cache') + self.cache = dict() + + +plans = FFTWCache() +FLOAT = np.float32 # One convenient place to +COMPLEX = np.complex64 # change from 32 to 64 bit + + +# Numpy functions: + +def _fftn_numpy(a, s=None, axes=None): + return np.fft.fftn(a, s, axes) + + +def _ifftn_numpy(a, s=None, axes=None): + return np.fft.ifftn(a, s, axes) + + +def _rfftn_numpy(a, s=None, axes=None): + return np.fft.rfftn(a, s, axes) + + +def _irfftn_numpy(a, s=None, axes=None): + return np.fft.irfftn(a, s, axes) + + +def _rfftn_adj_numpy(a): + n = 2 * (a.shape[-1] - 1) + out_shape = a.shape[:-1] + (n,) + out_arr = zeros(out_shape, dtype=a.dtype) + out_arr[:, :a.shape[1]] = a + return _ifftn_numpy(out_arr).real * np.prod(out_shape) + + +def _irfftn_adj_numpy(a): + n = a.shape[-1] // 2 + 1 + out_arr = _fftn_numpy(a, axes=(-1,)) / a.shape[-1] + if a.shape[-1] % 2 == 0: # even + out_arr[:, 1:n - 1] += np.conj(out_arr[:, :n - 1:-1]) + else: # odd + out_arr[:, 1:n] += np.conj(out_arr[:, :n - 1:-1]) + axes = tuple(range(len(out_arr.shape[:-1]))) + return _fftn_numpy(out_arr[:, :n], axes=axes) / np.prod(out_arr.shape[:-1]) + + +# FFTW functions: + +def _fftn_fftw(a, s=None, axes=None): + fftw = plans.lookup_fftw('fftn', a, s, axes, NTHREADS) + if fftw is None: + fftw = pyfftw.builders.fftn(a, s, axes, threads=NTHREADS) + plans.add_fftw('fftn', fftw, s, axes, NTHREADS) + return fftw(a).copy() + + +def _ifftn_fftw(a, s=None, axes=None): + fftw = plans.lookup_fftw('ifftn', a, s, axes, NTHREADS) + if fftw is None: + fftw = pyfftw.builders.ifftn(a, s, axes, threads=NTHREADS) + plans.add_fftw('ifftn', fftw, s, axes, NTHREADS) + return fftw(a).copy() + + +def _rfftn_fftw(a, s=None, axes=None): + fftw = plans.lookup_fftw('rfftn', a, s, axes, NTHREADS) + if fftw is None: + fftw = pyfftw.builders.rfftn(a, s, axes, threads=NTHREADS) + plans.add_fftw('rfftn', fftw, s, axes, NTHREADS) + return fftw(a).copy() + + +def _irfftn_fftw(a, s=None, axes=None): + fftw = plans.lookup_fftw('irfftn', a, s, axes, NTHREADS) + if fftw is None: + fftw = pyfftw.builders.irfftn(a, s, axes, threads=NTHREADS) + plans.add_fftw('irfftn', fftw, s, axes, NTHREADS) + return fftw(a).copy() + + +def _rfftn_adj_fftw(a): + # Careful: just works for even a (which is guaranteed by the kernel!) + n = 2 * (a.shape[-1] - 1) + out_shape = a.shape[:-1] + (n,) + out_arr = zeros(out_shape, dtype=a.dtype) + out_arr[:, :a.shape[-1]] = a + return _ifftn_fftw(out_arr).real * np.prod(out_shape) + + +def _irfftn_adj_fftw(a): + out_arr = _fftn_fftw(a, axes=(-1,)) / a.shape[-1] # FFT of last axis + n = a.shape[-1] // 2 + 1 + if a.shape[-1] % 2 == 0: # even + out_arr[:, 1:n - 1] += np.conj(out_arr[:, :n - 1:-1]) + else: # odd + out_arr[:, 1:n] += np.conj(out_arr[:, :n - 1:-1]) + axes = tuple(range(len(out_arr.shape[:-1]))) + return _fftn_fftw(out_arr[:, :n], axes=axes) / np.prod(out_arr.shape[:-1]) + + +# These wisdom functions do nothing if pyFFTW is not available: + +def dump_wisdom(fname): + """Wrapper function for the pyfftw.export_wisdom(), which uses a pickle dump. + + Parameters + ---------- + fname: string + Name of the file in which the wisdom is saved. + + Returns + ------- + None + + """ + _log.debug('Calling dump_wisdom') + if pyfftw is not None: + with open(fname, 'wb') as fp: + pickle.dump(pyfftw.export_wisdom(), fp, pickle.HIGHEST_PROTOCOL) + + +def load_wisdom(fname): + """Wrapper function for the pyfftw.import_wisdom(), which uses a pickle to load a file. + + Parameters + ---------- + fname: string + Name of the file from which the wisdom is loaded. + + Returns + ------- + None + + """ + _log.debug('Calling load_wisdom') + if pyfftw is not None: + if not os.path.exists(fname): + print("Warning: Wisdom file does not exist. First time use?") + else: + with open(fname, 'rb') as fp: + pyfftw.import_wisdom(pickle.load(fp)) + + +# Array setups: +def empty(shape, dtype=FLOAT): + """Return a new array of given shape and type without initializing entries. + + Parameters + ---------- + shape: int or tuple of int + Shape of the array. + dtype: data-type, optional + Desired output data-type. + + Returns + ------- + out: :class:`~numpy.ndarray` + The created array. + + """ + _log.debug('Calling empty') + result = np.empty(shape, dtype) + if pyfftw is not None: + result = pyfftw.n_byte_align(result, pyfftw.simd_alignment) + return result + + +def zeros(shape, dtype=FLOAT): + """Return a new array of given shape and type, filled with zeros. + + Parameters + ---------- + shape: int or tuple of int + Shape of the array. + dtype: data-type, optional + Desired output data-type. + + Returns + ------- + out: :class:`~numpy.ndarray` + The created array. + + """ + _log.debug('Calling zeros') + result = np.zeros(shape, dtype) + if pyfftw is not None: + result = pyfftw.n_byte_align(result, pyfftw.simd_alignment) + return result + + +def ones(shape, dtype=FLOAT): + """Return a new array of given shape and type, filled with ones. + + Parameters + ---------- + shape: int or tuple of int + Shape of the array. + dtype: data-type, optional + Desired output data-type. + + Returns + ------- + out: :class:`~numpy.ndarray` + The created array. + + """ + _log.debug('Calling ones') + result = np.ones(shape, dtype) + if pyfftw is not None: + result = pyfftw.n_byte_align(result, pyfftw.simd_alignment) + return result + + +# Configure backend: +def configure_backend(backend): + """Change FFT backend. + + Parameters + ---------- + backend: string + Backend to use. Supported values are "numpy" and "fftw". + + Returns + ------- + None + + """ + _log.debug('Calling configure_backend') + global fftn + global ifftn + global rfftn + global irfftn + global rfftn_adj + global irfftn_adj + global BACKEND + if backend == 'numpy': + fftn = _fftn_numpy + ifftn = _ifftn_numpy + rfftn = _rfftn_numpy + irfftn = _irfftn_numpy + rfftn_adj = _rfftn_adj_numpy + irfftn_adj = _irfftn_adj_numpy + BACKEND = 'numpy' + elif backend == 'fftw': + if pyfftw is not None: + fftn = _fftn_fftw + ifftn = _ifftn_fftw + rfftn = _rfftn_fftw + irfftn = _irfftn_fftw + rfftn_adj = _rfftn_adj_fftw + irfftn_adj = _irfftn_adj_fftw + BACKEND = 'pyfftw' + else: + print('Error: FFTW requested but not available') + + +# On import: +ifftn = None +fftn = None +rfftn = None +irfftn = None +rfftn_adj = None +irfftn_adj = None +if pyfftw is not None: + configure_backend('fftw') +else: + configure_backend('numpy') diff --git a/pyramid/fieldconverter.py b/pyramid/fieldconverter.py index 9a213f8218b6f4e6e6211312d4187de41e86ca0e..aff8e2e569e77e4fcd4258f04c3a964cac532b00 100644 --- a/pyramid/fieldconverter.py +++ b/pyramid/fieldconverter.py @@ -1,167 +1,167 @@ -# coding=utf-8 -"""Convert vector fields. - -The :mod:`~.fieldconverter` provides methods for converting a magnetization distribution `M` into -a vector potential `A` and convert this in turn into a magnetic field `B`. The direct way is also -possible. - -""" - -import logging - -import numpy as np - -from jutil import fft - -from pyramid.fielddata import VectorData - -__all__ = ['convert_M_to_A', 'convert_A_to_B', 'convert_M_to_B'] -_log = logging.getLogger(__name__) - - -def convert_M_to_A(magdata, b_0=1.0): - """Convert a magnetic vector distribution into a vector potential `A`. - - Parameters - ---------- - magdata: :class:`~pyramid.magdata.VectorData` object - The magnetic vector field from which the A-field is calculated. - b_0: float, optional - The saturation magnetization which is used in the calculation. - - Returns - ------- - b_data: :class:`~pyramid.magdata.VectorData` object - The calculated B-field. - - """ - _log.debug('Calling convert_M_to_A') - # Preparations of variables: - assert isinstance(magdata, VectorData), 'Only VectorData objects can be mapped!' - dim = magdata.dim - dim_kern = tuple(2 * np.array(dim) - 1) # Dimensions of the kernel - if fft.HAVE_FFTW: - dim_pad = tuple(2 * np.array(dim)) # is at least even (not neccessary a power of 2) - else: - dim_pad = tuple(2 ** np.ceil(np.log2(2 * np.array(dim))).astype(int)) # pow(2) - slice_B = (slice(dim[0] - 1, dim_kern[0]), # Shift because kernel center - slice(dim[1] - 1, dim_kern[1]), # is not at (0, 0, 0)! - slice(dim[2] - 1, dim_kern[2])) - slice_M = (slice(0, dim[0]), # Magnetization is padded on the far end! - slice(0, dim[1]), # B-field cutout is shifted as listed above - slice(0, dim[2])) # because of the kernel center! - # Set up kernels - coeff = magdata.a * b_0 / (4 * np.pi) - zzz, yyy, xxx = np.indices(dim_kern) - xxx -= dim[2] - 1 - yyy -= dim[1] - 1 - zzz -= dim[0] - 1 - k_x = np.empty(dim_kern, dtype=magdata.field.dtype) - k_y = np.empty(dim_kern, dtype=magdata.field.dtype) - k_z = np.empty(dim_kern, dtype=magdata.field.dtype) - k_x[...] = coeff * xxx / np.abs(xxx ** 2 + yyy ** 2 + zzz ** 2 + 1E-30) ** 3 - k_y[...] = coeff * yyy / np.abs(xxx ** 2 + yyy ** 2 + zzz ** 2 + 1E-30) ** 3 - k_z[...] = coeff * zzz / np.abs(xxx ** 2 + yyy ** 2 + zzz ** 2 + 1E-30) ** 3 - # Calculate Fourier trafo of kernel components: - k_x_fft = fft.rfftn(k_x, dim_pad) - k_y_fft = fft.rfftn(k_y, dim_pad) - k_z_fft = fft.rfftn(k_z, dim_pad) - # Prepare magnetization: - x_mag = np.zeros(dim_pad, dtype=magdata.field.dtype) - y_mag = np.zeros(dim_pad, dtype=magdata.field.dtype) - z_mag = np.zeros(dim_pad, dtype=magdata.field.dtype) - x_mag[slice_M] = magdata.field[0, ...] - y_mag[slice_M] = magdata.field[1, ...] - z_mag[slice_M] = magdata.field[2, ...] - # Calculate Fourier trafo of magnetization components: - x_mag_fft = fft.rfftn(x_mag) - y_mag_fft = fft.rfftn(y_mag) - z_mag_fft = fft.rfftn(z_mag) - # Convolve: - a_x_fft = y_mag_fft * k_z_fft - z_mag_fft * k_y_fft - a_y_fft = z_mag_fft * k_x_fft - x_mag_fft * k_z_fft - a_z_fft = x_mag_fft * k_y_fft - y_mag_fft * k_x_fft - a_x = fft.irfftn(a_x_fft)[slice_B] - a_y = fft.irfftn(a_y_fft)[slice_B] - a_z = fft.irfftn(a_z_fft)[slice_B] - # Return A-field: - return VectorData(magdata.a, np.asarray((a_x, a_y, a_z))) - - -def convert_A_to_B(a_data): - """Convert a vector potential `A` into a B-field distribution. - - Parameters - ---------- - a_data: :class:`~pyramid.magdata.VectorData` object - The vector potential field from which the A-field is calculated. - - Returns - ------- - b_data: :class:`~pyramid.magdata.VectorData` object - The calculated B-field. - - """ - _log.debug('Calling convert_A_to_B') - assert isinstance(a_data, VectorData), 'Only VectorData objects can be mapped!' - # - axis = tuple([i for i in range(3) if a_data.dim[i] > 1]) - # - x_grads = np.gradient(a_data.field[0, ...], axis=axis) #/ a_data.a - y_grads = np.gradient(a_data.field[1, ...], axis=axis) #/ a_data.a - z_grads = np.gradient(a_data.field[2, ...], axis=axis) #/ a_data.a - # - x_gradii = np.zeros(a_data.shape) - y_gradii = np.zeros(a_data.shape) - z_gradii = np.zeros(a_data.shape) - # - for i, axis in enumerate(axis): - x_gradii[axis] = x_grads[i] - y_gradii[axis] = y_grads[i] - z_gradii[axis] = z_grads[i] - # - x_grad_z, x_grad_y, x_grad_x = x_gradii - y_grad_z, y_grad_y, y_grad_x = y_gradii - z_grad_z, z_grad_y, z_grad_x = z_gradii - # Calculate cross product: - b_x = (z_grad_y - y_grad_z) - b_y = (x_grad_z - z_grad_x) - b_z = (y_grad_x - x_grad_y) - # Return B-field: - return VectorData(a_data.a, np.asarray((b_x, b_y, b_z))) - - - - - # Calculate gradients: - x_mag, y_mag, z_mag = a_data.field - x_grad_z, x_grad_y, x_grad_x = np.gradient(x_mag) - y_grad_z, y_grad_y, y_grad_x = np.gradient(y_mag) - z_grad_z, z_grad_y, z_grad_x = np.gradient(z_mag) - # Calculate cross product: - b_x = (z_grad_y - y_grad_z) - b_y = (x_grad_z - z_grad_x) - b_z = (y_grad_x - x_grad_y) - # Return B-field: - return VectorData(a_data.a, np.asarray((b_x, b_y, b_z))) - - -def convert_M_to_B(magdata, b_0=1.0): - """Convert a magnetic vector distribution into a B-field distribution. - - Parameters - ---------- - magdata: :class:`~pyramid.magdata.VectorData` object - The magnetic vector field from which the B-field is calculated. - b_0: float, optional - The saturation magnetization which is used in the calculation. - - Returns - ------- - b_data: :class:`~pyramid.magdata.VectorData` object - The calculated B-field. - - """ - _log.debug('Calling convert_M_to_B') - assert isinstance(magdata, VectorData), 'Only VectorData objects can be mapped!' - return convert_A_to_B(convert_M_to_A(magdata, b_0=b_0)) +# coding=utf-8 +"""Convert vector fields. + +The :mod:`~.fieldconverter` provides methods for converting a magnetization distribution `M` into +a vector potential `A` and convert this in turn into a magnetic field `B`. The direct way is also +possible. + +""" + +import logging + +import numpy as np + +from jutil import fft + +from pyramid.fielddata import VectorData + +__all__ = ['convert_M_to_A', 'convert_A_to_B', 'convert_M_to_B'] +_log = logging.getLogger(__name__) + + +def convert_M_to_A(magdata, b_0=1.0): + """Convert a magnetic vector distribution into a vector potential `A`. + + Parameters + ---------- + magdata: :class:`~pyramid.magdata.VectorData` object + The magnetic vector field from which the A-field is calculated. + b_0: float, optional + The saturation magnetization which is used in the calculation. + + Returns + ------- + b_data: :class:`~pyramid.magdata.VectorData` object + The calculated B-field. + + """ + _log.debug('Calling convert_M_to_A') + # Preparations of variables: + assert isinstance(magdata, VectorData), 'Only VectorData objects can be mapped!' + dim = magdata.dim + dim_kern = tuple(2 * np.array(dim) - 1) # Dimensions of the kernel + if fft.HAVE_FFTW: + dim_pad = tuple(2 * np.array(dim)) # is at least even (not neccessary a power of 2) + else: + dim_pad = tuple(2 ** np.ceil(np.log2(2 * np.array(dim))).astype(int)) # pow(2) + slice_B = (slice(dim[0] - 1, dim_kern[0]), # Shift because kernel center + slice(dim[1] - 1, dim_kern[1]), # is not at (0, 0, 0)! + slice(dim[2] - 1, dim_kern[2])) + slice_M = (slice(0, dim[0]), # Magnetization is padded on the far end! + slice(0, dim[1]), # B-field cutout is shifted as listed above + slice(0, dim[2])) # because of the kernel center! + # Set up kernels + coeff = magdata.a * b_0 / (4 * np.pi) + zzz, yyy, xxx = np.indices(dim_kern) + xxx -= dim[2] - 1 + yyy -= dim[1] - 1 + zzz -= dim[0] - 1 + k_x = np.empty(dim_kern, dtype=magdata.field.dtype) + k_y = np.empty(dim_kern, dtype=magdata.field.dtype) + k_z = np.empty(dim_kern, dtype=magdata.field.dtype) + k_x[...] = coeff * xxx / np.abs(xxx ** 2 + yyy ** 2 + zzz ** 2 + 1E-30) ** 3 + k_y[...] = coeff * yyy / np.abs(xxx ** 2 + yyy ** 2 + zzz ** 2 + 1E-30) ** 3 + k_z[...] = coeff * zzz / np.abs(xxx ** 2 + yyy ** 2 + zzz ** 2 + 1E-30) ** 3 + # Calculate Fourier trafo of kernel components: + k_x_fft = fft.rfftn(k_x, dim_pad) + k_y_fft = fft.rfftn(k_y, dim_pad) + k_z_fft = fft.rfftn(k_z, dim_pad) + # Prepare magnetization: + x_mag = np.zeros(dim_pad, dtype=magdata.field.dtype) + y_mag = np.zeros(dim_pad, dtype=magdata.field.dtype) + z_mag = np.zeros(dim_pad, dtype=magdata.field.dtype) + x_mag[slice_M] = magdata.field[0, ...] + y_mag[slice_M] = magdata.field[1, ...] + z_mag[slice_M] = magdata.field[2, ...] + # Calculate Fourier trafo of magnetization components: + x_mag_fft = fft.rfftn(x_mag) + y_mag_fft = fft.rfftn(y_mag) + z_mag_fft = fft.rfftn(z_mag) + # Convolve: + a_x_fft = y_mag_fft * k_z_fft - z_mag_fft * k_y_fft + a_y_fft = z_mag_fft * k_x_fft - x_mag_fft * k_z_fft + a_z_fft = x_mag_fft * k_y_fft - y_mag_fft * k_x_fft + a_x = fft.irfftn(a_x_fft)[slice_B] + a_y = fft.irfftn(a_y_fft)[slice_B] + a_z = fft.irfftn(a_z_fft)[slice_B] + # Return A-field: + return VectorData(magdata.a, np.asarray((a_x, a_y, a_z))) + + +def convert_A_to_B(a_data): + """Convert a vector potential `A` into a B-field distribution. + + Parameters + ---------- + a_data: :class:`~pyramid.magdata.VectorData` object + The vector potential field from which the A-field is calculated. + + Returns + ------- + b_data: :class:`~pyramid.magdata.VectorData` object + The calculated B-field. + + """ + _log.debug('Calling convert_A_to_B') + assert isinstance(a_data, VectorData), 'Only VectorData objects can be mapped!' + # + axis = tuple([i for i in range(3) if a_data.dim[i] > 1]) + # + x_grads = np.gradient(a_data.field[0, ...], axis=axis) #/ a_data.a + y_grads = np.gradient(a_data.field[1, ...], axis=axis) #/ a_data.a + z_grads = np.gradient(a_data.field[2, ...], axis=axis) #/ a_data.a + # + x_gradii = np.zeros(a_data.shape) + y_gradii = np.zeros(a_data.shape) + z_gradii = np.zeros(a_data.shape) + # + for i, axis in enumerate(axis): + x_gradii[axis] = x_grads[i] + y_gradii[axis] = y_grads[i] + z_gradii[axis] = z_grads[i] + # + x_grad_z, x_grad_y, x_grad_x = x_gradii + y_grad_z, y_grad_y, y_grad_x = y_gradii + z_grad_z, z_grad_y, z_grad_x = z_gradii + # Calculate cross product: + b_x = (z_grad_y - y_grad_z) + b_y = (x_grad_z - z_grad_x) + b_z = (y_grad_x - x_grad_y) + # Return B-field: + return VectorData(a_data.a, np.asarray((b_x, b_y, b_z))) + + + + + # Calculate gradients: + x_mag, y_mag, z_mag = a_data.field + x_grad_z, x_grad_y, x_grad_x = np.gradient(x_mag) + y_grad_z, y_grad_y, y_grad_x = np.gradient(y_mag) + z_grad_z, z_grad_y, z_grad_x = np.gradient(z_mag) + # Calculate cross product: + b_x = (z_grad_y - y_grad_z) + b_y = (x_grad_z - z_grad_x) + b_z = (y_grad_x - x_grad_y) + # Return B-field: + return VectorData(a_data.a, np.asarray((b_x, b_y, b_z))) + + +def convert_M_to_B(magdata, b_0=1.0): + """Convert a magnetic vector distribution into a B-field distribution. + + Parameters + ---------- + magdata: :class:`~pyramid.magdata.VectorData` object + The magnetic vector field from which the B-field is calculated. + b_0: float, optional + The saturation magnetization which is used in the calculation. + + Returns + ------- + b_data: :class:`~pyramid.magdata.VectorData` object + The calculated B-field. + + """ + _log.debug('Calling convert_M_to_B') + assert isinstance(magdata, VectorData), 'Only VectorData objects can be mapped!' + return convert_A_to_B(convert_M_to_A(magdata, b_0=b_0)) diff --git a/pyramid/fielddata.py b/pyramid/fielddata.py index c9ee0127413c2a7605a3366cfa87f80c45cdb285..d4a57ce1765b7df1adf82f461856d476fb0bf971 100644 --- a/pyramid/fielddata.py +++ b/pyramid/fielddata.py @@ -1,1442 +1,1473 @@ -# -*- coding: utf-8 -*- -# Copyright 2016 by Forschungszentrum Juelich GmbH -# Author: J. Caron -# -"""This module provides classes for storing vector and scalar 3D-field.""" - -import logging - -import os - -import tempfile - -import abc -from numbers import Number - -import numpy as np - -from matplotlib import pyplot as plt -from matplotlib.colors import ListedColormap -from matplotlib import patheffects - -from PIL import Image - -from scipy.ndimage.interpolation import zoom - -from . import colors -from . import plottools - -__all__ = ['VectorData', 'ScalarData'] - - -class FieldData(object, metaclass=abc.ABCMeta): - """Class for storing field data. - - Abstract base class for the representatio of magnetic or electric fields (see subclasses). - Fields can be accessed as 3D numpy arrays via the `field` property or as a vector via - `field_vec`. :class:`~.FieldData` objects support negation, arithmetic operators - (``+``, ``-``, ``*``) and their augmented counterparts (``+=``, ``-=``, ``*=``), with numbers - and other :class:`~.FieldData` objects of the same subclass, if their dimensions and grid - spacings match. It is possible to load data from HDF5 or LLG (.txt) files or to save the data - in these formats. Specialised plotting methods are also provided. - - Attributes - ---------- - a: float - The grid spacing in nm. - field: :class:`~numpy.ndarray` (N=4) - The field distribution for every 3D-gridpoint. - - """ - - _log = logging.getLogger(__name__ + '.FieldData') - - @property - def a(self): - """The grid spacing in nm.""" - return self._a - - @a.setter - def a(self, a): - assert isinstance(a, Number), 'Grid spacing has to be a number!' - assert a >= 0, 'Grid spacing has to be a positive number!' - self._a = float(a) - - @property - def shape(self): - """The shape of the `field` (3D for scalar, 4D vor vector field).""" - return self.field.shape - - @property - def dim(self): - """Dimensions (z, y, x) of the grid, only 3D coordinates, without components if present.""" - return self.shape[-3:] - - @property - def field(self): - """The field strength for every 3D-gridpoint (scalar: 3D, vector: 4D).""" - return self._field - - @field.setter - def field(self, field): - assert isinstance(field, np.ndarray), 'Field has to be a numpy array!' - assert 3 <= len(field.shape) <= 4, 'Field has to be 3- or 4-dimensional (scalar / vector)!' - if len(field.shape) == 4: - assert field.shape[0] == 3, 'A vector field has to have exactly 3 components!' - self._field = field - - @property - def field_amp(self): - """The field amplitude (returns the field itself for scalar and the vector amplitude - calculated via a square sum for a vector field.""" - if len(self.shape) == 4: - return np.sqrt(np.sum(self.field ** 2, axis=0)) - else: - return self.field - - @property - def field_vec(self): - """Vector containing the vector field distribution.""" - return np.reshape(self.field, -1) - - @field_vec.setter - def field_vec(self, mag_vec): - assert np.size(mag_vec) == np.prod(self.shape), \ - 'Vector has to match field shape! {} {}'.format(mag_vec.shape, np.prod(self.shape)) - self.field = mag_vec.reshape((3,) + self.dim) - - def __init__(self, a, field): - self._log.debug('Calling __init__') - self.a = a - self.field = field - self._log.debug('Created ' + str(self)) - - def __repr__(self): - self._log.debug('Calling __repr__') - return '%s(a=%r, field=%r)' % (self.__class__, self.a, self.field) - - def __str__(self): - self._log.debug('Calling __str__') - return '%s(a=%s, dim=%s)' % (self.__class__, self.a, self.dim) - - def __neg__(self): # -self - self._log.debug('Calling __neg__') - return self.__class__(self.a, -self.field) - - def __add__(self, other): # self + other - self._log.debug('Calling __add__') - assert isinstance(other, (FieldData, Number)), \ - 'Only FieldData objects and scalar numbers (as offsets) can be added/subtracted!' - if isinstance(other, Number): # other is a Number - self._log.debug('Adding an offset') - return self.__class__(self.a, self.field + other) - elif isinstance(other, FieldData): - self._log.debug('Adding two FieldData objects') - assert other.a == self.a, 'Added phase has to have the same grid spacing!' - assert other.shape == self.shape, 'Added field has to have the same dimensions!' - return self.__class__(self.a, self.field + other.field) - - def __sub__(self, other): # self - other - self._log.debug('Calling __sub__') - return self.__add__(-other) - - def __mul__(self, other): # self * other - self._log.debug('Calling __mul__') - assert isinstance(other, Number), 'FieldData objects can only be multiplied by numbers!' - return self.__class__(self.a, self.field * other) - - def __truediv__(self, other): # self / other - self._log.debug('Calling __truediv__') - assert isinstance(other, Number), 'FieldData objects can only be divided by numbers!' - return self.__class__(self.a, self.field / other) - - def __floordiv__(self, other): # self // other - self._log.debug('Calling __floordiv__') - assert isinstance(other, Number), 'FieldData objects can only be divided by numbers!' - return self.__class__(self.a, self.field // other) - - def __radd__(self, other): # other + self - self._log.debug('Calling __radd__') - return self.__add__(other) - - def __rsub__(self, other): # other - self - self._log.debug('Calling __rsub__') - return -self.__sub__(other) - - def __rmul__(self, other): # other * self - self._log.debug('Calling __rmul__') - return self.__mul__(other) - - def __iadd__(self, other): # self += other - self._log.debug('Calling __iadd__') - return self.__add__(other) - - def __isub__(self, other): # self -= other - self._log.debug('Calling __isub__') - return self.__sub__(other) - - def __imul__(self, other): # self *= other - self._log.debug('Calling __imul__') - return self.__mul__(other) - - def __itruediv__(self, other): # self /= other - self._log.debug('Calling __itruediv__') - return self.__truediv__(other) - - def __ifloordiv__(self, other): # self //= other - self._log.debug('Calling __ifloordiv__') - return self.__floordiv__(other) - - def __getitem__(self, item): - return self.__class__(self.a, self.field[item]) - - def __array__(self, dtype=None): # Used for numpy ufuncs, together with __array_wrap__! - if dtype: - return self.field.astype(dtype) - else: - return self.field - - def __array_wrap__(self, array, _=None): # _ catches the context, which is not used. - return type(self)(self.a, array) - - def copy(self): - """Returns a copy of the :class:`~.FieldData` object - - Returns - ------- - field_data: :class:`~.FieldData` - A copy of the :class:`~.FieldData`. - - """ - self._log.debug('Calling copy') - return self.__class__(self.a, self.field.copy()) - - def get_mask(self, threshold=0): - """Mask all pixels where the amplitude of the field lies above `threshold`. - - Parameters - ---------- - threshold : float, optional - A pixel only gets masked, if it lies above this threshold . The default is 0. - - Returns - ------- - mask : :class:`~numpy.ndarray` (N=3, boolean) - Mask of the pixels where the amplitude of the field lies above `threshold`. - - """ - self._log.debug('Calling get_mask') - return np.where(self.field_amp > threshold, True, False) - - def plot_mask(self, title='Mask', threshold=0, **kwargs): - """Plot the mask as a 3D-contour plot. - - Parameters - ---------- - title: string, optional - The title for the plot. - threshold : float, optional - A pixel only gets masked, if it lies above this threshold . The default is 0. - - Returns - ------- - plot : :class:`mayavi.modules.vectors.Vectors` - The plot object. - - """ - self._log.debug('Calling plot_mask') - from mayavi import mlab - mlab.figure(size=(750, 700)) - zzz, yyy, xxx = (np.indices(self.dim) + self.a / 2) - zzz, yyy, xxx = zzz.T, yyy.T, xxx.T - mask = self.get_mask(threshold=threshold).astype(int).T # Transpose because of VTK order! - extent = np.ravel(list(zip((0, 0, 0), mask.shape))) - cont = mlab.contour3d(xxx, yyy, zzz, mask, contours=[1], **kwargs) - mlab.outline(cont, extent=extent) - mlab.axes(cont, extent=extent) - mlab.title(title, height=0.95, size=0.35) - mlab.orientation_axes() - cont.scene.isometric_view() - return cont - - def plot_contour3d(self, title='Field Distribution', contours=10, opacity=0.25, **kwargs): - """Plot the field as a 3D-contour plot. - - Parameters - ---------- - title: string, optional - The title for the plot. - contours: int, optional - Number of contours which should be plotted. - opacity: float, optional - Defines the opacity of the contours. Default is 0.25. - - Returns - ------- - plot : :class:`mayavi.modules.vectors.Vectors` - The plot object. - - """ - self._log.debug('Calling plot_contour3d') - from mayavi import mlab - mlab.figure(size=(750, 700)) - zzz, yyy, xxx = (np.indices(self.dim) + self.a / 2) - zzz, yyy, xxx = zzz.T, yyy.T, xxx.T - field_amp = self.field_amp.T # Transpose because of VTK order! - if not isinstance(contours, (list, tuple, np.ndarray)): # Calculate the contours: - contours = list(np.linspace(field_amp.min(), field_amp.max(), contours)) - extent = np.ravel(list(zip((0, 0, 0), field_amp.shape))) - cont = mlab.contour3d(xxx, yyy, zzz, field_amp, contours=contours, - opacity=opacity, **kwargs) - mlab.outline(cont, extent=extent) - mlab.axes(cont, extent=extent) - mlab.title(title, height=0.95, size=0.35) - mlab.orientation_axes() - cont.scene.isometric_view() - return cont - - @abc.abstractmethod - def scale_down(self, n): - """Scale down the field distribution by averaging over two pixels along each axis. - - Parameters - ---------- - n : int, optional - Number of times the field distribution is scaled down. The default is 1. - - Returns - ------- - None - - Notes - ----- - Acts in place and changes dimensions and grid spacing accordingly. - Only possible, if each axis length is a power of 2! - - """ - pass - - @abc.abstractmethod - def scale_up(self, n, order): - """Scale up the field distribution using spline interpolation of the requested order. - - Parameters - ---------- - n : int, optional - Power of 2 with which the grid is scaled. Default is 1, which means every axis is - increased by a factor of ``2**1 = 2``. - order : int, optional - The order of the spline interpolation, which has to be in the range between 0 and 5 - and defaults to 0. - - Returns - ------- - None - - Notes - ----- - Acts in place and changes dimensions and grid spacing accordingly. - """ - pass - - @abc.abstractmethod - def get_vector(self, mask): - """Returns the field as a vector, specified by a mask. - - Parameters - ---------- - mask : :class:`~numpy.ndarray` (N=3, boolean) - Masks the pixels from which the entries should be taken. - - Returns - ------- - vector : :class:`~numpy.ndarray` (N=1) - The vector containing the field of the specified pixels. - - """ - pass - - @abc.abstractmethod - def set_vector(self, vector, mask): - """Set the field of the masked pixels to the values specified by `vector`. - - Parameters - ---------- - mask : :class:`~numpy.ndarray` (N=3, boolean), optional - Masks the pixels from which the field should be taken. - vector : :class:`~numpy.ndarray` (N=1) - The vector containing the field of the specified pixels. - - Returns - ------- - None - - """ - pass - - @classmethod - def from_signal(cls, signal): - """Convert a :class:`~hyperspy.signals.Signal` object to a :class:`~.FieldData` object. - - Parameters - ---------- - signal: :class:`~hyperspy.signals.Signal` - The :class:`~hyperspy.signals.Signal` object which should be converted to FieldData. - - Returns - ------- - magdata: :class:`~.FieldData` - A :class:`~.FieldData` object containing the loaded data. - - Notes - ----- - This method recquires the hyperspy package! - - """ - cls._log.debug('Calling from_signal') - return cls(signal.axes_manager[0].scale, signal.data) - - @abc.abstractmethod - def to_signal(self): - """Convert :class:`~.FieldData` data into a HyperSpy signal. - - Returns - ------- - signal: :class:`~hyperspy.signals.Signal` - Representation of the :class:`~.FieldData` object as a HyperSpy Signal. - - Notes - ----- - This method recquires the hyperspy package! - - """ - self._log.debug('Calling to_signal') - try: # Try importing HyperSpy: - # noinspection PyUnresolvedReferences - import hyperspy.api as hs - except ImportError: - self._log.error('This method recquires the hyperspy package!') - return - # Create signal: - signal = hs.signals.BaseSignal(self.field) # All axes are signal axes! - # Set axes: - signal.axes_manager[0].name = 'x-axis' - signal.axes_manager[0].units = 'nm' - signal.axes_manager[0].scale = self.a - signal.axes_manager[1].name = 'y-axis' - signal.axes_manager[1].units = 'nm' - signal.axes_manager[1].scale = self.a - signal.axes_manager[2].name = 'z-axis' - signal.axes_manager[2].units = 'nm' - signal.axes_manager[2].scale = self.a - return signal - - -class VectorData(FieldData): - - """Class for storing vector ield data. - - Represents 3-dimensional vector field distributions with 3 components which are stored as a - 3-dimensional numpy array in `field`, but which can also be accessed as a vector via - `field_vec`. :class:`~.VectorData` objects support negation, arithmetic operators - (``+``, ``-``, ``*``) and their augmented counterparts (``+=``, ``-=``, ``*=``), withnumbers - and other :class:`~.VectorData` objects, if their dimensions and grid spacings match. It is - possible to load data from HDF5 or LLG (.txt) files or to save the data in these formats. - Plotting methods are also provided. - - Attributes - ---------- - a: float - The grid spacing in nm. - field: :class:`~numpy.ndarray` (N=4) - The `x`-, `y`- and `z`-component of the vector field for every 3D-gridpoint - as a 4-dimensional numpy array (first dimension has to be 3, because of the 3 components). - - """ - _log = logging.getLogger(__name__ + '.VectorData') - - def scale_down(self, n=1): - """Scale down the field distribution by averaging over two pixels along each axis. - - Parameters - ---------- - n : int, optional - Number of times the field distribution is scaled down. The default is 1. - - Returns - ------- - None - - Notes - ----- - Acts in place and changes dimensions and grid spacing accordingly. - Only possible, if each axis length is a power of 2! - - """ - self._log.debug('Calling scale_down') - assert n > 0 and isinstance(n, int), 'n must be a positive integer!' - self.a *= 2 ** n - for t in range(n): - # Pad if necessary: - pz, py, px = self.dim[0] % 2, self.dim[1] % 2, self.dim[2] % 2 - if pz != 0 or py != 0 or px != 0: - self.field = np.pad(self.field, ((0, 0), (0, pz), (0, py), (0, px)), - mode='constant') - # Create coarser grid for the vector field: - shape_4d = (3, self.dim[0] // 2, 2, self.dim[1] // 2, 2, self.dim[2] // 2, 2) - self.field = self.field.reshape(shape_4d).mean(axis=(6, 4, 2)) - - def scale_up(self, n=1, order=0): - """Scale up the field distribution using spline interpolation of the requested order. - - Parameters - ---------- - n : int, optional - Power of 2 with which the grid is scaled. Default is 1, which means every axis is - increased by a factor of ``2**1 = 2``. - order : int, optional - The order of the spline interpolation, which has to be in the range between 0 and 5 - and defaults to 0. - - Returns - ------- - None - - Notes - ----- - Acts in place and changes dimensions and grid spacing accordingly. - """ - self._log.debug('Calling scale_up') - assert n > 0 and isinstance(n, int), 'n must be a positive integer!' - assert 5 > order >= 0 and isinstance(order, int), \ - 'order must be a positive integer between 0 and 5!' - self.a /= 2 ** n - self.field = np.array((zoom(self.field[0], zoom=2 ** n, order=order), - zoom(self.field[1], zoom=2 ** n, order=order), - zoom(self.field[2], zoom=2 ** n, order=order))) - - def pad(self, pad_values): - """Pad the current field distribution with zeros for each individual axis. - - Parameters - ---------- - pad_values : tuple of int - Number of zeros which should be padded. Provided as a tuple where each entry - corresponds to an axis. An entry can be one int (same padding for both sides) or again - a tuple which specifies the pad values for both sides of the corresponding axis. - - Returns - ------- - None - - Notes - ----- - Acts in place and changes dimensions accordingly. - """ - self._log.debug('Calling pad') - assert len(pad_values) == 3, 'Pad values for each dimension have to be provided!' - pv = np.zeros(6, dtype=np.int) - for i, values in enumerate(pad_values): - assert np.shape(values) in [(), (2,)], 'Only one or two values per axis can be given!' - pv[2 * i:2 * (i + 1)] = values - self.field = np.pad(self.field, ((0, 0), (pv[0], pv[1]), (pv[2], pv[3]), (pv[4], pv[5])), - mode='constant') - - def crop(self, crop_values): - """Crop the current field distribution with zeros for each individual axis. - - Parameters - ---------- - crop_values : tuple of int - Number of zeros which should be cropped. Provided as a tuple where each entry - corresponds to an axis. An entry can be one int (same cropping for both sides) or again - a tuple which specifies the crop values for both sides of the corresponding axis. - - Returns - ------- - None - - Notes - ----- - Acts in place and changes dimensions accordingly. - """ - self._log.debug('Calling crop') - assert len(crop_values) == 3, 'Crop values for each dimension have to be provided!' - cv = np.zeros(6, dtype=np.int) - for i, values in enumerate(crop_values): - assert np.shape(values) in [(), (2,)], 'Only one or two values per axis can be given!' - cv[2 * i:2 * (i + 1)] = values - cv *= np.resize([1, -1], len(cv)) - cv = np.where(cv == 0, None, cv) - self.field = self.field[:, cv[0]:cv[1], cv[2]:cv[3], cv[4]:cv[5]] - - def get_vector(self, mask): - """Returns the vector field components arranged in a vector, specified by a mask. - - Parameters - ---------- - mask : :class:`~numpy.ndarray` (N=3, boolean) - Masks the pixels from which the components should be taken. - - Returns - ------- - vector : :class:`~numpy.ndarray` (N=1) - The vector containing vector field components of the specified pixels. - Order is: first all `x`-, then all `y`-, then all `z`-components. - - """ - self._log.debug('Calling get_vector') - if mask is not None: - return np.reshape([self.field[0][mask], - self.field[1][mask], - self.field[2][mask]], -1) - else: - return self.field_vec - - def set_vector(self, vector, mask=None): - """Set the field components of the masked pixels to the values specified by `vector`. - - Parameters - ---------- - mask : :class:`~numpy.ndarray` (N=3, boolean), optional - Masks the pixels from which the components should be taken. - vector : :class:`~numpy.ndarray` (N=1) - The vector containing vector field components of the specified pixels. - Order is: first all `x`-, then all `y-, then all `z`-components. - - Returns - ------- - None - - """ - self._log.debug('Calling set_vector') - assert np.size(vector) % 3 == 0, 'Vector has to contain all 3 components for every pixel!' - count = np.size(vector) // 3 - if mask is not None: - self.field[0][mask] = vector[:count] # x-component - self.field[1][mask] = vector[count:2 * count] # y-component - self.field[2][mask] = vector[2 * count:] # z-component - else: - self.field_vec = vector - - def flip(self, axis='x'): - """Flip/mirror the vector field around the specified axis. - - Parameters - ---------- - axis: {'x', 'y', 'z'}, optional - The axis around which the vector field is flipped. - - Returns - ------- - magdata_flip: :class:`~.VectorData` - A flipped copy of the :class:`~.VectorData` object. - - """ - self._log.debug('Calling flip') - if axis == 'x': - mag_x, mag_y, mag_z = self.field[:, :, :, ::-1] - field_flip = np.array((-mag_x, mag_y, mag_z)) - elif axis == 'y': - mag_x, mag_y, mag_z = self.field[:, :, ::-1, :] - field_flip = np.array((mag_x, -mag_y, mag_z)) - elif axis == 'z': - mag_x, mag_y, mag_z = self.field[:, ::-1, :, :] - field_flip = np.array((mag_x, mag_y, -mag_z)) - else: - raise ValueError("Wrong input! 'x', 'y', 'z' allowed!") - return VectorData(self.a, field_flip) - - def rot90(self, axis='x'): - """Rotate the vector field 90° around the specified axis (right hand rotation). - - Parameters - ---------- - axis: {'x', 'y', 'z'}, optional - The axis around which the vector field is rotated. - - Returns - ------- - magdata_rot: :class:`~.VectorData` - A rotated copy of the :class:`~.VectorData` object. - - """ - self._log.debug('Calling rot90') - if axis == 'x': - field_rot = np.zeros((3, self.dim[1], self.dim[0], self.dim[2])) - for i in range(self.dim[2]): - mag_x, mag_y, mag_z = self.field[:, :, :, i] - mag_xrot, mag_yrot, mag_zrot = np.rot90(mag_x), np.rot90(mag_y), np.rot90(mag_z) - field_rot[:, :, :, i] = np.array((mag_xrot, mag_zrot, -mag_yrot)) - elif axis == 'y': - field_rot = np.zeros((3, self.dim[2], self.dim[1], self.dim[0])) - for i in range(self.dim[1]): - mag_x, mag_y, mag_z = self.field[:, :, i, :] - mag_xrot, mag_yrot, mag_zrot = np.rot90(mag_x), np.rot90(mag_y), np.rot90(mag_z) - field_rot[:, :, i, :] = np.array((mag_zrot, mag_yrot, -mag_xrot)) - elif axis == 'z': - field_rot = np.zeros((3, self.dim[0], self.dim[2], self.dim[1])) - for i in range(self.dim[0]): - mag_x, mag_y, mag_z = self.field[:, i, :, :] - mag_xrot, mag_yrot, mag_zrot = np.rot90(mag_x), np.rot90(mag_y), np.rot90(mag_z) - field_rot[:, i, :, :] = np.array((mag_yrot, -mag_xrot, mag_zrot)) - else: - raise ValueError("Wrong input! 'x', 'y', 'z' allowed!") - return VectorData(self.a, field_rot) - - def get_slice(self, ax_slice=None, proj_axis='z'): - """Extract a slice from the :class:`~.VectorData` object. - - Parameters - ---------- - proj_axis : {'z', 'y', 'x'}, optional - The axis, from which the slice is taken. The default is 'z'. - ax_slice : None or int, optional - The slice-index of the axis specified in `proj_axis`. Defaults to the center slice. - - Returns - ------- - u_mag, v_mag, w_mag, submask : :class:`~numpy.ndarray` (N=2) - The extracted vector field components in plane perpendicular to the `proj_axis` and - the perpendicular component. - - """ - self._log.debug('Calling get_slice') - # Find slice: - assert proj_axis == 'z' or proj_axis == 'y' or proj_axis == 'x', \ - 'Axis has to be x, y or z (as string).' - if ax_slice is None: - ax_slice = self.dim[{'z': 0, 'y': 1, 'x': 2}[proj_axis]] // 2 - if proj_axis == 'z': # Slice of the xy-plane with z = ax_slice - self._log.debug('proj_axis == z') - u_mag = np.copy(self.field[0][ax_slice, ...]) # x-component - v_mag = np.copy(self.field[1][ax_slice, ...]) # y-component - w_mag = np.copy(self.field[2][ax_slice, ...]) # z-component - elif proj_axis == 'y': # Slice of the xz-plane with y = ax_slice - self._log.debug('proj_axis == y') - u_mag = np.copy(self.field[0][:, ax_slice, :]) # x-component - v_mag = np.copy(self.field[2][:, ax_slice, :]) # z-component - w_mag = np.copy(self.field[1][:, ax_slice, :]) # y-component - elif proj_axis == 'x': # Slice of the zy-plane with x = ax_slice - self._log.debug('proj_axis == x') - u_mag = np.swapaxes(np.copy(self.field[2][..., ax_slice]), 0, 1) # z-component - v_mag = np.swapaxes(np.copy(self.field[1][..., ax_slice]), 0, 1) # y-component - w_mag = np.swapaxes(np.copy(self.field[0][..., ax_slice]), 0, 1) # x-component - else: - raise ValueError('{} is not a valid argument (use x, y or z)'.format(proj_axis)) - return u_mag, v_mag, w_mag - - def to_signal(self): - """Convert :class:`~.VectorData` data into a HyperSpy signal. - - Returns - ------- - signal: :class:`~hyperspy.signals.Signal` - Representation of the :class:`~.VectorData` object as a HyperSpy Signal. - - Notes - ----- - This method recquires the hyperspy package! - - """ - self._log.debug('Calling to_signal') - signal = super().to_signal() - # Set component axis: - signal.axes_manager[3].name = 'x/y/z-component' - signal.axes_manager[3].units = '' - # Set metadata: - signal.metadata.Signal.title = 'VectorData' - # Return signal: - return signal - - def save(self, filename, **kwargs): - """Saves the VectorData in the specified format. - - The function gets the format from the extension: - - hdf5 for HDF5. - - EMD Electron Microscopy Dataset format (also HDF5). - - llg format. - - ovf format. - - npy or npz for numpy formats. - - If no extension is provided, 'hdf5' is used. Most formats are - saved with the HyperSpy package (internally the fielddata is first - converted to a HyperSpy Signal. - - Each format accepts a different set of parameters. For details - see the specific format documentation. - - Parameters - ---------- - filename : str, optional - Name of the file which the VectorData is saved into. The extension - determines the saving procedure. - - """ - from .file_io.io_vectordata import save_vectordata - save_vectordata(self, filename, **kwargs) - - def plot_quiver(self, ar_dens=1, log=False, scaled=True, scale=1., b_0=None, # Only used here! - coloring='angle', cmap=None, # Used here and plot_streamlines! - proj_axis='z', ax_slice=None, show_mask=True, bgcolor=None, axis=None, - figsize=None, **kwargs): - """Plot a slice of the vector field as a quiver plot. - - Parameters - ---------- - ar_dens: int, optional - Number defining the arrow density which is plotted. A higher ar_dens number skips more - arrows (a number of 2 plots every second arrow). Default is 1. - log : boolean, optional - The loratihm of the arrow length is plotted instead. This is helpful if only the - direction of the arrows is important and the amplitude varies a lot. Default is False. - scaled : boolean, optional - Normalizes the plotted arrows in respect to the highest one. Default is True. - scale: float, optional - Additional multiplicative factor scaling the arrow length. Default is 1 - (no further scaling). - b_0 : float, optional - Saturation induction (saturation magnetisation times the vacuum permeability). - If this is specified, a quiverkey is used to indicate the length of the longest arrow. - coloring : {'angle', 'amplitude', 'uniform', matplotlib color} - Color coding mode of the arrows. Use 'full' (default), 'angle', 'amplitude', 'uniform' - (black or white, depending on `bgcolor`), or a matplotlib color keyword. - cmap : string, optional - The :class:`~matplotlib.colors.Colormap` which is used for the plot as a string. - If not set, an appropriate one is used. Note that a subclass of - :class:`~.colors.Colormap3D` should be used for angle encoding. - proj_axis : {'z', 'y', 'x'}, optional - The axis, from which a slice is plotted. The default is 'z'. - ax_slice : int, optional - The slice-index of the axis specified in `proj_axis`. Is set to the center of - `proj_axis` if not specified. - show_mask: boolean - Default is True. Shows the outlines of the mask slice if available. - bgcolor: {'white', 'black'}, optional - Determines the background color of the plot. - axis : :class:`~matplotlib.axes.AxesSubplot`, optional - Axis on which the graph is plotted. Creates a new figure if none is specified. - figsize : tuple of floats (N=2) - Size of the plot figure. - - Returns - ------- - axis: :class:`~matplotlib.axes.AxesSubplot` - The axis on which the graph is plotted. - - Notes - ----- - Uses :func:`~.plottools.format_axis` at the end. According keywords can also be given here. - - """ - self._log.debug('Calling plot_quiver') - a = self.a - if figsize is None: - figsize = plottools.FIGSIZE_DEFAULT - assert proj_axis == 'z' or proj_axis == 'y' or proj_axis == 'x', \ - 'Axis has to be x, y or z (as string).' - if ax_slice is None: - ax_slice = self.dim[{'z': 0, 'y': 1, 'x': 2}[proj_axis]] // 2 - # Extract slice and mask: - u_mag, v_mag = self.get_slice(ax_slice, proj_axis)[:2] - submask = np.where(np.hypot(u_mag, v_mag) > 0, True, False) - # Prepare quiver (select only used arrows if ar_dens is specified): - dim_uv = u_mag.shape - vv, uu = np.indices(dim_uv) + 0.5 # shift to center of pixel - uu = uu[::ar_dens, ::ar_dens] - vv = vv[::ar_dens, ::ar_dens] - u_mag = u_mag[::ar_dens, ::ar_dens] - v_mag = v_mag[::ar_dens, ::ar_dens] - amplitudes = np.hypot(u_mag, v_mag) - angles = np.angle(u_mag + 1j * v_mag, deg=True).tolist() - # Calculate the arrow colors: - if bgcolor is None: - bgcolor = 'white' # Default! - cmap_overwrite = cmap - if coloring == 'angle': - self._log.debug('Encoding angles') - hue = np.asarray(np.arctan2(v_mag, u_mag) / (2 * np.pi)) - hue[hue < 0] += 1 - cmap = colors.CMAP_CIRCULAR_DEFAULT - elif coloring == 'amplitude': - self._log.debug('Encoding amplitude') - hue = amplitudes / amplitudes.max() - if bgcolor == 'white': - cmap = colors.cmaps['cubehelix_reverse'] - else: - cmap = colors.cmaps['cubehelix_standard'] - elif coloring == 'uniform': - self._log.debug('Automatic uniform color encoding') - hue = amplitudes / amplitudes.max() - if bgcolor == 'white': - cmap = colors.cmaps['transparent_black'] - else: - cmap = colors.cmaps['transparent_white'] - else: - self._log.debug('Specified uniform color encoding') - hue = np.zeros_like(u_mag) - cmap = ListedColormap([coloring]) - if cmap_overwrite is not None: - cmap = cmap_overwrite - # If no axis is specified, a new figure is created: - if axis is None: - self._log.debug('axis is None') - fig = plt.figure(figsize=figsize) - axis = fig.add_subplot(1, 1, 1) - tight = True - else: - tight = False - axis.set_aspect('equal') - # Take the logarithm of the arrows to clearly show directions (if specified): - if log and np.any(amplitudes): # If the slice is empty, skip! - cutoff = 10 - amp = np.round(amplitudes, decimals=cutoff) - min_value = amp[np.nonzero(amp)].min() - u_mag = np.round(u_mag, decimals=cutoff) / min_value - u_mag = np.log10(np.abs(u_mag) + 1) * np.sign(u_mag) - v_mag = np.round(v_mag, decimals=cutoff) / min_value - v_mag = np.log10(np.abs(v_mag) + 1) * np.sign(v_mag) - amplitudes = np.hypot(u_mag, v_mag) # Recalculate (used if scaled)! - # Scale the amplitude of the arrows to the highest one (if specified): - if scaled: - u_mag /= amplitudes.max() + 1E-30 - v_mag /= amplitudes.max() + 1E-30 - # Plot quiver: - quiv = axis.quiver(uu, vv, u_mag, v_mag, hue, cmap=cmap, clim=(0, 1), angles=angles, - pivot='middle', units='xy', scale_units='xy', scale=scale / ar_dens, - minlength=0.05, width=1*ar_dens, headlength=2, headaxislength=2, - headwidth=2, minshaft=2) - axis.set_xlim(0, dim_uv[1]) - axis.set_ylim(0, dim_uv[0]) - # Determine colormap if necessary: - if coloring == 'amplitude': - cbar_mappable, cbar_label = quiv, 'amplitude' - else: - cbar_mappable, cbar_label = None, None - # Change background color: - axis.set_axis_bgcolor(bgcolor) - # Show mask: - if show_mask and not np.all(submask): # Plot mask if desired and not trivial! - vv, uu = np.indices(dim_uv) + 0.5 # shift to center of pixel - mask_color = 'white' if bgcolor == 'black' else 'black' - axis.contour(uu, vv, submask, levels=[0.5], colors=mask_color, - linestyles='dotted', linewidths=2) - # Plot quiverkey if B_0 is specified): - if b_0 and not log: # The angles needed for log would break the quiverkey! - label = '{:.3g} T'.format(amplitudes.max() * b_0) - quiv.angles = 'uv' # With a list of angles, the quiverkey would break! - stroke = plottools.STROKE_DEFAULT - txtcolor = 'w' if stroke == 'k' else 'k' - edgecolor = stroke if stroke is not None else 'none' - fontsize = kwargs.get('fontsize', None) - if fontsize is None: - fontsize = plottools.FONTSIZE_DEFAULT - qk = plt.quiverkey(Q=quiv, X=0.88, Y=0.065, U=1, label=label, labelpos='W', - coordinates='axes', facecolor=txtcolor, edgecolor=edgecolor, - labelcolor=txtcolor, linewidth=0.5, - clip_box=axis.bbox, clip_on=True, - fontproperties={'size': kwargs.get('fontsize', fontsize)}) - if stroke is not None: - qk.text.set_path_effects( - [patheffects.withStroke(linewidth=2, foreground=stroke)]) - # Return formatted axis: - return plottools.format_axis(axis, sampling=a, cbar_mappable=cbar_mappable, - cbar_label=cbar_label, tight_layout=tight, **kwargs) - - def plot_field(self, proj_axis='z', ax_slice=None, show_mask=True, bgcolor=None, axis=None, - figsize=None, **kwargs): - """Plot a slice of the vector field as a color field imshow plot. - - Parameters - ---------- - proj_axis : {'z', 'y', 'x'}, optional - The axis, from which a slice is plotted. The default is 'z'. - ax_slice : int, optional - The slice-index of the axis specified in `proj_axis`. Is set to the center of - `proj_axis` if not specified. - show_mask: boolean - Default is True. Shows the outlines of the mask slice if available. - bgcolor: {'white', 'black'}, optional - Determines the background color of the plot. - axis : :class:`~matplotlib.axes.AxesSubplot`, optional - Axis on which the graph is plotted. Creates a new figure if none is specified. - figsize : tuple of floats (N=2) - Size of the plot figure. - - Returns - ------- - axis: :class:`~matplotlib.axes.AxesSubplot` - The axis on which the graph is plotted. - - Notes - ----- - Uses :func:`~.plottools.format_axis` at the end. According keywords can also be given here. - - """ - self._log.debug('Calling plot_field') - a = self.a - if figsize is None: - figsize = plottools.FIGSIZE_DEFAULT - assert proj_axis == 'z' or proj_axis == 'y' or proj_axis == 'x', \ - 'Axis has to be x, y or z (as string).' - if ax_slice is None: - ax_slice = self.dim[{'z': 0, 'y': 1, 'x': 2}[proj_axis]] // 2 - # Extract slice and mask: - u_mag, v_mag, w_mag = self.get_slice(ax_slice, proj_axis) - submask = np.where(np.hypot(u_mag, v_mag) > 0, True, False) - # If no axis is specified, a new figure is created: - if axis is None: - self._log.debug('axis is None') - fig = plt.figure(figsize=figsize) - axis = fig.add_subplot(1, 1, 1) - tight = True - else: - tight = False - axis.set_aspect('equal') - # Determine 'z'-component for luminance (keep as gray if None): - z_mag = w_mag - if bgcolor == 'white': - z_mag = np.where(submask, z_mag, np.max(np.hypot(u_mag, v_mag))) - if bgcolor == 'black': - z_mag = np.where(submask, z_mag, -np.max(np.hypot(u_mag, v_mag))) - # Plot the field: - dim_uv = u_mag.shape - rgb = colors.CMAP_CIRCULAR_DEFAULT.rgb_from_vector(np.asarray((u_mag, v_mag, z_mag))) - axis.imshow(Image.fromarray(rgb), origin='lower', interpolation='none', - extent=(0, dim_uv[1], 0, dim_uv[0])) - # Change background color: - if bgcolor is not None: - axis.set_axis_bgcolor(bgcolor) - # Show mask: - if show_mask and not np.all(submask): # Plot mask if desired and not trivial! - vv, uu = np.indices(dim_uv) + 0.5 # shift to center of pixel - mask_color = 'white' if bgcolor == 'black' else 'black' - axis.contour(uu, vv, submask, levels=[0.5], colors=mask_color, - linestyles='dotted', linewidths=2) - # Return formatted axis: - return plottools.format_axis(axis, sampling=a, tight_layout=tight, **kwargs) - - def plot_quiver_field(self, **kwargs): - """Plot the vector field as a field plot with uniformly colored arrows overlayed. - - Parameters - ---------- - See :func:`~.plot_quiver` and :func:`~.plot_quiver` for parameters! - - Returns - ------- - axis: :class:`~matplotlib.axes.AxesSubplot` - The axis on which the graph is plotted. - - """ - # Extract parameters: - show_mask = kwargs.pop('show_mask', True) # Only needed once! - axis = kwargs.pop('axis', None) - # Set default bgcolor to white (only for combined plot), only if bgcolor was not specified: - kwargs.setdefault('bgcolor', 'white') - # Plot field first (with mask and axis formatting), then quiver: - axis = self.plot_field(axis=axis, show_mask=show_mask, **kwargs) - self.plot_quiver(coloring='uniform', show_mask=False, axis=axis, - format_axis=False, **kwargs) - # Return plotting axis: - return axis - - def plot_streamline(self, density=2, linewidth=2, coloring='angle', cmap=None, - proj_axis='z', ax_slice=None, show_mask=True, bgcolor=None, axis=None, - figsize=None, **kwargs): - """Plot a slice of the vector field as a quiver plot. - - Parameters - ---------- - density : float or 2-tuple, optional - Controls the closeness of streamlines. When density = 1, the domain is divided into a - 30x30 grid—density linearly scales this grid. Each cebll in the grid can have, at most, - one traversing streamline. For different densities in each direction, use - [density_x, density_y]. - linewidth : numeric or 2d array, optional - Vary linewidth when given a 2d array with the same shape as velocities. - coloring : {'angle', 'amplitude', 'uniform'} - Color coding mode of the arrows. Use 'full' (default), 'angle', 'amplitude' or - 'uniform'. - cmap : string, optional - The :class:`~matplotlib.colors.Colormap` which is used for the plot as a string. - If not set, an appropriate one is used. Note that a subclass of - :class:`~.colors.Colormap3D` should be used for angle encoding. - proj_axis : {'z', 'y', 'x'}, optional - The axis, from which a slice is plotted. The default is 'z'. - ax_slice : int, optional - The slice-index of the axis specified in `proj_axis`. Is set to the center of - `proj_axis` if not specified. - show_mask: boolean - Default is True. Shows the outlines of the mask slice if available. - bgcolor: {'white', 'black'}, optional - Determines the background color of the plot. - axis : :class:`~matplotlib.axes.AxesSubplot`, optional - Axis on which the graph is plotted. Creates a new figure if none is specified. - figsize : tuple of floats (N=2) - Size of the plot figure. - - Returns - ------- - axis: :class:`~matplotlib.axes.AxesSubplot` - The axis on which the graph is plotted. - - Notes - ----- - Uses :func:`~.plottools.format_axis` at the end. According keywords can also be given here. - - """ - self._log.debug('Calling plot_quiver') - a = self.a - if figsize is None: - figsize = plottools.FIGSIZE_DEFAULT - assert proj_axis == 'z' or proj_axis == 'y' or proj_axis == 'x', \ - 'Axis has to be x, y or z (as string).' - if ax_slice is None: - ax_slice = self.dim[{'z': 0, 'y': 1, 'x': 2}[proj_axis]] // 2 - u_mag, v_mag = self.get_slice(ax_slice, proj_axis)[:2] - submask = np.where(np.hypot(u_mag, v_mag) > 0, True, False) - # Prepare streamlines: - dim_uv = u_mag.shape - uu = np.arange(dim_uv[1]) + 0.5 # shift to center of pixel - vv = np.arange(dim_uv[0]) + 0.5 # shift to center of pixel - u_mag, v_mag = self.get_slice(ax_slice, proj_axis)[:2] - # v_mag = np.ma.array(v_mag, mask=submask) - amplitudes = np.hypot(u_mag, v_mag) - # Calculate the arrow colors: - if bgcolor is None: - bgcolor = 'white' # Default! - cmap_overwrite = cmap - if coloring == 'angle': - self._log.debug('Encoding angles') - hue = np.asarray(np.arctan2(v_mag, u_mag) / (2 * np.pi)) - hue[hue < 0] += 1 - cmap = colors.CMAP_CIRCULAR_DEFAULT - elif coloring == 'amplitude': - self._log.debug('Encoding amplitude') - hue = amplitudes / amplitudes.max() - if bgcolor == 'white': - cmap = colors.cmaps['cubehelix_reverse'] - else: - cmap = colors.cmaps['cubehelix_standard'] - elif coloring == 'uniform': - self._log.debug('Automatic uniform color encoding') - hue = amplitudes / amplitudes.max() - if bgcolor == 'white': - cmap = colors.cmaps['transparent_black'] - else: - cmap = colors.cmaps['transparent_white'] - else: - self._log.debug('Specified uniform color encoding') - hue = np.zeros_like(u_mag) - cmap = ListedColormap([coloring]) - if cmap_overwrite is not None: - cmap = cmap_overwrite - # If no axis is specified, a new figure is created: - if axis is None: - self._log.debug('axis is None') - fig = plt.figure(figsize=figsize) - axis = fig.add_subplot(1, 1, 1) - tight = True - else: - tight = False - axis.set_aspect('equal') - # Plot the streamlines: - im = plt.streamplot(uu, vv, u_mag, v_mag, density=density, linewidth=linewidth, - color=hue, cmap=cmap) - # Determine colormap if necessary: - if coloring == 'amplitude': - cbar_mappable, cbar_label = im, 'amplitude' - else: - cbar_mappable, cbar_label = None, None - # Change background color: - axis.set_axis_bgcolor(bgcolor) - # Show mask: - if show_mask and not np.all(submask): # Plot mask if desired and not trivial! - vv, uu = np.indices(dim_uv) + 0.5 # shift to center of pixel - mask_color = 'white' if bgcolor == 'black' else 'black' - axis.contour(uu, vv, submask, levels=[0.5], colors=mask_color, - linestyles='dotted', linewidths=2) - # Return formatted axis: - return plottools.format_axis(axis, sampling=a, cbar_mappable=cbar_mappable, - cbar_label=cbar_label, tight_layout=tight, **kwargs) - - def plot_quiver3d(self, title='Vector Field', limit=None, cmap='jet', mode='2darrow', - coloring='angle', ar_dens=1, opacity=1.0, grid=True, labels=True, - figsize=None): - """Plot the vector field as 3D-vectors in a quiverplot. - - Parameters - ---------- - title : string, optional - The title for the plot. - limit : float, optional - Plotlimit for the vector field arrow length used to scale the colormap. - cmap : string, optional - String describing the colormap which is used for amplitude encoding (default is 'jet'). - ar_dens: int, optional - Number defining the arrow density which is plotted. A higher ar_dens number skips more - arrows (a number of 2 plots every second arrow). Default is 1. - mode: string, optional - Mode, determining the glyphs used in the 3D plot. Default is '2darrow', which - corresponds to 2D arrows. For smaller amounts of arrows, 'arrow' (3D) is prettier. - coloring : {'angle', 'amplitude'}, optional - Color coding mode of the arrows. Use 'angle' (default) or 'amplitude'. - opacity: float, optional - Defines the opacity of the arrows. Default is 1.0 (completely opaque). - - Returns - ------- - plot : :class:`mayavi.modules.vectors.Vectors` - The plot object. - - """ - self._log.debug('Calling quiver_plot3D') - from mayavi import mlab - if limit is None: - limit = np.max(np.nan_to_num(self.field_amp)) - ad = ar_dens - # Create points and vector components as lists: - zzz, yyy, xxx = (np.indices(self.dim) + self.a / 2) - zzz = zzz[::ad, ::ad, ::ad].ravel() - yyy = yyy[::ad, ::ad, ::ad].ravel() - xxx = xxx[::ad, ::ad, ::ad].ravel() - x_mag = self.field[0][::ad, ::ad, ::ad].ravel() - y_mag = self.field[1][::ad, ::ad, ::ad].ravel() - z_mag = self.field[2][::ad, ::ad, ::ad].ravel() - # Plot them as vectors: - if figsize is None: - figsize = (750, 700) - mlab.figure(size=figsize, bgcolor=(0.5, 0.5, 0.5), fgcolor=(0., 0., 0.)) - extent = np.ravel(list(zip((0, 0, 0), (self.dim[2], self.dim[1], self.dim[0])))) - if coloring == 'angle': # Encodes the full angle via colorwheel and saturation: - self._log.debug('Encoding full 3D angles') - vecs = mlab.quiver3d(xxx, yyy, zzz, x_mag, y_mag, z_mag, mode=mode, opacity=opacity, - scalars=np.arange(len(xxx)), line_width=2) - vector = np.asarray((x_mag.ravel(), y_mag.ravel(), z_mag.ravel())) - rgb = colors.CMAP_CIRCULAR_DEFAULT.rgb_from_vector(vector) - rgba = np.hstack((rgb, 255 * np.ones((len(xxx), 1), dtype=np.uint8))) - vecs.glyph.color_mode = 'color_by_scalar' - vecs.module_manager.scalar_lut_manager.lut.table = rgba - mlab.draw() - elif coloring == 'amplitude': # Encodes the amplitude of the arrows with the jet colormap: - self._log.debug('Encoding amplitude') - vecs = mlab.quiver3d(xxx, yyy, zzz, x_mag, y_mag, z_mag, - mode=mode, colormap=cmap, opacity=opacity, line_width=2) - mlab.colorbar(label_fmt='%.2f') - mlab.colorbar(orientation='vertical') - else: - raise AttributeError('Coloring mode not supported!') - vecs.glyph.glyph_source.glyph_position = 'center' - vecs.module_manager.vector_lut_manager.data_range = np.array([0, limit]) - if grid: - mlab.outline(vecs, extent=extent) - if labels: - mlab.axes(vecs, extent=extent) - mlab.title(title, height=0.95, size=0.35) - mlab.orientation_axes() - return vecs - - def plot_quiver3d_to_2d(self, dim_uv=None, axis=None, figsize=None, azimuth=45, - elevation=60, distance=420, high_res=False, quiv_kwargs=None, - **kwargs): - from mayavi import mlab - if figsize is None: - figsize = plottools.FIGSIZE_DEFAULT - if axis is None: - self._log.debug('axis is None') - fig = plt.figure(figsize=figsize) - axis = fig.add_subplot(1, 1, 1) - axis.set_axis_bgcolor('gray') - kwargs.setdefault('labels', 'False') - if quiv_kwargs is None: - quiv_kwargs = {} - mlab.options.offscreen(True) - self.plot_quiver3d(figsize=(800, 800), **quiv_kwargs) - mlab.view(azimuth=azimuth, elevation=elevation, distance=distance) - if high_res: # Use temp files: - tmpdir = tempfile.mkdtemp() - temp_path = os.path.join(tmpdir, 'temp.png') - try: - mlab.savefig(temp_path, size=(2000, 2000)) - imgmap = np.asarray(Image.open(temp_path)) - except Exception as e: - raise e - finally: - os.remove(temp_path) - os.rmdir(tmpdir) - else: # Use screenshot (returns array WITH alpha!): - imgmap = mlab.screenshot(mode='rgba', antialiased=True) - mlab.close(mlab.gcf()) - mlab.options.offscreen(False) - if dim_uv is None: - dim_uv = self.dim[1:] - axis.imshow(imgmap, extent=[0, dim_uv[0], 0, dim_uv[1]], origin='upper') - kwargs.setdefault('scalebar', False) - kwargs.setdefault('hideaxes', True) - return plottools.format_axis(axis, **kwargs) - - -class ScalarData(FieldData): - """Class for storing scalar field data. - - Represents 3-dimensional scalar field distributions which is stored as a 3-dimensional - numpy array in `field`, but which can also be accessed as a vector via `field_vec`. - :class:`~.ScalarData` objects support negation, arithmetic operators (``+``, ``-``, ``*``) - and their augmented counterparts (``+=``, ``-=``, ``*=``), with numbers and other - :class:`~.ScalarData` objects, if their dimensions and grid spacings match. It is possible - to load data from HDF5 or LLG (.txt) files or to save the data in these formats. - Plotting methods are also provided. - - Attributes - ---------- - a: float - The grid spacing in nm. - field: :class:`~numpy.ndarray` (N=4) - The scalar field. - - """ - _log = logging.getLogger(__name__ + '.ScalarData') - - def scale_down(self, n=1): - """Scale down the field distribution by averaging over two pixels along each axis. - - Parameters - ---------- - n : int, optional - Number of times the field distribution is scaled down. The default is 1. - - Returns - ------- - None - - Notes - ----- - Acts in place and changes dimensions and grid spacing accordingly. - Only possible, if each axis length is a power of 2! - - """ - self._log.debug('Calling scale_down') - assert n > 0 and isinstance(n, int), 'n must be a positive integer!' - self.a *= 2 ** n - for t in range(n): - # Pad if necessary: - pz, py, px = self.dim[0] % 2, self.dim[1] % 2, self.dim[2] % 2 - if pz != 0 or py != 0 or px != 0: - self.field = np.pad(self.field, ((0, pz), (0, py), (0, px)), mode='constant') - # Create coarser grid for the field: - shape_4d = (self.dim[0] / 2, 2, self.dim[1] / 2, 2, self.dim[2] / 2, 2) - self.field = self.field.reshape(shape_4d).mean(axis=(5, 3, 1)) - - def scale_up(self, n=1, order=0): - """Scale up the field distribution using spline interpolation of the requested order. - - Parameters - ---------- - n : int, optional - Power of 2 with which the grid is scaled. Default is 1, which means every axis is - increased by a factor of ``2**1 = 2``. - order : int, optional - The order of the spline interpolation, which has to be in the range between 0 and 5 - and defaults to 0. - - Returns - ------- - None - - Notes - ----- - Acts in place and changes dimensions and grid spacing accordingly. - """ - self._log.debug('Calling scale_up') - assert n > 0 and isinstance(n, int), 'n must be a positive integer!' - assert 5 > order >= 0 and isinstance(order, int), \ - 'order must be a positive integer between 0 and 5!' - self.a /= 2 ** n - self.field = zoom(self.field, zoom=2 ** n, order=order) - - def get_vector(self, mask): - """Returns the field as a vector, specified by a mask. - - Parameters - ---------- - mask : :class:`~numpy.ndarray` (N=3, boolean) - Masks the pixels from which the components should be taken. - - Returns - ------- - vector : :class:`~numpy.ndarray` (N=1) - The vector containing the field of the specified pixels. - - """ - self._log.debug('Calling get_vector') - if mask is not None: - return np.reshape(self.field[mask], -1) - else: - return self.field_vec - - def set_vector(self, vector, mask=None): - """Set the field components of the masked pixels to the values specified by `vector`. - - Parameters - ---------- - mask : :class:`~numpy.ndarray` (N=3, boolean), optional - Masks the pixels from which the components should be taken. - vector : :class:`~numpy.ndarray` (N=1) - The vector containing the field of the specified pixels. - - Returns - ------- - None - - """ - self._log.debug('Calling set_vector') - if mask is not None: - self.field[mask] = vector - else: - self.field_vec = vector - - def to_signal(self): - """Convert :class:`~.ScalarData` data into a HyperSpy signal. - - Returns - ------- - signal: :class:`~hyperspy.signals.Signal` - Representation of the :class:`~.ScalarData` object as a HyperSpy Signal. - - Notes - ----- - This method recquires the hyperspy package! - - """ - self._log.debug('Calling to_signal') - signal = super().to_signal() - # Set metadata: - signal.metadata.Signal.title = 'ScalarData' - # Return signal: - return signal - - def save(self, filename, **kwargs): - """Saves the ScalarData in the specified format. - - The function gets the format from the extension: - - hdf5 for HDF5. - - EMD Electron Microscopy Dataset format (also HDF5). - - npy or npz for numpy formats. - - If no extension is provided, 'hdf5' is used. Most formats are - saved with the HyperSpy package (internally the fielddata is first - converted to a HyperSpy Signal. - - Each format accepts a different set of parameters. For details - see the specific format documentation. - - Parameters - ---------- - filename : str, optional - Name of the file which the ScalarData is saved into. The extension - determines the saving procedure. - - """ - from .file_io.io_scalardata import save_scalardata - save_scalardata(self, filename, **kwargs) +# -*- coding: utf-8 -*- +# Copyright 2016 by Forschungszentrum Juelich GmbH +# Author: J. Caron +# +"""This module provides classes for storing vector and scalar 3D-field.""" + +import logging + +import os + +import tempfile + +import abc +from numbers import Number + +import numpy as np + +from matplotlib import pyplot as plt +from matplotlib.colors import ListedColormap +from matplotlib import patheffects + +from PIL import Image + +from scipy.ndimage.interpolation import zoom + +from . import colors +from . import plottools + +__all__ = ['VectorData', 'ScalarData'] + + +class FieldData(object, metaclass=abc.ABCMeta): + """Class for storing field data. + + Abstract base class for the representatio of magnetic or electric fields (see subclasses). + Fields can be accessed as 3D numpy arrays via the `field` property or as a vector via + `field_vec`. :class:`~.FieldData` objects support negation, arithmetic operators + (``+``, ``-``, ``*``) and their augmented counterparts (``+=``, ``-=``, ``*=``), with numbers + and other :class:`~.FieldData` objects of the same subclass, if their dimensions and grid + spacings match. It is possible to load data from HDF5 or LLG (.txt) files or to save the data + in these formats. Specialised plotting methods are also provided. + + Attributes + ---------- + a: float + The grid spacing in nm. + field: :class:`~numpy.ndarray` (N=4) + The field distribution for every 3D-gridpoint. + + """ + + _log = logging.getLogger(__name__ + '.FieldData') + + @property + def a(self): + """The grid spacing in nm.""" + return self._a + + @a.setter + def a(self, a): + assert isinstance(a, Number), 'Grid spacing has to be a number!' + assert a >= 0, 'Grid spacing has to be a positive number!' + self._a = float(a) + + @property + def shape(self): + """The shape of the `field` (3D for scalar, 4D vor vector field).""" + return self.field.shape + + @property + def dim(self): + """Dimensions (z, y, x) of the grid, only 3D coordinates, without components if present.""" + return self.shape[-3:] + + @property + def field(self): + """The field strength for every 3D-gridpoint (scalar: 3D, vector: 4D).""" + return self._field + + @field.setter + def field(self, field): + assert isinstance(field, np.ndarray), 'Field has to be a numpy array!' + assert 3 <= len(field.shape) <= 4, 'Field has to be 3- or 4-dimensional (scalar / vector)!' + if len(field.shape) == 4: + assert field.shape[0] == 3, 'A vector field has to have exactly 3 components!' + self._field = field + + @property + def field_amp(self): + """The field amplitude (returns the field itself for scalar and the vector amplitude + calculated via a square sum for a vector field.""" + if len(self.shape) == 4: + return np.sqrt(np.sum(self.field ** 2, axis=0)) + else: + return self.field + + @property + def field_vec(self): + """Vector containing the vector field distribution.""" + return np.reshape(self.field, -1) + + @field_vec.setter + def field_vec(self, mag_vec): + assert np.size(mag_vec) == np.prod(self.shape), \ + 'Vector has to match field shape! {} {}'.format(mag_vec.shape, np.prod(self.shape)) + self.field = mag_vec.reshape((3,) + self.dim) + + def __init__(self, a, field): + self._log.debug('Calling __init__') + self.a = a + self.field = field + self._log.debug('Created ' + str(self)) + + def __repr__(self): + self._log.debug('Calling __repr__') + return '%s(a=%r, field=%r)' % (self.__class__, self.a, self.field) + + def __str__(self): + self._log.debug('Calling __str__') + return '%s(a=%s, dim=%s)' % (self.__class__, self.a, self.dim) + + def __neg__(self): # -self + self._log.debug('Calling __neg__') + return self.__class__(self.a, -self.field) + + def __add__(self, other): # self + other + self._log.debug('Calling __add__') + assert isinstance(other, (FieldData, Number)), \ + 'Only FieldData objects and scalar numbers (as offsets) can be added/subtracted!' + if isinstance(other, Number): # other is a Number + self._log.debug('Adding an offset') + return self.__class__(self.a, self.field + other) + elif isinstance(other, FieldData): + self._log.debug('Adding two FieldData objects') + assert other.a == self.a, 'Added phase has to have the same grid spacing!' + assert other.shape == self.shape, 'Added field has to have the same dimensions!' + return self.__class__(self.a, self.field + other.field) + + def __sub__(self, other): # self - other + self._log.debug('Calling __sub__') + return self.__add__(-other) + + def __mul__(self, other): # self * other + self._log.debug('Calling __mul__') + assert isinstance(other, Number), 'FieldData objects can only be multiplied by numbers!' + return self.__class__(self.a, self.field * other) + + def __truediv__(self, other): # self / other + self._log.debug('Calling __truediv__') + assert isinstance(other, Number), 'FieldData objects can only be divided by numbers!' + return self.__class__(self.a, self.field / other) + + def __floordiv__(self, other): # self // other + self._log.debug('Calling __floordiv__') + assert isinstance(other, Number), 'FieldData objects can only be divided by numbers!' + return self.__class__(self.a, self.field // other) + + def __radd__(self, other): # other + self + self._log.debug('Calling __radd__') + return self.__add__(other) + + def __rsub__(self, other): # other - self + self._log.debug('Calling __rsub__') + return -self.__sub__(other) + + def __rmul__(self, other): # other * self + self._log.debug('Calling __rmul__') + return self.__mul__(other) + + def __iadd__(self, other): # self += other + self._log.debug('Calling __iadd__') + return self.__add__(other) + + def __isub__(self, other): # self -= other + self._log.debug('Calling __isub__') + return self.__sub__(other) + + def __imul__(self, other): # self *= other + self._log.debug('Calling __imul__') + return self.__mul__(other) + + def __itruediv__(self, other): # self /= other + self._log.debug('Calling __itruediv__') + return self.__truediv__(other) + + def __ifloordiv__(self, other): # self //= other + self._log.debug('Calling __ifloordiv__') + return self.__floordiv__(other) + + def __getitem__(self, item): + return self.__class__(self.a, self.field[item]) + + def __array__(self, dtype=None): # Used for numpy ufuncs, together with __array_wrap__! + if dtype: + return self.field.astype(dtype) + else: + return self.field + + def __array_wrap__(self, array, _=None): # _ catches the context, which is not used. + return type(self)(self.a, array) + + def copy(self): + """Returns a copy of the :class:`~.FieldData` object + + Returns + ------- + field_data: :class:`~.FieldData` + A copy of the :class:`~.FieldData`. + + """ + self._log.debug('Calling copy') + return self.__class__(self.a, self.field.copy()) + + def get_mask(self, threshold=0): + """Mask all pixels where the amplitude of the field lies above `threshold`. + + Parameters + ---------- + threshold : float, optional + A pixel only gets masked, if it lies above this threshold . The default is 0. + + Returns + ------- + mask : :class:`~numpy.ndarray` (N=3, boolean) + Mask of the pixels where the amplitude of the field lies above `threshold`. + + """ + self._log.debug('Calling get_mask') + return np.where(self.field_amp > threshold, True, False) + + def plot_mask(self, title='Mask', threshold=0, **kwargs): + """Plot the mask as a 3D-contour plot. + + Parameters + ---------- + title: string, optional + The title for the plot. + threshold : float, optional + A pixel only gets masked, if it lies above this threshold . The default is 0. + + Returns + ------- + plot : :class:`mayavi.modules.vectors.Vectors` + The plot object. + + """ + self._log.debug('Calling plot_mask') + from mayavi import mlab + mlab.figure(size=(750, 700)) + zzz, yyy, xxx = (np.indices(self.dim) + self.a / 2) + zzz, yyy, xxx = zzz.T, yyy.T, xxx.T + mask = self.get_mask(threshold=threshold).astype(int).T # Transpose because of VTK order! + extent = np.ravel(list(zip((0, 0, 0), mask.shape))) + cont = mlab.contour3d(xxx, yyy, zzz, mask, contours=[1], **kwargs) + mlab.outline(cont, extent=extent) + mlab.axes(cont, extent=extent) + mlab.title(title, height=0.95, size=0.35) + mlab.orientation_axes() + cont.scene.isometric_view() + return cont + + def plot_contour3d(self, title='Field Distribution', contours=10, opacity=0.25, **kwargs): + """Plot the field as a 3D-contour plot. + + Parameters + ---------- + title: string, optional + The title for the plot. + contours: int, optional + Number of contours which should be plotted. + opacity: float, optional + Defines the opacity of the contours. Default is 0.25. + + Returns + ------- + plot : :class:`mayavi.modules.vectors.Vectors` + The plot object. + + """ + self._log.debug('Calling plot_contour3d') + from mayavi import mlab + mlab.figure(size=(750, 700)) + zzz, yyy, xxx = (np.indices(self.dim) + self.a / 2) + zzz, yyy, xxx = zzz.T, yyy.T, xxx.T + field_amp = self.field_amp.T # Transpose because of VTK order! + if not isinstance(contours, (list, tuple, np.ndarray)): # Calculate the contours: + contours = list(np.linspace(field_amp.min(), field_amp.max(), contours)) + extent = np.ravel(list(zip((0, 0, 0), field_amp.shape))) + cont = mlab.contour3d(xxx, yyy, zzz, field_amp, contours=contours, + opacity=opacity, **kwargs) + mlab.outline(cont, extent=extent) + mlab.axes(cont, extent=extent) + mlab.title(title, height=0.95, size=0.35) + mlab.orientation_axes() + cont.scene.isometric_view() + return cont + + @abc.abstractmethod + def scale_down(self, n): + """Scale down the field distribution by averaging over two pixels along each axis. + + Parameters + ---------- + n : int, optional + Number of times the field distribution is scaled down. The default is 1. + + Returns + ------- + None + + Notes + ----- + Acts in place and changes dimensions and grid spacing accordingly. + Only possible, if each axis length is a power of 2! + + """ + pass + + @abc.abstractmethod + def scale_up(self, n, order): + """Scale up the field distribution using spline interpolation of the requested order. + + Parameters + ---------- + n : int, optional + Power of 2 with which the grid is scaled. Default is 1, which means every axis is + increased by a factor of ``2**1 = 2``. + order : int, optional + The order of the spline interpolation, which has to be in the range between 0 and 5 + and defaults to 0. + + Returns + ------- + None + + Notes + ----- + Acts in place and changes dimensions and grid spacing accordingly. + """ + pass + + @abc.abstractmethod + def get_vector(self, mask): + """Returns the field as a vector, specified by a mask. + + Parameters + ---------- + mask : :class:`~numpy.ndarray` (N=3, boolean) + Masks the pixels from which the entries should be taken. + + Returns + ------- + vector : :class:`~numpy.ndarray` (N=1) + The vector containing the field of the specified pixels. + + """ + pass + + @abc.abstractmethod + def set_vector(self, vector, mask): + """Set the field of the masked pixels to the values specified by `vector`. + + Parameters + ---------- + mask : :class:`~numpy.ndarray` (N=3, boolean), optional + Masks the pixels from which the field should be taken. + vector : :class:`~numpy.ndarray` (N=1) + The vector containing the field of the specified pixels. + + Returns + ------- + None + + """ + pass + + @classmethod + def from_signal(cls, signal): + """Convert a :class:`~hyperspy.signals.Signal` object to a :class:`~.FieldData` object. + + Parameters + ---------- + signal: :class:`~hyperspy.signals.Signal` + The :class:`~hyperspy.signals.Signal` object which should be converted to FieldData. + + Returns + ------- + magdata: :class:`~.FieldData` + A :class:`~.FieldData` object containing the loaded data. + + Notes + ----- + This method recquires the hyperspy package! + + """ + cls._log.debug('Calling from_signal') + return cls(signal.axes_manager[0].scale, signal.data) + + @abc.abstractmethod + def to_signal(self): + """Convert :class:`~.FieldData` data into a HyperSpy signal. + + Returns + ------- + signal: :class:`~hyperspy.signals.Signal` + Representation of the :class:`~.FieldData` object as a HyperSpy Signal. + + Notes + ----- + This method recquires the hyperspy package! + + """ + self._log.debug('Calling to_signal') + try: # Try importing HyperSpy: + # noinspection PyUnresolvedReferences + import hyperspy.api as hs + except ImportError: + self._log.error('This method recquires the hyperspy package!') + return + # Create signal: + signal = hs.signals.BaseSignal(self.field) # All axes are signal axes! + # Set axes: + signal.axes_manager[0].name = 'x-axis' + signal.axes_manager[0].units = 'nm' + signal.axes_manager[0].scale = self.a + signal.axes_manager[1].name = 'y-axis' + signal.axes_manager[1].units = 'nm' + signal.axes_manager[1].scale = self.a + signal.axes_manager[2].name = 'z-axis' + signal.axes_manager[2].units = 'nm' + signal.axes_manager[2].scale = self.a + return signal + + +class VectorData(FieldData): + + """Class for storing vector ield data. + + Represents 3-dimensional vector field distributions with 3 components which are stored as a + 3-dimensional numpy array in `field`, but which can also be accessed as a vector via + `field_vec`. :class:`~.VectorData` objects support negation, arithmetic operators + (``+``, ``-``, ``*``) and their augmented counterparts (``+=``, ``-=``, ``*=``), withnumbers + and other :class:`~.VectorData` objects, if their dimensions and grid spacings match. It is + possible to load data from HDF5 or LLG (.txt) files or to save the data in these formats. + Plotting methods are also provided. + + Attributes + ---------- + a: float + The grid spacing in nm. + field: :class:`~numpy.ndarray` (N=4) + The `x`-, `y`- and `z`-component of the vector field for every 3D-gridpoint + as a 4-dimensional numpy array (first dimension has to be 3, because of the 3 components). + + """ + _log = logging.getLogger(__name__ + '.VectorData') + + def scale_down(self, n=1): + """Scale down the field distribution by averaging over two pixels along each axis. + + Parameters + ---------- + n : int, optional + Number of times the field distribution is scaled down. The default is 1. + + Returns + ------- + None + + Notes + ----- + Acts in place and changes dimensions and grid spacing accordingly. + Only possible, if each axis length is a power of 2! + + """ + self._log.debug('Calling scale_down') + assert n > 0 and isinstance(n, int), 'n must be a positive integer!' + self.a *= 2 ** n + for t in range(n): + # Pad if necessary: + pz, py, px = self.dim[0] % 2, self.dim[1] % 2, self.dim[2] % 2 + if pz != 0 or py != 0 or px != 0: + self.field = np.pad(self.field, ((0, 0), (0, pz), (0, py), (0, px)), + mode='constant') + # Create coarser grid for the vector field: + shape_4d = (3, self.dim[0] // 2, 2, self.dim[1] // 2, 2, self.dim[2] // 2, 2) + self.field = self.field.reshape(shape_4d).mean(axis=(6, 4, 2)) + + def scale_up(self, n=1, order=0): + """Scale up the field distribution using spline interpolation of the requested order. + + Parameters + ---------- + n : int, optional + Power of 2 with which the grid is scaled. Default is 1, which means every axis is + increased by a factor of ``2**1 = 2``. + order : int, optional + The order of the spline interpolation, which has to be in the range between 0 and 5 + and defaults to 0. + + Returns + ------- + None + + Notes + ----- + Acts in place and changes dimensions and grid spacing accordingly. + """ + self._log.debug('Calling scale_up') + assert n > 0 and isinstance(n, int), 'n must be a positive integer!' + assert 5 > order >= 0 and isinstance(order, int), \ + 'order must be a positive integer between 0 and 5!' + self.a /= 2 ** n + self.field = np.array((zoom(self.field[0], zoom=2 ** n, order=order), + zoom(self.field[1], zoom=2 ** n, order=order), + zoom(self.field[2], zoom=2 ** n, order=order))) + + def pad(self, pad_values): + """Pad the current field distribution with zeros for each individual axis. + + Parameters + ---------- + pad_values : tuple of int + Number of zeros which should be padded. Provided as a tuple where each entry + corresponds to an axis. An entry can be one int (same padding for both sides) or again + a tuple which specifies the pad values for both sides of the corresponding axis. + + Returns + ------- + None + + Notes + ----- + Acts in place and changes dimensions accordingly. + """ + self._log.debug('Calling pad') + assert len(pad_values) == 3, 'Pad values for each dimension have to be provided!' + pv = np.zeros(6, dtype=np.int) + for i, values in enumerate(pad_values): + assert np.shape(values) in [(), (2,)], 'Only one or two values per axis can be given!' + pv[2 * i:2 * (i + 1)] = values + self.field = np.pad(self.field, ((0, 0), (pv[0], pv[1]), (pv[2], pv[3]), (pv[4], pv[5])), + mode='constant') + + def crop(self, crop_values): + """Crop the current field distribution with zeros for each individual axis. + + Parameters + ---------- + crop_values : tuple of int + Number of zeros which should be cropped. Provided as a tuple where each entry + corresponds to an axis. An entry can be one int (same cropping for both sides) or again + a tuple which specifies the crop values for both sides of the corresponding axis. + + Returns + ------- + None + + Notes + ----- + Acts in place and changes dimensions accordingly. + """ + self._log.debug('Calling crop') + assert len(crop_values) == 3, 'Crop values for each dimension have to be provided!' + cv = np.zeros(6, dtype=np.int) + for i, values in enumerate(crop_values): + assert np.shape(values) in [(), (2,)], 'Only one or two values per axis can be given!' + cv[2 * i:2 * (i + 1)] = values + cv *= np.resize([1, -1], len(cv)) + cv = np.where(cv == 0, None, cv) + self.field = self.field[:, cv[0]:cv[1], cv[2]:cv[3], cv[4]:cv[5]] + + def get_vector(self, mask): + """Returns the vector field components arranged in a vector, specified by a mask. + + Parameters + ---------- + mask : :class:`~numpy.ndarray` (N=3, boolean) + Masks the pixels from which the components should be taken. + + Returns + ------- + vector : :class:`~numpy.ndarray` (N=1) + The vector containing vector field components of the specified pixels. + Order is: first all `x`-, then all `y`-, then all `z`-components. + + """ + self._log.debug('Calling get_vector') + if mask is not None: + return np.reshape([self.field[0][mask], + self.field[1][mask], + self.field[2][mask]], -1) + else: + return self.field_vec + + def set_vector(self, vector, mask=None): + """Set the field components of the masked pixels to the values specified by `vector`. + + Parameters + ---------- + mask : :class:`~numpy.ndarray` (N=3, boolean), optional + Masks the pixels from which the components should be taken. + vector : :class:`~numpy.ndarray` (N=1) + The vector containing vector field components of the specified pixels. + Order is: first all `x`-, then all `y-, then all `z`-components. + + Returns + ------- + None + + """ + self._log.debug('Calling set_vector') + assert np.size(vector) % 3 == 0, 'Vector has to contain all 3 components for every pixel!' + count = np.size(vector) // 3 + if mask is not None: + self.field[0][mask] = vector[:count] # x-component + self.field[1][mask] = vector[count:2 * count] # y-component + self.field[2][mask] = vector[2 * count:] # z-component + else: + self.field_vec = vector + + def flip(self, axis='x'): + """Flip/mirror the vector field around the specified axis. + + Parameters + ---------- + axis: {'x', 'y', 'z'}, optional + The axis around which the vector field is flipped. + + Returns + ------- + magdata_flip: :class:`~.VectorData` + A flipped copy of the :class:`~.VectorData` object. + + """ + self._log.debug('Calling flip') + if axis == 'x': + mag_x, mag_y, mag_z = self.field[:, :, :, ::-1] + field_flip = np.array((-mag_x, mag_y, mag_z)) + elif axis == 'y': + mag_x, mag_y, mag_z = self.field[:, :, ::-1, :] + field_flip = np.array((mag_x, -mag_y, mag_z)) + elif axis == 'z': + mag_x, mag_y, mag_z = self.field[:, ::-1, :, :] + field_flip = np.array((mag_x, mag_y, -mag_z)) + else: + raise ValueError("Wrong input! 'x', 'y', 'z' allowed!") + return VectorData(self.a, field_flip) + + def rot90(self, axis='x'): + """Rotate the vector field 90° around the specified axis (right hand rotation). + + Parameters + ---------- + axis: {'x', 'y', 'z'}, optional + The axis around which the vector field is rotated. + + Returns + ------- + magdata_rot: :class:`~.VectorData` + A rotated copy of the :class:`~.VectorData` object. + + """ + self._log.debug('Calling rot90') + if axis == 'x': + field_rot = np.zeros((3, self.dim[1], self.dim[0], self.dim[2])) + for i in range(self.dim[2]): + mag_x, mag_y, mag_z = self.field[:, :, :, i] + mag_xrot, mag_yrot, mag_zrot = np.rot90(mag_x), np.rot90(mag_y), np.rot90(mag_z) + field_rot[:, :, :, i] = np.array((mag_xrot, mag_zrot, -mag_yrot)) + elif axis == 'y': + field_rot = np.zeros((3, self.dim[2], self.dim[1], self.dim[0])) + for i in range(self.dim[1]): + mag_x, mag_y, mag_z = self.field[:, :, i, :] + mag_xrot, mag_yrot, mag_zrot = np.rot90(mag_x), np.rot90(mag_y), np.rot90(mag_z) + field_rot[:, :, i, :] = np.array((mag_zrot, mag_yrot, -mag_xrot)) + elif axis == 'z': + field_rot = np.zeros((3, self.dim[0], self.dim[2], self.dim[1])) + for i in range(self.dim[0]): + mag_x, mag_y, mag_z = self.field[:, i, :, :] + mag_xrot, mag_yrot, mag_zrot = np.rot90(mag_x), np.rot90(mag_y), np.rot90(mag_z) + field_rot[:, i, :, :] = np.array((mag_yrot, -mag_xrot, mag_zrot)) + else: + raise ValueError("Wrong input! 'x', 'y', 'z' allowed!") + return VectorData(self.a, field_rot) + + def get_slice(self, ax_slice=None, proj_axis='z'): + """Extract a slice from the :class:`~.VectorData` object. + + Parameters + ---------- + proj_axis : {'z', 'y', 'x'}, optional + The axis, from which the slice is taken. The default is 'z'. + ax_slice : None or int, optional + The slice-index of the axis specified in `proj_axis`. Defaults to the center slice. + + Returns + ------- + u_mag, v_mag, w_mag, submask : :class:`~numpy.ndarray` (N=2) + The extracted vector field components in plane perpendicular to the `proj_axis` and + the perpendicular component. + + """ + self._log.debug('Calling get_slice') + # Find slice: + assert proj_axis == 'z' or proj_axis == 'y' or proj_axis == 'x', \ + 'Axis has to be x, y or z (as string).' + if ax_slice is None: + ax_slice = self.dim[{'z': 0, 'y': 1, 'x': 2}[proj_axis]] // 2 + if proj_axis == 'z': # Slice of the xy-plane with z = ax_slice + self._log.debug('proj_axis == z') + u_mag = np.copy(self.field[0][ax_slice, ...]) # x-component + v_mag = np.copy(self.field[1][ax_slice, ...]) # y-component + w_mag = np.copy(self.field[2][ax_slice, ...]) # z-component + elif proj_axis == 'y': # Slice of the xz-plane with y = ax_slice + self._log.debug('proj_axis == y') + u_mag = np.copy(self.field[0][:, ax_slice, :]) # x-component + v_mag = np.copy(self.field[2][:, ax_slice, :]) # z-component + w_mag = np.copy(self.field[1][:, ax_slice, :]) # y-component + elif proj_axis == 'x': # Slice of the zy-plane with x = ax_slice + self._log.debug('proj_axis == x') + u_mag = np.swapaxes(np.copy(self.field[2][..., ax_slice]), 0, 1) # z-component + v_mag = np.swapaxes(np.copy(self.field[1][..., ax_slice]), 0, 1) # y-component + w_mag = np.swapaxes(np.copy(self.field[0][..., ax_slice]), 0, 1) # x-component + else: + raise ValueError('{} is not a valid argument (use x, y or z)'.format(proj_axis)) + return u_mag, v_mag, w_mag + + def to_signal(self): + """Convert :class:`~.VectorData` data into a HyperSpy signal. + + Returns + ------- + signal: :class:`~hyperspy.signals.Signal` + Representation of the :class:`~.VectorData` object as a HyperSpy Signal. + + Notes + ----- + This method recquires the hyperspy package! + + """ + self._log.debug('Calling to_signal') + signal = super().to_signal() + # Set component axis: + signal.axes_manager[3].name = 'x/y/z-component' + signal.axes_manager[3].units = '' + # Set metadata: + signal.metadata.Signal.title = 'VectorData' + # Return signal: + return signal + + def save(self, filename, **kwargs): + """Saves the VectorData in the specified format. + + The function gets the format from the extension: + - hdf5 for HDF5. + - EMD Electron Microscopy Dataset format (also HDF5). + - llg format. + - ovf format. + - npy or npz for numpy formats. + + If no extension is provided, 'hdf5' is used. Most formats are + saved with the HyperSpy package (internally the fielddata is first + converted to a HyperSpy Signal. + + Each format accepts a different set of parameters. For details + see the specific format documentation. + + Parameters + ---------- + filename : str, optional + Name of the file which the VectorData is saved into. The extension + determines the saving procedure. + + """ + from .file_io.io_vectordata import save_vectordata + save_vectordata(self, filename, **kwargs) + + def plot_quiver(self, ar_dens=1, log=False, scaled=True, scale=1., b_0=None, # Only used here! + coloring='angle', cmap=None, # Used here and plot_streamlines! + proj_axis='z', ax_slice=None, show_mask=True, bgcolor=None, axis=None, + figsize=None, **kwargs): + """Plot a slice of the vector field as a quiver plot. + + Parameters + ---------- + ar_dens: int, optional + Number defining the arrow density which is plotted. A higher ar_dens number skips more + arrows (a number of 2 plots every second arrow). Default is 1. + log : boolean, optional + The loratihm of the arrow length is plotted instead. This is helpful if only the + direction of the arrows is important and the amplitude varies a lot. Default is False. + scaled : boolean, optional + Normalizes the plotted arrows in respect to the highest one. Default is True. + scale: float, optional + Additional multiplicative factor scaling the arrow length. Default is 1 + (no further scaling). + b_0 : float, optional + Saturation induction (saturation magnetisation times the vacuum permeability). + If this is specified, a quiverkey is used to indicate the length of the longest arrow. + coloring : {'angle', 'amplitude', 'uniform', matplotlib color} + Color coding mode of the arrows. Use 'full' (default), 'angle', 'amplitude', 'uniform' + (black or white, depending on `bgcolor`), or a matplotlib color keyword. + cmap : string, optional + The :class:`~matplotlib.colors.Colormap` which is used for the plot as a string. + If not set, an appropriate one is used. Note that a subclass of + :class:`~.colors.Colormap3D` should be used for angle encoding. + proj_axis : {'z', 'y', 'x'}, optional + The axis, from which a slice is plotted. The default is 'z'. + ax_slice : int, optional + The slice-index of the axis specified in `proj_axis`. Is set to the center of + `proj_axis` if not specified. + show_mask: boolean + Default is True. Shows the outlines of the mask slice if available. + bgcolor: {'white', 'black'}, optional + Determines the background color of the plot. + axis : :class:`~matplotlib.axes.AxesSubplot`, optional + Axis on which the graph is plotted. Creates a new figure if none is specified. + figsize : tuple of floats (N=2) + Size of the plot figure. + + Returns + ------- + axis: :class:`~matplotlib.axes.AxesSubplot` + The axis on which the graph is plotted. + + Notes + ----- + Uses :func:`~.plottools.format_axis` at the end. According keywords can also be given here. + + """ + self._log.debug('Calling plot_quiver') + a = self.a + if figsize is None: + figsize = plottools.FIGSIZE_DEFAULT + assert proj_axis == 'z' or proj_axis == 'y' or proj_axis == 'x', \ + 'Axis has to be x, y or z (as string).' + if ax_slice is None: + ax_slice = self.dim[{'z': 0, 'y': 1, 'x': 2}[proj_axis]] // 2 + # Extract slice and mask: + u_mag, v_mag = self.get_slice(ax_slice, proj_axis)[:2] + submask = np.where(np.hypot(u_mag, v_mag) > 0, True, False) + # Prepare quiver (select only used arrows if ar_dens is specified): + dim_uv = u_mag.shape + vv, uu = np.indices(dim_uv) + 0.5 # shift to center of pixel + uu = uu[::ar_dens, ::ar_dens] + vv = vv[::ar_dens, ::ar_dens] + u_mag = u_mag[::ar_dens, ::ar_dens] + v_mag = v_mag[::ar_dens, ::ar_dens] + amplitudes = np.hypot(u_mag, v_mag) + angles = np.angle(u_mag + 1j * v_mag, deg=True).tolist() + # Calculate the arrow colors: + if bgcolor is None: + bgcolor = 'white' # Default! + cmap_overwrite = cmap + if coloring == 'angle': + self._log.debug('Encoding angles') + hue = np.asarray(np.arctan2(v_mag, u_mag) / (2 * np.pi)) + hue[hue < 0] += 1 + cmap = colors.CMAP_CIRCULAR_DEFAULT + elif coloring == 'amplitude': + self._log.debug('Encoding amplitude') + hue = amplitudes / amplitudes.max() + if bgcolor == 'white': + cmap = colors.cmaps['cubehelix_reverse'] + else: + cmap = colors.cmaps['cubehelix_standard'] + elif coloring == 'uniform': + self._log.debug('Automatic uniform color encoding') + hue = amplitudes / amplitudes.max() + if bgcolor == 'white': + cmap = colors.cmaps['transparent_black'] + else: + cmap = colors.cmaps['transparent_white'] + else: + self._log.debug('Specified uniform color encoding') + hue = np.zeros_like(u_mag) + cmap = ListedColormap([coloring]) + if cmap_overwrite is not None: + cmap = cmap_overwrite + # If no axis is specified, a new figure is created: + if axis is None: + self._log.debug('axis is None') + fig = plt.figure(figsize=figsize) + axis = fig.add_subplot(1, 1, 1) + tight = True + else: + tight = False + axis.set_aspect('equal') + # Take the logarithm of the arrows to clearly show directions (if specified): + if log and np.any(amplitudes): # If the slice is empty, skip! + cutoff = 10 + amp = np.round(amplitudes, decimals=cutoff) + min_value = amp[np.nonzero(amp)].min() + u_mag = np.round(u_mag, decimals=cutoff) / min_value + u_mag = np.log10(np.abs(u_mag) + 1) * np.sign(u_mag) + v_mag = np.round(v_mag, decimals=cutoff) / min_value + v_mag = np.log10(np.abs(v_mag) + 1) * np.sign(v_mag) + amplitudes = np.hypot(u_mag, v_mag) # Recalculate (used if scaled)! + # Scale the amplitude of the arrows to the highest one (if specified): + if scaled: + u_mag /= amplitudes.max() + 1E-30 + v_mag /= amplitudes.max() + 1E-30 + # Plot quiver: + # TODO: quiver does not work with matplotlib 2.0! FIX! + quiv = axis.quiver(uu, vv, u_mag, v_mag, hue, cmap=cmap, clim=(0, 1), angles=angles, + pivot='middle', units='xy', scale_units='xy', scale=scale / ar_dens, + minlength=0.05, width=1*ar_dens, headlength=2, headaxislength=2, + headwidth=2, minshaft=2) + axis.set_xlim(0, dim_uv[1]) + axis.set_ylim(0, dim_uv[0]) + # Determine colormap if necessary: + if coloring == 'amplitude': + cbar_mappable, cbar_label = quiv, 'amplitude' + else: + cbar_mappable, cbar_label = None, None + # Change background color: + axis.set_axis_bgcolor(bgcolor) + # Show mask: + if show_mask and not np.all(submask): # Plot mask if desired and not trivial! + vv, uu = np.indices(dim_uv) + 0.5 # shift to center of pixel + mask_color = 'white' if bgcolor == 'black' else 'black' + axis.contour(uu, vv, submask, levels=[0.5], colors=mask_color, + linestyles='dotted', linewidths=2) + # Plot quiverkey if B_0 is specified): + if b_0 and not log: # The angles needed for log would break the quiverkey! + label = '{:.3g} T'.format(amplitudes.max() * b_0) + quiv.angles = 'uv' # With a list of angles, the quiverkey would break! + stroke = plottools.STROKE_DEFAULT + txtcolor = 'w' if stroke == 'k' else 'k' + edgecolor = stroke if stroke is not None else 'none' + fontsize = kwargs.get('fontsize', None) + if fontsize is None: + fontsize = plottools.FONTSIZE_DEFAULT + qk = plt.quiverkey(Q=quiv, X=0.88, Y=0.065, U=1, label=label, labelpos='W', + coordinates='axes', facecolor=txtcolor, edgecolor=edgecolor, + labelcolor=txtcolor, linewidth=0.5, + clip_box=axis.bbox, clip_on=True, + fontproperties={'size': kwargs.get('fontsize', fontsize)}) + if stroke is not None: + qk.text.set_path_effects( + [patheffects.withStroke(linewidth=2, foreground=stroke)]) + # Return formatted axis: + return plottools.format_axis(axis, sampling=a, cbar_mappable=cbar_mappable, + cbar_label=cbar_label, tight_layout=tight, **kwargs) + + def plot_field(self, proj_axis='z', ax_slice=None, show_mask=True, bgcolor=None, axis=None, + figsize=None, **kwargs): + """Plot a slice of the vector field as a color field imshow plot. + + Parameters + ---------- + proj_axis : {'z', 'y', 'x'}, optional + The axis, from which a slice is plotted. The default is 'z'. + ax_slice : int, optional + The slice-index of the axis specified in `proj_axis`. Is set to the center of + `proj_axis` if not specified. + show_mask: boolean + Default is True. Shows the outlines of the mask slice if available. + bgcolor: {'white', 'black'}, optional + Determines the background color of the plot. + axis : :class:`~matplotlib.axes.AxesSubplot`, optional + Axis on which the graph is plotted. Creates a new figure if none is specified. + figsize : tuple of floats (N=2) + Size of the plot figure. + + Returns + ------- + axis: :class:`~matplotlib.axes.AxesSubplot` + The axis on which the graph is plotted. + + Notes + ----- + Uses :func:`~.plottools.format_axis` at the end. According keywords can also be given here. + + """ + self._log.debug('Calling plot_field') + a = self.a + if figsize is None: + figsize = plottools.FIGSIZE_DEFAULT + assert proj_axis == 'z' or proj_axis == 'y' or proj_axis == 'x', \ + 'Axis has to be x, y or z (as string).' + if ax_slice is None: + ax_slice = self.dim[{'z': 0, 'y': 1, 'x': 2}[proj_axis]] // 2 + # Extract slice and mask: + u_mag, v_mag, w_mag = self.get_slice(ax_slice, proj_axis) + submask = np.where(np.hypot(u_mag, v_mag) > 0, True, False) + # If no axis is specified, a new figure is created: + if axis is None: + self._log.debug('axis is None') + fig = plt.figure(figsize=figsize) + axis = fig.add_subplot(1, 1, 1) + tight = True + else: + tight = False + axis.set_aspect('equal') + # Determine 'z'-component for luminance (keep as gray if None): + z_mag = w_mag + if bgcolor == 'white': + z_mag = np.where(submask, z_mag, np.max(np.hypot(u_mag, v_mag))) + if bgcolor == 'black': + z_mag = np.where(submask, z_mag, -np.max(np.hypot(u_mag, v_mag))) + # Plot the field: + dim_uv = u_mag.shape + rgb = colors.CMAP_CIRCULAR_DEFAULT.rgb_from_vector(np.asarray((u_mag, v_mag, z_mag))) + axis.imshow(Image.fromarray(rgb), origin='lower', interpolation='none', + extent=(0, dim_uv[1], 0, dim_uv[0])) + # Change background color: + if bgcolor is not None: + axis.set_axis_bgcolor(bgcolor) + # Show mask: + if show_mask and not np.all(submask): # Plot mask if desired and not trivial! + vv, uu = np.indices(dim_uv) + 0.5 # shift to center of pixel + mask_color = 'white' if bgcolor == 'black' else 'black' + axis.contour(uu, vv, submask, levels=[0.5], colors=mask_color, + linestyles='dotted', linewidths=2) + # Return formatted axis: + return plottools.format_axis(axis, sampling=a, tight_layout=tight, **kwargs) + + def plot_quiver_field(self, **kwargs): + """Plot the vector field as a field plot with uniformly colored arrows overlayed. + + Parameters + ---------- + See :func:`~.plot_quiver` and :func:`~.plot_quiver` for parameters! + + Returns + ------- + axis: :class:`~matplotlib.axes.AxesSubplot` + The axis on which the graph is plotted. + + """ + # Extract parameters: + show_mask = kwargs.pop('show_mask', True) # Only needed once! + axis = kwargs.pop('axis', None) + # Set default bgcolor to white (only for combined plot), only if bgcolor was not specified: + kwargs.setdefault('bgcolor', 'white') + # Plot field first (with mask and axis formatting), then quiver: + axis = self.plot_field(axis=axis, show_mask=show_mask, **kwargs) + self.plot_quiver(coloring='uniform', show_mask=False, axis=axis, + format_axis=False, **kwargs) + # Return plotting axis: + return axis + + def plot_streamline(self, density=2, linewidth=2, coloring='angle', cmap=None, + proj_axis='z', ax_slice=None, show_mask=True, bgcolor=None, axis=None, + figsize=None, **kwargs): + """Plot a slice of the vector field as a quiver plot. + + Parameters + ---------- + density : float or 2-tuple, optional + Controls the closeness of streamlines. When density = 1, the domain is divided into a + 30x30 grid—density linearly scales this grid. Each cebll in the grid can have, at most, + one traversing streamline. For different densities in each direction, use + [density_x, density_y]. + linewidth : numeric or 2d array, optional + Vary linewidth when given a 2d array with the same shape as velocities. + coloring : {'angle', 'amplitude', 'uniform'} + Color coding mode of the arrows. Use 'full' (default), 'angle', 'amplitude' or + 'uniform'. + cmap : string, optional + The :class:`~matplotlib.colors.Colormap` which is used for the plot as a string. + If not set, an appropriate one is used. Note that a subclass of + :class:`~.colors.Colormap3D` should be used for angle encoding. + proj_axis : {'z', 'y', 'x'}, optional + The axis, from which a slice is plotted. The default is 'z'. + ax_slice : int, optional + The slice-index of the axis specified in `proj_axis`. Is set to the center of + `proj_axis` if not specified. + show_mask: boolean + Default is True. Shows the outlines of the mask slice if available. + bgcolor: {'white', 'black'}, optional + Determines the background color of the plot. + axis : :class:`~matplotlib.axes.AxesSubplot`, optional + Axis on which the graph is plotted. Creates a new figure if none is specified. + figsize : tuple of floats (N=2) + Size of the plot figure. + + Returns + ------- + axis: :class:`~matplotlib.axes.AxesSubplot` + The axis on which the graph is plotted. + + Notes + ----- + Uses :func:`~.plottools.format_axis` at the end. According keywords can also be given here. + + """ + self._log.debug('Calling plot_quiver') + a = self.a + if figsize is None: + figsize = plottools.FIGSIZE_DEFAULT + assert proj_axis == 'z' or proj_axis == 'y' or proj_axis == 'x', \ + 'Axis has to be x, y or z (as string).' + if ax_slice is None: + ax_slice = self.dim[{'z': 0, 'y': 1, 'x': 2}[proj_axis]] // 2 + u_mag, v_mag = self.get_slice(ax_slice, proj_axis)[:2] + submask = np.where(np.hypot(u_mag, v_mag) > 0, True, False) + # Prepare streamlines: + dim_uv = u_mag.shape + uu = np.arange(dim_uv[1]) + 0.5 # shift to center of pixel + vv = np.arange(dim_uv[0]) + 0.5 # shift to center of pixel + u_mag, v_mag = self.get_slice(ax_slice, proj_axis)[:2] + # v_mag = np.ma.array(v_mag, mask=submask) + amplitudes = np.hypot(u_mag, v_mag) + # Calculate the arrow colors: + if bgcolor is None: + bgcolor = 'white' # Default! + cmap_overwrite = cmap + if coloring == 'angle': + self._log.debug('Encoding angles') + hue = np.asarray(np.arctan2(v_mag, u_mag) / (2 * np.pi)) + hue[hue < 0] += 1 + cmap = colors.CMAP_CIRCULAR_DEFAULT + elif coloring == 'amplitude': + self._log.debug('Encoding amplitude') + hue = amplitudes / amplitudes.max() + if bgcolor == 'white': + cmap = colors.cmaps['cubehelix_reverse'] + else: + cmap = colors.cmaps['cubehelix_standard'] + elif coloring == 'uniform': + self._log.debug('Automatic uniform color encoding') + hue = amplitudes / amplitudes.max() + if bgcolor == 'white': + cmap = colors.cmaps['transparent_black'] + else: + cmap = colors.cmaps['transparent_white'] + else: + self._log.debug('Specified uniform color encoding') + hue = np.zeros_like(u_mag) + cmap = ListedColormap([coloring]) + if cmap_overwrite is not None: + cmap = cmap_overwrite + # If no axis is specified, a new figure is created: + if axis is None: + self._log.debug('axis is None') + fig = plt.figure(figsize=figsize) + axis = fig.add_subplot(1, 1, 1) + tight = True + else: + tight = False + axis.set_aspect('equal') + # Plot the streamlines: + im = plt.streamplot(uu, vv, u_mag, v_mag, density=density, linewidth=linewidth, + color=hue, cmap=cmap) + # Determine colormap if necessary: + if coloring == 'amplitude': + cbar_mappable, cbar_label = im, 'amplitude' + else: + cbar_mappable, cbar_label = None, None + # Change background color: + axis.set_axis_bgcolor(bgcolor) + # Show mask: + if show_mask and not np.all(submask): # Plot mask if desired and not trivial! + vv, uu = np.indices(dim_uv) + 0.5 # shift to center of pixel + mask_color = 'white' if bgcolor == 'black' else 'black' + axis.contour(uu, vv, submask, levels=[0.5], colors=mask_color, + linestyles='dotted', linewidths=2) + # Return formatted axis: + return plottools.format_axis(axis, sampling=a, cbar_mappable=cbar_mappable, + cbar_label=cbar_label, tight_layout=tight, **kwargs) + + def plot_quiver3d(self, title='Vector Field', limit=None, cmap='jet', mode='2darrow', + coloring='angle', ar_dens=1, opacity=1.0, grid=True, labels=True, + orientation=True, figsize=None): + """Plot the vector field as 3D-vectors in a quiverplot. + + Parameters + ---------- + title : string, optional + The title for the plot. + limit : float, optional + Plotlimit for the vector field arrow length used to scale the colormap. + cmap : string, optional + String describing the colormap which is used for amplitude encoding (default is 'jet'). + ar_dens: int, optional + Number defining the arrow density which is plotted. A higher ar_dens number skips more + arrows (a number of 2 plots every second arrow). Default is 1. + mode: string, optional + Mode, determining the glyphs used in the 3D plot. Default is '2darrow', which + corresponds to 2D arrows. For smaller amounts of arrows, 'arrow' (3D) is prettier. + coloring : {'angle', 'amplitude'}, optional + Color coding mode of the arrows. Use 'angle' (default) or 'amplitude'. + opacity: float, optional + Defines the opacity of the arrows. Default is 1.0 (completely opaque). + + Returns + ------- + plot : :class:`mayavi.modules.vectors.Vectors` + The plot object. + + """ + self._log.debug('Calling quiver_plot3D') + from mayavi import mlab + if limit is None: + limit = np.max(np.nan_to_num(self.field_amp)) + ad = ar_dens + # Create points and vector components as lists: + zzz, yyy, xxx = (np.indices(self.dim) + self.a / 2) + zzz = zzz[::ad, ::ad, ::ad].ravel() + yyy = yyy[::ad, ::ad, ::ad].ravel() + xxx = xxx[::ad, ::ad, ::ad].ravel() + x_mag = self.field[0][::ad, ::ad, ::ad].ravel() + y_mag = self.field[1][::ad, ::ad, ::ad].ravel() + z_mag = self.field[2][::ad, ::ad, ::ad].ravel() + # Plot them as vectors: + if figsize is None: + figsize = (750, 700) + mlab.figure(size=figsize, bgcolor=(0.5, 0.5, 0.5), fgcolor=(0., 0., 0.)) + extent = np.ravel(list(zip((0, 0, 0), (self.dim[2], self.dim[1], self.dim[0])))) + if coloring == 'angle': # Encodes the full angle via colorwheel and saturation: + self._log.debug('Encoding full 3D angles') + vecs = mlab.quiver3d(xxx, yyy, zzz, x_mag, y_mag, z_mag, mode=mode, opacity=opacity, + scalars=np.arange(len(xxx)), line_width=2) + vector = np.asarray((x_mag.ravel(), y_mag.ravel(), z_mag.ravel())) + rgb = colors.CMAP_CIRCULAR_DEFAULT.rgb_from_vector(vector) + rgba = np.hstack((rgb, 255 * np.ones((len(xxx), 1), dtype=np.uint8))) + vecs.glyph.color_mode = 'color_by_scalar' + vecs.module_manager.scalar_lut_manager.lut.table = rgba + mlab.draw() + elif coloring == 'amplitude': # Encodes the amplitude of the arrows with the jet colormap: + self._log.debug('Encoding amplitude') + vecs = mlab.quiver3d(xxx, yyy, zzz, x_mag, y_mag, z_mag, + mode=mode, colormap=cmap, opacity=opacity, line_width=2) + mlab.colorbar(label_fmt='%.2f') + mlab.colorbar(orientation='vertical') + else: + raise AttributeError('Coloring mode not supported!') + vecs.glyph.glyph_source.glyph_position = 'center' + vecs.module_manager.vector_lut_manager.data_range = np.array([0, limit]) + if grid: + mlab.outline(vecs, extent=extent) + if labels: + mlab.axes(vecs, extent=extent) + mlab.title(title, height=0.95, size=0.35) + if orientation: + oa = mlab.orientation_axes() + oa.marker.set_viewport(0, 0, 0.4, 0.4) + mlab.draw() + engine = mlab.get_engine() + scene = engine.scenes[0] + scene.scene.isometric_view() + return vecs + + def plot_quiver3d_to_2d(self, dim_uv=None, axis=None, figsize=None, high_res=False, **kwargs): + from mayavi import mlab + if figsize is None: + figsize = plottools.FIGSIZE_DEFAULT + if axis is None: + self._log.debug('axis is None') + fig = plt.figure(figsize=figsize) + axis = fig.add_subplot(1, 1, 1) + axis.set_axis_bgcolor('gray') + kwargs.setdefault('labels', 'False') + self.plot_quiver3d(figsize=(800, 800), **kwargs) + if high_res: # Use temp files: + tmpdir = tempfile.mkdtemp() + temp_path = os.path.join(tmpdir, 'temp.png') + try: + mlab.savefig(temp_path, size=(2000, 2000)) + imgmap = np.asarray(Image.open(temp_path)) + except Exception as e: + raise e + finally: + os.remove(temp_path) + os.rmdir(tmpdir) + else: # Use screenshot (returns array WITH alpha!): + imgmap = mlab.screenshot(mode='rgba', antialiased=True) + mlab.close(mlab.gcf()) + if dim_uv is None: + dim_uv = self.dim[1:] + axis.imshow(imgmap, extent=[0, dim_uv[0], 0, dim_uv[1]], origin='upper') + kwargs.setdefault('scalebar', False) + kwargs.setdefault('hideaxes', True) + return plottools.format_axis(axis, hideaxes=True, scalebar=False) + + +class ScalarData(FieldData): + """Class for storing scalar field data. + + Represents 3-dimensional scalar field distributions which is stored as a 3-dimensional + numpy array in `field`, but which can also be accessed as a vector via `field_vec`. + :class:`~.ScalarData` objects support negation, arithmetic operators (``+``, ``-``, ``*``) + and their augmented counterparts (``+=``, ``-=``, ``*=``), with numbers and other + :class:`~.ScalarData` objects, if their dimensions and grid spacings match. It is possible + to load data from HDF5 or LLG (.txt) files or to save the data in these formats. + Plotting methods are also provided. + + Attributes + ---------- + a: float + The grid spacing in nm. + field: :class:`~numpy.ndarray` (N=4) + The scalar field. + + """ + _log = logging.getLogger(__name__ + '.ScalarData') + + def scale_down(self, n=1): + """Scale down the field distribution by averaging over two pixels along each axis. + + Parameters + ---------- + n : int, optional + Number of times the field distribution is scaled down. The default is 1. + + Returns + ------- + None + + Notes + ----- + Acts in place and changes dimensions and grid spacing accordingly. + Only possible, if each axis length is a power of 2! + + """ + self._log.debug('Calling scale_down') + assert n > 0 and isinstance(n, int), 'n must be a positive integer!' + self.a *= 2 ** n + for t in range(n): + # Pad if necessary: + pz, py, px = self.dim[0] % 2, self.dim[1] % 2, self.dim[2] % 2 + if pz != 0 or py != 0 or px != 0: + self.field = np.pad(self.field, ((0, pz), (0, py), (0, px)), mode='constant') + # Create coarser grid for the field: + shape_4d = (self.dim[0] // 2, 2, self.dim[1] // 2, 2, self.dim[2] // 2, 2) + self.field = self.field.reshape(shape_4d).mean(axis=(5, 3, 1)) + + def scale_up(self, n=1, order=0): + """Scale up the field distribution using spline interpolation of the requested order. + + Parameters + ---------- + n : int, optional + Power of 2 with which the grid is scaled. Default is 1, which means every axis is + increased by a factor of ``2**1 = 2``. + order : int, optional + The order of the spline interpolation, which has to be in the range between 0 and 5 + and defaults to 0. + + Returns + ------- + None + + Notes + ----- + Acts in place and changes dimensions and grid spacing accordingly. + """ + self._log.debug('Calling scale_up') + assert n > 0 and isinstance(n, int), 'n must be a positive integer!' + assert 5 > order >= 0 and isinstance(order, int), \ + 'order must be a positive integer between 0 and 5!' + self.a /= 2 ** n + self.field = zoom(self.field, zoom=2 ** n, order=order) + + def get_vector(self, mask): + """Returns the field as a vector, specified by a mask. + + Parameters + ---------- + mask : :class:`~numpy.ndarray` (N=3, boolean) + Masks the pixels from which the components should be taken. + + Returns + ------- + vector : :class:`~numpy.ndarray` (N=1) + The vector containing the field of the specified pixels. + + """ + self._log.debug('Calling get_vector') + if mask is not None: + return np.reshape(self.field[mask], -1) + else: + return self.field_vec + + def set_vector(self, vector, mask=None): + """Set the field components of the masked pixels to the values specified by `vector`. + + Parameters + ---------- + mask : :class:`~numpy.ndarray` (N=3, boolean), optional + Masks the pixels from which the components should be taken. + vector : :class:`~numpy.ndarray` (N=1) + The vector containing the field of the specified pixels. + + Returns + ------- + None + + """ + self._log.debug('Calling set_vector') + if mask is not None: + self.field[mask] = vector + else: + self.field_vec = vector + + def rot90(self, axis='x'): + """Rotate the scalar field 90° around the specified axis (right hand rotation). + + Parameters + ---------- + axis: {'x', 'y', 'z'}, optional + The axis around which the vector field is rotated. + + Returns + ------- + scalardata_rot: :class:`~.ScalarData` + A rotated copy of the :class:`~.ScalarData` object. + + """ + self._log.debug('Calling rot90') + if axis == 'x': + field_rot = np.zeros((self.dim[1], self.dim[0], self.dim[2])) + for i in range(self.dim[2]): + field_rot[:, :, i] = np.rot90(self.field[:, :, i]) + elif axis == 'y': + field_rot = np.zeros((self.dim[2], self.dim[1], self.dim[0])) + for i in range(self.dim[1]): + field_rot[:, i, :] = np.rot90(self.field[:, i, :]) + elif axis == 'z': + field_rot = np.zeros((self.dim[0], self.dim[2], self.dim[1])) + for i in range(self.dim[0]): + field_rot[i, :, :] = np.rot90(self.field[i, :, :]) + else: + raise ValueError("Wrong input! 'x', 'y', 'z' allowed!") + return ScalarData(self.a, field_rot) + + def to_signal(self): + """Convert :class:`~.ScalarData` data into a HyperSpy signal. + + Returns + ------- + signal: :class:`~hyperspy.signals.Signal` + Representation of the :class:`~.ScalarData` object as a HyperSpy Signal. + + Notes + ----- + This method recquires the hyperspy package! + + """ + self._log.debug('Calling to_signal') + signal = super().to_signal() + # Set metadata: + signal.metadata.Signal.title = 'ScalarData' + # Return signal: + return signal + + def save(self, filename, **kwargs): + """Saves the ScalarData in the specified format. + + The function gets the format from the extension: + - hdf5 for HDF5. + - EMD Electron Microscopy Dataset format (also HDF5). + - npy or npz for numpy formats. + + If no extension is provided, 'hdf5' is used. Most formats are + saved with the HyperSpy package (internally the fielddata is first + converted to a HyperSpy Signal. + + Each format accepts a different set of parameters. For details + see the specific format documentation. + + Parameters + ---------- + filename : str, optional + Name of the file which the ScalarData is saved into. The extension + determines the saving procedure. + + """ + from .file_io.io_scalardata import save_scalardata + save_scalardata(self, filename, **kwargs) diff --git a/pyramid/file_io/__init__.py b/pyramid/file_io/__init__.py index bbdda83bccb2b8616c0206a9ec92ac57fc2ca21b..6e750ab99848775568594fc4a2e216b83880013a 100644 --- a/pyramid/file_io/__init__.py +++ b/pyramid/file_io/__init__.py @@ -1,13 +1,13 @@ -# -*- coding: utf-8 -*- -# Copyright 2016 by Forschungszentrum Juelich GmbH -# Author: J. Caron -# -"""Subpackage containing Pyramid IO functionality.""" - -from .io_phasemap import load_phasemap -from .io_vectordata import load_vectordata -from .io_scalardata import load_scalardata -from .io_projector import load_projector -from .io_dataset import load_dataset - -__all__ = ['load_phasemap', 'load_vectordata', 'load_scalardata', 'load_projector', 'load_dataset'] +# -*- coding: utf-8 -*- +# Copyright 2016 by Forschungszentrum Juelich GmbH +# Author: J. Caron +# +"""Subpackage containing Pyramid IO functionality.""" + +from .io_phasemap import load_phasemap +from .io_vectordata import load_vectordata +from .io_scalardata import load_scalardata +from .io_projector import load_projector +from .io_dataset import load_dataset + +__all__ = ['load_phasemap', 'load_vectordata', 'load_scalardata', 'load_projector', 'load_dataset'] diff --git a/pyramid/file_io/io_dataset.py b/pyramid/file_io/io_dataset.py index 49b318a8151667c9006bc874beaf79534b445dc1..3c4be310df07f02b8df06a0c7672c2a74a427447 100644 --- a/pyramid/file_io/io_dataset.py +++ b/pyramid/file_io/io_dataset.py @@ -1,107 +1,107 @@ -# -*- coding: utf-8 -*- -# Copyright 2016 by Forschungszentrum Juelich GmbH -# Author: J. Caron -# -"""IO functionality for DataSet objects.""" - -import logging - -import os - -import h5py - -import numpy as np - -from ..dataset import DataSet -from ..file_io.io_projector import load_projector -from ..file_io.io_phasemap import load_phasemap - -__all__ = ['load_projector'] -_log = logging.getLogger(__name__) - - -def save_dataset(dataset, filename, overwrite=True): - """%s""" - _log.debug('Calling save_dataset') - path, filename = os.path.split(filename) - name, extension = os.path.splitext(filename) - assert extension in ['.hdf5', ''], 'For now only HDF5 format is supported!' - if name.startswith('dataset_'): - name = name.split('dataset_')[1] - # Header file: - header_name = os.path.join(path, 'dataset_{}.hdf5'.format(name)) - if not os.path.isfile(header_name) or overwrite: # Write if file does not exist or if forced: - with h5py.File(header_name, 'w') as f: - f.attrs['a'] = dataset.a - f.attrs['dim'] = dataset.dim - f.attrs['b_0'] = dataset.b_0 - if dataset.mask is not None: - f.create_dataset('mask', data=dataset.mask) - if dataset.Se_inv is not None: - f.create_dataset('Se_inv', data=dataset.Se_inv) - # PhaseMaps and Projectors: - for i, projector in enumerate(dataset.projectors): - projector_name = 'projector_{}_{}_{}{}'.format(name, i, projector.get_info(), extension) - projector.save(os.path.join(path, projector_name), overwrite=overwrite) - phasemap_name = 'phasemap_{}_{}_{}{}'.format(name, i, projector.get_info(), extension) - dataset.phasemaps[i].save(os.path.join(path, phasemap_name), overwrite=overwrite) -save_dataset.__doc__ %= DataSet.save.__doc__ - - -def load_dataset(filename): - """Load HDF5 file into a :class:`~pyramid.dataset.DataSet` instance. - - Parameters - ---------- - filename: str - The filename to be loaded. - - Returns - ------- - projector : :class:`~.Projector` - A :class:`~.Projector` object containing the loaded data. - - Notes - ----- - This loads a header file and all matching HDF5 files which can be found. The filename - conventions have to be strictly followed for the process to be successful! - - """ - _log.debug('Calling load_dataset') - path, filename = os.path.split(filename) - if path == '': - path = '.' # Make sure this can be used later! - name, extension = os.path.splitext(filename) - assert extension in ['.hdf5', ''], 'For now only HDF5 format is supported!' - if name.startswith('dataset_'): - name = name.split('dataset_')[1] - # Header file: - header_name = os.path.join(path, 'dataset_{}.hdf5'.format(name)) - with h5py.File(header_name, 'r') as f: - a = f.attrs.get('a') - dim = f.attrs.get('dim') - b_0 = f.attrs.get('b_0') - mask = np.copy(f.get('mask', None)) - Se_inv = np.copy(f.get('Se_inv', None)) - dataset = DataSet(a, dim, b_0, mask, Se_inv) - # Projectors: - projectors = [] - for f in os.listdir(path): - if f.startswith('projector') and f.endswith('.hdf5'): - projector_name, i = f.split('_')[1:3] - if projector_name == name: - projector = load_projector(os.path.join(path, f)) - projectors.append((int(i), projector)) - projectors = [p[1] for p in sorted(projectors, key=lambda x: x[0])] - # PhaseMaps: - phasemaps = [] - for f in os.listdir(path): - if f.startswith('phasemap') and f.endswith('.hdf5'): - phasemap_name, i = f.split('_')[1:3] - if phasemap_name == name: - phasemap = load_phasemap(os.path.join(path, f)) - phasemaps.append((int(i), phasemap)) - phasemaps = [p[1] for p in sorted(phasemaps, key=lambda x: x[0])] - dataset.append(phasemaps, projectors) - # Return DataSet: - return dataset +# -*- coding: utf-8 -*- +# Copyright 2016 by Forschungszentrum Juelich GmbH +# Author: J. Caron +# +"""IO functionality for DataSet objects.""" + +import logging + +import os + +import h5py + +import numpy as np + +from ..dataset import DataSet +from ..file_io.io_projector import load_projector +from ..file_io.io_phasemap import load_phasemap + +__all__ = ['load_projector'] +_log = logging.getLogger(__name__) + + +def save_dataset(dataset, filename, overwrite=True): + """%s""" + _log.debug('Calling save_dataset') + path, filename = os.path.split(filename) + name, extension = os.path.splitext(filename) + assert extension in ['.hdf5', ''], 'For now only HDF5 format is supported!' + if name.startswith('dataset_'): + name = name.split('dataset_')[1] + # Header file: + header_name = os.path.join(path, 'dataset_{}.hdf5'.format(name)) + if not os.path.isfile(header_name) or overwrite: # Write if file does not exist or if forced: + with h5py.File(header_name, 'w') as f: + f.attrs['a'] = dataset.a + f.attrs['dim'] = dataset.dim + f.attrs['b_0'] = dataset.b_0 + if dataset.mask is not None: + f.create_dataset('mask', data=dataset.mask) + if dataset.Se_inv is not None: + f.create_dataset('Se_inv', data=dataset.Se_inv) + # PhaseMaps and Projectors: + for i, projector in enumerate(dataset.projectors): + projector_name = 'projector_{}_{}_{}{}'.format(name, i, projector.get_info(), extension) + projector.save(os.path.join(path, projector_name), overwrite=overwrite) + phasemap_name = 'phasemap_{}_{}_{}{}'.format(name, i, projector.get_info(), extension) + dataset.phasemaps[i].save(os.path.join(path, phasemap_name), overwrite=overwrite) +save_dataset.__doc__ %= DataSet.save.__doc__ + + +def load_dataset(filename): + """Load HDF5 file into a :class:`~pyramid.dataset.DataSet` instance. + + Parameters + ---------- + filename: str + The filename to be loaded. + + Returns + ------- + projector : :class:`~.Projector` + A :class:`~.Projector` object containing the loaded data. + + Notes + ----- + This loads a header file and all matching HDF5 files which can be found. The filename + conventions have to be strictly followed for the process to be successful! + + """ + _log.debug('Calling load_dataset') + path, filename = os.path.split(filename) + if path == '': + path = '.' # Make sure this can be used later! + name, extension = os.path.splitext(filename) + assert extension in ['.hdf5', ''], 'For now only HDF5 format is supported!' + if name.startswith('dataset_'): + name = name.split('dataset_')[1] + # Header file: + header_name = os.path.join(path, 'dataset_{}.hdf5'.format(name)) + with h5py.File(header_name, 'r') as f: + a = f.attrs.get('a') + dim = f.attrs.get('dim') + b_0 = f.attrs.get('b_0') + mask = np.copy(f.get('mask', None)) + Se_inv = np.copy(f.get('Se_inv', None)) + dataset = DataSet(a, dim, b_0, mask, Se_inv) + # Projectors: + projectors = [] + for f in os.listdir(path): + if f.startswith('projector') and f.endswith('.hdf5'): + projector_name, i = f.split('_')[1:3] + if projector_name == name: + projector = load_projector(os.path.join(path, f)) + projectors.append((int(i), projector)) + projectors = [p[1] for p in sorted(projectors, key=lambda x: x[0])] + # PhaseMaps: + phasemaps = [] + for f in os.listdir(path): + if f.startswith('phasemap') and f.endswith('.hdf5'): + phasemap_name, i = f.split('_')[1:3] + if phasemap_name == name: + phasemap = load_phasemap(os.path.join(path, f)) + phasemaps.append((int(i), phasemap)) + phasemaps = [p[1] for p in sorted(phasemaps, key=lambda x: x[0])] + dataset.append(phasemaps, projectors) + # Return DataSet: + return dataset diff --git a/pyramid/file_io/io_phasemap.py b/pyramid/file_io/io_phasemap.py index 86f01e25cc229d7d6ebaefbb309e1d61d44cfa41..9a98ae5845740f4c980ed43192e67201576d16a8 100644 --- a/pyramid/file_io/io_phasemap.py +++ b/pyramid/file_io/io_phasemap.py @@ -1,211 +1,211 @@ -# -*- coding: utf-8 -*- -# Copyright 2016 by Forschungszentrum Juelich GmbH -# Author: J. Caron -# -"""IO functionality for FieldData objects.""" - -import logging - -import os - -import numpy as np - -from PIL import Image - -from ..phasemap import PhaseMap - -__all__ = ['load_phasemap'] -_log = logging.getLogger(__name__) - - -def load_phasemap(filename, mask=None, confidence=None, a=None, **kwargs): - """Load supported file into a :class:`~pyramid.phasemap.PhaseMap` instance. - - The function loads the file according to the extension: - - hdf5 for HDF5. - - rpl for Ripple (useful to export to Digital Micrograph). - - dm3 and dm4 for Digital Micrograph files. - - unf for SEMPER unf binary format. - - txt format. - - npy or npz for numpy formats. - - Many image formats such as png, tiff, jpeg... - - Any extra keyword is passed to the corresponsing reader. For - available options see their individual documentation. - - Parameters - ---------- - filename: str - The filename to be loaded. - mask: str or tuple of str and dict, optional - If this is a filename, a mask will be loaded from an appropriate file. If additional - keywords should be provided, use a tuple of the filename and an according dictionary. - If this is `None`, no mask will be loaded. - default is False. - confidence: str or tuple of str and dict, optional - If this is a filename, a confidence matrix will be loaded from an appropriate file. If - additional keywords should be provided, use a tuple of the filename and an according - dictionary. If this is `None`, no mask will be loaded. - a: float or None, optional - If the grid spacing is not None, it will override a possibly loaded value from the files. - - Returns - ------- - phasemap : :class:`~.PhaseMap` - A :class:`~.PhaseMap` object containing the loaded data. - - """ - _log.debug('Calling load_phasemap') - phasemap = _load(filename, as_phasemap=True, a=a, **kwargs) - if mask is not None: - filemask, kwargs_mask = _parse_add_param(mask) - phasemap.mask = _load(filemask, **kwargs_mask) - if confidence is not None: - fileconf, kwargs_conf = _parse_add_param(confidence) - phasemap.confidence = _load(fileconf, **kwargs_conf) - return phasemap - - -def _load(filename, as_phasemap=False, a=1., **kwargs): - _log.debug('Calling _load') - extension = os.path.splitext(filename)[1] - # Load from txt-files: - if extension == '.txt': - return _load_from_txt(filename, as_phasemap, a, **kwargs) - # Load from npy-files: - elif extension in ['.npy', '.npz']: - return _load_from_npy(filename, as_phasemap, a, **kwargs) - elif extension in ['.jpeg', '.jpg', '.png', '.bmp', '.tif']: - return _load_from_img(filename, as_phasemap, a, **kwargs) - # Load with HyperSpy: - else: - if extension == '': - filename = '{}.hdf5'.format(filename) # Default! - return _load_from_hs(filename, as_phasemap, a, **kwargs) - - -def _parse_add_param(param): - if param is None: - return None, {} - elif isinstance(param, str): - return param, {} - elif isinstance(param, (list, tuple)) and len(param) == 2: - return param - else: - raise ValueError('Parameter can be a string or a tuple/list of a string and a dict!') - - -def _load_from_txt(filename, as_phasemap, a, **kwargs): - - def _load_arr(filename, **kwargs): - with open(filename, 'r') as phase_file: - if phase_file.readline().startswith('PYRAMID'): # File has pyramid structure: - return np.loadtxt(filename, delimiter='\t', skiprows=2) - else: # Try default args: - return np.loadtxt(filename, **kwargs) - - result = _load_arr(filename, **kwargs) - if as_phasemap: - if a is None: - a = 1. # Default! - with open(filename, 'r') as phase_file: - header = phase_file.readline() - if header.startswith('PYRAMID'): # File has pyramid structure: - a = float(phase_file.readline()[15:-4]) - return PhaseMap(a, result) - else: - return result - - -def _load_from_npy(filename, as_phasemap, a, **kwargs): - - result = np.load(filename, **kwargs) - if as_phasemap: - if a is None: - a = 1. # Use default! - return PhaseMap(a, result) - else: - return result - - -def _load_from_img(filename, as_phasemap, a, **kwargs): - - result = np.asarray(Image.open(filename, **kwargs).convert('L')) - if as_phasemap: - if a is None: - a = 1. # Use default! - return PhaseMap(a, result) - else: - return result - - -def _load_from_hs(filename, as_phasemap, a, **kwargs): - try: - import hyperspy.api as hs - except ImportError: - _log.error('This method recquires the hyperspy package!') - return - phasemap = PhaseMap.from_signal(hs.load(filename, **kwargs)) - if as_phasemap: - if a is not None: - phasemap.a = a - return phasemap - else: - return phasemap.phase - - -def save_phasemap(phasemap, filename, save_mask, save_conf, pyramid_format, **kwargs): - """%s""" - _log.debug('Calling save_phasemap') - extension = os.path.splitext(filename)[1] - if extension == '.txt': # Save to txt-files: - _save_to_txt(phasemap, filename, pyramid_format, save_mask, save_conf, **kwargs) - elif extension in ['.npy', '.npz']: # Save to npy-files: - _save_to_npy(phasemap, filename, save_mask, save_conf, **kwargs) - else: # Try HyperSpy: - _save_to_hs(phasemap, filename, save_mask, save_conf, **kwargs) -save_phasemap.__doc__ %= PhaseMap.save.__doc__ - - -def _save_to_txt(phasemap, filename, pyramid_format, save_mask, save_conf, **kwargs): - - def _save_arr(filename, array, tag, **kwargs): - if pyramid_format: - with open(filename, 'w') as phase_file: - name = os.path.splitext(os.path.split(filename)[1])[0] - phase_file.write('PYRAMID-{}: {}\n'.format(tag, name)) - phase_file.write('grid spacing = {} nm\n'.format(phasemap.a)) - save_kwargs = {'fmt': '%7.6e', 'delimiter': '\t'} - else: - save_kwargs = kwargs - with open(filename, 'ba') as phase_file: - np.savetxt(phase_file, array, **save_kwargs) - - name, extension = os.path.splitext(filename) - _save_arr('{}{}'.format(name, extension), phasemap.phase, 'PHASEMAP', **kwargs) - if save_mask: - _save_arr('{}_mask{}'.format(name, extension), phasemap.mask, 'MASK', **kwargs) - if save_conf: - _save_arr('{}_conf{}'.format(name, extension), phasemap.confidence, 'CONFIDENCE', **kwargs) - - -def _save_to_npy(phasemap, filename, save_mask, save_conf, **kwargs): - name, extension = os.path.splitext(filename) - np.save('{}{}'.format(name, extension), phasemap.phase, **kwargs) - if save_mask: - np.save('{}_mask{}'.format(name, extension), phasemap.mask, **kwargs) - if save_conf: - np.save('{}_conf{}'.format(name, extension), phasemap.confidence, **kwargs) - - -def _save_to_hs(phasemap, filename, save_mask, save_conf, **kwargs): - name, extension = os.path.splitext(filename) - phasemap.to_signal().save(filename, **kwargs) - if extension not in ['.hdf5', '.HDF5', '']: # mask and confidence are saved separately: - import hyperspy.api as hs - if save_mask: - mask_name = '{}_mask{}'.format(name, extension) - hs.signals.Signal2D(phasemap.mask, **kwargs).save(mask_name) - if save_conf: - conf_name = '{}_conf{}'.format(name, extension) - hs.signals.Signal2D(phasemap.confidence, **kwargs).save(conf_name) +# -*- coding: utf-8 -*- +# Copyright 2016 by Forschungszentrum Juelich GmbH +# Author: J. Caron +# +"""IO functionality for FieldData objects.""" + +import logging + +import os + +import numpy as np + +from PIL import Image + +from ..phasemap import PhaseMap + +__all__ = ['load_phasemap'] +_log = logging.getLogger(__name__) + + +def load_phasemap(filename, mask=None, confidence=None, a=None, **kwargs): + """Load supported file into a :class:`~pyramid.phasemap.PhaseMap` instance. + + The function loads the file according to the extension: + - hdf5 for HDF5. + - rpl for Ripple (useful to export to Digital Micrograph). + - dm3 and dm4 for Digital Micrograph files. + - unf for SEMPER unf binary format. + - txt format. + - npy or npz for numpy formats. + - Many image formats such as png, tiff, jpeg... + + Any extra keyword is passed to the corresponsing reader. For + available options see their individual documentation. + + Parameters + ---------- + filename: str + The filename to be loaded. + mask: str or tuple of str and dict, optional + If this is a filename, a mask will be loaded from an appropriate file. If additional + keywords should be provided, use a tuple of the filename and an according dictionary. + If this is `None`, no mask will be loaded. + default is False. + confidence: str or tuple of str and dict, optional + If this is a filename, a confidence matrix will be loaded from an appropriate file. If + additional keywords should be provided, use a tuple of the filename and an according + dictionary. If this is `None`, no mask will be loaded. + a: float or None, optional + If the grid spacing is not None, it will override a possibly loaded value from the files. + + Returns + ------- + phasemap : :class:`~.PhaseMap` + A :class:`~.PhaseMap` object containing the loaded data. + + """ + _log.debug('Calling load_phasemap') + phasemap = _load(filename, as_phasemap=True, a=a, **kwargs) + if mask is not None: + filemask, kwargs_mask = _parse_add_param(mask) + phasemap.mask = _load(filemask, **kwargs_mask) + if confidence is not None: + fileconf, kwargs_conf = _parse_add_param(confidence) + phasemap.confidence = _load(fileconf, **kwargs_conf) + return phasemap + + +def _load(filename, as_phasemap=False, a=1., **kwargs): + _log.debug('Calling _load') + extension = os.path.splitext(filename)[1] + # Load from txt-files: + if extension == '.txt': + return _load_from_txt(filename, as_phasemap, a, **kwargs) + # Load from npy-files: + elif extension in ['.npy', '.npz']: + return _load_from_npy(filename, as_phasemap, a, **kwargs) + elif extension in ['.jpeg', '.jpg', '.png', '.bmp', '.tif']: + return _load_from_img(filename, as_phasemap, a, **kwargs) + # Load with HyperSpy: + else: + if extension == '': + filename = '{}.hdf5'.format(filename) # Default! + return _load_from_hs(filename, as_phasemap, a, **kwargs) + + +def _parse_add_param(param): + if param is None: + return None, {} + elif isinstance(param, str): + return param, {} + elif isinstance(param, (list, tuple)) and len(param) == 2: + return param + else: + raise ValueError('Parameter can be a string or a tuple/list of a string and a dict!') + + +def _load_from_txt(filename, as_phasemap, a, **kwargs): + + def _load_arr(filename, **kwargs): + with open(filename, 'r') as phase_file: + if phase_file.readline().startswith('PYRAMID'): # File has pyramid structure: + return np.loadtxt(filename, delimiter='\t', skiprows=2) + else: # Try default args: + return np.loadtxt(filename, **kwargs) + + result = _load_arr(filename, **kwargs) + if as_phasemap: + if a is None: + a = 1. # Default! + with open(filename, 'r') as phase_file: + header = phase_file.readline() + if header.startswith('PYRAMID'): # File has pyramid structure: + a = float(phase_file.readline()[15:-4]) + return PhaseMap(a, result) + else: + return result + + +def _load_from_npy(filename, as_phasemap, a, **kwargs): + + result = np.load(filename, **kwargs) + if as_phasemap: + if a is None: + a = 1. # Use default! + return PhaseMap(a, result) + else: + return result + + +def _load_from_img(filename, as_phasemap, a, **kwargs): + + result = np.asarray(Image.open(filename, **kwargs).convert('L')) + if as_phasemap: + if a is None: + a = 1. # Use default! + return PhaseMap(a, result) + else: + return result + + +def _load_from_hs(filename, as_phasemap, a, **kwargs): + try: + import hyperspy.api as hs + except ImportError: + _log.error('This method recquires the hyperspy package!') + return + phasemap = PhaseMap.from_signal(hs.load(filename, **kwargs)) + if as_phasemap: + if a is not None: + phasemap.a = a + return phasemap + else: + return phasemap.phase + + +def save_phasemap(phasemap, filename, save_mask, save_conf, pyramid_format, **kwargs): + """%s""" + _log.debug('Calling save_phasemap') + extension = os.path.splitext(filename)[1] + if extension == '.txt': # Save to txt-files: + _save_to_txt(phasemap, filename, pyramid_format, save_mask, save_conf, **kwargs) + elif extension in ['.npy', '.npz']: # Save to npy-files: + _save_to_npy(phasemap, filename, save_mask, save_conf, **kwargs) + else: # Try HyperSpy: + _save_to_hs(phasemap, filename, save_mask, save_conf, **kwargs) +save_phasemap.__doc__ %= PhaseMap.save.__doc__ + + +def _save_to_txt(phasemap, filename, pyramid_format, save_mask, save_conf, **kwargs): + + def _save_arr(filename, array, tag, **kwargs): + if pyramid_format: + with open(filename, 'w') as phase_file: + name = os.path.splitext(os.path.split(filename)[1])[0] + phase_file.write('PYRAMID-{}: {}\n'.format(tag, name)) + phase_file.write('grid spacing = {} nm\n'.format(phasemap.a)) + save_kwargs = {'fmt': '%7.6e', 'delimiter': '\t'} + else: + save_kwargs = kwargs + with open(filename, 'ba') as phase_file: + np.savetxt(phase_file, array, **save_kwargs) + + name, extension = os.path.splitext(filename) + _save_arr('{}{}'.format(name, extension), phasemap.phase, 'PHASEMAP', **kwargs) + if save_mask: + _save_arr('{}_mask{}'.format(name, extension), phasemap.mask, 'MASK', **kwargs) + if save_conf: + _save_arr('{}_conf{}'.format(name, extension), phasemap.confidence, 'CONFIDENCE', **kwargs) + + +def _save_to_npy(phasemap, filename, save_mask, save_conf, **kwargs): + name, extension = os.path.splitext(filename) + np.save('{}{}'.format(name, extension), phasemap.phase, **kwargs) + if save_mask: + np.save('{}_mask{}'.format(name, extension), phasemap.mask, **kwargs) + if save_conf: + np.save('{}_conf{}'.format(name, extension), phasemap.confidence, **kwargs) + + +def _save_to_hs(phasemap, filename, save_mask, save_conf, **kwargs): + name, extension = os.path.splitext(filename) + phasemap.to_signal().save(filename, **kwargs) + if extension not in ['.hdf5', '.HDF5', '']: # mask and confidence are saved separately: + import hyperspy.api as hs + if save_mask: + mask_name = '{}_mask{}'.format(name, extension) + hs.signals.Signal2D(phasemap.mask, **kwargs).save(mask_name) + if save_conf: + conf_name = '{}_conf{}'.format(name, extension) + hs.signals.Signal2D(phasemap.confidence, **kwargs).save(conf_name) diff --git a/pyramid/file_io/io_projector.py b/pyramid/file_io/io_projector.py index 773152cd6db25dd1cce8319d1ae412aef29d22a2..001a9ab4e2da61609d2b4e6fa5ca73bfa4cd92bb 100644 --- a/pyramid/file_io/io_projector.py +++ b/pyramid/file_io/io_projector.py @@ -1,90 +1,90 @@ -# -*- coding: utf-8 -*- -# Copyright 2016 by Forschungszentrum Juelich GmbH -# Author: J. Caron -# -"""IO functionality for Projector objects.""" - -import logging - -import os - -from scipy.sparse import csr_matrix - -import numpy as np - -import h5py - -from .. import projector - -__all__ = ['load_projector'] -_log = logging.getLogger(__name__) - - -def save_projector(projector, filename, overwrite=True): - """%s""" - _log.debug('Calling save_projector') - name, extension = os.path.splitext(filename) - assert extension in ['.hdf5', ''], 'For now only HDF5 format is supported!' - filename = name + '.hdf5' # In case no extension is provided, set to HDF5! - if not os.path.isfile(filename) or overwrite: # Write if file does not exist or if forced: - with h5py.File(filename, 'w') as f: - class_name = projector.__class__.__name__ - f.attrs['class'] = class_name - if class_name == 'SimpleProjector': - f.attrs['axis'] = projector.axis - else: - f.attrs['tilt'] = projector.tilt - if class_name == 'RotTiltProjector': - f.attrs['rotation'] = projector.rotation - f.attrs['dim'] = projector.dim - f.attrs['dim_uv'] = projector.dim_uv - f.create_dataset('data', data=projector.weight.data) - f.create_dataset('indptr', data=projector.weight.indptr) - f.create_dataset('indices', data=projector.weight.indices) - f.create_dataset('coeff', data=projector.coeff) -save_projector.__doc__ %= projector.Projector.save.__doc__ - - -def load_projector(filename): - """Load HDF5 file into a :class:`~pyramid.projector.Projector` instance (or a subclass). - - Parameters - ---------- - filename: str - The filename to be loaded. - - Returns - ------- - projector : :class:`~.Projector` - A :class:`~.Projector` object containing the loaded data. - - """ - _log.debug('Calling load_projector') - name, extension = os.path.splitext(filename) - assert extension in ['.hdf5', ''], 'For now only HDF5 format is supported!' - filename = name + '.hdf5' # In case no extension is provided, set to HDF5! - with h5py.File(filename, 'r') as f: - # Retrieve dimensions: - dim = f.attrs.get('dim') - dim_uv = f.attrs.get('dim_uv') - size_2d, size_3d = np.prod(dim_uv), np.prod(dim) - # Retrieve weight matrix: - data = f.get('data') - indptr = f.get('indptr') - indices = f.get('indices') - weight = csr_matrix((data, indices, indptr), shape=(size_2d, size_3d)) - # Retrieve coefficients: - coeff = np.copy(f.get('coeff')) - # Construct projector: - result = projector.Projector(dim, dim_uv, weight, coeff) - # Specify projector type: - class_name = f.attrs.get('class') - result.__class__ = getattr(projector, class_name) - if class_name == 'SimpleProjector': - result.axis = f.attrs.get('axis') - else: - result.tilt = f.attrs.get('tilt') - if class_name == 'RotTiltProjector': - result.rotation = f.attrs.get('rotation') - # Return projector object: - return result +# -*- coding: utf-8 -*- +# Copyright 2016 by Forschungszentrum Juelich GmbH +# Author: J. Caron +# +"""IO functionality for Projector objects.""" + +import logging + +import os + +from scipy.sparse import csr_matrix + +import numpy as np + +import h5py + +from .. import projector + +__all__ = ['load_projector'] +_log = logging.getLogger(__name__) + + +def save_projector(projector, filename, overwrite=True): + """%s""" + _log.debug('Calling save_projector') + name, extension = os.path.splitext(filename) + assert extension in ['.hdf5', ''], 'For now only HDF5 format is supported!' + filename = name + '.hdf5' # In case no extension is provided, set to HDF5! + if not os.path.isfile(filename) or overwrite: # Write if file does not exist or if forced: + with h5py.File(filename, 'w') as f: + class_name = projector.__class__.__name__ + f.attrs['class'] = class_name + if class_name == 'SimpleProjector': + f.attrs['axis'] = projector.axis + else: + f.attrs['tilt'] = projector.tilt + if class_name == 'RotTiltProjector': + f.attrs['rotation'] = projector.rotation + f.attrs['dim'] = projector.dim + f.attrs['dim_uv'] = projector.dim_uv + f.create_dataset('data', data=projector.weight.data) + f.create_dataset('indptr', data=projector.weight.indptr) + f.create_dataset('indices', data=projector.weight.indices) + f.create_dataset('coeff', data=projector.coeff) +save_projector.__doc__ %= projector.Projector.save.__doc__ + + +def load_projector(filename): + """Load HDF5 file into a :class:`~pyramid.projector.Projector` instance (or a subclass). + + Parameters + ---------- + filename: str + The filename to be loaded. + + Returns + ------- + projector : :class:`~.Projector` + A :class:`~.Projector` object containing the loaded data. + + """ + _log.debug('Calling load_projector') + name, extension = os.path.splitext(filename) + assert extension in ['.hdf5', ''], 'For now only HDF5 format is supported!' + filename = name + '.hdf5' # In case no extension is provided, set to HDF5! + with h5py.File(filename, 'r') as f: + # Retrieve dimensions: + dim = f.attrs.get('dim') + dim_uv = f.attrs.get('dim_uv') + size_2d, size_3d = np.prod(dim_uv), np.prod(dim) + # Retrieve weight matrix: + data = f.get('data') + indptr = f.get('indptr') + indices = f.get('indices') + weight = csr_matrix((data, indices, indptr), shape=(size_2d, size_3d)) + # Retrieve coefficients: + coeff = np.copy(f.get('coeff')) + # Construct projector: + result = projector.Projector(dim, dim_uv, weight, coeff) + # Specify projector type: + class_name = f.attrs.get('class') + result.__class__ = getattr(projector, class_name) + if class_name == 'SimpleProjector': + result.axis = f.attrs.get('axis') + else: + result.tilt = f.attrs.get('tilt') + if class_name == 'RotTiltProjector': + result.rotation = f.attrs.get('rotation') + # Return projector object: + return result diff --git a/pyramid/file_io/io_scalardata.py b/pyramid/file_io/io_scalardata.py index 8945bf5e1ba0fa8281034e454b6761a251f9c67f..dcb12aa87fdf861f2e32addb716c90d46b8196c1 100644 --- a/pyramid/file_io/io_scalardata.py +++ b/pyramid/file_io/io_scalardata.py @@ -1,93 +1,93 @@ -# -*- coding: utf-8 -*- -# Copyright 2016 by Forschungszentrum Juelich GmbH -# Author: J. Caron -# -"""IO functionality for ScalarData objects.""" - -import logging - -import os - -import numpy as np - -from ..fielddata import ScalarData - -__all__ = ['load_scalardata'] -_log = logging.getLogger(__name__) - - -def load_scalardata(filename, a=None, **kwargs): - """Load supported file into a :class:`~pyramid.fielddata.ScalarData` instance. - - The function loads the file according to the extension: - - hdf5 for HDF5. - - EMD Electron Microscopy Dataset format (also HDF5). - - npy or npz for numpy formats. - - Any extra keyword is passed to the corresponsing reader. For - available options see their individual documentation. - - Parameters - ---------- - filename: str - The filename to be loaded. - a: float or None, optional - If the grid spacing is not None, it will override a possibly loaded value from the files. - - Returns - ------- - scalardata : :class:`~.ScalarData` - A :class:`~.ScalarData` object containing the loaded data. - - """ - _log.debug('Calling load_scalardata') - extension = os.path.splitext(filename)[1] - # Load from npy-files: - if extension in ['.npy', '.npz']: - return _load_from_npy(filename, a, **kwargs) - # Load with HyperSpy: - else: - if extension == '': - filename = '{}.hdf5'.format(filename) # Default! - return _load_from_hs(filename, a, **kwargs) - - -def _load_from_npy(filename, a, **kwargs): - _log.debug('Calling load_from_npy') - if a is None: - a = 1. # Use default! - return ScalarData(a, np.load(filename, **kwargs)) - - -def _load_from_hs(filename, a, **kwargs): - _log.debug('Calling load_from_hs') - try: - import hyperspy.api as hs - except ImportError: - _log.error('This method recquires the hyperspy package!') - return - scalardata = ScalarData.from_signal(hs.load(filename, **kwargs)) - if a is not None: - scalardata.a = a - return scalardata - - -def save_scalardata(scalardata, filename, **kwargs): - """%s""" - _log.debug('Calling save_scalardata') - extension = os.path.splitext(filename)[1] - if extension in ['.npy', '.npz']: # Save to npy-files: - _save_to_npy(scalardata, filename, **kwargs) - else: # Try HyperSpy: - _save_to_hs(scalardata, filename, **kwargs) -save_scalardata.__doc__ %= ScalarData.save.__doc__ - - -def _save_to_npy(scalardata, filename, **kwargs): - _log.debug('Calling save_to_npy') - np.save(filename, scalardata.field, **kwargs) - - -def _save_to_hs(scalardata, filename, **kwargs): - _log.debug('Calling save_to_hyperspy') - scalardata.to_signal().save(filename, **kwargs) +# -*- coding: utf-8 -*- +# Copyright 2016 by Forschungszentrum Juelich GmbH +# Author: J. Caron +# +"""IO functionality for ScalarData objects.""" + +import logging + +import os + +import numpy as np + +from ..fielddata import ScalarData + +__all__ = ['load_scalardata'] +_log = logging.getLogger(__name__) + + +def load_scalardata(filename, a=None, **kwargs): + """Load supported file into a :class:`~pyramid.fielddata.ScalarData` instance. + + The function loads the file according to the extension: + - hdf5 for HDF5. + - EMD Electron Microscopy Dataset format (also HDF5). + - npy or npz for numpy formats. + + Any extra keyword is passed to the corresponsing reader. For + available options see their individual documentation. + + Parameters + ---------- + filename: str + The filename to be loaded. + a: float or None, optional + If the grid spacing is not None, it will override a possibly loaded value from the files. + + Returns + ------- + scalardata : :class:`~.ScalarData` + A :class:`~.ScalarData` object containing the loaded data. + + """ + _log.debug('Calling load_scalardata') + extension = os.path.splitext(filename)[1] + # Load from npy-files: + if extension in ['.npy', '.npz']: + return _load_from_npy(filename, a, **kwargs) + # Load with HyperSpy: + else: + if extension == '': + filename = '{}.hdf5'.format(filename) # Default! + return _load_from_hs(filename, a, **kwargs) + + +def _load_from_npy(filename, a, **kwargs): + _log.debug('Calling load_from_npy') + if a is None: + a = 1. # Use default! + return ScalarData(a, np.load(filename, **kwargs)) + + +def _load_from_hs(filename, a, **kwargs): + _log.debug('Calling load_from_hs') + try: + import hyperspy.api as hs + except ImportError: + _log.error('This method recquires the hyperspy package!') + return + scalardata = ScalarData.from_signal(hs.load(filename, **kwargs)) + if a is not None: + scalardata.a = a + return scalardata + + +def save_scalardata(scalardata, filename, **kwargs): + """%s""" + _log.debug('Calling save_scalardata') + extension = os.path.splitext(filename)[1] + if extension in ['.npy', '.npz']: # Save to npy-files: + _save_to_npy(scalardata, filename, **kwargs) + else: # Try HyperSpy: + _save_to_hs(scalardata, filename, **kwargs) +save_scalardata.__doc__ %= ScalarData.save.__doc__ + + +def _save_to_npy(scalardata, filename, **kwargs): + _log.debug('Calling save_to_npy') + np.save(filename, scalardata.field, **kwargs) + + +def _save_to_hs(scalardata, filename, **kwargs): + _log.debug('Calling save_to_hyperspy') + scalardata.to_signal().save(filename, **kwargs) diff --git a/pyramid/file_io/io_vectordata.py b/pyramid/file_io/io_vectordata.py index ad63ba50ca7bed3164d22b202575ced5211d361a..ca02f828116c823cee34f83ce2448776b13fc9d5 100644 --- a/pyramid/file_io/io_vectordata.py +++ b/pyramid/file_io/io_vectordata.py @@ -1,252 +1,252 @@ -# -*- coding: utf-8 -*- -# Copyright 2016 by Forschungszentrum Juelich GmbH -# Author: J. Caron -# -"""IO functionality for VectorData objects.""" - -import logging - -import os - -import numpy as np - -from ..fielddata import VectorData - -__all__ = ['load_vectordata'] -_log = logging.getLogger(__name__) - - -def load_vectordata(filename, a=None, **kwargs): - """Load supported file into a :class:`~pyramid.fielddata.VectorData` instance. - - The function loads the file according to the extension: - - hdf5 for HDF5. - - EMD Electron Microscopy Dataset format (also HDF5). - - llg format. - - ovf format. - - npy or npz for numpy formats. - - Any extra keyword is passed to the corresponsing reader. For - available options see their individual documentation. - - Parameters - ---------- - filename: str - The filename to be loaded. - a: float or None, optional - If the grid spacing is not None, it will override a possibly loaded value from the files. - - Returns - ------- - vectordata : :class:`~.VectorData` - A :class:`~.VectorData` object containing the loaded data. - - """ - _log.debug('Calling load_vectordata') - extension = os.path.splitext(filename)[1] - # Load from llg-files: - if extension in ['.llg', '.txt']: - return _load_from_llg(filename, a) - # Load from ovf-files: - if extension in ['.ovf', '.omf', '.ohf', 'obf']: - return _load_from_ovf(filename, a) - # Load from npy-files: - elif extension in ['.npy', '.npz']: - return _load_from_npy(filename, a, **kwargs) - # Load with HyperSpy: - else: - if extension == '': - filename = '{}.hdf5'.format(filename) # Default! - return _load_from_hs(filename, a, **kwargs) - - -def _load_from_llg(filename, a): - _log.debug('Calling load_from_llg') - SCALE = 1.0E-9 / 1.0E-2 # From cm to nm - data = np.genfromtxt(filename, skip_header=2) - dim = tuple(np.genfromtxt(filename, dtype=int, skip_header=1, skip_footer=len(data[:, 0]))) - if a is None: - a = (data[1, 0] - data[0, 0]) / SCALE - field = data[:, 3:6].T.reshape((3,) + dim) - return VectorData(a, field) - - -def _load_from_ovf(filename, a): - _log.debug('Calling load_from_ovf') - with open(filename, 'rb') as mag_file: - assert mag_file.readline().startswith(b'# OOMMF') # Make sure file has .ovf-format! - read_header, read_data = False, False - header = {} - x_mag, y_mag, z_mag = [], [], [] - data_mode = '' - for line in mag_file: - # Read in additional info: - if not read_header and not read_data: - if line.startswith(b'# Segment count'): - assert int(line.split()[3]) == 1, 'Only one vector field can be read at time!' - elif line.startswith(b'# Begin: Header'): - read_header = True - elif line.startswith(b'# Begin: Data'): - read_data = True - data_mode = ' '.join(line.decode('utf-8').split()[3:]) - assert data_mode in ['text', 'Text', 'Binary 4', 'Binary 8'], \ - 'Data mode {} is currently not supported by this reader!'.format(data_mode) - assert header.get('meshtype') == 'rectangular', \ - 'Only rectangular grids can be currently read!' - # Read in header: - elif read_header: # Read header: - if line.startswith(b'# End: Header'): # Header is done: - read_header = False - continue - line = line.decode('utf-8') # Decode to use strings here! - line_list = line.split() - if '##' in line_list: # Strip trailing comments: - del line_list[line_list.index('##'):] - if len(line_list) <= 1: # Just '#' or empty line: - continue - key, value = line_list[1].strip(':'), ' '.join(line_list[2:]) - if key not in header: # Add new key, value pair: - header[key] = value - elif key == 'Desc': # Can go over several lines: - header['Desc'] = ' '.join([header['Desc'], value]) - # Read in data: - # TODO: Make it work for both text and binary! Put into HyperSpy? - # TODO: http://math.nist.gov/oommf/doc/userguide11b2/userguide/vectorfieldformat.html - elif read_data: # Currently in a data block: - if data_mode in ['text', 'Text']: # Read data as text: - if line.startswith(b'# End: Data'): # Header is done: - read_data = False - else: - x, y, z = [float(i) for i in line.split()] - x_mag.append(x) - y_mag.append(y) - z_mag.append(z) - elif 'Binary' in data_mode: - count = int(data_mode.split()[-1]) - dtype = '>f{}'.format(count) - dim = (int(header['znodes']), int(header['ynodes']), int(header['xnodes'])) - test = np.fromfile(mag_file, dtype='<f4', count=count*2+1) - if count == 4: # Binary 4: - assert test == '123456789.0', 'Wrong test bytes!' - if count == 8: # Binary 4: - assert test == '123456789012345.0', 'Wrong test bytes!' - data = np.fromfile(mag_file, dtype=dtype, count=3*np.prod(dim)) - data.reshape((3,) + dim) - x_mag, y_mag, z_mag = data - read_data = False # Stop reading data and search for new Segments (if any). - # Format after reading: - dim = (int(header['znodes']), int(header['ynodes']), int(header['xnodes'])) - x_mag = np.asarray(x_mag).reshape(dim) - y_mag = np.asarray(y_mag).reshape(dim) - z_mag = np.asarray(z_mag).reshape(dim) - field = np.asarray((x_mag, y_mag, z_mag)) * float(header.get('valuemultiplier', 1)) - if a is None: - assert header.get('xstepsize') == header.get('ystepsize') == header.get('zstepsize'), \ - 'Grid spacing is not equidistant!' - a = float(header.get('xstepsize', 1.)) - meshunit = header.get('meshunit', 'nm') - a *= {'m': 1e9, 'mm': 1e6, 'µm': 1e3, 'nm': 1}[meshunit] # Conversion to nm - return VectorData(a, field) - - -def _load_from_npy(filename, a, **kwargs): - _log.debug('Calling load_from_npy') - if a is None: - a = 1. # Use default! - return VectorData(a, np.load(filename, **kwargs)) - - -def _load_from_hs(filename, a, **kwargs): - _log.debug('Calling load_from_hs') - try: - import hyperspy.api as hs - except ImportError: - _log.error('This method recquires the hyperspy package!') - return - vectordata = VectorData.from_signal(hs.load(filename, **kwargs)) - if a is not None: - vectordata.a = a - return vectordata - - -def save_vectordata(vectordata, filename, **kwargs): - """%s""" - _log.debug('Calling save_vectordata') - extension = os.path.splitext(filename)[1] - if extension == '.llg': # Save to llg-files: - _save_to_llg(vectordata, filename) - elif extension == '.ovf': # Save to ovf-files: - _save_to_ovf(vectordata, filename, **kwargs) - elif extension in ['.npy', '.npz']: # Save to npy-files: - _save_to_npy(vectordata, filename, **kwargs) - else: # Try HyperSpy: - _save_to_hs(vectordata, filename, **kwargs) -save_vectordata.__doc__ %= VectorData.save.__doc__ - - -def _save_to_llg(vectordata, filename): - _log.debug('Calling save_to_llg') - dim = vectordata.dim - SCALE = 1.0E-9 / 1.0E-2 # from nm to cm - # Create 3D meshgrid and reshape it and the field into a list where x varies first: - zz, yy, xx = vectordata.a * SCALE * (np.indices(dim) + 0.5).reshape(3, -1) - x_vec, y_vec, z_vec = vectordata.field.reshape(3, -1) - data = np.array([xx, yy, zz, x_vec, y_vec, z_vec]).T - # Save data to file: - with open(filename, 'w') as mag_file: - mag_file.write('LLGFileCreator: %s\n'.format(filename)) - mag_file.write(' %d %d %d\n'.format(np.asarray(dim)[::-1])) - mag_file.writelines('\n'.join(' '.join('{:7.6e}'.format(cell) - for cell in row) for row in data)) - - -def _save_to_ovf(vectordata, filename): - _log.debug('Calling save_to_ovf') - with open(filename, 'w') as mag_file: - mag_file.write('# OOMMF OVF 2.0\n') - mag_file.write('# Segment count: 1\n') - mag_file.write('# Begin: Segment\n') - # Write Header: - mag_file.write('# Begin: Header\n') - name = os.path.split(os.path.split(filename)[1]) - mag_file.write('# Title: PYRAMID-VECTORDATA {}\n'.format(name)) - mag_file.write('# meshtype: rectangular\n') - mag_file.write('# meshunit: nm\n') - mag_file.write('# valueunit: A/m\n') - mag_file.write('# valuemultiplier: 1.\n') - mag_file.write('# xmin: 0.\n') - mag_file.write('# ymin: 0.\n') - mag_file.write('# zmin: 0.\n') - mag_file.write('# xmax: {}\n'.format(vectordata.a * vectordata.dim[2])) - mag_file.write('# ymax: {}\n'.format(vectordata.a * vectordata.dim[1])) - mag_file.write('# zmax: {}\n'.format(vectordata.a * vectordata.dim[0])) - mag_file.write('# xbase: 0.\n') - mag_file.write('# ybase: 0.\n') - mag_file.write('# zbase: 0.\n') - mag_file.write('# xstepsize: {}\n'.format(vectordata.a)) - mag_file.write('# ystepsize: {}\n'.format(vectordata.a)) - mag_file.write('# zstepsize: {}\n'.format(vectordata.a)) - mag_file.write('# xnodes: {}\n'.format(vectordata.dim[2])) - mag_file.write('# ynodes: {}\n'.format(vectordata.dim[1])) - mag_file.write('# znodes: {}\n'.format(vectordata.dim[0])) - mag_file.write('# End: Header\n') - # Write data: - mag_file.write('# Begin: Data Text\n') - x_mag, y_mag, z_mag = vectordata.field - x_mag = x_mag.ravel() - y_mag = y_mag.ravel() - z_mag = z_mag.ravel() - for i in range(np.prod(vectordata.dim)): - mag_file.write('{:g} {:g} {:g}\n'.format(x_mag[i], y_mag[i], z_mag[i])) - mag_file.write('# End: Data Text\n') - mag_file.write('# End: Segment\n') - - -def _save_to_npy(vectordata, filename, **kwargs): - _log.debug('Calling save_to_npy') - np.save(filename, vectordata.field, **kwargs) - - -def _save_to_hs(vectordata, filename, **kwargs): - _log.debug('Calling save_to_hyperspy') - vectordata.to_signal().save(filename, **kwargs) +# -*- coding: utf-8 -*- +# Copyright 2016 by Forschungszentrum Juelich GmbH +# Author: J. Caron +# +"""IO functionality for VectorData objects.""" + +import logging + +import os + +import numpy as np + +from ..fielddata import VectorData + +__all__ = ['load_vectordata'] +_log = logging.getLogger(__name__) + + +def load_vectordata(filename, a=None, **kwargs): + """Load supported file into a :class:`~pyramid.fielddata.VectorData` instance. + + The function loads the file according to the extension: + - hdf5 for HDF5. + - EMD Electron Microscopy Dataset format (also HDF5). + - llg format. + - ovf format. + - npy or npz for numpy formats. + + Any extra keyword is passed to the corresponsing reader. For + available options see their individual documentation. + + Parameters + ---------- + filename: str + The filename to be loaded. + a: float or None, optional + If the grid spacing is not None, it will override a possibly loaded value from the files. + + Returns + ------- + vectordata : :class:`~.VectorData` + A :class:`~.VectorData` object containing the loaded data. + + """ + _log.debug('Calling load_vectordata') + extension = os.path.splitext(filename)[1] + # Load from llg-files: + if extension in ['.llg', '.txt']: + return _load_from_llg(filename, a) + # Load from ovf-files: + if extension in ['.ovf', '.omf', '.ohf', 'obf']: + return _load_from_ovf(filename, a) + # Load from npy-files: + elif extension in ['.npy', '.npz']: + return _load_from_npy(filename, a, **kwargs) + # Load with HyperSpy: + else: + if extension == '': + filename = '{}.hdf5'.format(filename) # Default! + return _load_from_hs(filename, a, **kwargs) + + +def _load_from_llg(filename, a): + _log.debug('Calling load_from_llg') + SCALE = 1.0E-9 / 1.0E-2 # From cm to nm + data = np.genfromtxt(filename, skip_header=2) + dim = tuple(np.genfromtxt(filename, dtype=int, skip_header=1, skip_footer=len(data[:, 0]))) + if a is None: + a = (data[1, 0] - data[0, 0]) / SCALE + field = data[:, 3:6].T.reshape((3,) + dim) + return VectorData(a, field) + + +def _load_from_ovf(filename, a): + _log.debug('Calling load_from_ovf') + with open(filename, 'rb') as mag_file: + assert mag_file.readline().startswith(b'# OOMMF') # Make sure file has .ovf-format! + read_header, read_data = False, False + header = {} + x_mag, y_mag, z_mag = [], [], [] + data_mode = '' + for line in mag_file: + # Read in additional info: + if not read_header and not read_data: + if line.startswith(b'# Segment count'): + assert int(line.split()[3]) == 1, 'Only one vector field can be read at time!' + elif line.startswith(b'# Begin: Header'): + read_header = True + elif line.startswith(b'# Begin: Data'): + read_data = True + data_mode = ' '.join(line.decode('utf-8').split()[3:]) + assert data_mode in ['text', 'Text', 'Binary 4', 'Binary 8'], \ + 'Data mode {} is currently not supported by this reader!'.format(data_mode) + assert header.get('meshtype') == 'rectangular', \ + 'Only rectangular grids can be currently read!' + # Read in header: + elif read_header: # Read header: + if line.startswith(b'# End: Header'): # Header is done: + read_header = False + continue + line = line.decode('utf-8') # Decode to use strings here! + line_list = line.split() + if '##' in line_list: # Strip trailing comments: + del line_list[line_list.index('##'):] + if len(line_list) <= 1: # Just '#' or empty line: + continue + key, value = line_list[1].strip(':'), ' '.join(line_list[2:]) + if key not in header: # Add new key, value pair: + header[key] = value + elif key == 'Desc': # Can go over several lines: + header['Desc'] = ' '.join([header['Desc'], value]) + # Read in data: + # TODO: Make it work for both text and binary! Put into HyperSpy? + # TODO: http://math.nist.gov/oommf/doc/userguide11b2/userguide/vectorfieldformat.html + elif read_data: # Currently in a data block: + if data_mode in ['text', 'Text']: # Read data as text: + if line.startswith(b'# End: Data'): # Header is done: + read_data = False + else: + x, y, z = [float(i) for i in line.split()] + x_mag.append(x) + y_mag.append(y) + z_mag.append(z) + elif 'Binary' in data_mode: + count = int(data_mode.split()[-1]) + dtype = '>f{}'.format(count) + dim = (int(header['znodes']), int(header['ynodes']), int(header['xnodes'])) + test = np.fromfile(mag_file, dtype='<f4', count=count*2+1) + if count == 4: # Binary 4: + assert test == '123456789.0', 'Wrong test bytes!' + if count == 8: # Binary 4: + assert test == '123456789012345.0', 'Wrong test bytes!' + data = np.fromfile(mag_file, dtype=dtype, count=3*np.prod(dim)) + data.reshape((3,) + dim) + x_mag, y_mag, z_mag = data + read_data = False # Stop reading data and search for new Segments (if any). + # Format after reading: + dim = (int(header['znodes']), int(header['ynodes']), int(header['xnodes'])) + x_mag = np.asarray(x_mag).reshape(dim) + y_mag = np.asarray(y_mag).reshape(dim) + z_mag = np.asarray(z_mag).reshape(dim) + field = np.asarray((x_mag, y_mag, z_mag)) * float(header.get('valuemultiplier', 1)) + if a is None: + assert header.get('xstepsize') == header.get('ystepsize') == header.get('zstepsize'), \ + 'Grid spacing is not equidistant!' + a = float(header.get('xstepsize', 1.)) + meshunit = header.get('meshunit', 'nm') + a *= {'m': 1e9, 'mm': 1e6, 'µm': 1e3, 'nm': 1}[meshunit] # Conversion to nm + return VectorData(a, field) + + +def _load_from_npy(filename, a, **kwargs): + _log.debug('Calling load_from_npy') + if a is None: + a = 1. # Use default! + return VectorData(a, np.load(filename, **kwargs)) + + +def _load_from_hs(filename, a, **kwargs): + _log.debug('Calling load_from_hs') + try: + import hyperspy.api as hs + except ImportError: + _log.error('This method recquires the hyperspy package!') + return + vectordata = VectorData.from_signal(hs.load(filename, **kwargs)) + if a is not None: + vectordata.a = a + return vectordata + + +def save_vectordata(vectordata, filename, **kwargs): + """%s""" + _log.debug('Calling save_vectordata') + extension = os.path.splitext(filename)[1] + if extension == '.llg': # Save to llg-files: + _save_to_llg(vectordata, filename) + elif extension == '.ovf': # Save to ovf-files: + _save_to_ovf(vectordata, filename, **kwargs) + elif extension in ['.npy', '.npz']: # Save to npy-files: + _save_to_npy(vectordata, filename, **kwargs) + else: # Try HyperSpy: + _save_to_hs(vectordata, filename, **kwargs) +save_vectordata.__doc__ %= VectorData.save.__doc__ + + +def _save_to_llg(vectordata, filename): + _log.debug('Calling save_to_llg') + dim = vectordata.dim + SCALE = 1.0E-9 / 1.0E-2 # from nm to cm + # Create 3D meshgrid and reshape it and the field into a list where x varies first: + zz, yy, xx = vectordata.a * SCALE * (np.indices(dim) + 0.5).reshape(3, -1) + x_vec, y_vec, z_vec = vectordata.field.reshape(3, -1) + data = np.array([xx, yy, zz, x_vec, y_vec, z_vec]).T + # Save data to file: + with open(filename, 'w') as mag_file: + mag_file.write('LLGFileCreator: %s\n'.format(filename)) + mag_file.write(' %d %d %d\n'.format(np.asarray(dim)[::-1])) + mag_file.writelines('\n'.join(' '.join('{:7.6e}'.format(cell) + for cell in row) for row in data)) + + +def _save_to_ovf(vectordata, filename): + _log.debug('Calling save_to_ovf') + with open(filename, 'w') as mag_file: + mag_file.write('# OOMMF OVF 2.0\n') + mag_file.write('# Segment count: 1\n') + mag_file.write('# Begin: Segment\n') + # Write Header: + mag_file.write('# Begin: Header\n') + name = os.path.split(os.path.split(filename)[1]) + mag_file.write('# Title: PYRAMID-VECTORDATA {}\n'.format(name)) + mag_file.write('# meshtype: rectangular\n') + mag_file.write('# meshunit: nm\n') + mag_file.write('# valueunit: A/m\n') + mag_file.write('# valuemultiplier: 1.\n') + mag_file.write('# xmin: 0.\n') + mag_file.write('# ymin: 0.\n') + mag_file.write('# zmin: 0.\n') + mag_file.write('# xmax: {}\n'.format(vectordata.a * vectordata.dim[2])) + mag_file.write('# ymax: {}\n'.format(vectordata.a * vectordata.dim[1])) + mag_file.write('# zmax: {}\n'.format(vectordata.a * vectordata.dim[0])) + mag_file.write('# xbase: 0.\n') + mag_file.write('# ybase: 0.\n') + mag_file.write('# zbase: 0.\n') + mag_file.write('# xstepsize: {}\n'.format(vectordata.a)) + mag_file.write('# ystepsize: {}\n'.format(vectordata.a)) + mag_file.write('# zstepsize: {}\n'.format(vectordata.a)) + mag_file.write('# xnodes: {}\n'.format(vectordata.dim[2])) + mag_file.write('# ynodes: {}\n'.format(vectordata.dim[1])) + mag_file.write('# znodes: {}\n'.format(vectordata.dim[0])) + mag_file.write('# End: Header\n') + # Write data: + mag_file.write('# Begin: Data Text\n') + x_mag, y_mag, z_mag = vectordata.field + x_mag = x_mag.ravel() + y_mag = y_mag.ravel() + z_mag = z_mag.ravel() + for i in range(np.prod(vectordata.dim)): + mag_file.write('{:g} {:g} {:g}\n'.format(x_mag[i], y_mag[i], z_mag[i])) + mag_file.write('# End: Data Text\n') + mag_file.write('# End: Segment\n') + + +def _save_to_npy(vectordata, filename, **kwargs): + _log.debug('Calling save_to_npy') + np.save(filename, vectordata.field, **kwargs) + + +def _save_to_hs(vectordata, filename, **kwargs): + _log.debug('Calling save_to_hyperspy') + vectordata.to_signal().save(filename, **kwargs) diff --git a/pyramid/kernel.py b/pyramid/kernel.py index a503a3f4bdd6861bd8e4ecab279ba4eab1754ab8..8567ad849dbcb5922e6cac123042053b24032080 100644 --- a/pyramid/kernel.py +++ b/pyramid/kernel.py @@ -1,160 +1,160 @@ -# -*- coding: utf-8 -*- -# Copyright 2014 by Forschungszentrum Juelich GmbH -# Author: J. Caron -# -"""This module provides the :class:`~.Kernel` class, representing the phase contribution of one -single magnetized pixel.""" - -import logging - -import numpy as np - -from jutil import fft - -__all__ = ['Kernel', 'PHI_0'] - -PHI_0 = 2067.83 # magnetic flux in T*nm² - - -class Kernel(object): - """Class for calculating kernel matrices for the phase calculation. - - Represents the phase of a single magnetized pixel for two orthogonal directions (`u` and `v`), - which can be accessed via the corresponding attributes. The default elementary geometry is - `disc`, but can also be specified as the phase of a `slab` representation of a single - magnetized pixel. During the construction, a few attributes are calculated that are used in - the convolution during phase calculation in the different :class:`~Phasemapper` classes. - An instance of the :class:`~.Kernel` class can be called as a function with a `vector`, - which represents the projected magnetization onto a 2-dimensional grid. - - Attributes - ---------- - a : float - The grid spacing in nm. - dim_uv : tuple of int (N=2), optional - Dimensions of the 2-dimensional projected magnetization grid from which the phase should - be calculated. - dim_kern : tuple of int (N=2) - Dimensions of the kernel, which is ``2N-1`` for both axes compared to `dim_uv`. - dim_pad : tuple of int (N=2) - Dimensions of the padded FOV, which is ``2N`` (if FFTW is used) or the next highest power - of 2 (for numpy-FFT). - dim_fft : tuple of int (N=2) - Dimensions of the grid, which is used for the FFT, taking into account that a RFFT should - be used (one axis is halved in comparison to `dim_pad`). - b_0 : float, optional - Saturation magnetization in Tesla, which is used for the phase calculation. Default is 1. - geometry : {'disc', 'slab'}, optional - The elementary geometry of the single magnetized pixel. - u : :class:`~numpy.ndarray` (N=3) - The phase contribution of one pixel magnetized in u-direction. - v : :class:`~numpy.ndarray` (N=3) - The phase contribution of one pixel magnetized in v-direction. - u_fft : :class:`~numpy.ndarray` (N=3) - The real FFT of the phase contribution of one pixel magnetized in u-direction. - v_fft : :class:`~numpy.ndarray` (N=3) - The real FFT of the phase contribution of one pixel magnetized in v-direction. - slice_phase : tuple (N=2) of :class:`slice` - A tuple of :class:`slice` objects to extract the original FOV from the increased one with - size `dim_pad` for the elementary kernel phase. The kernel is shifted, thus the center is - not at (0, 0), which also shifts the slicing compared to `slice_mag`. - slice_mag : tuple (N=2) of :class:`slice` - A tuple of :class:`slice` objects to extract the original FOV from the increased one with - size `dim_pad` for the projected magnetization distribution. - prw_vec: tuple of 2 int, optional - A two-component vector describing the displacement of the reference wave to include - perturbation of this reference by the object itself (via fringing fields), (y, x). - dtype: numpy dtype, optional - Data type of the kernel. Default is np.float32. - - """ - - _log = logging.getLogger(__name__ + '.Kernel') - - def __init__(self, a, dim_uv, b_0=1., prw_vec=None, geometry='disc', dtype=np.float32): - self._log.debug('Calling __init__') - # Set basic properties: - self.b_0 = b_0 - self.prw_vec = prw_vec - self.dim_uv = dim_uv # Dimensions of the FOV - self.dim_kern = tuple(2 * np.array(dim_uv) - 1) # Dimensions of the kernel - self.a = a - self.geometry = geometry - # Set up FFT: - if fft.HAVE_FFTW: - self.dim_pad = tuple(2 * np.array(dim_uv)) # is at least even (not nec. power of 2) - else: - self.dim_pad = tuple(2 ** np.ceil(np.log2(2 * np.array(dim_uv))).astype(int)) # pow(2) - self.dim_fft = (self.dim_pad[0], self.dim_pad[1] // 2 + 1) # last axis is real - self.slice_phase = (slice(dim_uv[0] - 1, self.dim_kern[0]), # Shift because kernel center - slice(dim_uv[1] - 1, self.dim_kern[1])) # is not at (0, 0)! - self.slice_mag = (slice(0, dim_uv[0]), # Magnetization is padded on the far end! - slice(0, dim_uv[1])) # (Phase cutout is shifted as listed above) - # Calculate kernel (single pixel phase): - # [M_0] = A/m --> This is the magnetization, not the magnetic moment (A/m * m³ = Am²)! - # [PHI_0 / µ_0] = Tm² / Tm/A = Am - # [b_0] = [M_0] * [µ_0] = A/m * N/A² = N/Am = T - # [coeff] = [b_0 * a² / (2*PHI_0)] = T * m² / Tm² = 1 --> without unit (phase)! - coeff = b_0 * a ** 2 / (2 * PHI_0) # Minus is gone because of negative z-direction - v_dim, u_dim = dim_uv - u = np.linspace(-(u_dim - 1), u_dim - 1, num=2 * u_dim - 1) - v = np.linspace(-(v_dim - 1), v_dim - 1, num=2 * v_dim - 1) - uu, vv = np.meshgrid(u, v) - self.u = np.empty(self.dim_kern, dtype=dtype) - self.v = np.empty(self.dim_kern, dtype=dtype) - self.u[...] = coeff * self._get_elementary_phase(geometry, uu, vv, a) - self.v[...] = coeff * -self._get_elementary_phase(geometry, vv, uu, a) - # Include perturbed reference wave: - if prw_vec is not None: - uu += prw_vec[1] - vv += prw_vec[0] - self.u[...] -= coeff * self._get_elementary_phase(geometry, uu, vv, a) - self.v[...] -= coeff * -self._get_elementary_phase(geometry, vv, uu, a) - # Calculate Fourier trafo of kernel components: - self.u_fft = fft.rfftn(self.u, self.dim_pad) - self.v_fft = fft.rfftn(self.v, self.dim_pad) - self._log.debug('Created ' + str(self)) - - def __repr__(self): - self._log.debug('Calling __repr__') - return '%s(a=%r, dim_uv=%r, b_0=%r, prw_vec=%r, geometry=%r)' % \ - (self.__class__, self.a, self.dim_uv, self.b_0, self.prw_vec, self.geometry) - - def __str__(self): - self._log.debug('Calling __str__') - return 'Kernel(a=%s, dim_uv=%s, b_0=%s, prw_vec=%s, geometry=%s)' % \ - (self.a, self.dim_uv, self.b_0, self.prw_vec, self.geometry) - - def _get_elementary_phase(self, geometry, n, m, a): - self._log.debug('Calling _get_elementary_phase') - if geometry == 'disc': - in_or_out = ~ np.logical_and(n == 0, m == 0) - return m / (n ** 2 + m ** 2 + 1E-30) * in_or_out - elif geometry == 'slab': - def _F_a(n, m): - A = np.log(a ** 2 * (n ** 2 + m ** 2)) - B = np.arctan(n / m) - return n * A - 2 * n + 2 * m * B - - return 0.5 * (_F_a(n - 0.5, m - 0.5) - _F_a(n + 0.5, m - 0.5) - - _F_a(n - 0.5, m + 0.5) + _F_a(n + 0.5, m + 0.5)) - - def print_info(self): - """Print information about the kernel. - - Returns - ------- - None - - """ - self._log.debug('Calling log_info') - print('Shape of the FOV :', self.dim_uv) - print('Shape of the Kernel :', self.dim_kern) - print('Zero-padded shape :', self.dim_pad) - print('Shape of the FFT :', self.dim_fft) - print('Slice for the phase :', self.slice_phase) - print('Slice for the magn. :', self.slice_mag) - print('Saturation Induction:', self.b_0) - print('Grid spacing : {} nm'.format(self.a)) - print('Geometry :', self.geometry) - print('PRW vector : {} T'.format(self.prw_vec)) +# -*- coding: utf-8 -*- +# Copyright 2014 by Forschungszentrum Juelich GmbH +# Author: J. Caron +# +"""This module provides the :class:`~.Kernel` class, representing the phase contribution of one +single magnetized pixel.""" + +import logging + +import numpy as np + +from jutil import fft + +__all__ = ['Kernel', 'PHI_0'] + +PHI_0 = 2067.83 # magnetic flux in T*nm² + + +class Kernel(object): + """Class for calculating kernel matrices for the phase calculation. + + Represents the phase of a single magnetized pixel for two orthogonal directions (`u` and `v`), + which can be accessed via the corresponding attributes. The default elementary geometry is + `disc`, but can also be specified as the phase of a `slab` representation of a single + magnetized pixel. During the construction, a few attributes are calculated that are used in + the convolution during phase calculation in the different :class:`~Phasemapper` classes. + An instance of the :class:`~.Kernel` class can be called as a function with a `vector`, + which represents the projected magnetization onto a 2-dimensional grid. + + Attributes + ---------- + a : float + The grid spacing in nm. + dim_uv : tuple of int (N=2), optional + Dimensions of the 2-dimensional projected magnetization grid from which the phase should + be calculated. + dim_kern : tuple of int (N=2) + Dimensions of the kernel, which is ``2N-1`` for both axes compared to `dim_uv`. + dim_pad : tuple of int (N=2) + Dimensions of the padded FOV, which is ``2N`` (if FFTW is used) or the next highest power + of 2 (for numpy-FFT). + dim_fft : tuple of int (N=2) + Dimensions of the grid, which is used for the FFT, taking into account that a RFFT should + be used (one axis is halved in comparison to `dim_pad`). + b_0 : float, optional + Saturation magnetization in Tesla, which is used for the phase calculation. Default is 1. + geometry : {'disc', 'slab'}, optional + The elementary geometry of the single magnetized pixel. + u : :class:`~numpy.ndarray` (N=3) + The phase contribution of one pixel magnetized in u-direction. + v : :class:`~numpy.ndarray` (N=3) + The phase contribution of one pixel magnetized in v-direction. + u_fft : :class:`~numpy.ndarray` (N=3) + The real FFT of the phase contribution of one pixel magnetized in u-direction. + v_fft : :class:`~numpy.ndarray` (N=3) + The real FFT of the phase contribution of one pixel magnetized in v-direction. + slice_phase : tuple (N=2) of :class:`slice` + A tuple of :class:`slice` objects to extract the original FOV from the increased one with + size `dim_pad` for the elementary kernel phase. The kernel is shifted, thus the center is + not at (0, 0), which also shifts the slicing compared to `slice_mag`. + slice_mag : tuple (N=2) of :class:`slice` + A tuple of :class:`slice` objects to extract the original FOV from the increased one with + size `dim_pad` for the projected magnetization distribution. + prw_vec: tuple of 2 int, optional + A two-component vector describing the displacement of the reference wave to include + perturbation of this reference by the object itself (via fringing fields), (y, x). + dtype: numpy dtype, optional + Data type of the kernel. Default is np.float32. + + """ + + _log = logging.getLogger(__name__ + '.Kernel') + + def __init__(self, a, dim_uv, b_0=1., prw_vec=None, geometry='disc', dtype=np.float32): + self._log.debug('Calling __init__') + # Set basic properties: + self.b_0 = b_0 + self.prw_vec = prw_vec + self.dim_uv = dim_uv # Dimensions of the FOV + self.dim_kern = tuple(2 * np.array(dim_uv) - 1) # Dimensions of the kernel + self.a = a + self.geometry = geometry + # Set up FFT: + if fft.HAVE_FFTW: + self.dim_pad = tuple(2 * np.array(dim_uv)) # is at least even (not nec. power of 2) + else: + self.dim_pad = tuple(2 ** np.ceil(np.log2(2 * np.array(dim_uv))).astype(int)) # pow(2) + self.dim_fft = (self.dim_pad[0], self.dim_pad[1] // 2 + 1) # last axis is real + self.slice_phase = (slice(dim_uv[0] - 1, self.dim_kern[0]), # Shift because kernel center + slice(dim_uv[1] - 1, self.dim_kern[1])) # is not at (0, 0)! + self.slice_mag = (slice(0, dim_uv[0]), # Magnetization is padded on the far end! + slice(0, dim_uv[1])) # (Phase cutout is shifted as listed above) + # Calculate kernel (single pixel phase): + # [M_0] = A/m --> This is the magnetization, not the magnetic moment (A/m * m³ = Am²)! + # [PHI_0 / µ_0] = Tm² / Tm/A = Am + # [b_0] = [M_0] * [µ_0] = A/m * N/A² = N/Am = T + # [coeff] = [b_0 * a² / (2*PHI_0)] = T * m² / Tm² = 1 --> without unit (phase)! + coeff = b_0 * a ** 2 / (2 * PHI_0) # Minus is gone because of negative z-direction + v_dim, u_dim = dim_uv + u = np.linspace(-(u_dim - 1), u_dim - 1, num=2 * u_dim - 1) + v = np.linspace(-(v_dim - 1), v_dim - 1, num=2 * v_dim - 1) + uu, vv = np.meshgrid(u, v) + self.u = np.empty(self.dim_kern, dtype=dtype) + self.v = np.empty(self.dim_kern, dtype=dtype) + self.u[...] = coeff * self._get_elementary_phase(geometry, uu, vv, a) + self.v[...] = coeff * -self._get_elementary_phase(geometry, vv, uu, a) + # Include perturbed reference wave: + if prw_vec is not None: + uu += prw_vec[1] + vv += prw_vec[0] + self.u[...] -= coeff * self._get_elementary_phase(geometry, uu, vv, a) + self.v[...] -= coeff * -self._get_elementary_phase(geometry, vv, uu, a) + # Calculate Fourier trafo of kernel components: + self.u_fft = fft.rfftn(self.u, self.dim_pad) + self.v_fft = fft.rfftn(self.v, self.dim_pad) + self._log.debug('Created ' + str(self)) + + def __repr__(self): + self._log.debug('Calling __repr__') + return '%s(a=%r, dim_uv=%r, b_0=%r, prw_vec=%r, geometry=%r)' % \ + (self.__class__, self.a, self.dim_uv, self.b_0, self.prw_vec, self.geometry) + + def __str__(self): + self._log.debug('Calling __str__') + return 'Kernel(a=%s, dim_uv=%s, b_0=%s, prw_vec=%s, geometry=%s)' % \ + (self.a, self.dim_uv, self.b_0, self.prw_vec, self.geometry) + + def _get_elementary_phase(self, geometry, n, m, a): + self._log.debug('Calling _get_elementary_phase') + if geometry == 'disc': + in_or_out = ~ np.logical_and(n == 0, m == 0) + return m / (n ** 2 + m ** 2 + 1E-30) * in_or_out + elif geometry == 'slab': + def _F_a(n, m): + A = np.log(a ** 2 * (n ** 2 + m ** 2)) + B = np.arctan(n / m) + return n * A - 2 * n + 2 * m * B + + return 0.5 * (_F_a(n - 0.5, m - 0.5) - _F_a(n + 0.5, m - 0.5) - + _F_a(n - 0.5, m + 0.5) + _F_a(n + 0.5, m + 0.5)) + + def print_info(self): + """Print information about the kernel. + + Returns + ------- + None + + """ + self._log.debug('Calling log_info') + print('Shape of the FOV :', self.dim_uv) + print('Shape of the Kernel :', self.dim_kern) + print('Zero-padded shape :', self.dim_pad) + print('Shape of the FFT :', self.dim_fft) + print('Slice for the phase :', self.slice_phase) + print('Slice for the magn. :', self.slice_mag) + print('Saturation Induction:', self.b_0) + print('Grid spacing : {} nm'.format(self.a)) + print('Geometry :', self.geometry) + print('PRW vector : {} T'.format(self.prw_vec)) diff --git a/pyramid/magcreator/__init__.py b/pyramid/magcreator/__init__.py index b9f3c3cc9c7fa605721132d7ec8cb85f119d439d..a5d686e93c4fe46e3a1170675f1caba32bef9aea 100644 --- a/pyramid/magcreator/__init__.py +++ b/pyramid/magcreator/__init__.py @@ -1,12 +1,12 @@ -# -*- coding: utf-8 -*- -# Copyright 2016 by Forschungszentrum Juelich GmbH -# Author: J. Caron -# -"""Subpackage containing functionality for creating magnetic distributions.""" - -from . import shapes -from . import examples -from .magcreator import * - -__all__ = ['shapes', 'examples'] -__all__.extend(magcreator.__all__) +# -*- coding: utf-8 -*- +# Copyright 2016 by Forschungszentrum Juelich GmbH +# Author: J. Caron +# +"""Subpackage containing functionality for creating magnetic distributions.""" + +from . import shapes +from . import examples +from .magcreator import * + +__all__ = ['shapes', 'examples'] +__all__.extend(magcreator.__all__) diff --git a/pyramid/magcreator/examples.py b/pyramid/magcreator/examples.py index eded78d764c5cbb92aef2445bc4a0917e3fcadc2..68cf9517775049f3342985b5dea85759a1c050c7 100644 --- a/pyramid/magcreator/examples.py +++ b/pyramid/magcreator/examples.py @@ -1,305 +1,305 @@ -# -*- coding: utf-8 -*- -# Copyright 2016 by Forschungszentrum Juelich GmbH -# Author: J. Caron -# -"""Provide simple examples for magnetic distributions.""" - -import logging - -import numpy as np - -import random as rnd - -from . import magcreator as mc -from . import shapes -from ..fielddata import VectorData - - -__all__ = ['pyramid_logo', 'singularity', 'homog_pixel', 'homog_slab', 'homog_disc', - 'homog_sphere', 'homog_filament', 'homog_alternating_filament', - 'homog_array_sphere_disc_slab', 'homog_random_pixels', 'homog_random_slabs', - 'vortex_slab', 'vortex_disc', 'vortex_alternating_discs', 'vortex_sphere', - 'vortex_horseshoe', 'smooth_vortex_disc', 'source_disc', - 'core_shell_disc', 'core_shell_sphere'] -_log = logging.getLogger(__name__) - - -def pyramid_logo(a=1., dim=(1, 256, 256), phi=np.pi / 2, theta=np.pi / 2): - """Create pyramid logo.""" - _log.debug('Calling pyramid_logo') - mag_shape = np.zeros(dim) - x = range(dim[2]) - y = range(dim[1]) - xx, yy = np.meshgrid(x, y) - bottom = (yy >= 0.25 * dim[1]) - left = (yy <= 0.75 / 0.5 * dim[1] / dim[2] * xx) - right = np.fliplr(left) - mag_shape[0, ...] = np.logical_and(np.logical_and(left, right), bottom) - return VectorData(a, mc.create_mag_dist_homog(mag_shape, phi, theta)) - - -def singularity(a=1., dim=(8, 8, 8), center=None): - """Create magnetic singularity.""" - _log.debug('Calling singularity') - if center is None: - center = (dim[0] // 2, dim[1] // 2, dim[2] // 2) - x = np.linspace(-center[2], dim[2] - 1 - center[2], dim[2]) + 0.5 # pixel center! - y = np.linspace(-center[1], dim[1] - 1 - center[1], dim[1]) + 0.5 # pixel center! - z = np.linspace(-center[0], dim[0] - 1 - center[0], dim[0]) + 0.5 # pixel center! - yy, zz, xx = np.meshgrid(x, y, z) # What's up with this strange order??? - magnitude = np.array((xx, yy, zz)).astype(float) - magnitude /= np.sqrt((magnitude ** 2 + 1E-30).sum(axis=0)) # Normalise! - return VectorData(a, magnitude) - - -def homog_pixel(a=1., dim=(1, 9, 9), pixel=None, phi=np.pi/4, theta=np.pi/2): - """Create single magnetised slab.""" - _log.debug('Calling homog_pixel') - if pixel is None: - pixel = (dim[0] // 2, dim[1] // 2, dim[2] // 2) - mag_shape = shapes.pixel(dim, pixel) - return VectorData(a, mc.create_mag_dist_homog(mag_shape, phi, theta)) - - -def homog_slab(a=1., dim=(32, 32, 32), center=None, width=None, phi=np.pi/4, theta=np.pi/4): - """Create homogeneous slab magnetisation distribution.""" - _log.debug('Calling homog_slab') - if center is None: - center = (dim[0] // 2, dim[1] // 2, dim[2] // 2) - if width is None: - width = (np.max((dim[0] // 8, 1)), np.max((dim[1] // 2, 1)), np.max((dim[2] // 4, 1))) - mag_shape = shapes.slab(dim, center, width) - return VectorData(a, mc.create_mag_dist_homog(mag_shape, phi, theta)) - - -def homog_disc(a=1., dim=(32, 32, 32), center=None, radius=None, height=None, - phi=np.pi / 4, theta=np.pi / 4): - """Create homogeneous disc magnetisation distribution.""" - _log.debug('Calling homog_disc') - if center is None: - center = (dim[0] // 2, dim[1] // 2, dim[2] // 2) - if radius is None: - radius = dim[2] // 4 - if height is None: - height = np.max((dim[0] // 2, 1)) - mag_shape = shapes.disc(dim, center, radius, height) - return VectorData(a, mc.create_mag_dist_homog(mag_shape, phi, theta)) - - -def homog_sphere(a=1., dim=(32, 32, 32), center=None, radius=None, phi=np.pi/4, theta=np.pi/4): - """Create homogeneous sphere magnetisation distribution.""" - _log.debug('Calling homog_sphere') - if center is None: - center = (dim[0] // 2, dim[1] // 2, dim[2] // 2) - if radius is None: - radius = dim[2] // 4 - mag_shape = shapes.sphere(dim, center, radius) - return VectorData(a, mc.create_mag_dist_homog(mag_shape, phi, theta)) - - -def homog_filament(a=1., dim=(1, 21, 21), pos=None, phi=np.pi / 2, theta=np.pi/2): - """Create magnetisation distribution of a single magnetised filaments.""" - _log.debug('Calling homog_filament') - if pos is None: - pos = (0, dim[1] // 2) - mag_shape = shapes.filament(dim, pos) - return VectorData(a, mc.create_mag_dist_homog(mag_shape, phi, theta)) - - -def homog_alternating_filament(a=1., dim=(1, 21, 21), spacing=5, phi=np.pi/2, theta=np.pi/2): - """Create magnetisation distribution of alternating filaments.""" - _log.debug('Calling homog_alternating_filament') - count = int((dim[1] - 1) / spacing) + 1 - magdata = VectorData(a, np.zeros((3,) + dim)) - for i in range(count): - pos = i * spacing - mag_shape = shapes.filament(dim, (0, pos)) - magdata += VectorData(a, mc.create_mag_dist_homog(mag_shape, phi, theta)) - phi *= -1 # Switch the angle - return magdata - - -def homog_array_sphere_disc_slab(a=1., dim=(64, 128, 128), center_sp=(32, 96, 64), radius_sp=24, - center_di=(32, 32, 96), radius_di=24, height_di=24, - center_sl=(32, 32, 32), width_sl=(48, 48, 48)): - """Create array of several magnetisation distribution (sphere, disc and slab).""" - _log.debug('Calling homog_array_sphere_disc_slab') - mag_shape_sphere = shapes.sphere(dim, center_sp, radius_sp) - mag_shape_disc = shapes.disc(dim, center_di, radius_di, height_di) - mag_shape_slab = shapes.slab(dim, center_sl, width_sl) - magdata = VectorData(a, mc.create_mag_dist_homog(mag_shape_sphere, np.pi)) - magdata += VectorData(a, mc.create_mag_dist_homog(mag_shape_disc, np.pi / 2)) - magdata += VectorData(a, mc.create_mag_dist_homog(mag_shape_slab, np.pi / 4)) - return magdata - - -def homog_random_pixels(a=1., dim=(1, 64, 64), count=10, rnd_seed=24): - """Create random magnetised pixels.""" - _log.debug('Calling homog_random_pixels') - rnd.seed(rnd_seed) - magdata = VectorData(a, np.zeros((3,) + dim)) - for i in range(count): - pixel = (rnd.randrange(dim[0]), rnd.randrange(dim[1]), rnd.randrange(dim[2])) - mag_shape = shapes.pixel(dim, pixel) - phi = 2 * np.pi * rnd.random() - magdata += VectorData(a, mc.create_mag_dist_homog(mag_shape, phi)) - return magdata - - -def homog_random_slabs(a=1., dim=(1, 64, 64), count=10, width_max=5, rnd_seed=2): - """Create random magnetised slabs.""" - _log.debug('Create homog_random_slabs') - rnd.seed(rnd_seed) - magdata = VectorData(a, np.zeros((3,) + dim)) - for i in range(count): - width = (1, rnd.randint(1, width_max), rnd.randint(1, width_max)) - center = (rnd.randrange(int(width[0] / 2), dim[0] - int(width[0] / 2)), - rnd.randrange(int(width[1] / 2), dim[1] - int(width[1] / 2)), - rnd.randrange(int(width[2] / 2), dim[2] - int(width[2] / 2))) - mag_shape = shapes.slab(dim, center, width) - phi = 2 * np.pi * rnd.random() - magdata += VectorData(a, mc.create_mag_dist_homog(mag_shape, phi)) - return magdata - - -def vortex_slab(a=1., dim=(32, 32, 32), center=None, width=None, axis='z'): - """Create vortex slab magnetisation distribution.""" - _log.debug('Calling vortex_slab') - if center is None: - center = (dim[0] // 2, dim[1] // 2, dim[2] // 2) - if width is None: - width = (np.max((dim[0] // 2, 1)), np.max((dim[1] // 2, 1)), np.max((dim[2] // 2, 1))) - mag_shape = shapes.slab(dim, center, width) - magnitude = mc.create_mag_dist_vortex(mag_shape, center, axis) - return VectorData(a, magnitude) - - -def vortex_disc(a=1., dim=(32, 32, 32), center=None, radius=None, height=None, axis='z'): - """Create vortex disc magnetisation distribution.""" - _log.debug('Calling vortex_disc') - if center is None: - center = (dim[0] // 2, dim[1] // 2, dim[2] // 2) - if radius is None: - radius = dim[2] // 4 - if height is None: - height = np.max((dim[0] // 2, 1)) - mag_shape = shapes.disc(dim, center, radius, height, axis) - magnitude = mc.create_mag_dist_vortex(mag_shape, center, axis) - return VectorData(a, magnitude) - - -def vortex_alternating_discs(a=1., dim=(80, 32, 32), count=8): - """Create pillar of alternating vortex disc magnetisation distributions.""" - _log.debug('Calling vortex_alternating_discs') - segment_height = dim[0] // (count + 2) - magdata = VectorData(a, np.zeros((3,) + dim)) - for i in range(count): - axis = 'z' if i % 2 == 0 else '-z' - center = (segment_height * (i + 1 + 0.5), dim[1] // 2, dim[2] // 2) - radius = dim[2] // 4 - height = segment_height - mag_shape = shapes.disc(dim, center=center, radius=radius, height=height) - mag_amp = mc.create_mag_dist_vortex(mag_shape=mag_shape, center=center, axis=axis) - magdata += VectorData(1, mag_amp) - return magdata - - -def vortex_sphere(a=1., dim=(32, 32, 32), center=None, radius=None, axis='z'): - """Create vortex sphere magnetisation distribution.""" - _log.debug('Calling vortex_sphere') - if center is None: - center = (dim[0] // 2, dim[1] // 2, dim[2] // 2) - if radius is None: - radius = dim[2] // 4 - mag_shape = shapes.sphere(dim, center, radius) - magnitude = mc.create_mag_dist_vortex(mag_shape, center, axis) - return VectorData(a, magnitude) - - -def vortex_horseshoe(a=1., dim=(16, 64, 64), center=None, radius_core=None, - radius_shell=None, height=None): - """Create magnetic horseshoe vortex magnetisation distribution.""" - _log.debug('Calling vortex_horseshoe') - if center is None: - center = (dim[0] // 2, dim[1] // 2, dim[2] // 2) - if radius_core is None: - radius_core = dim[1] // 8 - if radius_shell is None: - radius_shell = dim[1] // 4 - if height is None: - height = np.max((dim[0] // 2, 1)) - mag_shape_core = shapes.disc(dim, center, radius_core, height) - mag_shape_outer = shapes.disc(dim, center, radius_shell, height) - mag_shape_horseshoe = np.logical_xor(mag_shape_outer, mag_shape_core) - mag_shape_horseshoe[:, dim[1] // 2:, :] = False - return VectorData(a, mc.create_mag_dist_vortex(mag_shape_horseshoe)) - - -def smooth_vortex_disc(a=1., dim=(32, 32, 32), center=None, radius=None, height=None, axis='z', - vortex_radius=None): - """Create smooth vortex disc magnetisation distribution.""" - _log.debug('Calling vortex_disc') - if center is None: - center = (dim[0] // 2, dim[1] // 2, dim[2] // 2) - if radius is None: - radius = dim[2] // 4 - if height is None: - height = np.max((dim[0] // 2, 1)) - if vortex_radius is None: - vortex_radius = radius // 2 - mag_shape = shapes.disc(dim, center, radius, height, axis) - magnitude = mc.create_mag_dist_smooth_vortex(mag_shape, center, vortex_radius, axis) - return VectorData(a, magnitude) - - -def source_disc(a=1., dim=(32, 32, 32), center=None, radius=None, height=None, axis='z'): - """Create source disc magnetisation distribution.""" - _log.debug('Calling vortex_disc') - if center is None: - center = (dim[0] // 2, dim[1] // 2, dim[2] // 2) - if radius is None: - radius = dim[2] // 4 - if height is None: - height = np.max((dim[0] // 2, 1)) - mag_shape = shapes.disc(dim, center, radius, height, axis) - magnitude = mc.create_mag_dist_source(mag_shape, center, axis) - return VectorData(a, magnitude) - - -def core_shell_disc(a=1., dim=(32, 32, 32), center=None, radius_core=None, - radius_shell=None, height=None, rate_core_to_shell=0.75): - """Create magnetic core shell disc magnetisation distribution.""" - _log.debug('Calling core_shell_disc') - if center is None: - center = (dim[0] // 2, dim[1] // 2, dim[2] // 2) - if radius_core is None: - radius_core = dim[1] // 8 - if radius_shell is None: - radius_shell = dim[1] // 4 - if height is None: - height = np.max((dim[0] // 2, 1)) - mag_shape_core = shapes.disc(dim, center, radius_core, height) - mag_shape_outer = shapes.disc(dim, center, radius_shell, height) - mag_shape_shell = np.logical_xor(mag_shape_outer, mag_shape_core) - magdata = VectorData(a, mc.create_mag_dist_vortex(mag_shape_shell)) * rate_core_to_shell - magdata += VectorData(a, mc.create_mag_dist_homog(mag_shape_core, phi=0, theta=0)) - return magdata - - -def core_shell_sphere(a=1., dim=(32, 32, 32), center=None, radius_core=None, - radius_shell=None, rate_core_to_shell=0.75): - """Create magnetic core shell sphere magnetisation distribution.""" - _log.debug('Calling core_shell_sphere') - if center is None: - center = (dim[0] // 2, dim[1] // 2, dim[2] // 2) - if radius_core is None: - radius_core = dim[1] // 8 - if radius_shell is None: - radius_shell = dim[1] // 4 - mag_shape_sphere = shapes.sphere(dim, center, radius_shell) - mag_shape_disc = shapes.disc(dim, center, radius_core, height=dim[0]) - mag_shape_core = np.logical_and(mag_shape_sphere, mag_shape_disc) - mag_shape_shell = np.logical_and(mag_shape_sphere, np.logical_not(mag_shape_core)) - magdata = VectorData(a, mc.create_mag_dist_vortex(mag_shape_shell)) * rate_core_to_shell - magdata += VectorData(a, mc.create_mag_dist_homog(mag_shape_core, phi=0, theta=0)) - return magdata +# -*- coding: utf-8 -*- +# Copyright 2016 by Forschungszentrum Juelich GmbH +# Author: J. Caron +# +"""Provide simple examples for magnetic distributions.""" + +import logging + +import numpy as np + +import random as rnd + +from . import magcreator as mc +from . import shapes +from ..fielddata import VectorData + + +__all__ = ['pyramid_logo', 'singularity', 'homog_pixel', 'homog_slab', 'homog_disc', + 'homog_sphere', 'homog_filament', 'homog_alternating_filament', + 'homog_array_sphere_disc_slab', 'homog_random_pixels', 'homog_random_slabs', + 'vortex_slab', 'vortex_disc', 'vortex_alternating_discs', 'vortex_sphere', + 'vortex_horseshoe', 'smooth_vortex_disc', 'source_disc', + 'core_shell_disc', 'core_shell_sphere'] +_log = logging.getLogger(__name__) + + +def pyramid_logo(a=1., dim=(1, 256, 256), phi=np.pi / 2, theta=np.pi / 2): + """Create pyramid logo.""" + _log.debug('Calling pyramid_logo') + mag_shape = np.zeros(dim) + x = range(dim[2]) + y = range(dim[1]) + xx, yy = np.meshgrid(x, y) + bottom = (yy >= 0.25 * dim[1]) + left = (yy <= 0.75 / 0.5 * dim[1] / dim[2] * xx) + right = np.fliplr(left) + mag_shape[0, ...] = np.logical_and(np.logical_and(left, right), bottom) + return VectorData(a, mc.create_mag_dist_homog(mag_shape, phi, theta)) + + +def singularity(a=1., dim=(8, 8, 8), center=None): + """Create magnetic singularity.""" + _log.debug('Calling singularity') + if center is None: + center = (dim[0] // 2, dim[1] // 2, dim[2] // 2) + x = np.linspace(-center[2], dim[2] - 1 - center[2], dim[2]) + 0.5 # pixel center! + y = np.linspace(-center[1], dim[1] - 1 - center[1], dim[1]) + 0.5 # pixel center! + z = np.linspace(-center[0], dim[0] - 1 - center[0], dim[0]) + 0.5 # pixel center! + yy, zz, xx = np.meshgrid(x, y, z) # What's up with this strange order??? + magnitude = np.array((xx, yy, zz)).astype(float) + magnitude /= np.sqrt((magnitude ** 2 + 1E-30).sum(axis=0)) # Normalise! + return VectorData(a, magnitude) + + +def homog_pixel(a=1., dim=(1, 9, 9), pixel=None, phi=np.pi/4, theta=np.pi/2): + """Create single magnetised slab.""" + _log.debug('Calling homog_pixel') + if pixel is None: + pixel = (dim[0] // 2, dim[1] // 2, dim[2] // 2) + mag_shape = shapes.pixel(dim, pixel) + return VectorData(a, mc.create_mag_dist_homog(mag_shape, phi, theta)) + + +def homog_slab(a=1., dim=(32, 32, 32), center=None, width=None, phi=np.pi/4, theta=np.pi/4): + """Create homogeneous slab magnetisation distribution.""" + _log.debug('Calling homog_slab') + if center is None: + center = (dim[0] // 2, dim[1] // 2, dim[2] // 2) + if width is None: + width = (np.max((dim[0] // 8, 1)), np.max((dim[1] // 2, 1)), np.max((dim[2] // 4, 1))) + mag_shape = shapes.slab(dim, center, width) + return VectorData(a, mc.create_mag_dist_homog(mag_shape, phi, theta)) + + +def homog_disc(a=1., dim=(32, 32, 32), center=None, radius=None, height=None, + phi=np.pi / 4, theta=np.pi / 4): + """Create homogeneous disc magnetisation distribution.""" + _log.debug('Calling homog_disc') + if center is None: + center = (dim[0] // 2, dim[1] // 2, dim[2] // 2) + if radius is None: + radius = dim[2] // 4 + if height is None: + height = np.max((dim[0] // 2, 1)) + mag_shape = shapes.disc(dim, center, radius, height) + return VectorData(a, mc.create_mag_dist_homog(mag_shape, phi, theta)) + + +def homog_sphere(a=1., dim=(32, 32, 32), center=None, radius=None, phi=np.pi/4, theta=np.pi/4): + """Create homogeneous sphere magnetisation distribution.""" + _log.debug('Calling homog_sphere') + if center is None: + center = (dim[0] // 2, dim[1] // 2, dim[2] // 2) + if radius is None: + radius = dim[2] // 4 + mag_shape = shapes.sphere(dim, center, radius) + return VectorData(a, mc.create_mag_dist_homog(mag_shape, phi, theta)) + + +def homog_filament(a=1., dim=(1, 21, 21), pos=None, phi=np.pi / 2, theta=np.pi/2): + """Create magnetisation distribution of a single magnetised filaments.""" + _log.debug('Calling homog_filament') + if pos is None: + pos = (0, dim[1] // 2) + mag_shape = shapes.filament(dim, pos) + return VectorData(a, mc.create_mag_dist_homog(mag_shape, phi, theta)) + + +def homog_alternating_filament(a=1., dim=(1, 21, 21), spacing=5, phi=np.pi/2, theta=np.pi/2): + """Create magnetisation distribution of alternating filaments.""" + _log.debug('Calling homog_alternating_filament') + count = int((dim[1] - 1) / spacing) + 1 + magdata = VectorData(a, np.zeros((3,) + dim)) + for i in range(count): + pos = i * spacing + mag_shape = shapes.filament(dim, (0, pos)) + magdata += VectorData(a, mc.create_mag_dist_homog(mag_shape, phi, theta)) + phi *= -1 # Switch the angle + return magdata + + +def homog_array_sphere_disc_slab(a=1., dim=(64, 128, 128), center_sp=(32, 96, 64), radius_sp=24, + center_di=(32, 32, 96), radius_di=24, height_di=24, + center_sl=(32, 32, 32), width_sl=(48, 48, 48)): + """Create array of several magnetisation distribution (sphere, disc and slab).""" + _log.debug('Calling homog_array_sphere_disc_slab') + mag_shape_sphere = shapes.sphere(dim, center_sp, radius_sp) + mag_shape_disc = shapes.disc(dim, center_di, radius_di, height_di) + mag_shape_slab = shapes.slab(dim, center_sl, width_sl) + magdata = VectorData(a, mc.create_mag_dist_homog(mag_shape_sphere, np.pi)) + magdata += VectorData(a, mc.create_mag_dist_homog(mag_shape_disc, np.pi / 2)) + magdata += VectorData(a, mc.create_mag_dist_homog(mag_shape_slab, np.pi / 4)) + return magdata + + +def homog_random_pixels(a=1., dim=(1, 64, 64), count=10, rnd_seed=24): + """Create random magnetised pixels.""" + _log.debug('Calling homog_random_pixels') + rnd.seed(rnd_seed) + magdata = VectorData(a, np.zeros((3,) + dim)) + for i in range(count): + pixel = (rnd.randrange(dim[0]), rnd.randrange(dim[1]), rnd.randrange(dim[2])) + mag_shape = shapes.pixel(dim, pixel) + phi = 2 * np.pi * rnd.random() + magdata += VectorData(a, mc.create_mag_dist_homog(mag_shape, phi)) + return magdata + + +def homog_random_slabs(a=1., dim=(1, 64, 64), count=10, width_max=5, rnd_seed=2): + """Create random magnetised slabs.""" + _log.debug('Create homog_random_slabs') + rnd.seed(rnd_seed) + magdata = VectorData(a, np.zeros((3,) + dim)) + for i in range(count): + width = (1, rnd.randint(1, width_max), rnd.randint(1, width_max)) + center = (rnd.randrange(int(width[0] / 2), dim[0] - int(width[0] / 2)), + rnd.randrange(int(width[1] / 2), dim[1] - int(width[1] / 2)), + rnd.randrange(int(width[2] / 2), dim[2] - int(width[2] / 2))) + mag_shape = shapes.slab(dim, center, width) + phi = 2 * np.pi * rnd.random() + magdata += VectorData(a, mc.create_mag_dist_homog(mag_shape, phi)) + return magdata + + +def vortex_slab(a=1., dim=(32, 32, 32), center=None, width=None, axis='z'): + """Create vortex slab magnetisation distribution.""" + _log.debug('Calling vortex_slab') + if center is None: + center = (dim[0] // 2, dim[1] // 2, dim[2] // 2) + if width is None: + width = (np.max((dim[0] // 2, 1)), np.max((dim[1] // 2, 1)), np.max((dim[2] // 2, 1))) + mag_shape = shapes.slab(dim, center, width) + magnitude = mc.create_mag_dist_vortex(mag_shape, center, axis) + return VectorData(a, magnitude) + + +def vortex_disc(a=1., dim=(32, 32, 32), center=None, radius=None, height=None, axis='z'): + """Create vortex disc magnetisation distribution.""" + _log.debug('Calling vortex_disc') + if center is None: + center = (dim[0] // 2, dim[1] // 2, dim[2] // 2) + if radius is None: + radius = dim[2] // 4 + if height is None: + height = np.max((dim[0] // 2, 1)) + mag_shape = shapes.disc(dim, center, radius, height, axis) + magnitude = mc.create_mag_dist_vortex(mag_shape, center, axis) + return VectorData(a, magnitude) + + +def vortex_alternating_discs(a=1., dim=(80, 32, 32), count=8): + """Create pillar of alternating vortex disc magnetisation distributions.""" + _log.debug('Calling vortex_alternating_discs') + segment_height = dim[0] // (count + 2) + magdata = VectorData(a, np.zeros((3,) + dim)) + for i in range(count): + axis = 'z' if i % 2 == 0 else '-z' + center = (segment_height * (i + 1 + 0.5), dim[1] // 2, dim[2] // 2) + radius = dim[2] // 4 + height = segment_height + mag_shape = shapes.disc(dim, center=center, radius=radius, height=height) + mag_amp = mc.create_mag_dist_vortex(mag_shape=mag_shape, center=center, axis=axis) + magdata += VectorData(1, mag_amp) + return magdata + + +def vortex_sphere(a=1., dim=(32, 32, 32), center=None, radius=None, axis='z'): + """Create vortex sphere magnetisation distribution.""" + _log.debug('Calling vortex_sphere') + if center is None: + center = (dim[0] // 2, dim[1] // 2, dim[2] // 2) + if radius is None: + radius = dim[2] // 4 + mag_shape = shapes.sphere(dim, center, radius) + magnitude = mc.create_mag_dist_vortex(mag_shape, center, axis) + return VectorData(a, magnitude) + + +def vortex_horseshoe(a=1., dim=(16, 64, 64), center=None, radius_core=None, + radius_shell=None, height=None): + """Create magnetic horseshoe vortex magnetisation distribution.""" + _log.debug('Calling vortex_horseshoe') + if center is None: + center = (dim[0] // 2, dim[1] // 2, dim[2] // 2) + if radius_core is None: + radius_core = dim[1] // 8 + if radius_shell is None: + radius_shell = dim[1] // 4 + if height is None: + height = np.max((dim[0] // 2, 1)) + mag_shape_core = shapes.disc(dim, center, radius_core, height) + mag_shape_outer = shapes.disc(dim, center, radius_shell, height) + mag_shape_horseshoe = np.logical_xor(mag_shape_outer, mag_shape_core) + mag_shape_horseshoe[:, dim[1] // 2:, :] = False + return VectorData(a, mc.create_mag_dist_vortex(mag_shape_horseshoe)) + + +def smooth_vortex_disc(a=1., dim=(32, 32, 32), center=None, radius=None, height=None, axis='z', + vortex_radius=None): + """Create smooth vortex disc magnetisation distribution.""" + _log.debug('Calling vortex_disc') + if center is None: + center = (dim[0] // 2, dim[1] // 2, dim[2] // 2) + if radius is None: + radius = dim[2] // 4 + if height is None: + height = np.max((dim[0] // 2, 1)) + if vortex_radius is None: + vortex_radius = radius // 2 + mag_shape = shapes.disc(dim, center, radius, height, axis) + magnitude = mc.create_mag_dist_smooth_vortex(mag_shape, center, vortex_radius, axis) + return VectorData(a, magnitude) + + +def source_disc(a=1., dim=(32, 32, 32), center=None, radius=None, height=None, axis='z'): + """Create source disc magnetisation distribution.""" + _log.debug('Calling vortex_disc') + if center is None: + center = (dim[0] // 2, dim[1] // 2, dim[2] // 2) + if radius is None: + radius = dim[2] // 4 + if height is None: + height = np.max((dim[0] // 2, 1)) + mag_shape = shapes.disc(dim, center, radius, height, axis) + magnitude = mc.create_mag_dist_source(mag_shape, center, axis) + return VectorData(a, magnitude) + + +def core_shell_disc(a=1., dim=(32, 32, 32), center=None, radius_core=None, + radius_shell=None, height=None, rate_core_to_shell=0.75): + """Create magnetic core shell disc magnetisation distribution.""" + _log.debug('Calling core_shell_disc') + if center is None: + center = (dim[0] // 2, dim[1] // 2, dim[2] // 2) + if radius_core is None: + radius_core = dim[1] // 8 + if radius_shell is None: + radius_shell = dim[1] // 4 + if height is None: + height = np.max((dim[0] // 2, 1)) + mag_shape_core = shapes.disc(dim, center, radius_core, height) + mag_shape_outer = shapes.disc(dim, center, radius_shell, height) + mag_shape_shell = np.logical_xor(mag_shape_outer, mag_shape_core) + magdata = VectorData(a, mc.create_mag_dist_vortex(mag_shape_shell)) * rate_core_to_shell + magdata += VectorData(a, mc.create_mag_dist_homog(mag_shape_core, phi=0, theta=0)) + return magdata + + +def core_shell_sphere(a=1., dim=(32, 32, 32), center=None, radius_core=None, + radius_shell=None, rate_core_to_shell=0.75): + """Create magnetic core shell sphere magnetisation distribution.""" + _log.debug('Calling core_shell_sphere') + if center is None: + center = (dim[0] // 2, dim[1] // 2, dim[2] // 2) + if radius_core is None: + radius_core = dim[1] // 8 + if radius_shell is None: + radius_shell = dim[1] // 4 + mag_shape_sphere = shapes.sphere(dim, center, radius_shell) + mag_shape_disc = shapes.disc(dim, center, radius_core, height=dim[0]) + mag_shape_core = np.logical_and(mag_shape_sphere, mag_shape_disc) + mag_shape_shell = np.logical_and(mag_shape_sphere, np.logical_not(mag_shape_core)) + magdata = VectorData(a, mc.create_mag_dist_vortex(mag_shape_shell)) * rate_core_to_shell + magdata += VectorData(a, mc.create_mag_dist_homog(mag_shape_core, phi=0, theta=0)) + return magdata diff --git a/pyramid/magcreator/magcreator.py b/pyramid/magcreator/magcreator.py index ab868613b18128d520c589e9828f64425fa4895c..1168b34c23fdc02cdf64ea5dcb864f02d1594d5a 100644 --- a/pyramid/magcreator/magcreator.py +++ b/pyramid/magcreator/magcreator.py @@ -1,285 +1,285 @@ -# -*- coding: utf-8 -*- -# Copyright 2016 by Forschungszentrum Juelich GmbH -# Author: J. Caron -# -"""Create simple magnetic distributions. - -The :mod:`~.magcreator` module is responsible for the creation of simple distributions of -magnetic moments. In the :mod:`~.shapes` module, you can find several general shapes for the -3-dimensional volume that should be magnetized (e.g. slabs, spheres, discs or single pixels). -These shapes are then used as input for the creating functions (or you could specify the -volume yourself as a 3-dimensional boolean matrix or a matrix with values in the range from 0 to 1, -which modifies the magnetization amplitude). The specified volume can either be magnetized -homogeneously with the :func:`~.create_mag_dist_homog` function by specifying the magnetization -direction, or in a vortex state with the :func:`~.create_mag_dist_vortex` by specifying the vortex -center. - -""" - -import logging - -import numpy as np -from numpy import pi - -__all__ = ['create_mag_dist_homog', 'create_mag_dist_vortex', 'create_mag_dist_source', - 'create_mag_dist_smooth_vortex'] -_log = logging.getLogger(__name__) - - -def create_mag_dist_homog(mag_shape, phi, theta=pi / 2): - """Create a 3-dimensional magnetic distribution of a homogeneously magnetized object. - - Parameters - ---------- - mag_shape : :class:`~numpy.ndarray` (N=3) - The magnetic shapes (see :mod:`.~shapes` for examples). - phi : float - The azimuthal angle, describing the direction of the magnetized object. - theta : float, optional - The polar angle, describing the direction of the magnetized object. - The default is pi/2, which means, the z-component is zero. - - Returns - ------- - amplitude : tuple (N=3) of :class:`~numpy.ndarray` (N=3) - The magnetic distribution as a tuple of the 3 components in - `x`-, `y`- and `z`-direction on the 3-dimensional grid. - - """ - _log.debug('Calling create_mag_dist_homog') - dim = np.shape(mag_shape) - assert len(dim) == 3, 'Magnetic shapes must describe 3-dimensional distributions!' - z_mag = np.ones(dim) * np.cos(theta) * mag_shape - y_mag = np.ones(dim) * np.sin(theta) * np.sin(phi) * mag_shape - x_mag = np.ones(dim) * np.sin(theta) * np.cos(phi) * mag_shape - return np.array([x_mag, y_mag, z_mag]) - - -def create_mag_dist_vortex(mag_shape, center=None, axis='z'): - """Create a 3-dimensional magnetic distribution of a homogeneous magnetized object. - - Parameters - ---------- - mag_shape : :class:`~numpy.ndarray` (N=3) - The magnetic shapes (see :mod:`.~shapes`` for examples). - center : tuple (N=2 or N=3), optional - The vortex center, given in 2D `(v, u)` or 3D `(z, y, x)`, where the perpendicular axis - is is discarded. Is set to the center of the field of view if not specified. The vortex - center has to be between two pixels. - axis : {'z', '-z', 'y', '-y', 'x', '-x'}, optional - The orientation of the vortex axis. The default is 'z'. Negative values invert the vortex - orientation. - - Returns - ------- - amplitude : tuple (N=3) of :class:`~numpy.ndarray` (N=3) - The magnetic distribution as a tuple of the 3 components in - `x`-, `y`- and `z`-direction on the 3-dimensional grid. - - Notes - ----- - To avoid singularities, the vortex center should lie between the pixel centers (which - reside at coordinates with _.5 at the end), i.e. integer values should be used as center - coordinates (e.g. coordinate 1 lies between the first and the second pixel). - - """ - _log.debug('Calling create_mag_dist_vortex') - dim = mag_shape.shape - assert len(dim) == 3, 'Magnetic shapes must describe 3-dimensional distributions!' - assert center is None or len(center) in {2, 3}, \ - 'Vortex center has to be defined in 3D or 2D or not at all!' - if center is None: - center = (dim[1] / 2, dim[2] / 2) - sign = -1 if '-' in axis else 1 - if axis in ('z', '-z'): - if len(center) == 3: # if a 3D-center is given, just take the x and y components - center = (center[1], center[2]) - u = np.linspace(-center[1], dim[2] - 1 - center[1], dim[2]) + 0.5 # pixel center! - v = np.linspace(-center[0], dim[1] - 1 - center[0], dim[1]) + 0.5 # pixel center! - uu, vv = np.meshgrid(u, v) - phi = np.expand_dims(np.arctan2(vv, uu) - pi / 2, axis=0) - phi = np.tile(phi, (dim[0], 1, 1)) - z_mag = np.zeros(dim) - y_mag = np.ones(dim) * -np.sin(phi) * mag_shape * sign - x_mag = np.ones(dim) * -np.cos(phi) * mag_shape * sign - elif axis in ('y', '-y'): - if len(center) == 3: # if a 3D-center is given, just take the x and z components - center = (center[0], center[2]) - u = np.linspace(-center[1], dim[2] - 1 - center[1], dim[2]) + 0.5 # pixel center! - v = np.linspace(-center[0], dim[0] - 1 - center[0], dim[0]) + 0.5 # pixel center! - uu, vv = np.meshgrid(u, v) - phi = np.expand_dims(np.arctan2(vv, uu) - pi / 2, axis=1) - phi = np.tile(phi, (1, dim[1], 1)) - z_mag = np.ones(dim) * np.sin(phi) * mag_shape * sign - y_mag = np.zeros(dim) - x_mag = np.ones(dim) * np.cos(phi) * mag_shape * sign - elif axis in ('x', '-x'): - if len(center) == 3: # if a 3D-center is given, just take the z and y components - center = (center[0], center[1]) - u = np.linspace(-center[1], dim[1] - 1 - center[1], dim[1]) + 0.5 # pixel center! - v = np.linspace(-center[0], dim[0] - 1 - center[0], dim[0]) + 0.5 # pixel center! - uu, vv = np.meshgrid(u, v) - phi = np.expand_dims(np.arctan2(vv, uu) - pi / 2, axis=2) - phi = np.tile(phi, (1, 1, dim[2])) - z_mag = np.ones(dim) * -np.sin(phi) * mag_shape * sign - y_mag = np.ones(dim) * -np.cos(phi) * mag_shape * sign - x_mag = np.zeros(dim) - else: - raise ValueError('{} is not a valid argument (use x, -x, y, -y, z or -z)'.format(axis)) - return np.array([x_mag, y_mag, z_mag]) - - -def create_mag_dist_smooth_vortex(mag_shape, center=None, vort_r=None, axis='z'): - """Create a 3-dimensional magnetic distribution of a homogeneous magnetized object. - - Parameters - ---------- - mag_shape : :class:`~numpy.ndarray` (N=3) - The magnetic shapes (see :mod:`.~shapes`` for examples). - center : tuple (N=2 or N=3), optional - The vortex center, given in 2D `(v, u)` or 3D `(z, y, x)`, where the perpendicular axis - is is discarded. Is set to the center of the field of view if not specified. The vortex - center has to be between two pixels. - axis : {'z', '-z', 'y', '-y', 'x', '-x'}, optional - The orientation of the vortex axis. The default is 'z'. Negative values invert the vortex - orientation. - - Returns - ------- - amplitude : tuple (N=3) of :class:`~numpy.ndarray` (N=3) - The magnetic distribution as a tuple of the 3 components in - `x`-, `y`- and `z`-direction on the 3-dimensional grid. - - Notes - ----- - To avoid singularities, the vortex center should lie between the pixel centers (which - reside at coordinates with _.5 at the end), i.e. integer values should be used as center - coordinates (e.g. coordinate 1 lies between the first and the second pixel). - - """ - - def core(r): - """Function describing the smooth vortex core.""" - return 1 - 2/np.pi * np.arcsin(np.tanh(np.pi*r/vort_r)) - - _log.debug('Calling create_mag_dist_vortex') - dim = mag_shape.shape - assert len(dim) == 3, 'Magnetic shapes must describe 3-dimensional distributions!' - assert center is None or len(center) in {2, 3}, \ - 'Vortex center has to be defined in 3D or 2D or not at all!' - if center is None: - center = (dim[1] / 2, dim[2] / 2) - sign = -1 if '-' in axis else 1 - if axis in ('z', '-z'): - if len(center) == 3: # if a 3D-center is given, just take the x and y components - center = (center[1], center[2]) - u = np.linspace(-center[1], dim[2] - 1 - center[1], dim[2]) + 0.5 # pixel center! - v = np.linspace(-center[0], dim[1] - 1 - center[0], dim[1]) + 0.5 # pixel center! - uu, vv = np.meshgrid(u, v) - rr = np.hypot(uu, vv)[None, :, :] - phi = np.expand_dims(np.arctan2(vv, uu) - pi / 2, axis=0) - phi = np.tile(phi, (dim[0], 1, 1)) - z_mag = np.ones(dim) * mag_shape * sign * core(rr) - y_mag = np.ones(dim) * -np.sin(phi) * mag_shape * sign * np.sqrt(1 - core(rr)) - x_mag = np.ones(dim) * -np.cos(phi) * mag_shape * sign * np.sqrt(1 - core(rr)) - elif axis in ('y', '-y'): - if len(center) == 3: # if a 3D-center is given, just take the x and z components - center = (center[0], center[2]) - u = np.linspace(-center[1], dim[2] - 1 - center[1], dim[2]) + 0.5 # pixel center! - v = np.linspace(-center[0], dim[0] - 1 - center[0], dim[0]) + 0.5 # pixel center! - uu, vv = np.meshgrid(u, v) - rr = np.hypot(uu, vv)[:, None, :] - phi = np.expand_dims(np.arctan2(vv, uu) - pi / 2, axis=1) - phi = np.tile(phi, (1, dim[1], 1)) - z_mag = np.ones(dim) * np.sin(phi) * mag_shape * sign * np.sqrt(1 - core(rr)) - y_mag = np.ones(dim) * mag_shape * sign * core(rr) - x_mag = np.ones(dim) * np.cos(phi) * mag_shape * sign * np.sqrt(1 - core(rr)) - elif axis in ('x', '-x'): - if len(center) == 3: # if a 3D-center is given, just take the z and y components - center = (center[0], center[1]) - u = np.linspace(-center[1], dim[1] - 1 - center[1], dim[1]) + 0.5 # pixel center! - v = np.linspace(-center[0], dim[0] - 1 - center[0], dim[0]) + 0.5 # pixel center! - uu, vv = np.meshgrid(u, v) - rr = np.hypot(uu, vv)[:, :, None] - phi = np.expand_dims(np.arctan2(vv, uu) - pi / 2, axis=2) - phi = np.tile(phi, (1, 1, dim[2])) - z_mag = np.ones(dim) * -np.sin(phi) * mag_shape * sign * np.sqrt(1 - core(rr)) - y_mag = np.ones(dim) * -np.cos(phi) * mag_shape * sign * np.sqrt(1 - core(rr)) - x_mag = np.ones(dim) * mag_shape * sign * core(rr) - else: - raise ValueError('{} is not a valid argument (use x, -x, y, -y, z or -z)'.format(axis)) - return np.array([x_mag, y_mag, z_mag]) - - -def create_mag_dist_source(mag_shape, center=None, axis='z'): - """Create a 3-dimensional magnetic distribution of a homogeneous magnetized object. - - Parameters - ---------- - mag_shape : :class:`~numpy.ndarray` (N=3) - The magnetic shapes (see :mod:`.~shapes`` for examples). - center : tuple (N=2 or N=3), optional - The source center, given in 2D `(v, u)` or 3D `(z, y, x)`, where the perpendicular axis - is is discarded. Is set to the center of the field of view if not specified. - The source center has to be between two pixels. - axis : {'z', '-z', 'y', '-y', 'x', '-x'}, optional - The orientation of the source axis. The default is 'z'. Negative values invert the source - to a sink. - - Returns - ------- - amplitude : tuple (N=3) of :class:`~numpy.ndarray` (N=3) - The magnetic distribution as a tuple of the 3 components in - `x`-, `y`- and `z`-direction on the 3-dimensional grid. - - Notes - ----- - To avoid singularities, the source center should lie between the pixel centers (which - reside at coordinates with _.5 at the end), i.e. integer values should be used as center - coordinates (e.g. coordinate 1 lies between the first and the second pixel). - - """ - _log.debug('Calling create_mag_dist_vortex') - dim = mag_shape.shape - assert len(dim) == 3, 'Magnetic shapes must describe 3-dimensional distributions!' - assert center is None or len(center) in {2, 3}, \ - 'Vortex center has to be defined in 3D or 2D or not at all!' - if center is None: - center = (dim[1] / 2, dim[2] / 2) - sign = -1 if '-' in axis else 1 - if axis in ('z', '-z'): - if len(center) == 3: # if a 3D-center is given, just take the x and y components - center = (center[1], center[2]) - u = np.linspace(-center[1], dim[2] - 1 - center[1], dim[2]) + 0.5 # pixel center! - v = np.linspace(-center[0], dim[1] - 1 - center[0], dim[1]) + 0.5 # pixel center! - uu, vv = np.meshgrid(u, v) - phi = np.expand_dims(np.arctan2(vv, uu) - pi / 2, axis=0) - phi = np.tile(phi, (dim[0], 1, 1)) - z_mag = np.zeros(dim) - y_mag = np.ones(dim) * np.cos(phi) * mag_shape * sign - x_mag = np.ones(dim) * -np.sin(phi) * mag_shape * sign - elif axis in ('y', '-y'): - if len(center) == 3: # if a 3D-center is given, just take the x and z components - center = (center[0], center[2]) - u = np.linspace(-center[1], dim[2] - 1 - center[1], dim[2]) + 0.5 # pixel center! - v = np.linspace(-center[0], dim[0] - 1 - center[0], dim[0]) + 0.5 # pixel center! - uu, vv = np.meshgrid(u, v) - phi = np.expand_dims(np.arctan2(vv, uu) - pi / 2, axis=1) - phi = np.tile(phi, (1, dim[1], 1)) - z_mag = np.ones(dim) * np.cos(phi) * mag_shape * sign - y_mag = np.zeros(dim) - x_mag = np.ones(dim) * -np.sin(phi) * mag_shape * sign - elif axis in ('x', '-x'): - if len(center) == 3: # if a 3D-center is given, just take the z and y components - center = (center[0], center[1]) - u = np.linspace(-center[1], dim[1] - 1 - center[1], dim[1]) + 0.5 # pixel center! - v = np.linspace(-center[0], dim[0] - 1 - center[0], dim[0]) + 0.5 # pixel center! - uu, vv = np.meshgrid(u, v) - phi = np.expand_dims(np.arctan2(vv, uu) - pi / 2, axis=2) - phi = np.tile(phi, (1, 1, dim[2])) - z_mag = np.ones(dim) * np.cos(phi) * mag_shape * sign - y_mag = np.ones(dim) * -np.sin(phi) * mag_shape * sign - x_mag = np.zeros(dim) - else: - raise ValueError('{} is not a valid argument (use x, -x, y, -y, z or -z)'.format(axis)) - return np.array([x_mag, y_mag, z_mag]) +# -*- coding: utf-8 -*- +# Copyright 2016 by Forschungszentrum Juelich GmbH +# Author: J. Caron +# +"""Create simple magnetic distributions. + +The :mod:`~.magcreator` module is responsible for the creation of simple distributions of +magnetic moments. In the :mod:`~.shapes` module, you can find several general shapes for the +3-dimensional volume that should be magnetized (e.g. slabs, spheres, discs or single pixels). +These shapes are then used as input for the creating functions (or you could specify the +volume yourself as a 3-dimensional boolean matrix or a matrix with values in the range from 0 to 1, +which modifies the magnetization amplitude). The specified volume can either be magnetized +homogeneously with the :func:`~.create_mag_dist_homog` function by specifying the magnetization +direction, or in a vortex state with the :func:`~.create_mag_dist_vortex` by specifying the vortex +center. + +""" + +import logging + +import numpy as np +from numpy import pi + +__all__ = ['create_mag_dist_homog', 'create_mag_dist_vortex', 'create_mag_dist_source', + 'create_mag_dist_smooth_vortex'] +_log = logging.getLogger(__name__) + + +def create_mag_dist_homog(mag_shape, phi, theta=pi / 2): + """Create a 3-dimensional magnetic distribution of a homogeneously magnetized object. + + Parameters + ---------- + mag_shape : :class:`~numpy.ndarray` (N=3) + The magnetic shapes (see :mod:`.~shapes` for examples). + phi : float + The azimuthal angle, describing the direction of the magnetized object. + theta : float, optional + The polar angle, describing the direction of the magnetized object. + The default is pi/2, which means, the z-component is zero. + + Returns + ------- + amplitude : tuple (N=3) of :class:`~numpy.ndarray` (N=3) + The magnetic distribution as a tuple of the 3 components in + `x`-, `y`- and `z`-direction on the 3-dimensional grid. + + """ + _log.debug('Calling create_mag_dist_homog') + dim = np.shape(mag_shape) + assert len(dim) == 3, 'Magnetic shapes must describe 3-dimensional distributions!' + z_mag = np.ones(dim) * np.cos(theta) * mag_shape + y_mag = np.ones(dim) * np.sin(theta) * np.sin(phi) * mag_shape + x_mag = np.ones(dim) * np.sin(theta) * np.cos(phi) * mag_shape + return np.array([x_mag, y_mag, z_mag]) + + +def create_mag_dist_vortex(mag_shape, center=None, axis='z'): + """Create a 3-dimensional magnetic distribution of a homogeneous magnetized object. + + Parameters + ---------- + mag_shape : :class:`~numpy.ndarray` (N=3) + The magnetic shapes (see :mod:`.~shapes`` for examples). + center : tuple (N=2 or N=3), optional + The vortex center, given in 2D `(v, u)` or 3D `(z, y, x)`, where the perpendicular axis + is is discarded. Is set to the center of the field of view if not specified. The vortex + center has to be between two pixels. + axis : {'z', '-z', 'y', '-y', 'x', '-x'}, optional + The orientation of the vortex axis. The default is 'z'. Negative values invert the vortex + orientation. + + Returns + ------- + amplitude : tuple (N=3) of :class:`~numpy.ndarray` (N=3) + The magnetic distribution as a tuple of the 3 components in + `x`-, `y`- and `z`-direction on the 3-dimensional grid. + + Notes + ----- + To avoid singularities, the vortex center should lie between the pixel centers (which + reside at coordinates with _.5 at the end), i.e. integer values should be used as center + coordinates (e.g. coordinate 1 lies between the first and the second pixel). + + """ + _log.debug('Calling create_mag_dist_vortex') + dim = mag_shape.shape + assert len(dim) == 3, 'Magnetic shapes must describe 3-dimensional distributions!' + assert center is None or len(center) in {2, 3}, \ + 'Vortex center has to be defined in 3D or 2D or not at all!' + if center is None: + center = (dim[1] / 2, dim[2] / 2) + sign = -1 if '-' in axis else 1 + if axis in ('z', '-z'): + if len(center) == 3: # if a 3D-center is given, just take the x and y components + center = (center[1], center[2]) + u = np.linspace(-center[1], dim[2] - 1 - center[1], dim[2]) + 0.5 # pixel center! + v = np.linspace(-center[0], dim[1] - 1 - center[0], dim[1]) + 0.5 # pixel center! + uu, vv = np.meshgrid(u, v) + phi = np.expand_dims(np.arctan2(vv, uu) - pi / 2, axis=0) + phi = np.tile(phi, (dim[0], 1, 1)) + z_mag = np.zeros(dim) + y_mag = np.ones(dim) * -np.sin(phi) * mag_shape * sign + x_mag = np.ones(dim) * -np.cos(phi) * mag_shape * sign + elif axis in ('y', '-y'): + if len(center) == 3: # if a 3D-center is given, just take the x and z components + center = (center[0], center[2]) + u = np.linspace(-center[1], dim[2] - 1 - center[1], dim[2]) + 0.5 # pixel center! + v = np.linspace(-center[0], dim[0] - 1 - center[0], dim[0]) + 0.5 # pixel center! + uu, vv = np.meshgrid(u, v) + phi = np.expand_dims(np.arctan2(vv, uu) - pi / 2, axis=1) + phi = np.tile(phi, (1, dim[1], 1)) + z_mag = np.ones(dim) * np.sin(phi) * mag_shape * sign + y_mag = np.zeros(dim) + x_mag = np.ones(dim) * np.cos(phi) * mag_shape * sign + elif axis in ('x', '-x'): + if len(center) == 3: # if a 3D-center is given, just take the z and y components + center = (center[0], center[1]) + u = np.linspace(-center[1], dim[1] - 1 - center[1], dim[1]) + 0.5 # pixel center! + v = np.linspace(-center[0], dim[0] - 1 - center[0], dim[0]) + 0.5 # pixel center! + uu, vv = np.meshgrid(u, v) + phi = np.expand_dims(np.arctan2(vv, uu) - pi / 2, axis=2) + phi = np.tile(phi, (1, 1, dim[2])) + z_mag = np.ones(dim) * -np.sin(phi) * mag_shape * sign + y_mag = np.ones(dim) * -np.cos(phi) * mag_shape * sign + x_mag = np.zeros(dim) + else: + raise ValueError('{} is not a valid argument (use x, -x, y, -y, z or -z)'.format(axis)) + return np.array([x_mag, y_mag, z_mag]) + + +def create_mag_dist_smooth_vortex(mag_shape, center=None, vort_r=None, axis='z'): + """Create a 3-dimensional magnetic distribution of a homogeneous magnetized object. + + Parameters + ---------- + mag_shape : :class:`~numpy.ndarray` (N=3) + The magnetic shapes (see :mod:`.~shapes`` for examples). + center : tuple (N=2 or N=3), optional + The vortex center, given in 2D `(v, u)` or 3D `(z, y, x)`, where the perpendicular axis + is is discarded. Is set to the center of the field of view if not specified. The vortex + center has to be between two pixels. + axis : {'z', '-z', 'y', '-y', 'x', '-x'}, optional + The orientation of the vortex axis. The default is 'z'. Negative values invert the vortex + orientation. + + Returns + ------- + amplitude : tuple (N=3) of :class:`~numpy.ndarray` (N=3) + The magnetic distribution as a tuple of the 3 components in + `x`-, `y`- and `z`-direction on the 3-dimensional grid. + + Notes + ----- + To avoid singularities, the vortex center should lie between the pixel centers (which + reside at coordinates with _.5 at the end), i.e. integer values should be used as center + coordinates (e.g. coordinate 1 lies between the first and the second pixel). + + """ + + def core(r): + """Function describing the smooth vortex core.""" + return 1 - 2/np.pi * np.arcsin(np.tanh(np.pi*r/vort_r)) + + _log.debug('Calling create_mag_dist_vortex') + dim = mag_shape.shape + assert len(dim) == 3, 'Magnetic shapes must describe 3-dimensional distributions!' + assert center is None or len(center) in {2, 3}, \ + 'Vortex center has to be defined in 3D or 2D or not at all!' + if center is None: + center = (dim[1] / 2, dim[2] / 2) + sign = -1 if '-' in axis else 1 + if axis in ('z', '-z'): + if len(center) == 3: # if a 3D-center is given, just take the x and y components + center = (center[1], center[2]) + u = np.linspace(-center[1], dim[2] - 1 - center[1], dim[2]) + 0.5 # pixel center! + v = np.linspace(-center[0], dim[1] - 1 - center[0], dim[1]) + 0.5 # pixel center! + uu, vv = np.meshgrid(u, v) + rr = np.hypot(uu, vv)[None, :, :] + phi = np.expand_dims(np.arctan2(vv, uu) - pi / 2, axis=0) + phi = np.tile(phi, (dim[0], 1, 1)) + z_mag = np.ones(dim) * mag_shape * sign * core(rr) + y_mag = np.ones(dim) * -np.sin(phi) * mag_shape * sign * np.sqrt(1 - core(rr)) + x_mag = np.ones(dim) * -np.cos(phi) * mag_shape * sign * np.sqrt(1 - core(rr)) + elif axis in ('y', '-y'): + if len(center) == 3: # if a 3D-center is given, just take the x and z components + center = (center[0], center[2]) + u = np.linspace(-center[1], dim[2] - 1 - center[1], dim[2]) + 0.5 # pixel center! + v = np.linspace(-center[0], dim[0] - 1 - center[0], dim[0]) + 0.5 # pixel center! + uu, vv = np.meshgrid(u, v) + rr = np.hypot(uu, vv)[:, None, :] + phi = np.expand_dims(np.arctan2(vv, uu) - pi / 2, axis=1) + phi = np.tile(phi, (1, dim[1], 1)) + z_mag = np.ones(dim) * np.sin(phi) * mag_shape * sign * np.sqrt(1 - core(rr)) + y_mag = np.ones(dim) * mag_shape * sign * core(rr) + x_mag = np.ones(dim) * np.cos(phi) * mag_shape * sign * np.sqrt(1 - core(rr)) + elif axis in ('x', '-x'): + if len(center) == 3: # if a 3D-center is given, just take the z and y components + center = (center[0], center[1]) + u = np.linspace(-center[1], dim[1] - 1 - center[1], dim[1]) + 0.5 # pixel center! + v = np.linspace(-center[0], dim[0] - 1 - center[0], dim[0]) + 0.5 # pixel center! + uu, vv = np.meshgrid(u, v) + rr = np.hypot(uu, vv)[:, :, None] + phi = np.expand_dims(np.arctan2(vv, uu) - pi / 2, axis=2) + phi = np.tile(phi, (1, 1, dim[2])) + z_mag = np.ones(dim) * -np.sin(phi) * mag_shape * sign * np.sqrt(1 - core(rr)) + y_mag = np.ones(dim) * -np.cos(phi) * mag_shape * sign * np.sqrt(1 - core(rr)) + x_mag = np.ones(dim) * mag_shape * sign * core(rr) + else: + raise ValueError('{} is not a valid argument (use x, -x, y, -y, z or -z)'.format(axis)) + return np.array([x_mag, y_mag, z_mag]) + + +def create_mag_dist_source(mag_shape, center=None, axis='z'): + """Create a 3-dimensional magnetic distribution of a homogeneous magnetized object. + + Parameters + ---------- + mag_shape : :class:`~numpy.ndarray` (N=3) + The magnetic shapes (see :mod:`.~shapes`` for examples). + center : tuple (N=2 or N=3), optional + The source center, given in 2D `(v, u)` or 3D `(z, y, x)`, where the perpendicular axis + is is discarded. Is set to the center of the field of view if not specified. + The source center has to be between two pixels. + axis : {'z', '-z', 'y', '-y', 'x', '-x'}, optional + The orientation of the source axis. The default is 'z'. Negative values invert the source + to a sink. + + Returns + ------- + amplitude : tuple (N=3) of :class:`~numpy.ndarray` (N=3) + The magnetic distribution as a tuple of the 3 components in + `x`-, `y`- and `z`-direction on the 3-dimensional grid. + + Notes + ----- + To avoid singularities, the source center should lie between the pixel centers (which + reside at coordinates with _.5 at the end), i.e. integer values should be used as center + coordinates (e.g. coordinate 1 lies between the first and the second pixel). + + """ + _log.debug('Calling create_mag_dist_vortex') + dim = mag_shape.shape + assert len(dim) == 3, 'Magnetic shapes must describe 3-dimensional distributions!' + assert center is None or len(center) in {2, 3}, \ + 'Vortex center has to be defined in 3D or 2D or not at all!' + if center is None: + center = (dim[1] / 2, dim[2] / 2) + sign = -1 if '-' in axis else 1 + if axis in ('z', '-z'): + if len(center) == 3: # if a 3D-center is given, just take the x and y components + center = (center[1], center[2]) + u = np.linspace(-center[1], dim[2] - 1 - center[1], dim[2]) + 0.5 # pixel center! + v = np.linspace(-center[0], dim[1] - 1 - center[0], dim[1]) + 0.5 # pixel center! + uu, vv = np.meshgrid(u, v) + phi = np.expand_dims(np.arctan2(vv, uu) - pi / 2, axis=0) + phi = np.tile(phi, (dim[0], 1, 1)) + z_mag = np.zeros(dim) + y_mag = np.ones(dim) * np.cos(phi) * mag_shape * sign + x_mag = np.ones(dim) * -np.sin(phi) * mag_shape * sign + elif axis in ('y', '-y'): + if len(center) == 3: # if a 3D-center is given, just take the x and z components + center = (center[0], center[2]) + u = np.linspace(-center[1], dim[2] - 1 - center[1], dim[2]) + 0.5 # pixel center! + v = np.linspace(-center[0], dim[0] - 1 - center[0], dim[0]) + 0.5 # pixel center! + uu, vv = np.meshgrid(u, v) + phi = np.expand_dims(np.arctan2(vv, uu) - pi / 2, axis=1) + phi = np.tile(phi, (1, dim[1], 1)) + z_mag = np.ones(dim) * np.cos(phi) * mag_shape * sign + y_mag = np.zeros(dim) + x_mag = np.ones(dim) * -np.sin(phi) * mag_shape * sign + elif axis in ('x', '-x'): + if len(center) == 3: # if a 3D-center is given, just take the z and y components + center = (center[0], center[1]) + u = np.linspace(-center[1], dim[1] - 1 - center[1], dim[1]) + 0.5 # pixel center! + v = np.linspace(-center[0], dim[0] - 1 - center[0], dim[0]) + 0.5 # pixel center! + uu, vv = np.meshgrid(u, v) + phi = np.expand_dims(np.arctan2(vv, uu) - pi / 2, axis=2) + phi = np.tile(phi, (1, 1, dim[2])) + z_mag = np.ones(dim) * np.cos(phi) * mag_shape * sign + y_mag = np.ones(dim) * -np.sin(phi) * mag_shape * sign + x_mag = np.zeros(dim) + else: + raise ValueError('{} is not a valid argument (use x, -x, y, -y, z or -z)'.format(axis)) + return np.array([x_mag, y_mag, z_mag]) diff --git a/pyramid/magcreator/shapes.py b/pyramid/magcreator/shapes.py index ca8e5b3bcbe682b701a546154faff595a71f05c3..120b3319b7b0ca416771510e1956f4cb7f791672 100644 --- a/pyramid/magcreator/shapes.py +++ b/pyramid/magcreator/shapes.py @@ -1,251 +1,251 @@ -# -*- coding: utf-8 -*- -# Copyright 2016 by Forschungszentrum Juelich GmbH -# Author: J. Caron -# -"""Provide simple shapes. - -This module is a collection of some methods that return a 3-dimensional -matrix that represents the field volume and consists of boolean values. -This matrix is used in the functions of the :mod:`~.magcreator` module to create -:class:`~pyramid.fielddata.VectorData` objects which store the field information. - -""" - -import logging - -import numpy as np - -__all__ = ['slab', 'disc', 'ellipse', 'ellipsoid', 'sphere', 'filament', 'pixel'] -_log = logging.getLogger(__name__) - - -def slab(dim, center, width): - """Create the shape of a slab. - - Parameters - ---------- - dim : tuple (N=3) - The dimensions of the grid `(z, y, x)`. - center : tuple (N=3) - The center of the slab in pixel coordinates `(z, y, x)`. - width : tuple (N=3) - The width of the slab in pixel coordinates `(z, y, x)`. - - Returns - ------- - shape : :class:`~numpy.ndarray` (N=3) - The shape as a 3D-array. - - """ - _log.debug('Calling slab') - assert np.shape(dim) == (3,), 'Parameter dim has to be a tuple of length 3!' - assert np.shape(center) == (3,), 'Parameter center has to be a tuple of length 3!' - assert np.shape(width) == (3,), 'Parameter width has to be a tuple of length 3!' - zz, yy, xx = np.indices(dim) + 0.5 - xx_shape = np.where(abs(xx - center[2]) <= width[2] / 2, True, False) - yy_shape = np.where(abs(yy - center[1]) <= width[1] / 2, True, False) - zz_shape = np.where(abs(zz - center[0]) <= width[0] / 2, True, False) - return np.logical_and(np.logical_and(xx_shape, yy_shape), zz_shape) - - -def disc(dim, center, radius, height, axis='z'): - """Create the shape of a cylindrical disc in x-, y-, or z-direction. - - Parameters - ---------- - dim : tuple (N=3) - The dimensions of the grid `(z, y, x)`. - center : tuple (N=3) - The center of the disc in pixel coordinates `(z, y, x)`. - radius : float - The radius of the disc in pixel coordinates. - height : float - The height of the disc in pixel coordinates. - axis : {'z', 'y', 'x'}, optional - The orientation of the disc axis. The default is 'z'. - - Returns - ------- - shape : :class:`~numpy.ndarray` (N=3) - The shape as a 3D-array. - - """ - _log.debug('Calling disc') - assert np.shape(dim) == (3,), 'Parameter dim has to be a a tuple of length 3!' - assert np.shape(center) == (3,), 'Parameter center has to be a a tuple of length 3!' - assert radius > 0 and np.shape(radius) == (), 'Radius has to be a positive scalar value!' - assert height > 0 and np.shape(height) == (), 'Height has to be a positive scalar value!' - assert axis in {'z', 'y', 'x'}, 'Axis has to be x, y or z (as a string)!' - zz, yy, xx = np.indices(dim) + 0.5 - xx -= center[2] - yy -= center[1] - zz -= center[0] - if axis == 'z': - uu, vv, ww = xx, yy, zz - elif axis == 'y': - uu, vv, ww = zz, xx, yy - elif axis == 'x': - uu, vv, ww = yy, zz, xx - else: - raise ValueError('{} is not a valid argument (use x, y or z)'.format(axis)) - return np.logical_and(np.where(np.hypot(uu, vv) <= radius, True, False), - np.where(abs(ww) <= height / 2, True, False)) - - -def ellipse(dim, center, width, height, axis='z'): - """Create the shape of an elliptical cylinder in x-, y-, or z-direction. - - Parameters - ---------- - dim : tuple (N=3) - The dimensions of the grid `(z, y, x)`. - center : tuple (N=3) - The center of the ellipse in pixel coordinates `(z, y, x)`. - width : tuple (N=2) - Length of the two axes of the ellipse. - height : float - The height of the ellipse in pixel coordinates. - axis : {'z', 'y', 'x'}, optional - The orientation of the ellipse axis. The default is 'z'. - - Returns - ------- - shape : :class:`~numpy.ndarray` (N=3) - The shape as a 3D-array. - - """ - _log.debug('Calling ellipse') - assert np.shape(dim) == (3,), 'Parameter dim has to be a a tuple of length 3!' - assert np.shape(center) == (3,), 'Parameter center has to be a a tuple of length 3!' - assert np.shape(width) == (2,), 'Parameter width has to be a a tuple of length 2!' - assert height > 0 and np.shape(height) == (), 'Height has to be a positive scalar value!' - assert axis in {'z', 'y', 'x'}, 'Axis has to be x, y or z (as a string)!' - zz, yy, xx = np.indices(dim) + 0.5 - xx -= center[2] - yy -= center[1] - zz -= center[0] - if axis == 'z': - uu, vv, ww = xx, yy, zz - elif axis == 'y': - uu, vv, ww = xx, zz, yy - elif axis == 'x': - uu, vv, ww = yy, zz, xx - else: - raise ValueError('{} is not a valid argument (use x, y or z)'.format(axis)) - distance = np.hypot(uu / (width[1] / 2), vv / (width[0] / 2)) - return np.logical_and(np.where(distance <= 1, True, False), - np.where(abs(ww) <= height / 2, True, False)) - - -def ellipsoid(dim, center, width): - """Create the shape of an ellipsoid. - - Parameters - ---------- - dim : tuple (N=3) - The dimensions of the grid `(z, y, x)`. - center : tuple (N=3) - The center of the ellipsoid in pixel coordinates `(z, y, x)`. - width : tuple (N=3) - The width of the ellipsoid `(z, y, x)`. - - Returns - ------- - shape : :class:`~numpy.ndarray` (N=3) - The shape as a 3D-array. - - """ - _log.debug('Calling ellipsoid') - assert np.shape(dim) == (3,), 'Parameter dim has to be a a tuple of length 3!' - assert np.shape(center) == (3,), 'Parameter center has to be a a tuple of length 3!' - assert np.shape(width) == (3,), 'Parameter width has to be a a tuple of length 3!' - zz, yy, xx = np.indices(dim) + 0.5 - distance = np.sqrt(((xx - center[2]) / (width[2] / 2)) ** 2 - + ((yy - center[1]) / (width[1] / 2)) ** 2 - + ((zz - center[0]) / (width[0] / 2)) ** 2) - return np.where(distance <= 1, True, False) - - -def sphere(dim, center, radius): - """Create the shape of a sphere. - - Parameters - ---------- - dim : tuple (N=3) - The dimensions of the grid `(z, y, x)`. - center : tuple (N=3) - The center of the sphere in pixel coordinates `(z, y, x)`. - radius : float - The radius of the sphere in pixel coordinates. - - Returns - ------- - shape : :class:`~numpy.ndarray` (N=3) - The shape as a 3D-array. - - """ - _log.debug('Calling sphere') - assert np.shape(dim) == (3,), 'Parameter dim has to be a a tuple of length 3!' - assert np.shape(center) == (3,), 'Parameter center has to be a a tuple of length 3!' - assert radius > 0 and np.shape(radius) == (), 'Radius has to be a positive scalar value!' - zz, yy, xx = np.indices(dim) + 0.5 - distance = np.sqrt((xx - center[2]) ** 2 + (yy - center[1]) ** 2 + (zz - center[0]) ** 2) - return np.where(distance <= radius, True, False) - - -def filament(dim, pos, axis='y'): - """Create the shape of a filament. - - Parameters - ---------- - dim : tuple (N=3) - The dimensions of the grid `(z, y, x)`. - pos : tuple (N=2) - The position of the filament in pixel coordinates `(coord1, coord2)`. - `coord1` and `coord2` stand for the two axes, which are perpendicular to `axis`. For - the default case (`axis = y`), it is `(coord1, coord2) = (z, x)`. - axis : {'y', 'x', 'z'}, optional - The orientation of the filament axis. The default is 'y'. - - Returns - ------- - shape : :class:`~numpy.ndarray` (N=3) - The shape as a 3D-array. - - """ - _log.debug('Calling filament') - assert np.shape(dim) == (3,), 'Parameter dim has to be a tuple of length 3!' - assert np.shape(pos) == (2,), 'Parameter pos has to be a tuple of length 2!' - assert axis in {'z', 'y', 'x'}, 'Axis has to be x, y or z (as a string)!' - shape = np.zeros(dim, dtype=bool) - if axis == 'z': - shape[:, pos[0], pos[1]] = True - elif axis == 'y': - shape[pos[0], :, pos[1]] = True - elif axis == 'x': - shape[pos[0], pos[1], :] = True - return shape - - -def pixel(dim, pixel): - """Create the shape of a single pixel. - - Parameters - ---------- - dim : tuple (N=3) - The dimensions of the grid `(z, y, x)`. - pixel : tuple (N=3) - The coordinates of the pixel `(z, y, x)`. - - Returns - ------- - shape : :class:`~numpy.ndarray` (N=3) - The shape as a 3D-array. - - """ - _log.debug('Calling pixel') - assert np.shape(dim) == (3,), 'Parameter dim has to be a tuple of length 3!' - assert np.shape(pixel) == (3,), 'Parameter pixel has to be a tuple of length 3!' - shape = np.zeros(dim, dtype=bool) - shape[pixel] = True - return shape +# -*- coding: utf-8 -*- +# Copyright 2016 by Forschungszentrum Juelich GmbH +# Author: J. Caron +# +"""Provide simple shapes. + +This module is a collection of some methods that return a 3-dimensional +matrix that represents the field volume and consists of boolean values. +This matrix is used in the functions of the :mod:`~.magcreator` module to create +:class:`~pyramid.fielddata.VectorData` objects which store the field information. + +""" + +import logging + +import numpy as np + +__all__ = ['slab', 'disc', 'ellipse', 'ellipsoid', 'sphere', 'filament', 'pixel'] +_log = logging.getLogger(__name__) + + +def slab(dim, center, width): + """Create the shape of a slab. + + Parameters + ---------- + dim : tuple (N=3) + The dimensions of the grid `(z, y, x)`. + center : tuple (N=3) + The center of the slab in pixel coordinates `(z, y, x)`. + width : tuple (N=3) + The width of the slab in pixel coordinates `(z, y, x)`. + + Returns + ------- + shape : :class:`~numpy.ndarray` (N=3) + The shape as a 3D-array. + + """ + _log.debug('Calling slab') + assert np.shape(dim) == (3,), 'Parameter dim has to be a tuple of length 3!' + assert np.shape(center) == (3,), 'Parameter center has to be a tuple of length 3!' + assert np.shape(width) == (3,), 'Parameter width has to be a tuple of length 3!' + zz, yy, xx = np.indices(dim) + 0.5 + xx_shape = np.where(abs(xx - center[2]) <= width[2] / 2, True, False) + yy_shape = np.where(abs(yy - center[1]) <= width[1] / 2, True, False) + zz_shape = np.where(abs(zz - center[0]) <= width[0] / 2, True, False) + return np.logical_and(np.logical_and(xx_shape, yy_shape), zz_shape) + + +def disc(dim, center, radius, height, axis='z'): + """Create the shape of a cylindrical disc in x-, y-, or z-direction. + + Parameters + ---------- + dim : tuple (N=3) + The dimensions of the grid `(z, y, x)`. + center : tuple (N=3) + The center of the disc in pixel coordinates `(z, y, x)`. + radius : float + The radius of the disc in pixel coordinates. + height : float + The height of the disc in pixel coordinates. + axis : {'z', 'y', 'x'}, optional + The orientation of the disc axis. The default is 'z'. + + Returns + ------- + shape : :class:`~numpy.ndarray` (N=3) + The shape as a 3D-array. + + """ + _log.debug('Calling disc') + assert np.shape(dim) == (3,), 'Parameter dim has to be a a tuple of length 3!' + assert np.shape(center) == (3,), 'Parameter center has to be a a tuple of length 3!' + assert radius > 0 and np.shape(radius) == (), 'Radius has to be a positive scalar value!' + assert height > 0 and np.shape(height) == (), 'Height has to be a positive scalar value!' + assert axis in {'z', 'y', 'x'}, 'Axis has to be x, y or z (as a string)!' + zz, yy, xx = np.indices(dim) + 0.5 + xx -= center[2] + yy -= center[1] + zz -= center[0] + if axis == 'z': + uu, vv, ww = xx, yy, zz + elif axis == 'y': + uu, vv, ww = zz, xx, yy + elif axis == 'x': + uu, vv, ww = yy, zz, xx + else: + raise ValueError('{} is not a valid argument (use x, y or z)'.format(axis)) + return np.logical_and(np.where(np.hypot(uu, vv) <= radius, True, False), + np.where(abs(ww) <= height / 2, True, False)) + + +def ellipse(dim, center, width, height, axis='z'): + """Create the shape of an elliptical cylinder in x-, y-, or z-direction. + + Parameters + ---------- + dim : tuple (N=3) + The dimensions of the grid `(z, y, x)`. + center : tuple (N=3) + The center of the ellipse in pixel coordinates `(z, y, x)`. + width : tuple (N=2) + Length of the two axes of the ellipse. + height : float + The height of the ellipse in pixel coordinates. + axis : {'z', 'y', 'x'}, optional + The orientation of the ellipse axis. The default is 'z'. + + Returns + ------- + shape : :class:`~numpy.ndarray` (N=3) + The shape as a 3D-array. + + """ + _log.debug('Calling ellipse') + assert np.shape(dim) == (3,), 'Parameter dim has to be a a tuple of length 3!' + assert np.shape(center) == (3,), 'Parameter center has to be a a tuple of length 3!' + assert np.shape(width) == (2,), 'Parameter width has to be a a tuple of length 2!' + assert height > 0 and np.shape(height) == (), 'Height has to be a positive scalar value!' + assert axis in {'z', 'y', 'x'}, 'Axis has to be x, y or z (as a string)!' + zz, yy, xx = np.indices(dim) + 0.5 + xx -= center[2] + yy -= center[1] + zz -= center[0] + if axis == 'z': + uu, vv, ww = xx, yy, zz + elif axis == 'y': + uu, vv, ww = xx, zz, yy + elif axis == 'x': + uu, vv, ww = yy, zz, xx + else: + raise ValueError('{} is not a valid argument (use x, y or z)'.format(axis)) + distance = np.hypot(uu / (width[1] / 2), vv / (width[0] / 2)) + return np.logical_and(np.where(distance <= 1, True, False), + np.where(abs(ww) <= height / 2, True, False)) + + +def ellipsoid(dim, center, width): + """Create the shape of an ellipsoid. + + Parameters + ---------- + dim : tuple (N=3) + The dimensions of the grid `(z, y, x)`. + center : tuple (N=3) + The center of the ellipsoid in pixel coordinates `(z, y, x)`. + width : tuple (N=3) + The width of the ellipsoid `(z, y, x)`. + + Returns + ------- + shape : :class:`~numpy.ndarray` (N=3) + The shape as a 3D-array. + + """ + _log.debug('Calling ellipsoid') + assert np.shape(dim) == (3,), 'Parameter dim has to be a a tuple of length 3!' + assert np.shape(center) == (3,), 'Parameter center has to be a a tuple of length 3!' + assert np.shape(width) == (3,), 'Parameter width has to be a a tuple of length 3!' + zz, yy, xx = np.indices(dim) + 0.5 + distance = np.sqrt(((xx - center[2]) / (width[2] / 2)) ** 2 + + ((yy - center[1]) / (width[1] / 2)) ** 2 + + ((zz - center[0]) / (width[0] / 2)) ** 2) + return np.where(distance <= 1, True, False) + + +def sphere(dim, center, radius): + """Create the shape of a sphere. + + Parameters + ---------- + dim : tuple (N=3) + The dimensions of the grid `(z, y, x)`. + center : tuple (N=3) + The center of the sphere in pixel coordinates `(z, y, x)`. + radius : float + The radius of the sphere in pixel coordinates. + + Returns + ------- + shape : :class:`~numpy.ndarray` (N=3) + The shape as a 3D-array. + + """ + _log.debug('Calling sphere') + assert np.shape(dim) == (3,), 'Parameter dim has to be a a tuple of length 3!' + assert np.shape(center) == (3,), 'Parameter center has to be a a tuple of length 3!' + assert radius > 0 and np.shape(radius) == (), 'Radius has to be a positive scalar value!' + zz, yy, xx = np.indices(dim) + 0.5 + distance = np.sqrt((xx - center[2]) ** 2 + (yy - center[1]) ** 2 + (zz - center[0]) ** 2) + return np.where(distance <= radius, True, False) + + +def filament(dim, pos, axis='y'): + """Create the shape of a filament. + + Parameters + ---------- + dim : tuple (N=3) + The dimensions of the grid `(z, y, x)`. + pos : tuple (N=2) + The position of the filament in pixel coordinates `(coord1, coord2)`. + `coord1` and `coord2` stand for the two axes, which are perpendicular to `axis`. For + the default case (`axis = y`), it is `(coord1, coord2) = (z, x)`. + axis : {'y', 'x', 'z'}, optional + The orientation of the filament axis. The default is 'y'. + + Returns + ------- + shape : :class:`~numpy.ndarray` (N=3) + The shape as a 3D-array. + + """ + _log.debug('Calling filament') + assert np.shape(dim) == (3,), 'Parameter dim has to be a tuple of length 3!' + assert np.shape(pos) == (2,), 'Parameter pos has to be a tuple of length 2!' + assert axis in {'z', 'y', 'x'}, 'Axis has to be x, y or z (as a string)!' + shape = np.zeros(dim, dtype=bool) + if axis == 'z': + shape[:, pos[0], pos[1]] = True + elif axis == 'y': + shape[pos[0], :, pos[1]] = True + elif axis == 'x': + shape[pos[0], pos[1], :] = True + return shape + + +def pixel(dim, pixel): + """Create the shape of a single pixel. + + Parameters + ---------- + dim : tuple (N=3) + The dimensions of the grid `(z, y, x)`. + pixel : tuple (N=3) + The coordinates of the pixel `(z, y, x)`. + + Returns + ------- + shape : :class:`~numpy.ndarray` (N=3) + The shape as a 3D-array. + + """ + _log.debug('Calling pixel') + assert np.shape(dim) == (3,), 'Parameter dim has to be a tuple of length 3!' + assert np.shape(pixel) == (3,), 'Parameter pixel has to be a tuple of length 3!' + shape = np.zeros(dim, dtype=bool) + shape[pixel] = True + return shape diff --git a/pyramid/phasemap.py b/pyramid/phasemap.py index 2c68f4182eecb415ad46495da4467ca1c103cd43..c443c782c7c12caebf2178749062218adbc968e9 100644 --- a/pyramid/phasemap.py +++ b/pyramid/phasemap.py @@ -1,938 +1,938 @@ -# -*- coding: utf-8 -*- -# Copyright 2016 by Forschungszentrum Juelich GmbH -# Author: J. Caron -# -"""This module provides the :class:`~.PhaseMap` class for storing phase map data.""" - -import logging - -from numbers import Number - -import numpy as np - -from PIL import Image - -import matplotlib as mpl -import matplotlib.pyplot as plt -from matplotlib.colors import LinearSegmentedColormap -from matplotlib.ticker import MaxNLocator - -from mpl_toolkits.mplot3d import Axes3D - -from scipy import ndimage - -import warnings - -from . import colors -from . import plottools - -__all__ = ['PhaseMap'] - - -class PhaseMap(object): - """Class for storing phase map data. - - Represents 2-dimensional phase maps. The phase information itself is stored as a 2-dimensional - matrix in `phase`, but can also be accessed as a vector via `phase_vec`. :class:`~.PhaseMap` - objects support negation, arithmetic operators (``+``, ``-``, ``*``) and their augmented - counterparts (``+=``, ``-=``, ``*=``), with numbers and other :class:`~.PhaseMap` - objects, if their dimensions and grid spacings match. It is possible to load data from HDF5 - or textfiles or to save the data in these formats. Methods for plotting the phase or a - corresponding holographic contour map are provided. Holographic contour maps are created by - taking the cosine of the (optionally amplified) phase and encoding the direction of the - 2-dimensional gradient via color. The directional encoding can be seen by using the - :func:`~.make_color_wheel` function. Use the :func:`~.plot_combined` function to plot the - phase map and the holographic contour map next to each other. - - Attributes - ---------- - a: float - The grid spacing in nm. - phase: :class:`~numpy.ndarray` (N=2) - Array containing the phase shift. - mask: :class:`~numpy.ndarray` (boolean, N=2, optional) - Mask which determines the projected magnetization distribution, gotten from MIP images or - otherwise acquired. Defaults to an array of ones (all pixels are considered). - confidence: :class:`~numpy.ndarray` (N=2, optional) - Confidence array which determines the trust of specific regions of the phasemap. A value - of 1 means the pixel is trustworthy, a value of 0 means it is not. Defaults to an array of - ones (full trust for all pixels). Can be used for the construction of Se_inv. - - """ - - _log = logging.getLogger(__name__) - - UNITDICT = {u'rad': 1E0, - u'mrad': 1E3, - u'µrad': 1E6, - u'nrad': 1E9, - u'1/rad': 1E0, - u'1/mrad': 1E-3, - u'1/µrad': 1E-6, - u'1/nrad': 1E-9} - - @property - def a(self): - """Grid spacing in nm.""" - return self._a - - @a.setter - def a(self, a): - assert isinstance(a, Number), 'Grid spacing has to be a number!' - assert a >= 0, 'Grid spacing has to be a positive number!' - self._a = float(a) - - @property - def dim_uv(self): - """Dimensions of the grid.""" - return self._dim_uv - - @property - def phase(self): - """Array containing the phase shift.""" - return self._phase - - @phase.setter - def phase(self, phase): - assert isinstance(phase, np.ndarray), 'Phase has to be a numpy array!' - assert len(phase.shape) == 2, 'Phase has to be 2-dimensional, not {}!'.format(phase.shape) - self._phase = phase.astype(dtype=np.float32) - self._dim_uv = phase.shape - - @property - def phase_vec(self): - """Vector containing the phase shift.""" - return self.phase.ravel() - - @phase_vec.setter - def phase_vec(self, phase_vec): - assert isinstance(phase_vec, np.ndarray), 'Vector has to be a numpy array!' - assert np.size(phase_vec) == np.prod(self.dim_uv), 'Vector size has to match phase!' - self.phase = phase_vec.reshape(self.dim_uv) - - @property - def mask(self): - """Mask which determines the projected magnetization distribution""" - return self._mask - - @mask.setter - def mask(self, mask): - if mask is not None: - assert mask.shape == self.phase.shape, 'Mask and phase dimensions must match!!' - else: - mask = np.ones_like(self.phase, dtype=bool) - self._mask = mask.astype(np.bool) - - @property - def confidence(self): - """Confidence array which determines the trust of specific regions of the phasemap.""" - return self._confidence - - @confidence.setter - def confidence(self, confidence): - if confidence is not None: - assert confidence.shape == self.phase.shape, \ - 'Confidence and phase dimensions must match!' - confidence = confidence.astype(dtype=np.float32) - confidence /= confidence.max() # Normalise! - else: - confidence = np.ones_like(self.phase, dtype=np.float32) - self._confidence = confidence - - def __init__(self, a, phase, mask=None, confidence=None): - self._log.debug('Calling __init__') - self.a = a - self.phase = phase - self.mask = mask - self.confidence = confidence - self._log.debug('Created ' + str(self)) - - def __repr__(self): - self._log.debug('Calling __repr__') - return '%s(a=%r, phase=%r, mask=%r, confidence=%r)' % \ - (self.__class__, self.a, self.phase, self.mask, self.confidence) - - def __str__(self): - self._log.debug('Calling __str__') - return 'PhaseMap(a=%s, dim_uv=%s, mask=%s)' % (self.a, self.dim_uv, not np.all(self.mask)) - - def __neg__(self): # -self - self._log.debug('Calling __neg__') - return PhaseMap(self.a, -self.phase, self.mask, self.confidence) - - def __add__(self, other): # self + other - self._log.debug('Calling __add__') - assert isinstance(other, (PhaseMap, Number)), \ - 'Only PhaseMap objects and scalar numbers (as offsets) can be added/subtracted!' - if isinstance(other, PhaseMap): - self._log.debug('Adding two PhaseMap objects') - assert other.a == self.a, 'Added phase has to have the same grid spacing!' - assert other.phase.shape == self.dim_uv, \ - 'Added field has to have the same dimensions!' - mask_comb = np.logical_or(self.mask, other.mask) # masks combine - conf_comb = (self.confidence + other.confidence) / 2 # confidence averaged! - return PhaseMap(self.a, self.phase + other.phase, mask_comb, conf_comb) - else: # other is a Number - self._log.debug('Adding an offset') - return PhaseMap(self.a, self.phase + other, self.mask, self.confidence) - - def __sub__(self, other): # self - other - self._log.debug('Calling __sub__') - return self.__add__(-other) - - def __mul__(self, other): # self * other - self._log.debug('Calling __mul__') - assert (isinstance(other, Number) or - (isinstance(other, np.ndarray) and other.shape == self.dim_uv)), \ - 'PhaseMap objects can only be multiplied by scalar numbers or fitting arrays!' - return PhaseMap(self.a, self.phase * other, self.mask, self.confidence) - - def __truediv__(self, other): # self / other - self._log.debug('Calling __truediv__') - assert (isinstance(other, Number) or - (isinstance(other, np.ndarray) and other.shape == self.dim_uv)), \ - 'PhaseMap objects can only be divided by scalar numbers or fitting arrays!' - return PhaseMap(self.a, self.phase / other, self.mask, self.confidence) - - def __floordiv__(self, other): # self // other - self._log.debug('Calling __floordiv__') - assert (isinstance(other, Number) or - (isinstance(other, np.ndarray) and other.shape == self.dim_uv)), \ - 'PhaseMap objects can only be divided by scalar numbers or fitting arrays!' - return PhaseMap(self.a, self.phase // other, self.mask, self.confidence) - - def __radd__(self, other): # other + self - self._log.debug('Calling __radd__') - return self.__add__(other) - - def __rsub__(self, other): # other - self - self._log.debug('Calling __rsub__') - return -self.__sub__(other) - - def __rmul__(self, other): # other * self - self._log.debug('Calling __rmul__') - return self.__mul__(other) - - def __iadd__(self, other): # self += other - self._log.debug('Calling __iadd__') - return self.__add__(other) - - def __isub__(self, other): # self -= other - self._log.debug('Calling __isub__') - return self.__sub__(other) - - def __imul__(self, other): # self *= other - self._log.debug('Calling __imul__') - return self.__mul__(other) - - def __itruediv__(self, other): # self /= other - self._log.debug('Calling __itruediv__') - return self.__truediv__(other) - - def __ifloordiv__(self, other): # self //= other - self._log.debug('Calling __ifloordiv__') - return self.__floordiv__(other) - - def __getitem__(self, item): - return PhaseMap(self.a, self.phase[item], self.mask[item], self.confidence[item]) - - def __array__(self, dtype=None): # Used for numpy ufuncs, together with __array_wrap__! - if dtype: - return self.phase.astype(dtype) - else: - return self.phase - - def __array_wrap__(self, array, _=None): # _ catches the context, which is not used. - return PhaseMap(self.a, array, self.mask, self.confidence) - - def copy(self): - """Returns a copy of the :class:`~.PhaseMap` object - - Returns - ------- - phasemap: :class:`~.PhaseMap` - A copy of the :class:`~.PhaseMap`. - - """ - self._log.debug('Calling copy') - return PhaseMap(self.a, self.phase.copy(), self.mask.copy(), - self.confidence.copy()) - - def scale_down(self, n=1): - """Scale down the phase map by averaging over two pixels along each axis. - - Parameters - ---------- - n : int, optional - Number of times the phase map is scaled down. The default is 1. - - Returns - ------- - None - - Notes - ----- - Acts in place and changes dimensions and grid spacing accordingly. - Only possible, if each axis length is a power of 2! - - """ - self._log.debug('Calling scale_down') - assert n > 0 and isinstance(n, int), 'n must be a positive integer!' - self.a *= 2 ** n - for t in range(n): - # Pad if necessary: - pv, pu = (0, self.dim_uv[0] % 2), (0, self.dim_uv[1] % 2) - if pv != 0 or pu != 0: - self.pad((pv, pu), mode='edge') - # Create coarser grid for the magnetization: - dim_uv = self.dim_uv - self.phase = self.phase.reshape((dim_uv[0] // 2, 2, - dim_uv[1] // 2, 2)).mean(axis=(3, 1)) - mask = self.mask.reshape(dim_uv[0] // 2, 2, dim_uv[1] // 2, 2) - self.mask = mask[:, 0, :, 0] & mask[:, 1, :, 0] & mask[:, 0, :, 1] & mask[:, 1, :, 1] - self.confidence = self.confidence.reshape(dim_uv[0] // 2, 2, - dim_uv[1] // 2, 2).mean(axis=(3, 1)) - - def scale_up(self, n=1, order=0): - """Scale up the phase map using spline interpolation of the requested order. - - Parameters - ---------- - n : int, optional - Power of 2 with which the grid is scaled. Default is 1, which means every axis is - increased by a factor of ``2**1 = 2``. - order : int, optional - The order of the spline interpolation, which has to be in the range between 0 and 5 - and defaults to 0. - - Returns - ------- - None - - Notes - ----- - Acts in place and changes dimensions and grid spacing accordingly. - - """ - self._log.debug('Calling scale_up') - assert n > 0 and isinstance(n, int), 'n must be a positive integer!' - assert 5 > order >= 0 and isinstance(order, int), \ - 'order must be a positive integer between 0 and 5!' - self.a /= 2 ** n - self.phase = ndimage.zoom(self.phase, zoom=2 ** n, order=order) - self.mask = ndimage.zoom(self.mask, zoom=2 ** n, order=0) - self.confidence = ndimage.zoom(self.confidence, zoom=2 ** n, order=order) - - def pad(self, pad_values, mode='constant', masked=False, **kwds): - """Pad the current phase map with zeros for each individual axis. - - Parameters - ---------- - pad_values : tuple of int - Number of zeros which should be padded. Provided as a tuple where each entry - corresponds to an axis. An entry can be one int (same padding for both sides) or again - a tuple which specifies the pad values for both sides of the corresponding axis. - mode: string or function - A string values or a user supplied function. ‘constant’ pads with zeros. ‘edge’ pads - with the edge values of array. See the numpy pad function for an in depth guide. - masked: boolean, optional - Determines if the padded areas should be masked or not. `True` creates a 'buffer - zone' for the magnetization distribution in the reconstruction. Default is `False` - - Returns - ------- - None - - Notes - ----- - Acts in place and changes dimensions accordingly. - The confidence of the padded areas is set to zero! - - """ - self._log.debug('Calling pad') - assert len(pad_values) == 2, 'Pad values for each dimension have to be provided!' - pval = np.zeros(4, dtype=np.int) - for i, values in enumerate(pad_values): - assert np.shape(values) in [(), (2,)], 'Only one or two values per axis can be given!' - pval[2 * i:2 * (i + 1)] = values - self.phase = np.pad(self.phase, ((pval[0], pval[1]), - (pval[2], pval[3])), mode=mode, **kwds) - self.confidence = np.pad(self.confidence, ((pval[0], pval[1]), - (pval[2], pval[3])), mode=mode, **kwds) - if masked: - mask_kwds = {'mode': 'constant', 'constant_values': True} - else: - mask_kwds = {'mode': mode} - self.mask = np.pad(self.mask, ((pval[0], pval[1]), (pval[2], pval[3])), **mask_kwds) - - def crop(self, crop_values): - """Pad the current phase map with zeros for each individual axis. - - Parameters - ---------- - crop_values : tuple of int - Number of zeros which should be cropped. Provided as a tuple where each entry - corresponds to an axis. An entry can be one int (same cropping for both sides) or again - a tuple which specifies the crop values for both sides of the corresponding axis. - - Returns - ------- - None - - Notes - ----- - Acts in place and changes dimensions accordingly. - - """ - self._log.debug('Calling crop') - assert len(crop_values) == 2, 'Crop values for each dimension have to be provided!' - cv = np.zeros(4, dtype=np.int) - for i, values in enumerate(crop_values): - assert np.shape(values) in [(), (2,)], 'Only one or two values per axis can be given!' - cv[2 * i:2 * (i + 1)] = values - cv *= np.resize([1, -1], len(cv)) - cv = np.where(cv == 0, None, cv) - self.phase = self.phase[cv[0]:cv[1], cv[2]:cv[3]] - self.mask = self.mask[cv[0]:cv[1], cv[2]:cv[3]] - self.confidence = self.confidence[cv[0]:cv[1], cv[2]:cv[3]] - - def flip(self, axis='u'): - """Flip/mirror the phase map around the specified axis. - - Parameters - ---------- - axis: {'u', 'v'}, optional - The axis around which the phase map is flipped. - - Returns - ------- - phasemap_flip: :class:`~.PhaseMap` - A flipped copy of the :class:`~.PhaseMap` object. - - """ - self._log.debug('Calling flip') - if axis == 'u': - return PhaseMap(self.a, np.flipud(self.phase), np.flipud(self.mask), - np.flipud(self.confidence)) - if axis == 'v': - return PhaseMap(self.a, np.fliplr(self.phase), np.fliplr(self.mask), - np.fliplr(self.confidence)) - else: - raise ValueError("Wrong input! 'u', 'v' allowed!") - - def rotate(self, angle): - """Rotate the phase map (right hand rotation). - - Parameters - ---------- - angle: float - The angle around which the phase map is rotated. - - Returns - ------- - phasemap_rot: :class:`~.PhaseMap` - A rotated copy of the :class:`~.PhaseMap` object. - - """ - self._log.debug('Calling rotate') - phase_rot = ndimage.rotate(self.phase, angle, reshape=False) - mask_rot = ndimage.rotate(self.mask, angle, reshape=False, order=0) - conf_rot = ndimage.rotate(self.confidence, angle, reshape=False) - return PhaseMap(self.a, phase_rot, mask_rot, conf_rot) - - def shift(self, shift): - """Shift the phase map (subpixel accuracy). - - Parameters - ---------- - shift : float or sequence, optional - The shift along the axes. If a float, shift is the same for each axis. - If a sequence, shift should contain one value for each axis. - - Returns - ------- - phasemap_shift: :class:`~.PhaseMap` - A shifted copy of the :class:`~.PhaseMap` object. - - """ - self._log.debug('Calling shift') - phase_rot = ndimage.shift(self.phase, shift, mode='constant', cval=0) - mask_rot = ndimage.shift(self.mask, shift, mode='constant', cval=False, order=0) - conf_rot = ndimage.shift(self.confidence, shift, mode='constant', cval=0) - return PhaseMap(self.a, phase_rot, mask_rot, conf_rot) - - @classmethod - def from_signal(cls, signal): - """Convert a :class:`~hyperspy.signals.Image` object to a :class:`~.PhaseMap` object. - - Parameters - ---------- - signal: :class:`~hyperspy.signals.Image` - The :class:`~hyperspy.signals.Image` object which should be converted to a PhaseMap. - - Returns - ------- - phasemap: :class:`~.PhaseMap` - A :class:`~.PhaseMap` object containing the loaded data. - - Notes - ----- - This method recquires the hyperspy package! - - """ - cls._log.debug('Calling from_signal') - # Extract phase: - phase = signal.data - # Extract properties: - a = signal.axes_manager.signal_axes[0].scale - try: - mask = signal.metadata.Signal.mask - confidence = signal.metadata.Signal.confidence - except AttributeError: - mask = None - confidence = None - return cls(a, phase, mask, confidence) - - def to_signal(self): - """Convert :class:`~.PhaseMap` data into a HyperSpy Signal. - - Returns - ------- - signal: :class:`~hyperspy.signals.Signal2D` - Representation of the :class:`~.PhaseMap` object as a HyperSpy Signal. - - Notes - ----- - This method recquires the hyperspy package! - - """ - self._log.debug('Calling to_signal') - try: # Try importing HyperSpy: - # noinspection PyUnresolvedReferences - import hyperspy.api as hs - except ImportError: - self._log.error('This method recquires the hyperspy package!') - return - # Create signal: - signal = hs.signals.Signal2D(self.phase) - # Set axes: - signal.axes_manager.signal_axes[0].name = 'x-axis' - signal.axes_manager.signal_axes[0].units = 'nm' - signal.axes_manager.signal_axes[0].scale = self.a - signal.axes_manager.signal_axes[1].name = 'y-axis' - signal.axes_manager.signal_axes[1].units = 'nm' - signal.axes_manager.signal_axes[1].scale = self.a - # Set metadata: - signal.metadata.Signal.title = 'PhaseMap' - signal.metadata.Signal.unit = 'rad' - signal.metadata.Signal.mask = self.mask - signal.metadata.Signal.confidence = self.confidence - # Create and return signal: - return signal - - def save(self, filename, save_mask=False, save_conf=False, pyramid_format=True, **kwargs): - """Saves the phasemap in the specified format. - - The function gets the format from the extension: - - hdf5 for HDF5. - - rpl for Ripple (useful to export to Digital Micrograph). - - unf for SEMPER unf binary format. - - txt format. - - Many image formats such as png, tiff, jpeg... - - If no extension is provided, 'hdf5' is used. Most formats are - saved with the HyperSpy package (internally the phasemap is first - converted to a HyperSpy Signal. - - Each format accepts a different set of parameters. For details - see the specific format documentation. - - Parameters - ---------- - filename: str, optional - Name of the file which the phasemap is saved into. The extension - determines the saving procedure. - save_mask: boolean, optional - If True, the `mask` is saved, too. For all formats, except HDF5, a separate file will - be created. HDF5 always saves the `mask` in the metadata, independent of this flag. The - default is False. - save_conf: boolean, optional - If True, the `confidence` is saved, too. For all formats, except HDF5, a separate file - will be created. HDF5 always saves the `confidence` in the metadata, independent of - this flag. The default is False - pyramid_format: boolean, optional - Only used for saving to '.txt' files. If this is True, the grid spacing is saved - in an appropriate header. Otherwise just the phase is written with the - corresponding `kwargs`. - - """ - from .file_io.io_phasemap import save_phasemap - save_phasemap(self, filename, save_mask, save_conf, pyramid_format, **kwargs) - - def plot_phase(self, unit='auto', vmin=None, vmax=None, sigma_clip=None, symmetric=True, - show_mask=True, show_conf=True, norm=None, cbar=True, # specific to plot_phase! - cmap=None, interpolation='none', axis=None, figsize=None, **kwargs): - """Display the phasemap as a colormesh. - - Parameters - ---------- - unit: {'rad', 'mrad', 'µrad', '1/rad', '1/mrad', '1/µrad'}, optional - The plotting unit of the phase map. The phase is scaled accordingly before plotting. - Inverse radians should be used for gain maps! - vmin : float, optional - Minimum value used for determining the plot limits. If not set, it will be - determined by the minimum of the phase directly. - vmax : float, optional - Maximum value used for determining the plot limits. If not set, it will be - determined by the minimum of the phase directly. - sigma_clip : int, optional - If this is not `None`, the values outside `sigma_clip` times the standard deviation - will be clipped for the calculation of the plotting `limit`. - symmetric : boolean, optional - If True (default), a zero symmetric colormap is assumed and a zero value (which - will always be present) will be set to the central color color of the colormap. - show_mask : bool, optional - A switch determining if the mask should be plotted or not. Default is True. - show_conf : float, optional - A switch determining if the confidence should be plotted or not. Default is True. - norm : :class:`~matplotlib.colors.Normalize` or subclass, optional - Norm, which is used to determine the colors to encode the phase information. - cbar : bool, optional - If True (default), a colorbar will be plotted. - cmap : string, optional - The :class:`~matplotlib.colors.Colormap` which is used for the plot as a string. - interpolation : {'none, 'bilinear', 'cubic', 'nearest'}, optional - Defines the interpolation method for the holographic contour map. - No interpolation is used in the default case. - axis : :class:`~matplotlib.axes.AxesSubplot`, optional - Axis on which the graph is plotted. Creates a new figure if none is specified. - figsize : tuple of floats (N=2) - Size of the plot figure. - - Returns - ------- - axis, cbar: :class:`~matplotlib.axes.AxesSubplot` - The axis on which the graph is plotted. - - Notes - ----- - Uses :func:`~.plottools.format_axis` at the end. According keywords can also be given here. - - """ - self._log.debug('Calling plot_phase') - a = self.a - if figsize is None: - figsize = plottools.FIGSIZE_DEFAULT - # Take units into consideration: - if unit == 'auto': # Try to automatically determine unit (recommended): - for key, value in self.UNITDICT.items(): - if not key.startswith('1/'): - order = np.floor(np.log10(np.abs(self.phase).max() * value)) - if -1 <= order < 2: - unit = key - if unit == 'auto': # No fitting unit was found: - unit = 'rad' - # Scale phase and make last check if order is okay: - phase = self.phase * self.UNITDICT[unit] - order = np.floor(np.log10(np.abs(phase).max())) - if order > 2 or order < -6: # Display would look bad - unit = '{} x 1E{:g}'.format(unit, order) - phase /= 10 ** order - # Calculate limits if necessary (not necessary if both limits are already set): - if vmin is None and vmax is None: - phase_l = phase - # Clip non-trustworthy regions for the limit calculation: - if show_conf: - phase_trust = np.where(self.confidence > 0.9, phase_l, np.nan) - phase_min, phase_max = np.nanmin(phase_trust), np.nanmax(phase_trust) - phase_l = np.clip(phase_l, phase_min, phase_max) - # Cut outlier beyond a certain sigma-margin: - if sigma_clip is not None: - outlier = np.abs(phase_l - np.mean(phase_l)) < sigma_clip * np.std(phase_l) - phase_sigma = np.where(outlier, phase_l, np.nan) - phase_min, phase_max = np.nanmin(phase_sigma), np.nanmax(phase_sigma) - phase_l = np.clip(phase_l, phase_min, phase_max) - # Calculate the limits if necessary (zero has to be present!): - if vmin is None: - vmin = np.min(phase_l) - if vmax is None: - vmax = np.max(phase_l) - # Configure colormap, to fix white to zero if colormap is symmetric: - if symmetric: - if cmap is None: - cmap = plt.get_cmap('RdBu') - elif isinstance(cmap, str): # Get colormap if given as string: - cmap = plt.get_cmap(cmap) - vmin, vmax = np.min([vmin, -0]), np.max([0, vmax]) # Ensure zero is present! - limit = np.max(np.abs([vmin, vmax])) - start = (vmin + limit) / (2 * limit) - end = (vmax + limit) / (2 * limit) - cmap_colors = cmap(np.linspace(start, end, 256)) - cmap = LinearSegmentedColormap.from_list('Symmetric', cmap_colors) - # If no axis is specified, a new figure is created: - if axis is None: - fig = plt.figure(figsize=figsize) - axis = fig.add_subplot(1, 1, 1) - tight = True - else: - tight = False - axis.set_aspect('equal') - # Plot the phasemap: - im = axis.imshow(phase, cmap=cmap, vmin=vmin, vmax=vmax, interpolation=interpolation, - norm=norm, origin='lower', extent=(0, self.dim_uv[1], 0, self.dim_uv[0])) - if show_mask or show_conf: - vv, uu = np.indices(self.dim_uv) + 0.5 - if show_conf and not np.all(self.confidence == 1.0): - colormap = colors.cmaps['transparent_confidence'] - axis.imshow(self.confidence, cmap=colormap, interpolation=interpolation, - origin='lower', extent=(0, self.dim_uv[1], 0, self.dim_uv[0])) - if show_mask and not np.all(self.mask): # Plot mask if desired and not trivial! - axis.contour(uu, vv, self.mask, levels=[0.5], colors='k', linestyles='dotted', - linewidths=2) - # Determine colorbar title: - cbar_label = kwargs.pop('cbar_label', None) - cbar_mappable = None - if cbar: - cbar_mappable = im - if cbar_label is None: - if unit.startswith('1/'): - cbar_name = 'gain' - else: - cbar_name = 'phase' - if mpl.rcParams['text.usetex'] and 'µ' in unit: # Make sure µ works in latex: - mpl.rc('text.latex', preamble=r'\usepackage{txfonts},\usepackage{lmodern}') - unit = unit.replace('µ', '$\muup$') # Upright µ! - cbar_label = u'{} [{}]'.format(cbar_name, unit) - # Return formatted axis: - return plottools.format_axis(axis, sampling=a, cbar_mappable=cbar_mappable, - cbar_label=cbar_label, tight_layout=tight, **kwargs) - - def plot_holo(self, gain='auto', # specific to plot_holo! - cmap=None, interpolation='none', axis=None, figsize=None, **kwargs): - """Display the color coded holography image. - - Parameters - ---------- - gain : float or 'auto', optional - The gain factor for determining the number of contour lines. The default is 'auto', - which means that the gain will be determined automatically to look pretty. - cmap : string, optional - The :class:`~matplotlib.colors.Colormap` which is used for the plot as a string. - interpolation : {'none, 'bilinear', 'cubic', 'nearest'}, optional - Defines the interpolation method for the holographic contour map. - No interpolation is used in the default case. - axis : :class:`~matplotlib.axes.AxesSubplot`, optional - Axis on which the graph is plotted. Creates a new figure if none is specified. - figsize : tuple of floats (N=2) - Size of the plot figure. - - Returns - ------- - axis: :class:`~matplotlib.axes.AxesSubplot` - The axis on which the graph is plotted. - - Notes - ----- - Uses :func:`~.plottools.format_axis` at the end. According keywords can also be given here. - - """ - self._log.debug('Calling plot_holo') - a = self.a - if figsize is None: - figsize = plottools.FIGSIZE_DEFAULT - # Calculate gain if 'auto' is selected: - if gain == 'auto': - gain = 4 * 2 * np.pi / (np.abs(self.phase).max() + 1E-30) - gain = round(gain, -int(np.floor(np.log10(abs(gain))))) - # Calculate the holography image intensity: - holo = np.cos(gain * self.phase) - holo += 1 # Shift to positive values - holo /= 2 # Rescale to [0, 1] - # Calculate the phase gradients: - # B = rot(A) --> B_x = grad_y(A_z), B_y = -grad_x(A_z); phi_m ~ -int(A_z) - # sign switch --> B_x = -grad_y(phi_m), B_y = grad_x(phi_m) - grad_x, grad_y = np.gradient(self.phase, self.a, self.a) - # Clip outliers: - sigma_clip = 2 - outlier_x = np.abs(grad_x - np.mean(grad_x)) < sigma_clip * np.std(grad_x) - grad_x_sigma = np.where(outlier_x, grad_x, np.nan) - grad_x_min, grad_x_max = np.nanmin(grad_x_sigma), np.nanmax(grad_x_sigma) - grad_x = np.clip(grad_x, grad_x_min, grad_x_max) - outlier_y = np.abs(grad_y - np.mean(grad_y)) < sigma_clip * np.std(grad_y) - grad_y_sigma = np.where(outlier_y, grad_y, np.nan) - grad_y_min, grad_y_max = np.nanmin(grad_y_sigma), np.nanmax(grad_y_sigma) - grad_y = np.clip(grad_y, grad_y_min, grad_y_max) - # Calculate colors: - if cmap is None: - cmap = colors.CMAP_CIRCULAR_DEFAULT - vector = np.asarray((grad_x, -grad_y, np.zeros_like(grad_x))) - rgb = cmap.rgb_from_vector(vector) - rgb = (holo.T * rgb.T).T.astype(np.uint8) - holo_image = Image.fromarray(rgb) - # If no axis is specified, a new figure is created: - if axis is None: - fig = plt.figure(figsize=figsize) - axis = fig.add_subplot(1, 1, 1) - tight = True - else: - tight = False - axis.set_aspect('equal') - # Plot the image and set axes: - axis.imshow(holo_image, origin='lower', interpolation=interpolation, - extent=(0, self.dim_uv[1], 0, self.dim_uv[0])) - note = kwargs.pop('note', None) - if note is None: - note = 'gain: {:g}'.format(gain) - stroke = kwargs.pop('stroke', 'k') # Default for holo is white with black outline! - return plottools.format_axis(axis, sampling=a, note=note, tight_layout=tight, - stroke=stroke, **kwargs) - - def plot_combined(self, title='', phase_title='', holo_title='', figsize=None, **kwargs): - """Display the phase map and the resulting color coded holography image in one plot. - - Parameters - ---------- - title : string, optional - The super title of the plot. The default is 'Combined Plot'. - phase_title : string, optional - The title of the phase map. - holo_title : string, optional - The title of the holographic contour map - figsize : tuple of floats (N=2) - Size of the plot figure. - - Returns - ------- - phase_axis, holo_axis: :class:`~matplotlib.axes.AxesSubplot` - The axes on which the graphs are plotted. - - Notes - ----- - Uses :func:`~.plottools.format_axis` at the end. According keywords can also be given here. - - """ - self._log.debug('Calling plot_combined') - # Create combined plot and set title: - if figsize is None: - figsize = (plottools.FIGSIZE_DEFAULT[0]*2 + 1, plottools.FIGSIZE_DEFAULT[1]) - fig = plt.figure(figsize=figsize) - fig.suptitle(title, fontsize=20) - # Only phase is annotated, holo will show gain: - note = kwargs.pop('note', None) - # Plot holography image: - holo_axis = fig.add_subplot(1, 2, 1, aspect='equal') - self.plot_holo(axis=holo_axis, title=holo_title, note=None, **kwargs) - # Plot phase map: - phase_axis = fig.add_subplot(1, 2, 2, aspect='equal') - self.plot_phase(axis=phase_axis, title=phase_title, note=note, **kwargs) - # Tighten layout if axis was created here: - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - plt.tight_layout() - # Return the plotting axes: - return phase_axis, holo_axis - - def plot_phase_with_hist(self, bins='auto', unit='rad', - title='', phase_title='', hist_title='', figsize=None, **kwargs): - """Display the phase map and a histogram of the phase values of all pixels. - - Parameters - ---------- - bins : int or sequence of scalars or str, optional - Bin argument that goes to the matplotlib.hist function (more documentation there). - The default is 'auto', which tries to pick something nice. - unit: {'rad', 'mrad', 'µrad', '1/rad', '1/mrad', '1/µrad'}, optional - The plotting unit of the phase map. The phase is scaled accordingly before plotting. - Inverse radians should be used for gain maps! - title : string, optional - The super title of the plot. The default is 'Combined Plot'. - phase_title : string, optional - The title of the phase map. - hist_title : string, optional - The title of the histogram. - figsize : tuple of floats (N=2) - Size of the plot figure. - - Returns - ------- - phase_axis, holo_axis: :class:`~matplotlib.axes.AxesSubplot` - The axes on which the graphs are plotted. - - Notes - ----- - Uses :func:`~.plottools.format_axis` at the end. According keywords can also be given here. - - """ - self._log.debug('Calling plot_phase_with_hist') - # Create combined plot and set title: - if figsize is None: - figsize = (plottools.FIGSIZE_DEFAULT[0]*2 + 1, plottools.FIGSIZE_DEFAULT[1]) - fig = plt.figure(figsize=figsize) - fig.suptitle(title, fontsize=20) - # Plot histogram: - hist_axis = fig.add_subplot(1, 2, 1) - vec = self.phase_vec * self.UNITDICT[unit] # Take units into consideration: - vec *= np.where(self.confidence > 0.5, 1, 0).ravel() # Discard low confidence points! - hist_axis.hist(vec, bins=bins, histtype='stepfilled', color='g') - # Format histogram: - x0, x1 = hist_axis.get_xlim() - y0, y1 = hist_axis.get_ylim() - hist_axis.set(aspect=np.abs(x1 - x0) / np.abs(y1 - y0) * 0.94) # Last value because cbar! - fontsize = kwargs.get('fontsize', 16) - hist_axis.tick_params(axis='both', which='major', labelsize=fontsize) - hist_axis.set_title(hist_title, fontsize=fontsize) - hist_axis.set_xlabel('phase [{}]'.format(unit), fontsize=fontsize) - hist_axis.set_ylabel('count', fontsize=fontsize) - # Plot phase map: - phase_axis = fig.add_subplot(1, 2, 2, aspect=1) - self.plot_phase(unit=unit, axis=phase_axis, title=phase_title, **kwargs) - # Tighten layout: - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - plt.tight_layout() - # Return the plotting axes: - return phase_axis, hist_axis - - def plot_phase3d(self, title='Phase Map', unit='rad', cmap='RdBu'): - """Display the phasemap as a 3D surface with contourplots. - - Parameters - ---------- - title : string, optional - The title of the plot. The default is 'Phase Map'. - unit: {'rad', 'mrad', 'µrad'}, optional - The plotting unit of the phase map. The phase is scaled accordingly before plotting. - cmap : string, optional - The :class:`~matplotlib.colors.Colormap` which is used for the plot as a string. - The default is 'RdBu'. - - Returns - ------- - axis: :class:`~matplotlib.axes.AxesSubplot` - The axis on which the graph is plotted. - - """ - self._log.debug('Calling plot_phase3d') - # Take units into consideration: - phase = self.phase * self.UNITDICT[unit] - # Create figure and axis: - fig = plt.figure() - axis = Axes3D(fig) - # Plot surface and contours: - vv, uu = np.indices(self.dim_uv) - axis.plot_surface(uu, vv, phase, rstride=4, cstride=4, alpha=0.7, cmap=cmap, - linewidth=0, antialiased=False) - axis.contourf(uu, vv, phase, 15, zdir='z', offset=np.min(phase), cmap=cmap) - axis.set_title(title) - axis.view_init(45, -135) - axis.set_xlabel('u-axis [px]') - axis.set_ylabel('v-axis [px]') - axis.set_zlabel('phase shift [{}]'.format(unit)) - if self.dim_uv[0] >= self.dim_uv[1]: - u_bin, v_bin = np.max((2, np.floor(9 * self.dim_uv[1] / self.dim_uv[0]))), 9 - else: - u_bin, v_bin = 9, np.max((2, np.floor(9 * self.dim_uv[0] / self.dim_uv[1]))) - axis.xaxis.set_major_locator(MaxNLocator(nbins=u_bin, integer=True)) - axis.yaxis.set_major_locator(MaxNLocator(nbins=v_bin, integer=True)) - # Return plotting axis: - return axis +# -*- coding: utf-8 -*- +# Copyright 2016 by Forschungszentrum Juelich GmbH +# Author: J. Caron +# +"""This module provides the :class:`~.PhaseMap` class for storing phase map data.""" + +import logging + +from numbers import Number + +import numpy as np + +from PIL import Image + +import matplotlib as mpl +import matplotlib.pyplot as plt +from matplotlib.colors import LinearSegmentedColormap +from matplotlib.ticker import MaxNLocator + +from mpl_toolkits.mplot3d import Axes3D + +from scipy import ndimage + +import warnings + +from . import colors +from . import plottools + +__all__ = ['PhaseMap'] + + +class PhaseMap(object): + """Class for storing phase map data. + + Represents 2-dimensional phase maps. The phase information itself is stored as a 2-dimensional + matrix in `phase`, but can also be accessed as a vector via `phase_vec`. :class:`~.PhaseMap` + objects support negation, arithmetic operators (``+``, ``-``, ``*``) and their augmented + counterparts (``+=``, ``-=``, ``*=``), with numbers and other :class:`~.PhaseMap` + objects, if their dimensions and grid spacings match. It is possible to load data from HDF5 + or textfiles or to save the data in these formats. Methods for plotting the phase or a + corresponding holographic contour map are provided. Holographic contour maps are created by + taking the cosine of the (optionally amplified) phase and encoding the direction of the + 2-dimensional gradient via color. The directional encoding can be seen by using the + :func:`~.make_color_wheel` function. Use the :func:`~.plot_combined` function to plot the + phase map and the holographic contour map next to each other. + + Attributes + ---------- + a: float + The grid spacing in nm. + phase: :class:`~numpy.ndarray` (N=2) + Array containing the phase shift. + mask: :class:`~numpy.ndarray` (boolean, N=2, optional) + Mask which determines the projected magnetization distribution, gotten from MIP images or + otherwise acquired. Defaults to an array of ones (all pixels are considered). + confidence: :class:`~numpy.ndarray` (N=2, optional) + Confidence array which determines the trust of specific regions of the phasemap. A value + of 1 means the pixel is trustworthy, a value of 0 means it is not. Defaults to an array of + ones (full trust for all pixels). Can be used for the construction of Se_inv. + + """ + + _log = logging.getLogger(__name__) + + UNITDICT = {u'rad': 1E0, + u'mrad': 1E3, + u'µrad': 1E6, + u'nrad': 1E9, + u'1/rad': 1E0, + u'1/mrad': 1E-3, + u'1/µrad': 1E-6, + u'1/nrad': 1E-9} + + @property + def a(self): + """Grid spacing in nm.""" + return self._a + + @a.setter + def a(self, a): + assert isinstance(a, Number), 'Grid spacing has to be a number!' + assert a >= 0, 'Grid spacing has to be a positive number!' + self._a = float(a) + + @property + def dim_uv(self): + """Dimensions of the grid.""" + return self._dim_uv + + @property + def phase(self): + """Array containing the phase shift.""" + return self._phase + + @phase.setter + def phase(self, phase): + assert isinstance(phase, np.ndarray), 'Phase has to be a numpy array!' + assert len(phase.shape) == 2, 'Phase has to be 2-dimensional, not {}!'.format(phase.shape) + self._phase = phase.astype(dtype=np.float32) + self._dim_uv = phase.shape + + @property + def phase_vec(self): + """Vector containing the phase shift.""" + return self.phase.ravel() + + @phase_vec.setter + def phase_vec(self, phase_vec): + assert isinstance(phase_vec, np.ndarray), 'Vector has to be a numpy array!' + assert np.size(phase_vec) == np.prod(self.dim_uv), 'Vector size has to match phase!' + self.phase = phase_vec.reshape(self.dim_uv) + + @property + def mask(self): + """Mask which determines the projected magnetization distribution""" + return self._mask + + @mask.setter + def mask(self, mask): + if mask is not None: + assert mask.shape == self.phase.shape, 'Mask and phase dimensions must match!!' + else: + mask = np.ones_like(self.phase, dtype=bool) + self._mask = mask.astype(np.bool) + + @property + def confidence(self): + """Confidence array which determines the trust of specific regions of the phasemap.""" + return self._confidence + + @confidence.setter + def confidence(self, confidence): + if confidence is not None: + assert confidence.shape == self.phase.shape, \ + 'Confidence and phase dimensions must match!' + confidence = confidence.astype(dtype=np.float32) + confidence /= confidence.max() # Normalise! + else: + confidence = np.ones_like(self.phase, dtype=np.float32) + self._confidence = confidence + + def __init__(self, a, phase, mask=None, confidence=None): + self._log.debug('Calling __init__') + self.a = a + self.phase = phase + self.mask = mask + self.confidence = confidence + self._log.debug('Created ' + str(self)) + + def __repr__(self): + self._log.debug('Calling __repr__') + return '%s(a=%r, phase=%r, mask=%r, confidence=%r)' % \ + (self.__class__, self.a, self.phase, self.mask, self.confidence) + + def __str__(self): + self._log.debug('Calling __str__') + return 'PhaseMap(a=%s, dim_uv=%s, mask=%s)' % (self.a, self.dim_uv, not np.all(self.mask)) + + def __neg__(self): # -self + self._log.debug('Calling __neg__') + return PhaseMap(self.a, -self.phase, self.mask, self.confidence) + + def __add__(self, other): # self + other + self._log.debug('Calling __add__') + assert isinstance(other, (PhaseMap, Number)), \ + 'Only PhaseMap objects and scalar numbers (as offsets) can be added/subtracted!' + if isinstance(other, PhaseMap): + self._log.debug('Adding two PhaseMap objects') + assert other.a == self.a, 'Added phase has to have the same grid spacing!' + assert other.phase.shape == self.dim_uv, \ + 'Added field has to have the same dimensions!' + mask_comb = np.logical_or(self.mask, other.mask) # masks combine + conf_comb = (self.confidence + other.confidence) / 2 # confidence averaged! + return PhaseMap(self.a, self.phase + other.phase, mask_comb, conf_comb) + else: # other is a Number + self._log.debug('Adding an offset') + return PhaseMap(self.a, self.phase + other, self.mask, self.confidence) + + def __sub__(self, other): # self - other + self._log.debug('Calling __sub__') + return self.__add__(-other) + + def __mul__(self, other): # self * other + self._log.debug('Calling __mul__') + assert (isinstance(other, Number) or + (isinstance(other, np.ndarray) and other.shape == self.dim_uv)), \ + 'PhaseMap objects can only be multiplied by scalar numbers or fitting arrays!' + return PhaseMap(self.a, self.phase * other, self.mask, self.confidence) + + def __truediv__(self, other): # self / other + self._log.debug('Calling __truediv__') + assert (isinstance(other, Number) or + (isinstance(other, np.ndarray) and other.shape == self.dim_uv)), \ + 'PhaseMap objects can only be divided by scalar numbers or fitting arrays!' + return PhaseMap(self.a, self.phase / other, self.mask, self.confidence) + + def __floordiv__(self, other): # self // other + self._log.debug('Calling __floordiv__') + assert (isinstance(other, Number) or + (isinstance(other, np.ndarray) and other.shape == self.dim_uv)), \ + 'PhaseMap objects can only be divided by scalar numbers or fitting arrays!' + return PhaseMap(self.a, self.phase // other, self.mask, self.confidence) + + def __radd__(self, other): # other + self + self._log.debug('Calling __radd__') + return self.__add__(other) + + def __rsub__(self, other): # other - self + self._log.debug('Calling __rsub__') + return -self.__sub__(other) + + def __rmul__(self, other): # other * self + self._log.debug('Calling __rmul__') + return self.__mul__(other) + + def __iadd__(self, other): # self += other + self._log.debug('Calling __iadd__') + return self.__add__(other) + + def __isub__(self, other): # self -= other + self._log.debug('Calling __isub__') + return self.__sub__(other) + + def __imul__(self, other): # self *= other + self._log.debug('Calling __imul__') + return self.__mul__(other) + + def __itruediv__(self, other): # self /= other + self._log.debug('Calling __itruediv__') + return self.__truediv__(other) + + def __ifloordiv__(self, other): # self //= other + self._log.debug('Calling __ifloordiv__') + return self.__floordiv__(other) + + def __getitem__(self, item): + return PhaseMap(self.a, self.phase[item], self.mask[item], self.confidence[item]) + + def __array__(self, dtype=None): # Used for numpy ufuncs, together with __array_wrap__! + if dtype: + return self.phase.astype(dtype) + else: + return self.phase + + def __array_wrap__(self, array, _=None): # _ catches the context, which is not used. + return PhaseMap(self.a, array, self.mask, self.confidence) + + def copy(self): + """Returns a copy of the :class:`~.PhaseMap` object + + Returns + ------- + phasemap: :class:`~.PhaseMap` + A copy of the :class:`~.PhaseMap`. + + """ + self._log.debug('Calling copy') + return PhaseMap(self.a, self.phase.copy(), self.mask.copy(), + self.confidence.copy()) + + def scale_down(self, n=1): + """Scale down the phase map by averaging over two pixels along each axis. + + Parameters + ---------- + n : int, optional + Number of times the phase map is scaled down. The default is 1. + + Returns + ------- + None + + Notes + ----- + Acts in place and changes dimensions and grid spacing accordingly. + Only possible, if each axis length is a power of 2! + + """ + self._log.debug('Calling scale_down') + assert n > 0 and isinstance(n, int), 'n must be a positive integer!' + self.a *= 2 ** n + for t in range(n): + # Pad if necessary: + pv, pu = (0, self.dim_uv[0] % 2), (0, self.dim_uv[1] % 2) + if pv != 0 or pu != 0: + self.pad((pv, pu), mode='edge') + # Create coarser grid for the magnetization: + dim_uv = self.dim_uv + self.phase = self.phase.reshape((dim_uv[0] // 2, 2, + dim_uv[1] // 2, 2)).mean(axis=(3, 1)) + mask = self.mask.reshape(dim_uv[0] // 2, 2, dim_uv[1] // 2, 2) + self.mask = mask[:, 0, :, 0] & mask[:, 1, :, 0] & mask[:, 0, :, 1] & mask[:, 1, :, 1] + self.confidence = self.confidence.reshape(dim_uv[0] // 2, 2, + dim_uv[1] // 2, 2).mean(axis=(3, 1)) + + def scale_up(self, n=1, order=0): + """Scale up the phase map using spline interpolation of the requested order. + + Parameters + ---------- + n : int, optional + Power of 2 with which the grid is scaled. Default is 1, which means every axis is + increased by a factor of ``2**1 = 2``. + order : int, optional + The order of the spline interpolation, which has to be in the range between 0 and 5 + and defaults to 0. + + Returns + ------- + None + + Notes + ----- + Acts in place and changes dimensions and grid spacing accordingly. + + """ + self._log.debug('Calling scale_up') + assert n > 0 and isinstance(n, int), 'n must be a positive integer!' + assert 5 > order >= 0 and isinstance(order, int), \ + 'order must be a positive integer between 0 and 5!' + self.a /= 2 ** n + self.phase = ndimage.zoom(self.phase, zoom=2 ** n, order=order) + self.mask = ndimage.zoom(self.mask, zoom=2 ** n, order=0) + self.confidence = ndimage.zoom(self.confidence, zoom=2 ** n, order=order) + + def pad(self, pad_values, mode='constant', masked=False, **kwds): + """Pad the current phase map with zeros for each individual axis. + + Parameters + ---------- + pad_values : tuple of int + Number of zeros which should be padded. Provided as a tuple where each entry + corresponds to an axis. An entry can be one int (same padding for both sides) or again + a tuple which specifies the pad values for both sides of the corresponding axis. + mode: string or function + A string values or a user supplied function. ‘constant’ pads with zeros. ‘edge’ pads + with the edge values of array. See the numpy pad function for an in depth guide. + masked: boolean, optional + Determines if the padded areas should be masked or not. `True` creates a 'buffer + zone' for the magnetization distribution in the reconstruction. Default is `False` + + Returns + ------- + None + + Notes + ----- + Acts in place and changes dimensions accordingly. + The confidence of the padded areas is set to zero! + + """ + self._log.debug('Calling pad') + assert len(pad_values) == 2, 'Pad values for each dimension have to be provided!' + pval = np.zeros(4, dtype=np.int) + for i, values in enumerate(pad_values): + assert np.shape(values) in [(), (2,)], 'Only one or two values per axis can be given!' + pval[2 * i:2 * (i + 1)] = values + self.phase = np.pad(self.phase, ((pval[0], pval[1]), + (pval[2], pval[3])), mode=mode, **kwds) + self.confidence = np.pad(self.confidence, ((pval[0], pval[1]), + (pval[2], pval[3])), mode=mode, **kwds) + if masked: + mask_kwds = {'mode': 'constant', 'constant_values': True} + else: + mask_kwds = {'mode': mode} + self.mask = np.pad(self.mask, ((pval[0], pval[1]), (pval[2], pval[3])), **mask_kwds) + + def crop(self, crop_values): + """Pad the current phase map with zeros for each individual axis. + + Parameters + ---------- + crop_values : tuple of int + Number of zeros which should be cropped. Provided as a tuple where each entry + corresponds to an axis. An entry can be one int (same cropping for both sides) or again + a tuple which specifies the crop values for both sides of the corresponding axis. + + Returns + ------- + None + + Notes + ----- + Acts in place and changes dimensions accordingly. + + """ + self._log.debug('Calling crop') + assert len(crop_values) == 2, 'Crop values for each dimension have to be provided!' + cv = np.zeros(4, dtype=np.int) + for i, values in enumerate(crop_values): + assert np.shape(values) in [(), (2,)], 'Only one or two values per axis can be given!' + cv[2 * i:2 * (i + 1)] = values + cv *= np.resize([1, -1], len(cv)) + cv = np.where(cv == 0, None, cv) + self.phase = self.phase[cv[0]:cv[1], cv[2]:cv[3]] + self.mask = self.mask[cv[0]:cv[1], cv[2]:cv[3]] + self.confidence = self.confidence[cv[0]:cv[1], cv[2]:cv[3]] + + def flip(self, axis='u'): + """Flip/mirror the phase map around the specified axis. + + Parameters + ---------- + axis: {'u', 'v'}, optional + The axis around which the phase map is flipped. + + Returns + ------- + phasemap_flip: :class:`~.PhaseMap` + A flipped copy of the :class:`~.PhaseMap` object. + + """ + self._log.debug('Calling flip') + if axis == 'u': + return PhaseMap(self.a, np.flipud(self.phase), np.flipud(self.mask), + np.flipud(self.confidence)) + if axis == 'v': + return PhaseMap(self.a, np.fliplr(self.phase), np.fliplr(self.mask), + np.fliplr(self.confidence)) + else: + raise ValueError("Wrong input! 'u', 'v' allowed!") + + def rotate(self, angle): + """Rotate the phase map (right hand rotation). + + Parameters + ---------- + angle: float + The angle around which the phase map is rotated. + + Returns + ------- + phasemap_rot: :class:`~.PhaseMap` + A rotated copy of the :class:`~.PhaseMap` object. + + """ + self._log.debug('Calling rotate') + phase_rot = ndimage.rotate(self.phase, angle, reshape=False) + mask_rot = ndimage.rotate(self.mask, angle, reshape=False, order=0) + conf_rot = ndimage.rotate(self.confidence, angle, reshape=False) + return PhaseMap(self.a, phase_rot, mask_rot, conf_rot) + + def shift(self, shift): + """Shift the phase map (subpixel accuracy). + + Parameters + ---------- + shift : float or sequence, optional + The shift along the axes. If a float, shift is the same for each axis. + If a sequence, shift should contain one value for each axis. + + Returns + ------- + phasemap_shift: :class:`~.PhaseMap` + A shifted copy of the :class:`~.PhaseMap` object. + + """ + self._log.debug('Calling shift') + phase_rot = ndimage.shift(self.phase, shift, mode='constant', cval=0) + mask_rot = ndimage.shift(self.mask, shift, mode='constant', cval=False, order=0) + conf_rot = ndimage.shift(self.confidence, shift, mode='constant', cval=0) + return PhaseMap(self.a, phase_rot, mask_rot, conf_rot) + + @classmethod + def from_signal(cls, signal): + """Convert a :class:`~hyperspy.signals.Image` object to a :class:`~.PhaseMap` object. + + Parameters + ---------- + signal: :class:`~hyperspy.signals.Image` + The :class:`~hyperspy.signals.Image` object which should be converted to a PhaseMap. + + Returns + ------- + phasemap: :class:`~.PhaseMap` + A :class:`~.PhaseMap` object containing the loaded data. + + Notes + ----- + This method recquires the hyperspy package! + + """ + cls._log.debug('Calling from_signal') + # Extract phase: + phase = signal.data + # Extract properties: + a = signal.axes_manager.signal_axes[0].scale + try: + mask = signal.metadata.Signal.mask + confidence = signal.metadata.Signal.confidence + except AttributeError: + mask = None + confidence = None + return cls(a, phase, mask, confidence) + + def to_signal(self): + """Convert :class:`~.PhaseMap` data into a HyperSpy Signal. + + Returns + ------- + signal: :class:`~hyperspy.signals.Signal2D` + Representation of the :class:`~.PhaseMap` object as a HyperSpy Signal. + + Notes + ----- + This method recquires the hyperspy package! + + """ + self._log.debug('Calling to_signal') + try: # Try importing HyperSpy: + # noinspection PyUnresolvedReferences + import hyperspy.api as hs + except ImportError: + self._log.error('This method recquires the hyperspy package!') + return + # Create signal: + signal = hs.signals.Signal2D(self.phase) + # Set axes: + signal.axes_manager.signal_axes[0].name = 'x-axis' + signal.axes_manager.signal_axes[0].units = 'nm' + signal.axes_manager.signal_axes[0].scale = self.a + signal.axes_manager.signal_axes[1].name = 'y-axis' + signal.axes_manager.signal_axes[1].units = 'nm' + signal.axes_manager.signal_axes[1].scale = self.a + # Set metadata: + signal.metadata.Signal.title = 'PhaseMap' + signal.metadata.Signal.unit = 'rad' + signal.metadata.Signal.mask = self.mask + signal.metadata.Signal.confidence = self.confidence + # Create and return signal: + return signal + + def save(self, filename, save_mask=False, save_conf=False, pyramid_format=True, **kwargs): + """Saves the phasemap in the specified format. + + The function gets the format from the extension: + - hdf5 for HDF5. + - rpl for Ripple (useful to export to Digital Micrograph). + - unf for SEMPER unf binary format. + - txt format. + - Many image formats such as png, tiff, jpeg... + + If no extension is provided, 'hdf5' is used. Most formats are + saved with the HyperSpy package (internally the phasemap is first + converted to a HyperSpy Signal. + + Each format accepts a different set of parameters. For details + see the specific format documentation. + + Parameters + ---------- + filename: str, optional + Name of the file which the phasemap is saved into. The extension + determines the saving procedure. + save_mask: boolean, optional + If True, the `mask` is saved, too. For all formats, except HDF5, a separate file will + be created. HDF5 always saves the `mask` in the metadata, independent of this flag. The + default is False. + save_conf: boolean, optional + If True, the `confidence` is saved, too. For all formats, except HDF5, a separate file + will be created. HDF5 always saves the `confidence` in the metadata, independent of + this flag. The default is False + pyramid_format: boolean, optional + Only used for saving to '.txt' files. If this is True, the grid spacing is saved + in an appropriate header. Otherwise just the phase is written with the + corresponding `kwargs`. + + """ + from .file_io.io_phasemap import save_phasemap + save_phasemap(self, filename, save_mask, save_conf, pyramid_format, **kwargs) + + def plot_phase(self, unit='auto', vmin=None, vmax=None, sigma_clip=None, symmetric=True, + show_mask=True, show_conf=True, norm=None, cbar=True, # specific to plot_phase! + cmap=None, interpolation='none', axis=None, figsize=None, **kwargs): + """Display the phasemap as a colormesh. + + Parameters + ---------- + unit: {'rad', 'mrad', 'µrad', '1/rad', '1/mrad', '1/µrad'}, optional + The plotting unit of the phase map. The phase is scaled accordingly before plotting. + Inverse radians should be used for gain maps! + vmin : float, optional + Minimum value used for determining the plot limits. If not set, it will be + determined by the minimum of the phase directly. + vmax : float, optional + Maximum value used for determining the plot limits. If not set, it will be + determined by the minimum of the phase directly. + sigma_clip : int, optional + If this is not `None`, the values outside `sigma_clip` times the standard deviation + will be clipped for the calculation of the plotting `limit`. + symmetric : boolean, optional + If True (default), a zero symmetric colormap is assumed and a zero value (which + will always be present) will be set to the central color color of the colormap. + show_mask : bool, optional + A switch determining if the mask should be plotted or not. Default is True. + show_conf : float, optional + A switch determining if the confidence should be plotted or not. Default is True. + norm : :class:`~matplotlib.colors.Normalize` or subclass, optional + Norm, which is used to determine the colors to encode the phase information. + cbar : bool, optional + If True (default), a colorbar will be plotted. + cmap : string, optional + The :class:`~matplotlib.colors.Colormap` which is used for the plot as a string. + interpolation : {'none, 'bilinear', 'cubic', 'nearest'}, optional + Defines the interpolation method for the holographic contour map. + No interpolation is used in the default case. + axis : :class:`~matplotlib.axes.AxesSubplot`, optional + Axis on which the graph is plotted. Creates a new figure if none is specified. + figsize : tuple of floats (N=2) + Size of the plot figure. + + Returns + ------- + axis, cbar: :class:`~matplotlib.axes.AxesSubplot` + The axis on which the graph is plotted. + + Notes + ----- + Uses :func:`~.plottools.format_axis` at the end. According keywords can also be given here. + + """ + self._log.debug('Calling plot_phase') + a = self.a + if figsize is None: + figsize = plottools.FIGSIZE_DEFAULT + # Take units into consideration: + if unit == 'auto': # Try to automatically determine unit (recommended): + for key, value in self.UNITDICT.items(): + if not key.startswith('1/'): + order = np.floor(np.log10(np.abs(self.phase).max() * value)) + if -1 <= order < 2: + unit = key + if unit == 'auto': # No fitting unit was found: + unit = 'rad' + # Scale phase and make last check if order is okay: + phase = self.phase * self.UNITDICT[unit] + order = np.floor(np.log10(np.abs(phase).max())) + if order > 2 or order < -6: # Display would look bad + unit = '{} x 1E{:g}'.format(unit, order) + phase /= 10 ** order + # Calculate limits if necessary (not necessary if both limits are already set): + if vmin is None and vmax is None: + phase_l = phase + # Clip non-trustworthy regions for the limit calculation: + if show_conf: + phase_trust = np.where(self.confidence > 0.9, phase_l, np.nan) + phase_min, phase_max = np.nanmin(phase_trust), np.nanmax(phase_trust) + phase_l = np.clip(phase_l, phase_min, phase_max) + # Cut outlier beyond a certain sigma-margin: + if sigma_clip is not None: + outlier = np.abs(phase_l - np.mean(phase_l)) < sigma_clip * np.std(phase_l) + phase_sigma = np.where(outlier, phase_l, np.nan) + phase_min, phase_max = np.nanmin(phase_sigma), np.nanmax(phase_sigma) + phase_l = np.clip(phase_l, phase_min, phase_max) + # Calculate the limits if necessary (zero has to be present!): + if vmin is None: + vmin = np.min(phase_l) + if vmax is None: + vmax = np.max(phase_l) + # Configure colormap, to fix white to zero if colormap is symmetric: + if symmetric: + if cmap is None: + cmap = plt.get_cmap('RdBu') + elif isinstance(cmap, str): # Get colormap if given as string: + cmap = plt.get_cmap(cmap) + vmin, vmax = np.min([vmin, -0]), np.max([0, vmax]) # Ensure zero is present! + limit = np.max(np.abs([vmin, vmax])) + start = (vmin + limit) / (2 * limit) + end = (vmax + limit) / (2 * limit) + cmap_colors = cmap(np.linspace(start, end, 256)) + cmap = LinearSegmentedColormap.from_list('Symmetric', cmap_colors) + # If no axis is specified, a new figure is created: + if axis is None: + fig = plt.figure(figsize=figsize) + axis = fig.add_subplot(1, 1, 1) + tight = True + else: + tight = False + axis.set_aspect('equal') + # Plot the phasemap: + im = axis.imshow(phase, cmap=cmap, vmin=vmin, vmax=vmax, interpolation=interpolation, + norm=norm, origin='lower', extent=(0, self.dim_uv[1], 0, self.dim_uv[0])) + if show_mask or show_conf: + vv, uu = np.indices(self.dim_uv) + 0.5 + if show_conf and not np.all(self.confidence == 1.0): + colormap = colors.cmaps['transparent_confidence'] + axis.imshow(self.confidence, cmap=colormap, interpolation=interpolation, + origin='lower', extent=(0, self.dim_uv[1], 0, self.dim_uv[0])) + if show_mask and not np.all(self.mask): # Plot mask if desired and not trivial! + axis.contour(uu, vv, self.mask, levels=[0.5], colors='k', linestyles='dotted', + linewidths=2) + # Determine colorbar title: + cbar_label = kwargs.pop('cbar_label', None) + cbar_mappable = None + if cbar: + cbar_mappable = im + if cbar_label is None: + if unit.startswith('1/'): + cbar_name = 'gain' + else: + cbar_name = 'phase' + if mpl.rcParams['text.usetex'] and 'µ' in unit: # Make sure µ works in latex: + mpl.rc('text.latex', preamble=r'\usepackage{txfonts},\usepackage{lmodern}') + unit = unit.replace('µ', '$\muup$') # Upright µ! + cbar_label = u'{} [{}]'.format(cbar_name, unit) + # Return formatted axis: + return plottools.format_axis(axis, sampling=a, cbar_mappable=cbar_mappable, + cbar_label=cbar_label, tight_layout=tight, **kwargs) + + def plot_holo(self, gain='auto', # specific to plot_holo! + cmap=None, interpolation='none', axis=None, figsize=None, **kwargs): + """Display the color coded holography image. + + Parameters + ---------- + gain : float or 'auto', optional + The gain factor for determining the number of contour lines. The default is 'auto', + which means that the gain will be determined automatically to look pretty. + cmap : string, optional + The :class:`~matplotlib.colors.Colormap` which is used for the plot as a string. + interpolation : {'none, 'bilinear', 'cubic', 'nearest'}, optional + Defines the interpolation method for the holographic contour map. + No interpolation is used in the default case. + axis : :class:`~matplotlib.axes.AxesSubplot`, optional + Axis on which the graph is plotted. Creates a new figure if none is specified. + figsize : tuple of floats (N=2) + Size of the plot figure. + + Returns + ------- + axis: :class:`~matplotlib.axes.AxesSubplot` + The axis on which the graph is plotted. + + Notes + ----- + Uses :func:`~.plottools.format_axis` at the end. According keywords can also be given here. + + """ + self._log.debug('Calling plot_holo') + a = self.a + if figsize is None: + figsize = plottools.FIGSIZE_DEFAULT + # Calculate gain if 'auto' is selected: + if gain == 'auto': + gain = 4 * 2 * np.pi / (np.abs(self.phase).max() + 1E-30) + gain = round(gain, -int(np.floor(np.log10(abs(gain))))) + # Calculate the holography image intensity: + holo = np.cos(gain * self.phase) + holo += 1 # Shift to positive values + holo /= 2 # Rescale to [0, 1] + # Calculate the phase gradients: + # B = rot(A) --> B_x = grad_y(A_z), B_y = -grad_x(A_z); phi_m ~ -int(A_z) + # sign switch --> B_x = -grad_y(phi_m), B_y = grad_x(phi_m) + grad_x, grad_y = np.gradient(self.phase, self.a, self.a) + # Clip outliers: + sigma_clip = 2 + outlier_x = np.abs(grad_x - np.mean(grad_x)) < sigma_clip * np.std(grad_x) + grad_x_sigma = np.where(outlier_x, grad_x, np.nan) + grad_x_min, grad_x_max = np.nanmin(grad_x_sigma), np.nanmax(grad_x_sigma) + grad_x = np.clip(grad_x, grad_x_min, grad_x_max) + outlier_y = np.abs(grad_y - np.mean(grad_y)) < sigma_clip * np.std(grad_y) + grad_y_sigma = np.where(outlier_y, grad_y, np.nan) + grad_y_min, grad_y_max = np.nanmin(grad_y_sigma), np.nanmax(grad_y_sigma) + grad_y = np.clip(grad_y, grad_y_min, grad_y_max) + # Calculate colors: + if cmap is None: + cmap = colors.CMAP_CIRCULAR_DEFAULT + vector = np.asarray((grad_x, -grad_y, np.zeros_like(grad_x))) + rgb = cmap.rgb_from_vector(vector) + rgb = (holo.T * rgb.T).T.astype(np.uint8) + holo_image = Image.fromarray(rgb) + # If no axis is specified, a new figure is created: + if axis is None: + fig = plt.figure(figsize=figsize) + axis = fig.add_subplot(1, 1, 1) + tight = True + else: + tight = False + axis.set_aspect('equal') + # Plot the image and set axes: + axis.imshow(holo_image, origin='lower', interpolation=interpolation, + extent=(0, self.dim_uv[1], 0, self.dim_uv[0])) + note = kwargs.pop('note', None) + if note is None: + note = 'gain: {:g}'.format(gain) + stroke = kwargs.pop('stroke', 'k') # Default for holo is white with black outline! + return plottools.format_axis(axis, sampling=a, note=note, tight_layout=tight, + stroke=stroke, **kwargs) + + def plot_combined(self, title='', phase_title='', holo_title='', figsize=None, **kwargs): + """Display the phase map and the resulting color coded holography image in one plot. + + Parameters + ---------- + title : string, optional + The super title of the plot. The default is 'Combined Plot'. + phase_title : string, optional + The title of the phase map. + holo_title : string, optional + The title of the holographic contour map + figsize : tuple of floats (N=2) + Size of the plot figure. + + Returns + ------- + phase_axis, holo_axis: :class:`~matplotlib.axes.AxesSubplot` + The axes on which the graphs are plotted. + + Notes + ----- + Uses :func:`~.plottools.format_axis` at the end. According keywords can also be given here. + + """ + self._log.debug('Calling plot_combined') + # Create combined plot and set title: + if figsize is None: + figsize = (plottools.FIGSIZE_DEFAULT[0]*2 + 1, plottools.FIGSIZE_DEFAULT[1]) + fig = plt.figure(figsize=figsize) + fig.suptitle(title, fontsize=20) + # Only phase is annotated, holo will show gain: + note = kwargs.pop('note', None) + # Plot holography image: + holo_axis = fig.add_subplot(1, 2, 1, aspect='equal') + self.plot_holo(axis=holo_axis, title=holo_title, note=None, **kwargs) + # Plot phase map: + phase_axis = fig.add_subplot(1, 2, 2, aspect='equal') + self.plot_phase(axis=phase_axis, title=phase_title, note=note, **kwargs) + # Tighten layout if axis was created here: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + plt.tight_layout() + # Return the plotting axes: + return phase_axis, holo_axis + + def plot_phase_with_hist(self, bins='auto', unit='rad', + title='', phase_title='', hist_title='', figsize=None, **kwargs): + """Display the phase map and a histogram of the phase values of all pixels. + + Parameters + ---------- + bins : int or sequence of scalars or str, optional + Bin argument that goes to the matplotlib.hist function (more documentation there). + The default is 'auto', which tries to pick something nice. + unit: {'rad', 'mrad', 'µrad', '1/rad', '1/mrad', '1/µrad'}, optional + The plotting unit of the phase map. The phase is scaled accordingly before plotting. + Inverse radians should be used for gain maps! + title : string, optional + The super title of the plot. The default is 'Combined Plot'. + phase_title : string, optional + The title of the phase map. + hist_title : string, optional + The title of the histogram. + figsize : tuple of floats (N=2) + Size of the plot figure. + + Returns + ------- + phase_axis, holo_axis: :class:`~matplotlib.axes.AxesSubplot` + The axes on which the graphs are plotted. + + Notes + ----- + Uses :func:`~.plottools.format_axis` at the end. According keywords can also be given here. + + """ + self._log.debug('Calling plot_phase_with_hist') + # Create combined plot and set title: + if figsize is None: + figsize = (plottools.FIGSIZE_DEFAULT[0]*2 + 1, plottools.FIGSIZE_DEFAULT[1]) + fig = plt.figure(figsize=figsize) + fig.suptitle(title, fontsize=20) + # Plot histogram: + hist_axis = fig.add_subplot(1, 2, 1) + vec = self.phase_vec * self.UNITDICT[unit] # Take units into consideration: + vec *= np.where(self.confidence > 0.5, 1, 0).ravel() # Discard low confidence points! + hist_axis.hist(vec, bins=bins, histtype='stepfilled', color='g') + # Format histogram: + x0, x1 = hist_axis.get_xlim() + y0, y1 = hist_axis.get_ylim() + hist_axis.set(aspect=np.abs(x1 - x0) / np.abs(y1 - y0) * 0.94) # Last value because cbar! + fontsize = kwargs.get('fontsize', 16) + hist_axis.tick_params(axis='both', which='major', labelsize=fontsize) + hist_axis.set_title(hist_title, fontsize=fontsize) + hist_axis.set_xlabel('phase [{}]'.format(unit), fontsize=fontsize) + hist_axis.set_ylabel('count', fontsize=fontsize) + # Plot phase map: + phase_axis = fig.add_subplot(1, 2, 2, aspect=1) + self.plot_phase(unit=unit, axis=phase_axis, title=phase_title, **kwargs) + # Tighten layout: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + plt.tight_layout() + # Return the plotting axes: + return phase_axis, hist_axis + + def plot_phase3d(self, title='Phase Map', unit='rad', cmap='RdBu'): + """Display the phasemap as a 3D surface with contourplots. + + Parameters + ---------- + title : string, optional + The title of the plot. The default is 'Phase Map'. + unit: {'rad', 'mrad', 'µrad'}, optional + The plotting unit of the phase map. The phase is scaled accordingly before plotting. + cmap : string, optional + The :class:`~matplotlib.colors.Colormap` which is used for the plot as a string. + The default is 'RdBu'. + + Returns + ------- + axis: :class:`~matplotlib.axes.AxesSubplot` + The axis on which the graph is plotted. + + """ + self._log.debug('Calling plot_phase3d') + # Take units into consideration: + phase = self.phase * self.UNITDICT[unit] + # Create figure and axis: + fig = plt.figure() + axis = Axes3D(fig) + # Plot surface and contours: + vv, uu = np.indices(self.dim_uv) + axis.plot_surface(uu, vv, phase, rstride=4, cstride=4, alpha=0.7, cmap=cmap, + linewidth=0, antialiased=False) + axis.contourf(uu, vv, phase, 15, zdir='z', offset=np.min(phase), cmap=cmap) + axis.set_title(title) + axis.view_init(45, -135) + axis.set_xlabel('u-axis [px]') + axis.set_ylabel('v-axis [px]') + axis.set_zlabel('phase shift [{}]'.format(unit)) + if self.dim_uv[0] >= self.dim_uv[1]: + u_bin, v_bin = np.max((2, np.floor(9 * self.dim_uv[1] / self.dim_uv[0]))), 9 + else: + u_bin, v_bin = 9, np.max((2, np.floor(9 * self.dim_uv[0] / self.dim_uv[1]))) + axis.xaxis.set_major_locator(MaxNLocator(nbins=u_bin, integer=True)) + axis.yaxis.set_major_locator(MaxNLocator(nbins=v_bin, integer=True)) + # Return plotting axis: + return axis diff --git a/pyramid/plottools.py b/pyramid/plottools.py index 99b23bc4570c53287a19b803d66556e1521ca4ca..e656283df493f867f7a9d70b431ab551a4a49949 100644 --- a/pyramid/plottools.py +++ b/pyramid/plottools.py @@ -1,317 +1,322 @@ -# -*- coding: utf-8 -*- -# Copyright 2016 by Forschungszentrum Juelich GmbH -# Author: J. Caron -# Adapted from mpl_toolkits.axes_grid2 -"""This module provides the useful plotting utilities.""" - -import numpy as np - -import matplotlib as mpl -import matplotlib.pyplot as plt -from matplotlib.offsetbox import AnchoredOffsetbox, AuxTransformBox, VPacker, TextArea -from matplotlib.transforms import blended_transform_factory, IdentityTransform -from matplotlib.patches import Rectangle -from matplotlib import patheffects -from matplotlib.ticker import MaxNLocator, FuncFormatter - -from mpl_toolkits.axes_grid1 import make_axes_locatable - -import warnings - -from . import colors - -__all__ = ['format_axis', 'pretty_plots', 'add_scalebar', - 'add_annotation', 'add_colorwheel', 'add_cbar'] - -FIGSIZE_DEFAULT = (6.7, 5) -FONTSIZE_DEFAULT = 20 -STROKE_DEFAULT = None - - -def pretty_plots(figsize=None, fontsize=None, stroke=None): - """Set IPython formats (for interactive and PDF output) and set pretty matplotlib font.""" - from IPython.display import set_matplotlib_formats - set_matplotlib_formats('png', 'pdf') # png for interactive, pdf, for PDF output! - mpl.rcParams['mathtext.fontset'] = 'stix' # Mathtext in $...$! - mpl.rcParams['font.family'] = 'STIXGeneral' # Set normal text to look the same! - mpl.rcParams['figure.max_open_warning'] = 0 # Disable Max Open Figure warning! - if figsize is not None: - global FIGSIZE_DEFAULT - FIGSIZE_DEFAULT = figsize - if fontsize is not None: - global FONTSIZE_DEFAULT - FONTSIZE_DEFAULT = fontsize - global STROKE_DEFAULT - STROKE_DEFAULT = stroke - - -def add_scalebar(axis, sampling=1, fontsize=None, stroke=None): - """Add a scalebar to the axis. - - Parameters - ---------- - axis : :class:`~matplotlib.axes.AxesSubplot` - Axis to which the scalebar is added. - sampling : float, optional - The grid spacing in nm. If not given, 1 nm is assumed. - fontsize : int, optional - The fontsize which should be used for the label. Default is 16. - stroke : None or color, optional - If not None, a stroke will be applied to the text, e.g. to make it more visible. - - Returns - ------- - aoffbox : :class:`~matplotlib.offsetbox.AnchoredOffsetbox` - The box containing the scalebar. - - """ - if fontsize is None: - fontsize = FONTSIZE_DEFAULT - if stroke is None: - stroke = STROKE_DEFAULT - # Transform that scales the width along the data and leaves height constant at 8 pt (text): - transform = blended_transform_factory(axis.transData, IdentityTransform()) - # Transform axis borders (1, 1) to data borders to get number of pixels in y and x: - bb0 = axis.transLimits.inverted().transform((0, 0)) - bb1 = axis.transLimits.inverted().transform((1, 1)) - dim_uv = (int(abs(bb1[1] - bb0[1])), int(abs(bb1[0] - bb0[0]))) - # Calculate scale - scale = np.max((dim_uv[1] / 4, 1)) * sampling - thresholds = [1, 5, 10, 50, 100, 500, 1000] - for t in thresholds: # For larger grids (real images), multiples of threshold look better! - if scale > t: - scale = (scale // t) * t - # Set dimensions: - width = scale / sampling # In data coordinate system! - height = 8 # In display coordinate system! - # Set label: - if scale >= 1000: # Use higher order instead! - label = '{:.3g} µm'.format(scale/1000) - else: - label = '{:.3g} nm'.format(scale) - # Create scalebar rectangle: - bars = AuxTransformBox(transform) - bars.add_artist(Rectangle((0, 0), width, height, fc='w', linewidth=1, - clip_box=axis.bbox, clip_on=True)) - # Create text: - txtcolor = 'w' if stroke == 'k' else 'k' - txt = TextArea(label, textprops={'color': txtcolor, 'fontsize': fontsize}) - txt.set_clip_box(axis.bbox) - if stroke is not None: - txt._text.set_path_effects([patheffects.withStroke(linewidth=2, foreground=stroke)]) - # Pack both together an create AnchoredOffsetBox: - bars = VPacker(children=[txt, bars], align="center", pad=0.5, sep=5) - aoffbox = AnchoredOffsetbox(loc=3, pad=0.5, borderpad=0.1, child=bars, frameon=False) - axis.add_artist(aoffbox) - # Return: - return aoffbox - - -def add_annotation(axis, label, stroke=None, fontsize=None): - """Add an annotation to the axis on the upper left corner. - - Parameters - ---------- - axis : :class:`~matplotlib.axes.AxesSubplot` - Axis to which the annotation is added. - label : string - The text of the annotation. - fontsize : int, optional - The fontsize which should be used for the annotation. Default is 16. - stroke : None or color, optional - If not None, a stroke will be applied to the text, e.g. to make it more visible. - - Returns - ------- - aoffbox : :class:`~matplotlib.offsetbox.AnchoredOffsetbox` - The box containing the annotation. - - """ - if fontsize is None: - fontsize = FONTSIZE_DEFAULT - if stroke is None: - stroke = STROKE_DEFAULT - # Create text: - txtcolor = 'w' if stroke == 'k' else 'k' - txt = TextArea(label, textprops={'color': txtcolor, 'fontsize': fontsize}) - txt.set_clip_box(axis.bbox) - if stroke is not None: - txt._text.set_path_effects([patheffects.withStroke(linewidth=2, foreground=stroke)]) - # Pack into and add AnchoredOffsetBox: - aoffbox = AnchoredOffsetbox(loc=2, pad=0.5, borderpad=0.1, child=txt, frameon=False) - axis.add_artist(aoffbox) - return aoffbox - - -def add_colorwheel(axis): - """Add a colorwheel to the axis on the upper right corner. - - Parameters - ---------- - axis : :class:`~matplotlib.axes.AxesSubplot` - Axis to which the colorwheel is added. - - Returns - ------- - axis : :class:`~matplotlib.axes.AxesSubplot` - The same axis which was given to this function is returned. - - """ - from mpl_toolkits.axes_grid.inset_locator import inset_axes - inset_axes = inset_axes(axis, width=0.75, height=0.75, loc=1) - inset_axes.axis('off') - cmap = colors.CMAP_CIRCULAR_DEFAULT - return cmap.plot_colorwheel(size=100, axis=inset_axes, alpha=0, arrows=True) - - -def add_cbar(axis, mappable, label='', fontsize=None): - """Add a colorbar to the right of the given axis. - - Parameters - ---------- - axis : :class:`~matplotlib.axes.AxesSubplot` - Axis to which the colorbar is added. - mappable : mappable pyplot - If this is not None, a colorbar will be plotted on the right side of the axes, - label : string, optional - The label of the colorbar. If not set, no label is used. - fontsize : int, optional - The fontsize which should be used for the label. Default is 16. - - Returns - ------- - cbar : :class:`~matplotlib.Colorbar` - The created colorbar. - - """ - if fontsize is None: - fontsize = FONTSIZE_DEFAULT - divider = make_axes_locatable(axis) - cbar_ax = divider.append_axes('right', size='5%', pad=0.1) - cbar = plt.colorbar(mappable, cax=cbar_ax) - cbar.ax.tick_params(labelsize=fontsize) - # Make sure labels don't stick out of tight bbox: - labels = cbar.ax.get_yticklabels() - delta = 0.03 * (cbar.vmax - cbar.vmin) - lmin = float(labels[0]._text.replace(u'\u2212', '-').strip('$')) # No unicode or latex! - lmax = float(labels[-1]._text.replace(u'\u2212', '-').strip('$')) # No unicode or latex! - redo_max = True if cbar.vmax - lmax < delta else False - redo_min = True if lmin - cbar.vmin < delta else False - mappable.set_clim(cbar.vmin - delta * redo_min, cbar.vmax + delta * redo_max) - # A lot of plotting magic to make plot width consistent (mostly): - cbar.ax.set_yticklabels(labels, ha='right') - renderer = plt.gcf().canvas.get_renderer() - bbox_0 = labels[0].get_window_extent(renderer) - bbox_1 = labels[-1].get_window_extent(renderer) - cbar_pad = np.max((bbox_0.width, bbox_1.width)) + 5 # bit of padding left! - cbar.ax.yaxis.set_tick_params(pad=cbar_pad) - max_txt = plt.text(0, 0, u'\u22120.00', fontsize=fontsize) - bbox_max = max_txt.get_window_extent(renderer) - max_txt.remove() - cbar_pad_max = bbox_max.width + 10 # bit of padding right! - cbar.set_label(label, fontsize=fontsize, labelpad=max(cbar_pad_max - cbar_pad, 0)) - # Set focus back to axis and return cbar: - plt.sca(axis) - return cbar - - -def format_axis(axis, format_axis=True, title='', fontsize=None, stroke=None, scalebar=True, - hideaxes=None, sampling=1, note=None, colorwheel=False, cbar_mappable=None, - cbar_label='', tight_layout=True, **_): - """Format an axis and add a lot of nice features. - - Parameters - ---------- - axis : :class:`~matplotlib.axes.AxesSubplot` - Axis on which the graph is plotted. - format_axis : bool, optional - If False, the formatting will be skipped (the axis is still returned). Default is True. - title : string, optional - The title of the plot. The default is an empty string''. - fontsize : int, optional - The fontsize which should be used for labels and titles. Default is 16. - stroke : None or color, optional - If not None, a stroke will be applied to the text, e.g. to make it more visible. - scalebar : bool, optional - Defines if a scalebar should be plotted in the lower left corner (default: True). Axes - are made invisible. If set to False, the axes are formatted to ook pretty, instead. - hideaxes : True, optional - If True, the axes will be turned invisible. If not specified (None), this is True if a - scalebar is plotted, False otherwise. - sampling : float, optional - The grid spacing in nm. If not given, 1 nm is assumed. - note: string or None, optional - An annotation string which is displayed in the upper left - colorwheel : bool, optional - Defines if a colorwheel should be plotted in the upper right corner (default: False). - cbar_mappable : mappable pyplot or None, optional - If this is not None, a colorbar will be plotted on the right side of the axes, - which uses this mappable object. - cbar_label : string, optional - The label of the colorbar. If `None`, no label is used. - tight_layout : bool, optional - If True, `plt.tight_layout()` is executed after the formatting. Default is True. - - Returns - ------- - axis : :class:`~matplotlib.axes.AxesSubplot` - The same axis which was given to this function is returned. - - """ - if not format_axis: # Skip (sometimes useful if more than one plot is used on the same axis! - return axis - if fontsize is None: - fontsize = FONTSIZE_DEFAULT - if stroke is None: - stroke = STROKE_DEFAULT - # Get dimensions: - bb0 = axis.transLimits.inverted().transform((0, 0)) - bb1 = axis.transLimits.inverted().transform((1, 1)) - dim_uv = (int(abs(bb1[1] - bb0[1])), int(abs(bb1[0] - bb0[0]))) - # Set the title and the axes labels: - axis.set_xlim(0, dim_uv[1]) - axis.set_ylim(0, dim_uv[0]) - axis.set_title(title, fontsize=fontsize) - # Add scalebar: - if scalebar: - add_scalebar(axis, sampling=sampling, fontsize=fontsize, stroke=stroke) - if hideaxes is None: - hideaxes = True - # Determine major tick locations (useful for grid, even if ticks will not be used): - if dim_uv[0] >= dim_uv[1]: - u_bin, v_bin = np.max((2, np.floor(9 * dim_uv[1] / dim_uv[0]))), 9 - else: - u_bin, v_bin = 9, np.max((2, np.floor(9 * dim_uv[0] / dim_uv[1]))) - axis.xaxis.set_major_locator(MaxNLocator(nbins=u_bin, integer=True)) - axis.yaxis.set_major_locator(MaxNLocator(nbins=v_bin, integer=True)) - # Hide axes label and ticks if wanted: - if hideaxes: - for tic in axis.xaxis.get_major_ticks(): - tic.tick1On = tic.tick2On = False - tic.label1On = tic.label2On = False - for tic in axis.yaxis.get_major_ticks(): - tic.tick1On = tic.tick2On = False - tic.label1On = tic.label2On = False - else: # Set the axes ticks and labels: - axis.set_xlabel('u-axis [nm]', fontsize=fontsize) - axis.set_ylabel('v-axis [nm]', fontsize=fontsize) - axis.xaxis.set_major_formatter(FuncFormatter(lambda x, pos: '{:.3g}'.format(x * sampling))) - axis.yaxis.set_major_formatter(FuncFormatter(lambda x, pos: '{:.3g}'.format(x * sampling))) - axis.tick_params(axis='both', which='major', labelsize=fontsize) - # Add annotation: - if note: - add_annotation(axis, label=note, fontsize=fontsize, stroke=stroke) - # Add colorhweel: - if colorwheel: - add_colorwheel(axis) - # Add colorbar: - if cbar_mappable: - # Construct colorbar: - add_cbar(axis, mappable=cbar_mappable, label=cbar_label, fontsize=fontsize) - # Tighten layout if axis was created here: - if tight_layout: - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - plt.tight_layout() - # Return plotting axis: - return axis +# -*- coding: utf-8 -*- +# Copyright 2016 by Forschungszentrum Juelich GmbH +# Author: J. Caron +# Adapted from mpl_toolkits.axes_grid2 +"""This module provides the useful plotting utilities.""" + +import numpy as np + +import matplotlib as mpl +import matplotlib.pyplot as plt +from matplotlib.offsetbox import AnchoredOffsetbox, AuxTransformBox, VPacker, TextArea +from matplotlib.transforms import blended_transform_factory, IdentityTransform +from matplotlib.patches import Rectangle +from matplotlib import patheffects +from matplotlib.ticker import MaxNLocator, FuncFormatter + +from mpl_toolkits.axes_grid1 import make_axes_locatable + +import warnings + +from . import colors + +__all__ = ['format_axis', 'pretty_plots', 'add_scalebar', + 'add_annotation', 'add_colorwheel', 'add_cbar'] + +FIGSIZE_DEFAULT = (6.7, 5) +FONTSIZE_DEFAULT = 20 +STROKE_DEFAULT = None + + +def pretty_plots(figsize=None, fontsize=None, stroke=None): + """Set IPython formats (for interactive and PDF output) and set pretty matplotlib font.""" + from IPython.display import set_matplotlib_formats + set_matplotlib_formats('png', 'pdf') # png for interactive, pdf, for PDF output! + mpl.rcParams['mathtext.fontset'] = 'stix' # Mathtext in $...$! + mpl.rcParams['font.family'] = 'STIXGeneral' # Set normal text to look the same! + mpl.rcParams['figure.max_open_warning'] = 0 # Disable Max Open Figure warning! + if figsize is not None: + global FIGSIZE_DEFAULT + FIGSIZE_DEFAULT = figsize + mpl.rcParams['figure.figsize'] = FIGSIZE_DEFAULT + if fontsize is not None: + global FONTSIZE_DEFAULT + FONTSIZE_DEFAULT = fontsize + mpl.rcParams['xtick.labelsize'] = FONTSIZE_DEFAULT + mpl.rcParams['ytick.labelsize'] = FONTSIZE_DEFAULT + mpl.rcParams['axes.labelsize'] = FONTSIZE_DEFAULT + mpl.rcParams['legend.fontsize'] = FONTSIZE_DEFAULT + global STROKE_DEFAULT + STROKE_DEFAULT = stroke + + +def add_scalebar(axis, sampling=1, fontsize=None, stroke=None): + """Add a scalebar to the axis. + + Parameters + ---------- + axis : :class:`~matplotlib.axes.AxesSubplot` + Axis to which the scalebar is added. + sampling : float, optional + The grid spacing in nm. If not given, 1 nm is assumed. + fontsize : int, optional + The fontsize which should be used for the label. Default is 16. + stroke : None or color, optional + If not None, a stroke will be applied to the text, e.g. to make it more visible. + + Returns + ------- + aoffbox : :class:`~matplotlib.offsetbox.AnchoredOffsetbox` + The box containing the scalebar. + + """ + if fontsize is None: + fontsize = FONTSIZE_DEFAULT + if stroke is None: + stroke = STROKE_DEFAULT + # Transform that scales the width along the data and leaves height constant at 8 pt (text): + transform = blended_transform_factory(axis.transData, IdentityTransform()) + # Transform axis borders (1, 1) to data borders to get number of pixels in y and x: + bb0 = axis.transLimits.inverted().transform((0, 0)) + bb1 = axis.transLimits.inverted().transform((1, 1)) + dim_uv = (int(abs(bb1[1] - bb0[1])), int(abs(bb1[0] - bb0[0]))) + # Calculate scale + scale = np.max((dim_uv[1] / 4, 1)) * sampling + thresholds = [1, 5, 10, 50, 100, 500, 1000] + for t in thresholds: # For larger grids (real images), multiples of threshold look better! + if scale > t: + scale = (scale // t) * t + # Set dimensions: + width = scale / sampling # In data coordinate system! + height = 8 # In display coordinate system! + # Set label: + if scale >= 1000: # Use higher order instead! + label = '{:.3g} µm'.format(scale/1000) + else: + label = '{:.3g} nm'.format(scale) + # Create scalebar rectangle: + bars = AuxTransformBox(transform) + bars.add_artist(Rectangle((0, 0), width, height, fc='w', linewidth=1, + clip_box=axis.bbox, clip_on=True)) + # Create text: + txtcolor = 'w' if stroke == 'k' else 'k' + txt = TextArea(label, textprops={'color': txtcolor, 'fontsize': fontsize}) + txt.set_clip_box(axis.bbox) + if stroke is not None: + txt._text.set_path_effects([patheffects.withStroke(linewidth=2, foreground=stroke)]) + # Pack both together an create AnchoredOffsetBox: + bars = VPacker(children=[txt, bars], align="center", pad=0.5, sep=5) + aoffbox = AnchoredOffsetbox(loc=3, pad=0.5, borderpad=0.1, child=bars, frameon=False) + axis.add_artist(aoffbox) + # Return: + return aoffbox + + +def add_annotation(axis, label, stroke=None, fontsize=None): + """Add an annotation to the axis on the upper left corner. + + Parameters + ---------- + axis : :class:`~matplotlib.axes.AxesSubplot` + Axis to which the annotation is added. + label : string + The text of the annotation. + fontsize : int, optional + The fontsize which should be used for the annotation. Default is 16. + stroke : None or color, optional + If not None, a stroke will be applied to the text, e.g. to make it more visible. + + Returns + ------- + aoffbox : :class:`~matplotlib.offsetbox.AnchoredOffsetbox` + The box containing the annotation. + + """ + if fontsize is None: + fontsize = FONTSIZE_DEFAULT + if stroke is None: + stroke = STROKE_DEFAULT + # Create text: + txtcolor = 'w' if stroke == 'k' else 'k' + txt = TextArea(label, textprops={'color': txtcolor, 'fontsize': fontsize}) + txt.set_clip_box(axis.bbox) + if stroke is not None: + txt._text.set_path_effects([patheffects.withStroke(linewidth=2, foreground=stroke)]) + # Pack into and add AnchoredOffsetBox: + aoffbox = AnchoredOffsetbox(loc=2, pad=0.5, borderpad=0.1, child=txt, frameon=False) + axis.add_artist(aoffbox) + return aoffbox + + +def add_colorwheel(axis): + """Add a colorwheel to the axis on the upper right corner. + + Parameters + ---------- + axis : :class:`~matplotlib.axes.AxesSubplot` + Axis to which the colorwheel is added. + + Returns + ------- + axis : :class:`~matplotlib.axes.AxesSubplot` + The same axis which was given to this function is returned. + + """ + from mpl_toolkits.axes_grid.inset_locator import inset_axes + inset_axes = inset_axes(axis, width=0.75, height=0.75, loc=1) + inset_axes.axis('off') + cmap = colors.CMAP_CIRCULAR_DEFAULT + return cmap.plot_colorwheel(size=100, axis=inset_axes, alpha=0, arrows=True) + + +def add_cbar(axis, mappable, label='', fontsize=None): + """Add a colorbar to the right of the given axis. + + Parameters + ---------- + axis : :class:`~matplotlib.axes.AxesSubplot` + Axis to which the colorbar is added. + mappable : mappable pyplot + If this is not None, a colorbar will be plotted on the right side of the axes, + label : string, optional + The label of the colorbar. If not set, no label is used. + fontsize : int, optional + The fontsize which should be used for the label. Default is 16. + + Returns + ------- + cbar : :class:`~matplotlib.Colorbar` + The created colorbar. + + """ + if fontsize is None: + fontsize = FONTSIZE_DEFAULT + divider = make_axes_locatable(axis) + cbar_ax = divider.append_axes('right', size='5%', pad=0.1) + cbar = plt.colorbar(mappable, cax=cbar_ax) + cbar.ax.tick_params(labelsize=fontsize) + # Make sure labels don't stick out of tight bbox: + labels = cbar.ax.get_yticklabels() + delta = 0.03 * (cbar.vmax - cbar.vmin) + lmin = float(labels[0]._text.replace(u'\u2212', '-').strip('$')) # No unicode or latex! + lmax = float(labels[-1]._text.replace(u'\u2212', '-').strip('$')) # No unicode or latex! + redo_max = True if cbar.vmax - lmax < delta else False + redo_min = True if lmin - cbar.vmin < delta else False + mappable.set_clim(cbar.vmin - delta * redo_min, cbar.vmax + delta * redo_max) + # A lot of plotting magic to make plot width consistent (mostly): + cbar.ax.set_yticklabels(labels, ha='right') + renderer = plt.gcf().canvas.get_renderer() + bbox_0 = labels[0].get_window_extent(renderer) + bbox_1 = labels[-1].get_window_extent(renderer) + cbar_pad = np.max((bbox_0.width, bbox_1.width)) + 5 # bit of padding left! + cbar.ax.yaxis.set_tick_params(pad=cbar_pad) + max_txt = plt.text(0, 0, u'\u22120.00', fontsize=fontsize) + bbox_max = max_txt.get_window_extent(renderer) + max_txt.remove() + cbar_pad_max = bbox_max.width + 10 # bit of padding right! + cbar.set_label(label, fontsize=fontsize, labelpad=max(cbar_pad_max - cbar_pad, 0)) + # Set focus back to axis and return cbar: + plt.sca(axis) + return cbar + + +def format_axis(axis, format_axis=True, title='', fontsize=None, stroke=None, scalebar=True, + hideaxes=None, sampling=1, note=None, colorwheel=False, cbar_mappable=None, + cbar_label='', tight_layout=True, **_): + """Format an axis and add a lot of nice features. + + Parameters + ---------- + axis : :class:`~matplotlib.axes.AxesSubplot` + Axis on which the graph is plotted. + format_axis : bool, optional + If False, the formatting will be skipped (the axis is still returned). Default is True. + title : string, optional + The title of the plot. The default is an empty string''. + fontsize : int, optional + The fontsize which should be used for labels and titles. Default is 16. + stroke : None or color, optional + If not None, a stroke will be applied to the text, e.g. to make it more visible. + scalebar : bool, optional + Defines if a scalebar should be plotted in the lower left corner (default: True). Axes + are made invisible. If set to False, the axes are formatted to ook pretty, instead. + hideaxes : True, optional + If True, the axes will be turned invisible. If not specified (None), this is True if a + scalebar is plotted, False otherwise. + sampling : float, optional + The grid spacing in nm. If not given, 1 nm is assumed. + note: string or None, optional + An annotation string which is displayed in the upper left + colorwheel : bool, optional + Defines if a colorwheel should be plotted in the upper right corner (default: False). + cbar_mappable : mappable pyplot or None, optional + If this is not None, a colorbar will be plotted on the right side of the axes, + which uses this mappable object. + cbar_label : string, optional + The label of the colorbar. If `None`, no label is used. + tight_layout : bool, optional + If True, `plt.tight_layout()` is executed after the formatting. Default is True. + + Returns + ------- + axis : :class:`~matplotlib.axes.AxesSubplot` + The same axis which was given to this function is returned. + + """ + if not format_axis: # Skip (sometimes useful if more than one plot is used on the same axis! + return axis + if fontsize is None: + fontsize = FONTSIZE_DEFAULT + if stroke is None: + stroke = STROKE_DEFAULT + # Get dimensions: + bb0 = axis.transLimits.inverted().transform((0, 0)) + bb1 = axis.transLimits.inverted().transform((1, 1)) + dim_uv = (int(abs(bb1[1] - bb0[1])), int(abs(bb1[0] - bb0[0]))) + # Set the title and the axes labels: + axis.set_xlim(0, dim_uv[1]) + axis.set_ylim(0, dim_uv[0]) + axis.set_title(title, fontsize=fontsize) + # Add scalebar: + if scalebar: + add_scalebar(axis, sampling=sampling, fontsize=fontsize, stroke=stroke) + if hideaxes is None: + hideaxes = True + # Determine major tick locations (useful for grid, even if ticks will not be used): + if dim_uv[0] >= dim_uv[1]: + u_bin, v_bin = np.max((2, np.floor(9 * dim_uv[1] / dim_uv[0]))), 9 + else: + u_bin, v_bin = 9, np.max((2, np.floor(9 * dim_uv[0] / dim_uv[1]))) + axis.xaxis.set_major_locator(MaxNLocator(nbins=u_bin, integer=True)) + axis.yaxis.set_major_locator(MaxNLocator(nbins=v_bin, integer=True)) + # Hide axes label and ticks if wanted: + if hideaxes: + for tic in axis.xaxis.get_major_ticks(): + tic.tick1On = tic.tick2On = False + tic.label1On = tic.label2On = False + for tic in axis.yaxis.get_major_ticks(): + tic.tick1On = tic.tick2On = False + tic.label1On = tic.label2On = False + else: # Set the axes ticks and labels: + axis.set_xlabel('u-axis [nm]', fontsize=fontsize) + axis.set_ylabel('v-axis [nm]', fontsize=fontsize) + axis.xaxis.set_major_formatter(FuncFormatter(lambda x, pos: '{:.3g}'.format(x * sampling))) + axis.yaxis.set_major_formatter(FuncFormatter(lambda x, pos: '{:.3g}'.format(x * sampling))) + axis.tick_params(axis='both', which='major', labelsize=fontsize) + # Add annotation: + if note: + add_annotation(axis, label=note, fontsize=fontsize, stroke=stroke) + # Add colorhweel: + if colorwheel: + add_colorwheel(axis) + # Add colorbar: + if cbar_mappable: + # Construct colorbar: + add_cbar(axis, mappable=cbar_mappable, label=cbar_label, fontsize=fontsize) + # Tighten layout if axis was created here: + if tight_layout: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + plt.tight_layout() + # Return plotting axis: + return axis diff --git a/pyramid/projector.py b/pyramid/projector.py index 8bbba0158865cecb6d7aa38f98684ba4ed3cd901..eddcebbc4b3cbcee4b25124a707692469b21ace0 100644 --- a/pyramid/projector.py +++ b/pyramid/projector.py @@ -1,709 +1,709 @@ -# -*- coding: utf-8 -*- -# Copyright 2014 by Forschungszentrum Juelich GmbH -# Author: J. Caron -# -"""This module provides the abstract base class :class:`~.Projector` and concrete subclasses for -projections of vector and scalar fields.""" - -import itertools -import logging - -try: - if type(get_ipython()).__name__ == 'ZMQInteractiveShell': # IPython Notebook! - from tqdm import tqdm_notebook as tqdm - else: # IPython, but not a Notebook (e.g. terminal) - from tqdm import tqdm -except NameError: - from tqdm import tqdm - -import numpy as np -from numpy import pi -from scipy.sparse import coo_matrix, csr_matrix - -from pyramid.fielddata import VectorData, ScalarData -from pyramid.quaternion import Quaternion - -__all__ = ['RotTiltProjector', 'XTiltProjector', 'YTiltProjector', 'SimpleProjector'] - - -class Projector(object): - """Base class representing a projection function. - - The :class:`~.Projector` class represents a projection function for a 3-dimensional - vector- or scalar field onto a 2-dimensional grid. :class:`~.Projector` is an abstract base - class and provides a unified interface which should be subclassed with a custom - :func:`__init__` function, which should call the parent :func:`__init__` method. Concrete - subclasses can be called as a function and take a `vector` as argument which contains the - 3-dimensional field. The output is the projected field, given as a `vector`. Depending on the - length of the input and the given dimensions `dim` at construction time, vector or scalar - projection is choosen intelligently. - - Attributes - ---------- - dim : tuple (N=3) - Dimensions (z, y, x) of the magnetization distribution. - dim_uv : tuple (N=2) - Dimensions (v, u) of the projected grid. - size_3d : int - Number of voxels of the 3-dimensional grid. - size_2d : int - Number of pixels of the 2-dimensional projected grid. - weight : :class:`~scipy.sparse.csr_matrix` (N=2) - The weight matrix containing the weighting coefficients for the 3D to 2D mapping. - coeff : list (N=2) - List containing the six weighting coefficients describing the influence of the 3 components - of a 3-dimensional vector field on the 2 projected components. - m: int - Size of the image space. - n: int - Size of the input space. - sparsity : float - Measures the sparsity of the weighting (not the complete one!), 1 means completely sparse! - - """ - - _log = logging.getLogger(__name__ + '.Projector') - - @property - def sparsity(self): - """The sparsity of the projector weight matrix.""" - return 1. - len(self.weight.data) / np.prod(self.weight.shape) - - def __init__(self, dim, dim_uv, weight, coeff): - self._log.debug('Calling __init__') - self.dim = tuple(dim) - self.dim_uv = tuple(dim_uv) - self.weight = weight - self.coeff = coeff - self.size_2d, self.size_3d = weight.shape - self.n = 3 * np.prod(dim) - self.m = 2 * np.prod(dim_uv) - self._log.debug('Created ' + str(self)) - - def __repr__(self): - self._log.debug('Calling __repr__') - return '%s(dim=%r, dim_uv=%r, weight=%r, coeff=%r)' % \ - (self.__class__, self.dim, self.dim_uv, self.weight, self.coeff) - - def __str__(self): - self._log.debug('Calling __str__') - return 'Projector(dim=%s, dim_uv=%s, coeff=%s)' % (self.dim, self.dim_uv, self.coeff) - - def __call__(self, field_data): - if isinstance(field_data, VectorData): - field_empty = np.zeros((3, 1) + self.dim_uv, dtype=field_data.field.dtype) - field_data_proj = VectorData(field_data.a, field_empty) - field_proj = self.jac_dot(field_data.field_vec).reshape((2,) + self.dim_uv) - field_data_proj.field[0:2, 0, ...] = field_proj - elif isinstance(field_data, ScalarData): - field_empty = np.zeros((1,) + self.dim_uv, dtype=field_data.field.dtype) - field_data_proj = ScalarData(field_data.a, field_empty) - field_proj = self.jac_dot(field_data.field_vec).reshape(self.dim_uv) - field_data_proj.field[0, ...] = field_proj - else: - raise TypeError('Input is neither of type VectorData or ScalarData') - return field_data_proj - - def _vector_field_projection(self, vector): - result = np.zeros(2 * self.size_2d, dtype=vector.dtype) - # Go over all possible component projections (z, y, x) to (u, v): - vec_x, vec_y, vec_z = np.split(vector, 3) - vec_x_weighted = self.weight.dot(vec_x) - vec_y_weighted = self.weight.dot(vec_y) - vec_z_weighted = self.weight.dot(vec_z) - slice_u = slice(0, self.size_2d) - slice_v = slice(self.size_2d, 2 * self.size_2d) - if self.coeff[0][0] != 0: # x to u - result[slice_u] += self.coeff[0][0] * vec_x_weighted - if self.coeff[0][1] != 0: # y to u - result[slice_u] += self.coeff[0][1] * vec_y_weighted - if self.coeff[0][2] != 0: # z to u - result[slice_u] += self.coeff[0][2] * vec_z_weighted - if self.coeff[1][0] != 0: # x to v - result[slice_v] += self.coeff[1][0] * vec_x_weighted - if self.coeff[1][1] != 0: # y to v - result[slice_v] += self.coeff[1][1] * vec_y_weighted - if self.coeff[1][2] != 0: # z to v - result[slice_v] += self.coeff[1][2] * vec_z_weighted - return result - - def _vector_field_projection_T(self, vector): - result = np.zeros(3 * self.size_3d) - # Go over all possible component projections (u, v) to (z, y, x): - vec_u, vec_v = np.split(vector, 2) - vec_u_weighted = self.weight.T.dot(vec_u) - vec_v_weighted = self.weight.T.dot(vec_v) - slice_x = slice(0, self.size_3d) - slice_y = slice(self.size_3d, 2 * self.size_3d) - slice_z = slice(2 * self.size_3d, 3 * self.size_3d) - if self.coeff[0][0] != 0: # u to x - result[slice_x] += self.coeff[0][0] * vec_u_weighted - if self.coeff[0][1] != 0: # u to y - result[slice_y] += self.coeff[0][1] * vec_u_weighted - if self.coeff[0][2] != 0: # u to z - result[slice_z] += self.coeff[0][2] * vec_u_weighted - if self.coeff[1][0] != 0: # v to x - result[slice_x] += self.coeff[1][0] * vec_v_weighted - if self.coeff[1][1] != 0: # v to y - result[slice_y] += self.coeff[1][1] * vec_v_weighted - if self.coeff[1][2] != 0: # v to z - result[slice_z] += self.coeff[1][2] * vec_v_weighted - return result - - def _scalar_field_projection(self, vector): - self._log.debug('Calling _scalar_field_projection') - return np.array(self.weight.dot(vector)) - - def _scalar_field_projection_T(self, vector): - self._log.debug('Calling _scalar_field_projection_T') - return np.array(self.weight.T.dot(vector)) - - def jac_dot(self, vector): - """Multiply a `vector` with the jacobi matrix of this :class:`~.Projector` object. - - Parameters - ---------- - vector : :class:`~numpy.ndarray` (N=1) - Vector containing the field which should be projected. Must have the same or 3 times - the size of `size_3d` of the projector for scalar and vector projection, respectively. - - Returns - ------- - proj_vector : :class:`~numpy.ndarray` (N=1) - Vector containing the projected field of the 2-dimensional grid. The length is - always`size_2d`. - - """ - if len(vector) == 3 * self.size_3d: # mode == 'vector' - return self._vector_field_projection(vector) - elif len(vector) == self.size_3d: # mode == 'scalar' - return self._scalar_field_projection(vector) - else: - raise AssertionError('Vector size has to be suited either for ' - 'vector- or scalar-field-projection!') - - def jac_T_dot(self, vector): - """Multiply a `vector` with the transp. jacobi matrix of this :class:`~.Projector` object. - - Parameters - ---------- - vector : :class:`~numpy.ndarray` (N=1) - Vector containing the field which should be projected. Must have the same or 2 times - the size of `size_2d` of the projector for scalar and vector projection, respectively. - - Returns - ------- - proj_vector : :class:`~numpy.ndarray` (N=1) - Vector containing the multiplication of the input with the transposed jacobi matrix - of the :class:`~.Projector` object. - - """ - if len(vector) == 2 * self.size_2d: # mode == 'vector' - return self._vector_field_projection_T(vector) - elif len(vector) == self.size_2d: # mode == 'scalar' - return self._scalar_field_projection_T(vector) - else: - raise AssertionError('Vector size has to be suited either for ' - 'vector- or scalar-field-projection!') - - def save(self, filename, overwrite=True): - """Saves the projector as an HDF5 file. - - Parameters - ---------- - filename: str - Name of the file which the phasemap is saved into. HDF5 files are supported. - overwrite: bool, optional - If True (default), an existing file will be overwritten, if False, this - (silently!) does nothing. - """ - from .file_io.io_projector import save_projector - save_projector(self, filename, overwrite) - - def get_info(self, verbose): - """Get specific information about the projector as a string. - - Parameters - ---------- - verbose: boolean, optional - If this is true, the text looks prettier (maybe using latex). Default is False for the - use in file names and such. - - Returns - ------- - info : string - Information about the projector as a string, e.g. for the use in plot titles. - - """ - return 'Base projector' - - -class RotTiltProjector(Projector): - """Class representing a projection function with a rotation around z followed by tilt around x. - - The :class:`~.XTiltProjector` class represents a projection function for a 3-dimensional - vector- or scalar field onto a 2-dimensional grid, which is a concrete subclass of - :class:`~.Projector`. - - Attributes - ---------- - dim : tuple (N=3) - Dimensions (z, y, x) of the magnetization distribution. - rotation : float - Angle in `rad` describing the rotation around the z-axis before the tilt is happening. - tilt : float - Angle in `rad` describing the tilt of the beam direction relative to the x-axis. - dim_uv : tuple (N=2), optional - Dimensions (v, u) of the projection. If not set defaults to the (y, x)-dimensions. - subcount : int (optional) - Number of subpixels along one axis. This is used to create the lookup table which uses - a discrete subgrid to estimate the impact point of a voxel onto a pixel and the weight on - all surrounding pixels. Default is 11 (odd numbers provide a symmetric center). - - """ - - _log = logging.getLogger(__name__ + '.RotTiltProjector') - - def __init__(self, dim, rotation, tilt, dim_uv=None, subcount=11, verbose=False): - self._log.debug('Calling __init__') - self.rotation = rotation - self.tilt = tilt - # Determine dimensions: - dim_z, dim_y, dim_x = dim - center = (dim_z / 2., dim_y / 2., dim_x / 2.) - if dim_uv is None: - dim_v = max(dim_x, dim_y) # first rotate around z-axis (take x and y into account) - dim_u = max(dim_v, dim_z) # then tilt around x-axis (now z matters, too) - dim_uv = (dim_v, dim_u) - dim_v, dim_u = dim_uv - # Creating coordinate list of all voxels: - voxels = list(itertools.product(range(dim_z), range(dim_y), range(dim_x))) - # Calculate vectors to voxels relative to rotation center: - voxel_vecs = (np.asarray(voxels) + 0.5 - np.asarray(center)).T - # Create tilt, rotation and combined quaternion, careful: Quaternion(w,x,y,z), not (z,y,x): - quat_x = Quaternion.from_axisangle((1, 0, 0), tilt) # Tilt around x-axis - quat_z = Quaternion.from_axisangle((0, 0, 1), rotation) # Rotate around z-axis - quat = quat_x * quat_z # Combined quaternion (first rotate around z, then tilt around x) - # Calculate impact positions on the projected pixel coordinate grid (flip because quat.): - impacts = np.flipud(quat.matrix[:2, :].dot(np.flipud(voxel_vecs))) # only care for x/y - impacts[1, :] += dim_u / 2. # Shift back to normal indices - impacts[0, :] += dim_v / 2. # Shift back to normal indices - # Calculate equivalence radius: - R = (3 / (4 * np.pi)) ** (1 / 3.) - # Prepare weight matrix calculation: - rows = [] # 2D projection - columns = [] # 3D distribution - data = [] # weights - # Create 4D lookup table (1&2: which neighbour weight?, 3&4: which subpixel is hit?) - weight_lookup = self._create_weight_lookup(subcount, R) - # Go over all voxels: - disable = not verbose - for i, voxel in enumerate(tqdm(voxels, disable=disable, leave=False, - desc='Set up projector')): - column_index = voxel[0] * dim_y * dim_x + voxel[1] * dim_x + voxel[2] - remainder, impact = np.modf(impacts[:, i]) # split index of impact and remainder! - sub_pixel = (remainder * subcount).astype(dtype=np.int) # sub_pixel inside impact px. - # Go over all influenced pixels (impact and neighbours, indices are [0, 1, 2]!): - for px_ind in list(itertools.product(range(3), range(3))): - # Pixel indices influenced by the impact (px_ind-1 to center them around impact): - pixel = (impact + np.array(px_ind) - 1).astype(dtype=np.int) - # Check if pixel is out of bound: - if 0 <= pixel[0] < dim_uv[0] and 0 <= pixel[1] < dim_uv[1]: - # Lookup weight in 4-dimensional lookup table! - weight = weight_lookup[px_ind[0], px_ind[1], sub_pixel[0], sub_pixel[1]] - # Only write into sparse matrix if weight is not zero: - if weight != 0.: - row_index = pixel[0] * dim_u + pixel[1] - columns.append(column_index) - rows.append(row_index) - data.append(weight) - # Calculate weight matrix and coefficients for jacobi matrix: - shape = (np.prod(dim_uv), np.prod(dim)) - weights = csr_matrix(coo_matrix((data, (rows, columns)), shape=shape)) - # Calculate coefficients by rotating unity matrix (unit vectors, (x,y,z)): - coeff = quat.matrix[:2, :].dot(np.eye(3)) - super().__init__(dim, dim_uv, weights, coeff) - self._log.debug('Created ' + str(self)) - - @staticmethod - def _create_weight_lookup(subcount, R): - s = subcount - Rz = R * s # Radius in subgrid units - dim_zoom = (3 * s, 3 * s) # Dimensions of the subgrid, (3, 3) because of neighbour count! - cent_zoom = (np.asarray(dim_zoom) / 2.).astype(dtype=np.int) # Center of the subgrid - y, x = np.indices(dim_zoom) - y -= cent_zoom[0] - x -= cent_zoom[1] - # Calculate projected thickness of an equivalence sphere (normed!): - d = np.where(np.hypot(x, y) <= Rz, Rz ** 2 - x ** 2 - y ** 2, 0) - d = np.sqrt(d) - d /= d.sum() - # Create lookup table (4D): - lookup = np.zeros((3, 3, s, s)) - # Go over all 9 pixels (center and neighbours): - for pixel in list(itertools.product(range(3), range(3))): - pixel_lb = np.array(pixel) * s # Convert to subgrid, hit bottom left of the pixel! - # Go over all subpixels in the center that can be hit: - for sub_pixel in list(itertools.product(range(s), range(s))): - shift = np.array(sub_pixel) - np.array((s // 2, s // 2)) # relative to center! - lb = pixel_lb - shift # Shift summing zone according to hit subpixel! - # Make sure, that the summing zone is in bounds (otherwise correct accordingly): - lb = np.where(lb >= 0, lb, [0, 0]) - tr = np.where(lb < 3 * s, lb + np.array((s, s)), [3 * s, 3 * s]) - # Calculate weight by summing over the summing zone: - weight = d[lb[0]:tr[0], lb[1]:tr[1]].sum() - lookup[pixel[0], pixel[1], sub_pixel[0], sub_pixel[1]] = weight - return lookup - - def get_info(self, verbose=False): - """Get specific information about the projector as a string. - - Parameters - ---------- - verbose: boolean, optional - If this is true, the text looks prettier (maybe using latex). Default is False for the - use in file names and such. - - Returns - ------- - info : string - Information about the projector as a string, e.g. for the use in plot titles. - - """ - theta_ang = int(np.round(self.rotation * 180 / pi)) - phi_ang = int(np.round(self.tilt * 180 / pi)) - if verbose: - return u'$\\theta = {:d}$°, $\phi = {:d}$°'.format(theta_ang, phi_ang) - else: - return u'theta={:d}_phi={:d}°'.format(theta_ang, phi_ang) - - -class XTiltProjector(Projector): - """Class representing a projection function with a tilt around the x-axis. - - The :class:`~.XTiltProjector` class represents a projection function for a 3-dimensional - vector- or scalar field onto a 2-dimensional grid, which is a concrete subclass of - :class:`~.Projector`. - - Attributes - ---------- - dim : tuple (N=3) - Dimensions (z, y, x) of the magnetization distribution. - tilt : float - Angle in `rad` describing the tilt of the beam direction relative to the x-axis. - dim_uv : tuple (N=2), optional - Dimensions (v, u) of the projection. If not set defaults to the (y, x)-dimensions. - - """ - - _log = logging.getLogger(__name__ + '.XTiltProjector') - - def __init__(self, dim, tilt, dim_uv=None, verbose=False): - self._log.debug('Calling __init__') - self.tilt = tilt - # Set starting variables: - # length along projection (proj, z), perpendicular (perp, y) and rotation (rot, x) axis: - dim_proj, dim_perp, dim_rot = dim - if dim_uv is None: - dim_uv = (max(dim_perp, dim_proj), dim_rot) # x-y-plane - dim_v, dim_u = dim_uv # y, x - assert dim_v >= dim_perp and dim_u >= dim_rot, 'Projected dimensions are too small!' - # Creating coordinate list of all voxels (for one slice): - voxels = list(itertools.product(range(dim_proj), range(dim_perp))) # z-y-plane - # Calculate positions along the projected pixel coordinate system: - center = (dim_proj / 2., dim_perp / 2.) - positions = self._get_position(voxels, center, tilt, dim_v) - # Calculate weight-matrix: - r = 1 / np.sqrt(np.pi) # radius of the voxel circle - rho = 0.5 / r - row = [] - col = [] - data = [] - # One slice: - disable = not verbose - for i, voxel in enumerate(tqdm(voxels, disable=disable, leave=False, - desc='Set up projector')): - impacts = self._get_impact(positions[i], r, dim_v) # impact along projected y-axis - voxel_index = voxel[0] * dim_rot * dim_perp + voxel[1] * dim_rot # 0: z, 1: y - for impact in impacts: - impact_index = impact * dim_u + (dim_u - dim_rot) // 2 - distance = np.abs(impact + 0.5 - positions[i]) - delta = distance / r - col.append(voxel_index) - row.append(impact_index) - data.append(self._get_weight(delta, rho)) - # All other slices (along x): - data = np.tile(data, dim_rot) - columns = np.tile(col, dim_rot) - rows = np.tile(row, dim_rot) - addition = np.repeat(np.arange(dim_rot), len(row)) - columns += addition - rows += addition - # Calculate weight matrix and coefficients for jacobi matrix: - shape = (np.prod(dim_uv), np.prod(dim)) - weight = csr_matrix(coo_matrix((data, (rows, columns)), shape=shape)) - coeff = [[1, 0, 0], [0, np.cos(tilt), np.sin(tilt)]] - super().__init__(dim, dim_uv, weight, coeff) - self._log.debug('Created ' + str(self)) - - @staticmethod - def _get_position(points, center, tilt, size): - point_vecs = np.asarray(points) + 0.5 - np.asarray(center) # vectors pointing to points - direc_vec = np.array((np.cos(tilt), -np.sin(tilt))) # vector pointing along projection - distances = np.cross(direc_vec, point_vecs) # here (special case): divisor is one! - distances += size / 2. # Shift to the center of the projection - return distances - - @staticmethod - def _get_impact(pos, r, size): - return [x for x in np.arange(np.floor(pos - r), np.floor(pos + r) + 1, dtype=int) - if 0 <= x < size] - - @staticmethod - def _get_weight(delta, rho): # use circles to represent the voxels - lo, up = delta - rho, delta + rho - # Upper boundary: - if up >= 1: - w_up = 0.5 - else: - w_up = (up * np.sqrt(1 - up ** 2) + np.arctan(up / np.sqrt(1 - up ** 2))) / pi - # Lower boundary: - if lo <= -1: - w_lo = -0.5 - else: - w_lo = (lo * np.sqrt(1 - lo ** 2) + np.arctan(lo / np.sqrt(1 - lo ** 2))) / pi - return w_up - w_lo - - def get_info(self, verbose=False): - """Get specific information about the projector as a string. - - Parameters - ---------- - verbose: boolean, optional - If this is true, the text looks prettier (maybe using latex). Default is False for the - use in file names and such. - - Returns - ------- - info : string - Information about the projector as a string, e.g. for the use in plot titles. - - """ - if verbose: - return u'x-tilt: $\phi = {:d}$°'.format(int(np.round(self.tilt * 180 / pi))) - else: - return u'xtilt_phi={:d}°'.format(int(np.round(self.tilt * 180 / pi))) - - -class YTiltProjector(Projector): - """Class representing a projection function with a tilt around the y-axis. - - The :class:`~.YTiltProjector` class represents a projection function for a 3-dimensional - vector- or scalar field onto a 2-dimensional grid, which is a concrete subclass of - :class:`~.Projector`. - - Attributes - ---------- - dim : tuple (N=3) - Dimensions (z, y, x) of the magnetization distribution. - tilt : float - Angle in `rad` describing the tilt of the beam direction relative to the y-axis. - dim_uv : tuple (N=2), optional - Dimensions (v, u) of the projection. If not set defaults to the (y, x)-dimensions. - - """ - - _log = logging.getLogger(__name__ + '.YTiltProjector') - - def __init__(self, dim, tilt, dim_uv=None, verbose=False): - self._log.debug('Calling __init__') - self.tilt = tilt - # Set starting variables: - # length along projection (proj, z), rotation (rot, y) and perpendicular (perp, x) axis: - dim_proj, dim_rot, dim_perp = dim - if dim_uv is None: - dim_uv = (dim_rot, max(dim_perp, dim_proj)) # x-y-plane - dim_v, dim_u = dim_uv # y, x - assert dim_v >= dim_rot and dim_u >= dim_perp, 'Projected dimensions are too small!' - # Creating coordinate list of all voxels (for one slice): - voxels = list(itertools.product(range(dim_proj), range(dim_perp))) # z-x-plane - # Calculate positions along the projected pixel coordinate system: - center = (dim_proj / 2., dim_perp / 2.) - positions = self._get_position(voxels, center, tilt, dim_u) - # Calculate weight-matrix: - r = 1 / np.sqrt(np.pi) # radius of the voxel circle - rho = 0.5 / r - row = [] - col = [] - data = [] - # One slice: - disable = not verbose - for i, voxel in enumerate(tqdm(voxels, disable=disable, leave=False, - desc='Set up projector')): - impacts = self._get_impact(positions[i], r, dim_u) # impact along projected x-axis - voxel_index = voxel[0] * dim_perp * dim_rot + voxel[1] # 0: z, 1: x - for impact in impacts: - impact_index = impact + (dim_v - dim_rot) // 2 * dim_u - distance = np.abs(impact + 0.5 - positions[i]) - delta = distance / r - col.append(voxel_index) - row.append(impact_index) - data.append(self._get_weight(delta, rho)) - # All other slices (along y): - data = np.tile(data, dim_rot) - columns = np.tile(col, dim_rot) - rows = np.tile(row, dim_rot) - addition = np.repeat(np.arange(dim_rot), len(row)) - columns += addition * dim_perp - rows += addition * dim_u - # Calculate weight matrix and coefficients for jacobi matrix: - shape = (np.prod(dim_uv), np.prod(dim)) - weight = csr_matrix(coo_matrix((data, (rows, columns)), shape=shape)) - coeff = [[np.cos(tilt), 0, np.sin(tilt)], [0, 1, 0]] - super().__init__(dim, dim_uv, weight, coeff) - self._log.debug('Created ' + str(self)) - - @staticmethod - def _get_position(points, center, tilt, size): - point_vecs = np.asarray(points) + 0.5 - np.asarray(center) # vectors pointing to points - direc_vec = np.array((np.cos(tilt), -np.sin(tilt))) # vector pointing along projection - distances = np.cross(direc_vec, point_vecs) # here (special case): divisor is one! - distances += size / 2. # Shift to the center of the projection - return distances - - @staticmethod - def _get_impact(pos, r, size): - return [x for x in np.arange(np.floor(pos - r), np.floor(pos + r) + 1, dtype=int) - if 0 <= x < size] - - @staticmethod - def _get_weight(delta, rho): # use circles to represent the voxels - lo, up = delta - rho, delta + rho - # Upper boundary: - if up >= 1: - w_up = 0.5 - else: - w_up = (up * np.sqrt(1 - up ** 2) + np.arctan(up / np.sqrt(1 - up ** 2))) / pi - # Lower boundary: - if lo <= -1: - w_lo = -0.5 - else: - w_lo = (lo * np.sqrt(1 - lo ** 2) + np.arctan(lo / np.sqrt(1 - lo ** 2))) / pi - return w_up - w_lo - - def get_info(self, verbose=False): - """Get specific information about the projector as a string. - - Parameters - ---------- - verbose: boolean, optional - If this is true, the text looks prettier (maybe using latex). Default is False for the - use in file names and such. - - Returns - ------- - info : string - Information about the projector as a string, e.g. for the use in plot titles. - - """ - if verbose: - return u'y-tilt: $\phi = {:d}$°'.format(int(np.round(self.tilt * 180 / pi))) - else: - return u'ytilt_phi={:d}°'.format(int(np.round(self.tilt * 180 / pi))) - - -class SimpleProjector(Projector): - """Class representing a projection function along one of the major axes. - - The :class:`~.SimpleProjector` class represents a projection function for a 3-dimensional - vector- or scalar field onto a 2-dimensional grid, which is a concrete subclass of - :class:`~.Projector`. - - Attributes - ---------- - dim : tuple (N=3) - Dimensions (z, y, x) of the magnetization distribution. - axis : {'z', 'y', 'x'}, optional - Main axis along which the magnetic distribution is projected (given as a string). Defaults - to the z-axis. - dim_uv : tuple (N=2), optional - Dimensions (v, u) of the projection. If not set it uses the 3D default dimensions. - - """ - - _log = logging.getLogger(__name__ + '.SimpleProjector') - AXIS_DICT = {'z': (0, 1, 2), 'y': (1, 0, 2), 'x': (2, 1, 0)} # (0:z, 1:y, 2:x) -> (proj, v, u) - - # coordinate switch for 'x': u, v --> z, y (not y, z!)! - - def __init__(self, dim, axis='z', dim_uv=None, verbose=False): - self._log.debug('Calling __init__') - assert axis in {'z', 'y', 'x'}, 'Projection axis has to be x, y or z (given as a string)!' - self.axis = axis - proj, v, u = self.AXIS_DICT[axis] - dim_proj, dim_v, dim_u = dim[proj], dim[v], dim[u] - dim_z, dim_y, dim_x = dim - size_2d = dim_u * dim_v - size_3d = np.prod(dim) - data = np.repeat(1, size_3d) # size_3d ones in the matrix (each voxel is projected) - indptr = np.arange(0, size_3d + 1, dim_proj) # each row has dim_proj 1-entries - if axis == 'z': - self._log.debug('Projecting along the z-axis') - coeff = [[1, 0, 0], [0, 1, 0]] - indices = np.array([np.arange(row, size_3d, size_2d) - for row in range(size_2d)]).reshape(-1) - elif axis == 'y': - self._log.debug('Projection along the y-axis') - coeff = [[1, 0, 0], [0, 0, 1]] - indices = np.array( - [np.arange(row % dim_x, dim_x * dim_y, dim_x) + row // dim_x * dim_x * dim_y - for row in range(size_2d)]).reshape(-1) - elif axis == 'x': - self._log.debug('Projection along the x-axis') - coeff = [[0, 0, 1], [0, 1, 0]] # Caution, coordinate switch: u, v --> z, y (not y, z!) - indices = np.array( - [np.arange(dim_x) + (row % dim_z) * dim_x * dim_y + row // dim_z * dim_x - for row in range(size_2d)]).reshape(-1) - else: - raise ValueError('{} is not a valid axis parameter (use x, y or z)!'.format(axis)) - if dim_uv is not None: - indptr = list(indptr) # convert to use insert() and append() - # Calculate padding: - d_v = (np.floor((dim_uv[0] - dim_v) / 2).astype(int), - np.ceil((dim_uv[0] - dim_v) / 2).astype(int)) - d_u = (np.floor((dim_uv[1] - dim_u) / 2).astype(int), - np.ceil((dim_uv[1] - dim_u) / 2).astype(int)) - indptr.extend([indptr[-1]] * d_v[1] * dim_uv[1]) # add empty lines at the end - for i in np.arange(dim_v, 0, -1): # all slices in between - up, lo = i * dim_u, (i - 1) * dim_u # upper / lower slice end - indptr[up:up] = [indptr[up]] * d_u[1] # end of the slice - indptr[lo:lo] = [indptr[lo]] * d_u[0] # start of the slice - indptr = [0] * d_v[0] * dim_uv[1] + indptr # insert empty rows at the beginning - else: # Make sure dim_uv is defined (used for the assertion) - dim_uv = dim_v, dim_u - assert dim_uv[0] >= dim_v and dim_uv[1] >= dim_u, 'Projected dimensions are too small!' - # Create weight-matrix: - shape = (np.prod(dim_uv), np.prod(dim)) - weight = csr_matrix((data, indices, indptr), shape=shape) - super().__init__(dim, dim_uv, weight, coeff) - self._log.debug('Created ' + str(self)) - - def get_info(self, verbose=False): - """Get specific information about the projector as a string. - - Parameters - ---------- - verbose: boolean, optional - If this is true, the text looks prettier (maybe using latex). Default is False for the - use in file names and such. - - Returns - ------- - info : string - Information about the projector as a string, e.g. for the use in plot titles. - - """ - if verbose: - return 'projected along {}-axis'.format(self.axis) - else: - return '{}axis'.format(self.axis) +# -*- coding: utf-8 -*- +# Copyright 2014 by Forschungszentrum Juelich GmbH +# Author: J. Caron +# +"""This module provides the abstract base class :class:`~.Projector` and concrete subclasses for +projections of vector and scalar fields.""" + +import itertools +import logging + +try: + if type(get_ipython()).__name__ == 'ZMQInteractiveShell': # IPython Notebook! + from tqdm import tqdm_notebook as tqdm + else: # IPython, but not a Notebook (e.g. terminal) + from tqdm import tqdm +except NameError: + from tqdm import tqdm + +import numpy as np +from numpy import pi +from scipy.sparse import coo_matrix, csr_matrix + +from pyramid.fielddata import VectorData, ScalarData +from pyramid.quaternion import Quaternion + +__all__ = ['RotTiltProjector', 'XTiltProjector', 'YTiltProjector', 'SimpleProjector'] + + +class Projector(object): + """Base class representing a projection function. + + The :class:`~.Projector` class represents a projection function for a 3-dimensional + vector- or scalar field onto a 2-dimensional grid. :class:`~.Projector` is an abstract base + class and provides a unified interface which should be subclassed with a custom + :func:`__init__` function, which should call the parent :func:`__init__` method. Concrete + subclasses can be called as a function and take a `vector` as argument which contains the + 3-dimensional field. The output is the projected field, given as a `vector`. Depending on the + length of the input and the given dimensions `dim` at construction time, vector or scalar + projection is choosen intelligently. + + Attributes + ---------- + dim : tuple (N=3) + Dimensions (z, y, x) of the magnetization distribution. + dim_uv : tuple (N=2) + Dimensions (v, u) of the projected grid. + size_3d : int + Number of voxels of the 3-dimensional grid. + size_2d : int + Number of pixels of the 2-dimensional projected grid. + weight : :class:`~scipy.sparse.csr_matrix` (N=2) + The weight matrix containing the weighting coefficients for the 3D to 2D mapping. + coeff : list (N=2) + List containing the six weighting coefficients describing the influence of the 3 components + of a 3-dimensional vector field on the 2 projected components. + m: int + Size of the image space. + n: int + Size of the input space. + sparsity : float + Measures the sparsity of the weighting (not the complete one!), 1 means completely sparse! + + """ + + _log = logging.getLogger(__name__ + '.Projector') + + @property + def sparsity(self): + """The sparsity of the projector weight matrix.""" + return 1. - len(self.weight.data) / np.prod(self.weight.shape) + + def __init__(self, dim, dim_uv, weight, coeff): + self._log.debug('Calling __init__') + self.dim = tuple(dim) + self.dim_uv = tuple(dim_uv) + self.weight = weight + self.coeff = coeff + self.size_2d, self.size_3d = weight.shape + self.n = 3 * np.prod(dim) + self.m = 2 * np.prod(dim_uv) + self._log.debug('Created ' + str(self)) + + def __repr__(self): + self._log.debug('Calling __repr__') + return '%s(dim=%r, dim_uv=%r, weight=%r, coeff=%r)' % \ + (self.__class__, self.dim, self.dim_uv, self.weight, self.coeff) + + def __str__(self): + self._log.debug('Calling __str__') + return 'Projector(dim=%s, dim_uv=%s, coeff=%s)' % (self.dim, self.dim_uv, self.coeff) + + def __call__(self, field_data): + if isinstance(field_data, VectorData): + field_empty = np.zeros((3, 1) + self.dim_uv, dtype=field_data.field.dtype) + field_data_proj = VectorData(field_data.a, field_empty) + field_proj = self.jac_dot(field_data.field_vec).reshape((2,) + self.dim_uv) + field_data_proj.field[0:2, 0, ...] = field_proj + elif isinstance(field_data, ScalarData): + field_empty = np.zeros((1,) + self.dim_uv, dtype=field_data.field.dtype) + field_data_proj = ScalarData(field_data.a, field_empty) + field_proj = self.jac_dot(field_data.field_vec).reshape(self.dim_uv) + field_data_proj.field[0, ...] = field_proj + else: + raise TypeError('Input is neither of type VectorData or ScalarData') + return field_data_proj + + def _vector_field_projection(self, vector): + result = np.zeros(2 * self.size_2d, dtype=vector.dtype) + # Go over all possible component projections (z, y, x) to (u, v): + vec_x, vec_y, vec_z = np.split(vector, 3) + vec_x_weighted = self.weight.dot(vec_x) + vec_y_weighted = self.weight.dot(vec_y) + vec_z_weighted = self.weight.dot(vec_z) + slice_u = slice(0, self.size_2d) + slice_v = slice(self.size_2d, 2 * self.size_2d) + if self.coeff[0][0] != 0: # x to u + result[slice_u] += self.coeff[0][0] * vec_x_weighted + if self.coeff[0][1] != 0: # y to u + result[slice_u] += self.coeff[0][1] * vec_y_weighted + if self.coeff[0][2] != 0: # z to u + result[slice_u] += self.coeff[0][2] * vec_z_weighted + if self.coeff[1][0] != 0: # x to v + result[slice_v] += self.coeff[1][0] * vec_x_weighted + if self.coeff[1][1] != 0: # y to v + result[slice_v] += self.coeff[1][1] * vec_y_weighted + if self.coeff[1][2] != 0: # z to v + result[slice_v] += self.coeff[1][2] * vec_z_weighted + return result + + def _vector_field_projection_T(self, vector): + result = np.zeros(3 * self.size_3d) + # Go over all possible component projections (u, v) to (z, y, x): + vec_u, vec_v = np.split(vector, 2) + vec_u_weighted = self.weight.T.dot(vec_u) + vec_v_weighted = self.weight.T.dot(vec_v) + slice_x = slice(0, self.size_3d) + slice_y = slice(self.size_3d, 2 * self.size_3d) + slice_z = slice(2 * self.size_3d, 3 * self.size_3d) + if self.coeff[0][0] != 0: # u to x + result[slice_x] += self.coeff[0][0] * vec_u_weighted + if self.coeff[0][1] != 0: # u to y + result[slice_y] += self.coeff[0][1] * vec_u_weighted + if self.coeff[0][2] != 0: # u to z + result[slice_z] += self.coeff[0][2] * vec_u_weighted + if self.coeff[1][0] != 0: # v to x + result[slice_x] += self.coeff[1][0] * vec_v_weighted + if self.coeff[1][1] != 0: # v to y + result[slice_y] += self.coeff[1][1] * vec_v_weighted + if self.coeff[1][2] != 0: # v to z + result[slice_z] += self.coeff[1][2] * vec_v_weighted + return result + + def _scalar_field_projection(self, vector): + self._log.debug('Calling _scalar_field_projection') + return np.array(self.weight.dot(vector)) + + def _scalar_field_projection_T(self, vector): + self._log.debug('Calling _scalar_field_projection_T') + return np.array(self.weight.T.dot(vector)) + + def jac_dot(self, vector): + """Multiply a `vector` with the jacobi matrix of this :class:`~.Projector` object. + + Parameters + ---------- + vector : :class:`~numpy.ndarray` (N=1) + Vector containing the field which should be projected. Must have the same or 3 times + the size of `size_3d` of the projector for scalar and vector projection, respectively. + + Returns + ------- + proj_vector : :class:`~numpy.ndarray` (N=1) + Vector containing the projected field of the 2-dimensional grid. The length is + always`size_2d`. + + """ + if len(vector) == 3 * self.size_3d: # mode == 'vector' + return self._vector_field_projection(vector) + elif len(vector) == self.size_3d: # mode == 'scalar' + return self._scalar_field_projection(vector) + else: + raise AssertionError('Vector size has to be suited either for ' + 'vector- or scalar-field-projection!') + + def jac_T_dot(self, vector): + """Multiply a `vector` with the transp. jacobi matrix of this :class:`~.Projector` object. + + Parameters + ---------- + vector : :class:`~numpy.ndarray` (N=1) + Vector containing the field which should be projected. Must have the same or 2 times + the size of `size_2d` of the projector for scalar and vector projection, respectively. + + Returns + ------- + proj_vector : :class:`~numpy.ndarray` (N=1) + Vector containing the multiplication of the input with the transposed jacobi matrix + of the :class:`~.Projector` object. + + """ + if len(vector) == 2 * self.size_2d: # mode == 'vector' + return self._vector_field_projection_T(vector) + elif len(vector) == self.size_2d: # mode == 'scalar' + return self._scalar_field_projection_T(vector) + else: + raise AssertionError('Vector size has to be suited either for ' + 'vector- or scalar-field-projection!') + + def save(self, filename, overwrite=True): + """Saves the projector as an HDF5 file. + + Parameters + ---------- + filename: str + Name of the file which the phasemap is saved into. HDF5 files are supported. + overwrite: bool, optional + If True (default), an existing file will be overwritten, if False, this + (silently!) does nothing. + """ + from .file_io.io_projector import save_projector + save_projector(self, filename, overwrite) + + def get_info(self, verbose): + """Get specific information about the projector as a string. + + Parameters + ---------- + verbose: boolean, optional + If this is true, the text looks prettier (maybe using latex). Default is False for the + use in file names and such. + + Returns + ------- + info : string + Information about the projector as a string, e.g. for the use in plot titles. + + """ + return 'Base projector' + + +class RotTiltProjector(Projector): + """Class representing a projection function with a rotation around z followed by tilt around x. + + The :class:`~.XTiltProjector` class represents a projection function for a 3-dimensional + vector- or scalar field onto a 2-dimensional grid, which is a concrete subclass of + :class:`~.Projector`. + + Attributes + ---------- + dim : tuple (N=3) + Dimensions (z, y, x) of the magnetization distribution. + rotation : float + Angle in `rad` describing the rotation around the z-axis before the tilt is happening. + tilt : float + Angle in `rad` describing the tilt of the beam direction relative to the x-axis. + dim_uv : tuple (N=2), optional + Dimensions (v, u) of the projection. If not set defaults to the (y, x)-dimensions. + subcount : int (optional) + Number of subpixels along one axis. This is used to create the lookup table which uses + a discrete subgrid to estimate the impact point of a voxel onto a pixel and the weight on + all surrounding pixels. Default is 11 (odd numbers provide a symmetric center). + + """ + + _log = logging.getLogger(__name__ + '.RotTiltProjector') + + def __init__(self, dim, rotation, tilt, dim_uv=None, subcount=11, verbose=False): + self._log.debug('Calling __init__') + self.rotation = rotation + self.tilt = tilt + # Determine dimensions: + dim_z, dim_y, dim_x = dim + center = (dim_z / 2., dim_y / 2., dim_x / 2.) + if dim_uv is None: + dim_v = max(dim_x, dim_y) # first rotate around z-axis (take x and y into account) + dim_u = max(dim_v, dim_z) # then tilt around x-axis (now z matters, too) + dim_uv = (dim_v, dim_u) + dim_v, dim_u = dim_uv + # Creating coordinate list of all voxels: + voxels = list(itertools.product(range(dim_z), range(dim_y), range(dim_x))) + # Calculate vectors to voxels relative to rotation center: + voxel_vecs = (np.asarray(voxels) + 0.5 - np.asarray(center)).T + # Create tilt, rotation and combined quaternion, careful: Quaternion(w,x,y,z), not (z,y,x): + quat_x = Quaternion.from_axisangle((1, 0, 0), tilt) # Tilt around x-axis + quat_z = Quaternion.from_axisangle((0, 0, 1), rotation) # Rotate around z-axis + quat = quat_x * quat_z # Combined quaternion (first rotate around z, then tilt around x) + # Calculate impact positions on the projected pixel coordinate grid (flip because quat.): + impacts = np.flipud(quat.matrix[:2, :].dot(np.flipud(voxel_vecs))) # only care for x/y + impacts[1, :] += dim_u / 2. # Shift back to normal indices + impacts[0, :] += dim_v / 2. # Shift back to normal indices + # Calculate equivalence radius: + R = (3 / (4 * np.pi)) ** (1 / 3.) + # Prepare weight matrix calculation: + rows = [] # 2D projection + columns = [] # 3D distribution + data = [] # weights + # Create 4D lookup table (1&2: which neighbour weight?, 3&4: which subpixel is hit?) + weight_lookup = self._create_weight_lookup(subcount, R) + # Go over all voxels: + disable = not verbose + for i, voxel in enumerate(tqdm(voxels, disable=disable, leave=False, + desc='Set up projector')): + column_index = voxel[0] * dim_y * dim_x + voxel[1] * dim_x + voxel[2] + remainder, impact = np.modf(impacts[:, i]) # split index of impact and remainder! + sub_pixel = (remainder * subcount).astype(dtype=np.int) # sub_pixel inside impact px. + # Go over all influenced pixels (impact and neighbours, indices are [0, 1, 2]!): + for px_ind in list(itertools.product(range(3), range(3))): + # Pixel indices influenced by the impact (px_ind-1 to center them around impact): + pixel = (impact + np.array(px_ind) - 1).astype(dtype=np.int) + # Check if pixel is out of bound: + if 0 <= pixel[0] < dim_uv[0] and 0 <= pixel[1] < dim_uv[1]: + # Lookup weight in 4-dimensional lookup table! + weight = weight_lookup[px_ind[0], px_ind[1], sub_pixel[0], sub_pixel[1]] + # Only write into sparse matrix if weight is not zero: + if weight != 0.: + row_index = pixel[0] * dim_u + pixel[1] + columns.append(column_index) + rows.append(row_index) + data.append(weight) + # Calculate weight matrix and coefficients for jacobi matrix: + shape = (np.prod(dim_uv), np.prod(dim)) + weights = csr_matrix(coo_matrix((data, (rows, columns)), shape=shape)) + # Calculate coefficients by rotating unity matrix (unit vectors, (x,y,z)): + coeff = quat.matrix[:2, :].dot(np.eye(3)) + super().__init__(dim, dim_uv, weights, coeff) + self._log.debug('Created ' + str(self)) + + @staticmethod + def _create_weight_lookup(subcount, R): + s = subcount + Rz = R * s # Radius in subgrid units + dim_zoom = (3 * s, 3 * s) # Dimensions of the subgrid, (3, 3) because of neighbour count! + cent_zoom = (np.asarray(dim_zoom) / 2.).astype(dtype=np.int) # Center of the subgrid + y, x = np.indices(dim_zoom) + y -= cent_zoom[0] + x -= cent_zoom[1] + # Calculate projected thickness of an equivalence sphere (normed!): + d = np.where(np.hypot(x, y) <= Rz, Rz ** 2 - x ** 2 - y ** 2, 0) + d = np.sqrt(d) + d /= d.sum() + # Create lookup table (4D): + lookup = np.zeros((3, 3, s, s)) + # Go over all 9 pixels (center and neighbours): + for pixel in list(itertools.product(range(3), range(3))): + pixel_lb = np.array(pixel) * s # Convert to subgrid, hit bottom left of the pixel! + # Go over all subpixels in the center that can be hit: + for sub_pixel in list(itertools.product(range(s), range(s))): + shift = np.array(sub_pixel) - np.array((s // 2, s // 2)) # relative to center! + lb = pixel_lb - shift # Shift summing zone according to hit subpixel! + # Make sure, that the summing zone is in bounds (otherwise correct accordingly): + lb = np.where(lb >= 0, lb, [0, 0]) + tr = np.where(lb < 3 * s, lb + np.array((s, s)), [3 * s, 3 * s]) + # Calculate weight by summing over the summing zone: + weight = d[lb[0]:tr[0], lb[1]:tr[1]].sum() + lookup[pixel[0], pixel[1], sub_pixel[0], sub_pixel[1]] = weight + return lookup + + def get_info(self, verbose=False): + """Get specific information about the projector as a string. + + Parameters + ---------- + verbose: boolean, optional + If this is true, the text looks prettier (maybe using latex). Default is False for the + use in file names and such. + + Returns + ------- + info : string + Information about the projector as a string, e.g. for the use in plot titles. + + """ + theta_ang = int(np.round(self.rotation * 180 / pi)) + phi_ang = int(np.round(self.tilt * 180 / pi)) + if verbose: + return u'$\\theta = {:d}$°, $\phi = {:d}$°'.format(theta_ang, phi_ang) + else: + return u'theta={:d}_phi={:d}°'.format(theta_ang, phi_ang) + + +class XTiltProjector(Projector): + """Class representing a projection function with a tilt around the x-axis. + + The :class:`~.XTiltProjector` class represents a projection function for a 3-dimensional + vector- or scalar field onto a 2-dimensional grid, which is a concrete subclass of + :class:`~.Projector`. + + Attributes + ---------- + dim : tuple (N=3) + Dimensions (z, y, x) of the magnetization distribution. + tilt : float + Angle in `rad` describing the tilt of the beam direction relative to the x-axis. + dim_uv : tuple (N=2), optional + Dimensions (v, u) of the projection. If not set defaults to the (y, x)-dimensions. + + """ + + _log = logging.getLogger(__name__ + '.XTiltProjector') + + def __init__(self, dim, tilt, dim_uv=None, verbose=False): + self._log.debug('Calling __init__') + self.tilt = tilt + # Set starting variables: + # length along projection (proj, z), perpendicular (perp, y) and rotation (rot, x) axis: + dim_proj, dim_perp, dim_rot = dim + if dim_uv is None: + dim_uv = (max(dim_perp, dim_proj), dim_rot) # x-y-plane + dim_v, dim_u = dim_uv # y, x + assert dim_v >= dim_perp and dim_u >= dim_rot, 'Projected dimensions are too small!' + # Creating coordinate list of all voxels (for one slice): + voxels = list(itertools.product(range(dim_proj), range(dim_perp))) # z-y-plane + # Calculate positions along the projected pixel coordinate system: + center = (dim_proj / 2., dim_perp / 2.) + positions = self._get_position(voxels, center, tilt, dim_v) + # Calculate weight-matrix: + r = 1 / np.sqrt(np.pi) # radius of the voxel circle + rho = 0.5 / r + row = [] + col = [] + data = [] + # One slice: + disable = not verbose + for i, voxel in enumerate(tqdm(voxels, disable=disable, leave=False, + desc='Set up projector')): + impacts = self._get_impact(positions[i], r, dim_v) # impact along projected y-axis + voxel_index = voxel[0] * dim_rot * dim_perp + voxel[1] * dim_rot # 0: z, 1: y + for impact in impacts: + impact_index = impact * dim_u + (dim_u - dim_rot) // 2 + distance = np.abs(impact + 0.5 - positions[i]) + delta = distance / r + col.append(voxel_index) + row.append(impact_index) + data.append(self._get_weight(delta, rho)) + # All other slices (along x): + data = np.tile(data, dim_rot) + columns = np.tile(col, dim_rot) + rows = np.tile(row, dim_rot) + addition = np.repeat(np.arange(dim_rot), len(row)) + columns += addition + rows += addition + # Calculate weight matrix and coefficients for jacobi matrix: + shape = (np.prod(dim_uv), np.prod(dim)) + weight = csr_matrix(coo_matrix((data, (rows, columns)), shape=shape)) + coeff = [[1, 0, 0], [0, np.cos(tilt), np.sin(tilt)]] + super().__init__(dim, dim_uv, weight, coeff) + self._log.debug('Created ' + str(self)) + + @staticmethod + def _get_position(points, center, tilt, size): + point_vecs = np.asarray(points) + 0.5 - np.asarray(center) # vectors pointing to points + direc_vec = np.array((np.cos(tilt), -np.sin(tilt))) # vector pointing along projection + distances = np.cross(direc_vec, point_vecs) # here (special case): divisor is one! + distances += size / 2. # Shift to the center of the projection + return distances + + @staticmethod + def _get_impact(pos, r, size): + return [x for x in np.arange(np.floor(pos - r), np.floor(pos + r) + 1, dtype=int) + if 0 <= x < size] + + @staticmethod + def _get_weight(delta, rho): # use circles to represent the voxels + lo, up = delta - rho, delta + rho + # Upper boundary: + if up >= 1: + w_up = 0.5 + else: + w_up = (up * np.sqrt(1 - up ** 2) + np.arctan(up / np.sqrt(1 - up ** 2))) / pi + # Lower boundary: + if lo <= -1: + w_lo = -0.5 + else: + w_lo = (lo * np.sqrt(1 - lo ** 2) + np.arctan(lo / np.sqrt(1 - lo ** 2))) / pi + return w_up - w_lo + + def get_info(self, verbose=False): + """Get specific information about the projector as a string. + + Parameters + ---------- + verbose: boolean, optional + If this is true, the text looks prettier (maybe using latex). Default is False for the + use in file names and such. + + Returns + ------- + info : string + Information about the projector as a string, e.g. for the use in plot titles. + + """ + if verbose: + return u'x-tilt: $\phi = {:d}$°'.format(int(np.round(self.tilt * 180 / pi))) + else: + return u'xtilt_phi={:d}°'.format(int(np.round(self.tilt * 180 / pi))) + + +class YTiltProjector(Projector): + """Class representing a projection function with a tilt around the y-axis. + + The :class:`~.YTiltProjector` class represents a projection function for a 3-dimensional + vector- or scalar field onto a 2-dimensional grid, which is a concrete subclass of + :class:`~.Projector`. + + Attributes + ---------- + dim : tuple (N=3) + Dimensions (z, y, x) of the magnetization distribution. + tilt : float + Angle in `rad` describing the tilt of the beam direction relative to the y-axis. + dim_uv : tuple (N=2), optional + Dimensions (v, u) of the projection. If not set defaults to the (y, x)-dimensions. + + """ + + _log = logging.getLogger(__name__ + '.YTiltProjector') + + def __init__(self, dim, tilt, dim_uv=None, verbose=False): + self._log.debug('Calling __init__') + self.tilt = tilt + # Set starting variables: + # length along projection (proj, z), rotation (rot, y) and perpendicular (perp, x) axis: + dim_proj, dim_rot, dim_perp = dim + if dim_uv is None: + dim_uv = (dim_rot, max(dim_perp, dim_proj)) # x-y-plane + dim_v, dim_u = dim_uv # y, x + assert dim_v >= dim_rot and dim_u >= dim_perp, 'Projected dimensions are too small!' + # Creating coordinate list of all voxels (for one slice): + voxels = list(itertools.product(range(dim_proj), range(dim_perp))) # z-x-plane + # Calculate positions along the projected pixel coordinate system: + center = (dim_proj / 2., dim_perp / 2.) + positions = self._get_position(voxels, center, tilt, dim_u) + # Calculate weight-matrix: + r = 1 / np.sqrt(np.pi) # radius of the voxel circle + rho = 0.5 / r + row = [] + col = [] + data = [] + # One slice: + disable = not verbose + for i, voxel in enumerate(tqdm(voxels, disable=disable, leave=False, + desc='Set up projector')): + impacts = self._get_impact(positions[i], r, dim_u) # impact along projected x-axis + voxel_index = voxel[0] * dim_perp * dim_rot + voxel[1] # 0: z, 1: x + for impact in impacts: + impact_index = impact + (dim_v - dim_rot) // 2 * dim_u + distance = np.abs(impact + 0.5 - positions[i]) + delta = distance / r + col.append(voxel_index) + row.append(impact_index) + data.append(self._get_weight(delta, rho)) + # All other slices (along y): + data = np.tile(data, dim_rot) + columns = np.tile(col, dim_rot) + rows = np.tile(row, dim_rot) + addition = np.repeat(np.arange(dim_rot), len(row)) + columns += addition * dim_perp + rows += addition * dim_u + # Calculate weight matrix and coefficients for jacobi matrix: + shape = (np.prod(dim_uv), np.prod(dim)) + weight = csr_matrix(coo_matrix((data, (rows, columns)), shape=shape)) + coeff = [[np.cos(tilt), 0, np.sin(tilt)], [0, 1, 0]] + super().__init__(dim, dim_uv, weight, coeff) + self._log.debug('Created ' + str(self)) + + @staticmethod + def _get_position(points, center, tilt, size): + point_vecs = np.asarray(points) + 0.5 - np.asarray(center) # vectors pointing to points + direc_vec = np.array((np.cos(tilt), -np.sin(tilt))) # vector pointing along projection + distances = np.cross(direc_vec, point_vecs) # here (special case): divisor is one! + distances += size / 2. # Shift to the center of the projection + return distances + + @staticmethod + def _get_impact(pos, r, size): + return [x for x in np.arange(np.floor(pos - r), np.floor(pos + r) + 1, dtype=int) + if 0 <= x < size] + + @staticmethod + def _get_weight(delta, rho): # use circles to represent the voxels + lo, up = delta - rho, delta + rho + # Upper boundary: + if up >= 1: + w_up = 0.5 + else: + w_up = (up * np.sqrt(1 - up ** 2) + np.arctan(up / np.sqrt(1 - up ** 2))) / pi + # Lower boundary: + if lo <= -1: + w_lo = -0.5 + else: + w_lo = (lo * np.sqrt(1 - lo ** 2) + np.arctan(lo / np.sqrt(1 - lo ** 2))) / pi + return w_up - w_lo + + def get_info(self, verbose=False): + """Get specific information about the projector as a string. + + Parameters + ---------- + verbose: boolean, optional + If this is true, the text looks prettier (maybe using latex). Default is False for the + use in file names and such. + + Returns + ------- + info : string + Information about the projector as a string, e.g. for the use in plot titles. + + """ + if verbose: + return u'y-tilt: $\phi = {:d}$°'.format(int(np.round(self.tilt * 180 / pi))) + else: + return u'ytilt_phi={:d}°'.format(int(np.round(self.tilt * 180 / pi))) + + +class SimpleProjector(Projector): + """Class representing a projection function along one of the major axes. + + The :class:`~.SimpleProjector` class represents a projection function for a 3-dimensional + vector- or scalar field onto a 2-dimensional grid, which is a concrete subclass of + :class:`~.Projector`. + + Attributes + ---------- + dim : tuple (N=3) + Dimensions (z, y, x) of the magnetization distribution. + axis : {'z', 'y', 'x'}, optional + Main axis along which the magnetic distribution is projected (given as a string). Defaults + to the z-axis. + dim_uv : tuple (N=2), optional + Dimensions (v, u) of the projection. If not set it uses the 3D default dimensions. + + """ + + _log = logging.getLogger(__name__ + '.SimpleProjector') + AXIS_DICT = {'z': (0, 1, 2), 'y': (1, 0, 2), 'x': (2, 1, 0)} # (0:z, 1:y, 2:x) -> (proj, v, u) + + # coordinate switch for 'x': u, v --> z, y (not y, z!)! + + def __init__(self, dim, axis='z', dim_uv=None, verbose=False): + self._log.debug('Calling __init__') + assert axis in {'z', 'y', 'x'}, 'Projection axis has to be x, y or z (given as a string)!' + self.axis = axis + proj, v, u = self.AXIS_DICT[axis] + dim_proj, dim_v, dim_u = dim[proj], dim[v], dim[u] + dim_z, dim_y, dim_x = dim + size_2d = dim_u * dim_v + size_3d = np.prod(dim) + data = np.repeat(1, size_3d) # size_3d ones in the matrix (each voxel is projected) + indptr = np.arange(0, size_3d + 1, dim_proj) # each row has dim_proj 1-entries + if axis == 'z': + self._log.debug('Projecting along the z-axis') + coeff = [[1, 0, 0], [0, 1, 0]] + indices = np.array([np.arange(row, size_3d, size_2d) + for row in range(size_2d)]).reshape(-1) + elif axis == 'y': + self._log.debug('Projection along the y-axis') + coeff = [[1, 0, 0], [0, 0, 1]] + indices = np.array( + [np.arange(row % dim_x, dim_x * dim_y, dim_x) + row // dim_x * dim_x * dim_y + for row in range(size_2d)]).reshape(-1) + elif axis == 'x': + self._log.debug('Projection along the x-axis') + coeff = [[0, 0, 1], [0, 1, 0]] # Caution, coordinate switch: u, v --> z, y (not y, z!) + indices = np.array( + [np.arange(dim_x) + (row % dim_z) * dim_x * dim_y + row // dim_z * dim_x + for row in range(size_2d)]).reshape(-1) + else: + raise ValueError('{} is not a valid axis parameter (use x, y or z)!'.format(axis)) + if dim_uv is not None: + indptr = list(indptr) # convert to use insert() and append() + # Calculate padding: + d_v = (np.floor((dim_uv[0] - dim_v) / 2).astype(int), + np.ceil((dim_uv[0] - dim_v) / 2).astype(int)) + d_u = (np.floor((dim_uv[1] - dim_u) / 2).astype(int), + np.ceil((dim_uv[1] - dim_u) / 2).astype(int)) + indptr.extend([indptr[-1]] * d_v[1] * dim_uv[1]) # add empty lines at the end + for i in np.arange(dim_v, 0, -1): # all slices in between + up, lo = i * dim_u, (i - 1) * dim_u # upper / lower slice end + indptr[up:up] = [indptr[up]] * d_u[1] # end of the slice + indptr[lo:lo] = [indptr[lo]] * d_u[0] # start of the slice + indptr = [0] * d_v[0] * dim_uv[1] + indptr # insert empty rows at the beginning + else: # Make sure dim_uv is defined (used for the assertion) + dim_uv = dim_v, dim_u + assert dim_uv[0] >= dim_v and dim_uv[1] >= dim_u, 'Projected dimensions are too small!' + # Create weight-matrix: + shape = (np.prod(dim_uv), np.prod(dim)) + weight = csr_matrix((data, indices, indptr), shape=shape) + super().__init__(dim, dim_uv, weight, coeff) + self._log.debug('Created ' + str(self)) + + def get_info(self, verbose=False): + """Get specific information about the projector as a string. + + Parameters + ---------- + verbose: boolean, optional + If this is true, the text looks prettier (maybe using latex). Default is False for the + use in file names and such. + + Returns + ------- + info : string + Information about the projector as a string, e.g. for the use in plot titles. + + """ + if verbose: + return 'projected along {}-axis'.format(self.axis) + else: + return '{}axis'.format(self.axis) diff --git a/pyramid/quaternion.py b/pyramid/quaternion.py index 1267707683fd6663686002a4534e4b8ef6eb80d0..568e2676af8bf7c117729f3e4a801fadec3d55cd 100644 --- a/pyramid/quaternion.py +++ b/pyramid/quaternion.py @@ -1,140 +1,140 @@ -# -*- coding: utf-8 -*- -# Copyright 2014 by Forschungszentrum Juelich GmbH -# Author: J. Caron -# -"""This module provides the :class:`~.Quaternion` class which can be used for rotations.""" - -import logging - -import numpy as np - -__all__ = ['Quaternion'] - - -class Quaternion(object): - """Class representing a rotation expressed by a quaternion. - - A quaternion is a four-dimensional description of a rotation which can also be described by - a rotation vector (`v1`, `v2`, `v3`) and a rotation angle :math:`\theta`. The four components - are calculated to: - - .. math:: - - w = \cos(\theta/2) - x = v_1 \cdot \sin(\theta/2) - y = v_2 \cdot \sin(\theta/2) - z = v_3 \cdot \sin(\theta/2) - - Use the :func:`~.from_axisangle` and :func:`~.to_axisangle` to convert to axis-angle - representation and vice versa. Quaternions can be multiplied by other quaternions, which - results in a new rotation or with a vector, which results in a rotated vector. - - Attributes - ---------- - values : float - The four quaternion values `w`, `x`, `y`, `z`. - - """ - - NORM_TOLERANCE = 1E-6 - - _log = logging.getLogger(__name__ + '.Quaternion') - - @property - def conj(self): - """The conjugate of the quaternion, representing a tilt in opposite direction.""" - w, x, y, z = self.values - return Quaternion((w, -x, -y, -z)) - - @property - def matrix(self): - """The rotation matrix representation of the quaternion.""" - w, x, y, z = self.values - return np.array([[1 - 2 * (y ** 2 + z ** 2), 2 * (x * y - w * z), 2 * (x * z + w * y)], - [2 * (x * y + w * z), 1 - 2 * (x ** 2 + z ** 2), 2 * (y * z - w * x)], - [2 * (x * z - w * y), 2 * (y * z + w * x), 1 - 2 * (x ** 2 + y ** 2)]]) - - def __init__(self, values): - self._log.debug('Calling __init__') - self.values = values - self._normalize() - self._log.debug('Created ' + str(self)) - - def __mul__(self, other): # self * other - self._log.debug('Calling __mul__') - if isinstance(other, Quaternion): # Quaternion multiplication - return self.dot_quat(self, other) - elif len(other) == 3: # vector multiplication - q_vec = Quaternion((0,) + tuple(other)) - q = self.dot_quat(self.dot_quat(self, q_vec), self.conj) - return q.values[1:] - - def dot_quat(self, q1, q2): - """Multiply two :class:`~.Quaternion` objects to create a new one (always normalized). - - Parameters - ---------- - q1, q2 : :class:`~.Quaternion` - The quaternion which should be multiplied. - - Returns - ------- - quaternion : :class:`~.Quaternion` - The resulting quaternion. - - """ - self._log.debug('Calling dot_quat') - w1, x1, y1, z1 = q1.values - w2, x2, y2, z2 = q2.values - w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2 - x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2 - y = w1 * y2 + y1 * w2 + z1 * x2 - x1 * z2 - z = w1 * z2 + z1 * w2 + x1 * y2 - y1 * x2 - return Quaternion((w, x, y, z)) - - def _normalize(self): - self._log.debug('Calling _normalize') - mag2 = np.sum(n ** 2 for n in self.values) - if abs(mag2 - 1.0) > self.NORM_TOLERANCE: - mag = np.sqrt(mag2) - self.values = tuple(n / mag for n in self.values) - - @classmethod - def from_axisangle(cls, vector, theta): - """Create a quaternion from an axis-angle representation - - Parameters - ---------- - vector : :class:`~numpy.ndarray` (N=3) - Vector around which the rotation is executed. - theta : float - Rotation angle. - - Returns - ------- - quaternion : :class:`~.Quaternion` - The resulting quaternion. - - """ - cls._log.debug('Calling from_axisangle') - x, y, z = vector - theta /= 2. - w = np.cos(theta) - x *= np.sin(theta) - y *= np.sin(theta) - z *= np.sin(theta) - return cls((w, x, y, z)) - - def to_axisangle(self): - """Convert the quaternion to axis-angle-representation. - - Returns - ------- - vector, theta : :class:`~numpy.ndarray` (N=3), float - Vector around which the rotation is executed and rotation angle. - - """ - self._log.debug('Calling to_axisangle') - w, x, y, z = self.values - theta = 2.0 * np.arccos(w) - return np.array((x, y, z)), theta +# -*- coding: utf-8 -*- +# Copyright 2014 by Forschungszentrum Juelich GmbH +# Author: J. Caron +# +"""This module provides the :class:`~.Quaternion` class which can be used for rotations.""" + +import logging + +import numpy as np + +__all__ = ['Quaternion'] + + +class Quaternion(object): + """Class representing a rotation expressed by a quaternion. + + A quaternion is a four-dimensional description of a rotation which can also be described by + a rotation vector (`v1`, `v2`, `v3`) and a rotation angle :math:`\theta`. The four components + are calculated to: + + .. math:: + + w = \cos(\theta/2) + x = v_1 \cdot \sin(\theta/2) + y = v_2 \cdot \sin(\theta/2) + z = v_3 \cdot \sin(\theta/2) + + Use the :func:`~.from_axisangle` and :func:`~.to_axisangle` to convert to axis-angle + representation and vice versa. Quaternions can be multiplied by other quaternions, which + results in a new rotation or with a vector, which results in a rotated vector. + + Attributes + ---------- + values : float + The four quaternion values `w`, `x`, `y`, `z`. + + """ + + NORM_TOLERANCE = 1E-6 + + _log = logging.getLogger(__name__ + '.Quaternion') + + @property + def conj(self): + """The conjugate of the quaternion, representing a tilt in opposite direction.""" + w, x, y, z = self.values + return Quaternion((w, -x, -y, -z)) + + @property + def matrix(self): + """The rotation matrix representation of the quaternion.""" + w, x, y, z = self.values + return np.array([[1 - 2 * (y ** 2 + z ** 2), 2 * (x * y - w * z), 2 * (x * z + w * y)], + [2 * (x * y + w * z), 1 - 2 * (x ** 2 + z ** 2), 2 * (y * z - w * x)], + [2 * (x * z - w * y), 2 * (y * z + w * x), 1 - 2 * (x ** 2 + y ** 2)]]) + + def __init__(self, values): + self._log.debug('Calling __init__') + self.values = values + self._normalize() + self._log.debug('Created ' + str(self)) + + def __mul__(self, other): # self * other + self._log.debug('Calling __mul__') + if isinstance(other, Quaternion): # Quaternion multiplication + return self.dot_quat(self, other) + elif len(other) == 3: # vector multiplication + q_vec = Quaternion((0,) + tuple(other)) + q = self.dot_quat(self.dot_quat(self, q_vec), self.conj) + return q.values[1:] + + def dot_quat(self, q1, q2): + """Multiply two :class:`~.Quaternion` objects to create a new one (always normalized). + + Parameters + ---------- + q1, q2 : :class:`~.Quaternion` + The quaternion which should be multiplied. + + Returns + ------- + quaternion : :class:`~.Quaternion` + The resulting quaternion. + + """ + self._log.debug('Calling dot_quat') + w1, x1, y1, z1 = q1.values + w2, x2, y2, z2 = q2.values + w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2 + x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2 + y = w1 * y2 + y1 * w2 + z1 * x2 - x1 * z2 + z = w1 * z2 + z1 * w2 + x1 * y2 - y1 * x2 + return Quaternion((w, x, y, z)) + + def _normalize(self): + self._log.debug('Calling _normalize') + mag2 = np.sum(n ** 2 for n in self.values) + if abs(mag2 - 1.0) > self.NORM_TOLERANCE: + mag = np.sqrt(mag2) + self.values = tuple(n / mag for n in self.values) + + @classmethod + def from_axisangle(cls, vector, theta): + """Create a quaternion from an axis-angle representation + + Parameters + ---------- + vector : :class:`~numpy.ndarray` (N=3) + Vector around which the rotation is executed. + theta : float + Rotation angle. + + Returns + ------- + quaternion : :class:`~.Quaternion` + The resulting quaternion. + + """ + cls._log.debug('Calling from_axisangle') + x, y, z = vector + theta /= 2. + w = np.cos(theta) + x *= np.sin(theta) + y *= np.sin(theta) + z *= np.sin(theta) + return cls((w, x, y, z)) + + def to_axisangle(self): + """Convert the quaternion to axis-angle-representation. + + Returns + ------- + vector, theta : :class:`~numpy.ndarray` (N=3), float + Vector around which the rotation is executed and rotation angle. + + """ + self._log.debug('Calling to_axisangle') + w, x, y, z = self.values + theta = 2.0 * np.arccos(w) + return np.array((x, y, z)), theta diff --git a/pyramid/ramp.py b/pyramid/ramp.py index d53fbb5d348e424d6a8c2c301884a7fa703eeb10..2d08b8ff9fa81e6ed0ba0135c3ea32cce5d28f8d 100644 --- a/pyramid/ramp.py +++ b/pyramid/ramp.py @@ -1,223 +1,223 @@ -# -*- coding: utf-8 -*- -# Copyright 2014 by Forschungszentrum Juelich GmbH -# Author: J. Caron -# -"""This module provides the :class:`~.Ramp` class which implements polynomial phase ramps.""" - -import numpy as np - -from pyramid.phasemap import PhaseMap - -__all__ = ['Ramp'] - - -class Ramp(object): - """Class representing a polynomial phase ramp. - - Sometimes additional phase ramps occur in phase maps which do not stem from a magnetization - distribution inside the FOV. This class allows the construction (and via the derivative - functions also the reconstruction) of a polynomial ramp. This class is generally constructed - within the ForwardModel and can be retrieved as its attribute if ramp information should be - accessed. - - Attributes - ---------- - data_set : :class:`~dataset.DataSet` - :class:`~dataset.DataSet` object, which stores all required information calculation. - order : int or None (default) - Polynomial order of the additional phase ramp which will be added to the phase maps. - All ramp parameters have to be at the end of the input vector and are split automatically. - Default is None (no ramps are added). - deg_of_freedom : int - Number of degrees of freedom. This is calculated to ``1 + 2 * order``. There is just one - degree of freedom for a ramp of order zero (offset), every higher order contributes two - degrees of freedom. - param_cache : :class:`numpy.ndarray` (N=2) - Parameter cache which is used to store the polynomial coefficients. Higher coefficients - (one for each degree of freedom) are saved along the first axis, values for different - images along the second axis. - n : int - Size of the input space. Coincides with the numer of entries in `param_cache` and - calculates to ``deg_of_freedom * data_set.count``. - - Notes - ----- - After a reconstruction the relevant polynomial ramp information is stored in the - `param_cache`. If a phasemap with index `i` in the DataSet should be corrected use: - - .. code-block:: python - - phasemap -= ramp(i=0, dof_list=[0, 1, 2]) - - - The optional parameter `dof_list` can be used to specify a list of degrees of freedom which - should be used for the ramp (e.g. `[0]` will just apply the offset, `[0, 1, 2]` will apply - the offset and linear ramps in both directions). - - Fitting polynoms of higher orders than `order = 1` is possible but not recommended, because - features which stem from the magnetization could be covered by the polynom, decreasing the - phase contribution of the magnetization distribution, leading to a false retrieval. - - """ - - def __init__(self, data_set, order=None): - assert order is None or (isinstance(order, int) and order >= 0), \ - 'Order has to be None or a positive integer!' - self.order = order - self.a = data_set.a - self.count = data_set.count - self.dimensions = [projector.dim_uv for projector in data_set.projectors] - self.hook_points = data_set.hook_points - self.deg_of_freedom = (1 + 2 * self.order) if self.order is not None else 0 - self.param_cache = np.zeros((self.deg_of_freedom, self.count)) - self.n = self.deg_of_freedom * self.count # 0 if order is None - - def __call__(self, index, dof_list=None): - if self.order is None: # Do nothing if order is None! - return 0 - else: - if dof_list is None: # if no specific list is supplied! - dof_list = range(self.deg_of_freedom) # use all available degrees of freedom - dim_uv = self.dimensions[index] - phase_ramp = np.zeros(dim_uv) - # Iterate over all degrees of freedom: - for dof in dof_list: - # Add the contribution of the current degree of freedom: - phase_ramp += (self.param_cache[dof][index] * - self.create_poly_mesh(self.a, dof, dim_uv)) - return PhaseMap(self.a, phase_ramp, mask=np.zeros(dim_uv, dtype=np.bool)) - - def jac_dot(self, index): - """Calculate the product of the Jacobi matrix . - - Parameters - ---------- - index : int - Index of the phasemap from the `dataset` for which the phase ramp is calculated. - - Returns - ------- - result_vector : :class:`~numpy.ndarray` (N=1) - Product of the Jacobi matrix (which is not explicitely calculated) with the input - `vector`. Just the ramp contribution is calculated! - - """ - if self.order is None: # Do nothing if order is None! - return 0 - else: - dim_uv = self.dimensions[index] - phase_ramp = np.zeros(dim_uv) - # Iterate over all degrees of freedom: - for dof in range(self.deg_of_freedom): - # Add the contribution of the current degree of freedom: - phase_ramp += (self.param_cache[dof][index] * - self.create_poly_mesh(self.a, dof, dim_uv)) - return np.ravel(phase_ramp) - - def jac_T_dot(self, vector): - """'Calculate the transposed ramp parameters from a given `vector`. - - Parameters - ---------- - vector : :class:`~numpy.ndarray` (N=1) - Vectorized form of all 2D phase maps one after another in one vector. - - Returns - ------- - result_vector : :class:`~numpy.ndarray` (N=1) - Transposed ramp parameters. - - """ - result = [] - hp = self.hook_points - # Iterate over all degrees of freedom: - for dof in range(self.deg_of_freedom): - # Iterate over all projectors: - for i, dim_uv in enumerate(self.dimensions): - sub_vec = vector[hp[i]:hp[i + 1]] - poly_mesh = self.create_poly_mesh(self.a, dof, dim_uv) - # Transposed ramp parameters: summed product of the vector with the poly-mesh: - result.append(np.sum(sub_vec * np.ravel(poly_mesh))) - return result - - def extract_ramp_params(self, x): - """Extract the ramp parameters of an input vector and return the rest. - - Parameters - ---------- - x : :class:`~numpy.ndarray` (N=1) - Input vector which consists of the vectorised magnetization distribution and the ramp - parameters at the end which will be extracted. - - Returns - ------- - result_vector : :class:`~numpy.ndarray` (N=1) - Inpput vector without the extracted ramp parameters. - - Notes - ----- - This method should always be used before a vector `x` is processed if it is known that - ramp parameters are present so that other functions do not have to bother with them - and the :class:`.~ramp` already knows all important parameters for its own functions. - - """ - if self.order is not None: # Do nothing if order is None! - # Split off ramp parameters and fill cache: - x, ramp_params = np.split(x, [-self.n]) - self.param_cache = ramp_params.reshape((self.deg_of_freedom, self.count)) - return x - - @classmethod - def create_poly_mesh(cls, a, deg_of_freedom, dim_uv): - """Create a polynomial mesh for the ramp calculation for a specific degree of freedom. - - Parameters - ---------- - a : float - Grid spacing which should be used for the ramp. - deg_of_freedom : int - Current degree of freedom for which the mesh should be created. 0 corresponds to a - simple offset, 1 corresponds to a linear ramp in u-direction, 2 to a linear ramp in - v-direction and so on. - dim_uv : tuple (N=2) - Dimensions of the 2D mesh that should be created. - - Returns - ------- - result_mesh : :class:`~numpy.ndarray` (N=2) - Polynomial mesh that was created and can be used for further calculations. - - """ - # Determine if u-direction (u_or_v == 1) or v-direction (u_or_v == 0)! - u_or_v = (deg_of_freedom - 1) % 2 - # Determine polynomial order: - order = (deg_of_freedom + 1) // 2 - # Return polynomial mesh: - return (np.indices(dim_uv)[u_or_v] * a) ** order - - @classmethod - def create_ramp(cls, a, dim_uv, params): - """Class method to create an arbitrary polynomial ramp. - - Parameters - ---------- - a : float - Grid spacing which should be used for the ramp. - dim_uv : tuple (N=2) - Dimensions of the 2D mesh that should be created. - params : list - List of ramp parameters. The first entry corresponds to a simple offset, the second - and third correspond to a linear ramp in u- and v-direction, respectively and so on. - - Returns - ------- - phase_ramp : :class:`~pyramid.phasemap.PhaseMap` - The phase ramp as a :class:`~pyramid.phasemap.PhaseMap` object. - - """ - phase_ramp = np.zeros(dim_uv) - dof_list = range(len(params)) - for dof in dof_list: - phase_ramp += params[dof] * cls.create_poly_mesh(a, dof, dim_uv) - # Return the phase ramp as a PhaseMap with empty (!) mask: - return PhaseMap(a, phase_ramp, mask=np.zeros(dim_uv, dtype=np.bool)) +# -*- coding: utf-8 -*- +# Copyright 2014 by Forschungszentrum Juelich GmbH +# Author: J. Caron +# +"""This module provides the :class:`~.Ramp` class which implements polynomial phase ramps.""" + +import numpy as np + +from pyramid.phasemap import PhaseMap + +__all__ = ['Ramp'] + + +class Ramp(object): + """Class representing a polynomial phase ramp. + + Sometimes additional phase ramps occur in phase maps which do not stem from a magnetization + distribution inside the FOV. This class allows the construction (and via the derivative + functions also the reconstruction) of a polynomial ramp. This class is generally constructed + within the ForwardModel and can be retrieved as its attribute if ramp information should be + accessed. + + Attributes + ---------- + data_set : :class:`~dataset.DataSet` + :class:`~dataset.DataSet` object, which stores all required information calculation. + order : int or None (default) + Polynomial order of the additional phase ramp which will be added to the phase maps. + All ramp parameters have to be at the end of the input vector and are split automatically. + Default is None (no ramps are added). + deg_of_freedom : int + Number of degrees of freedom. This is calculated to ``1 + 2 * order``. There is just one + degree of freedom for a ramp of order zero (offset), every higher order contributes two + degrees of freedom. + param_cache : :class:`numpy.ndarray` (N=2) + Parameter cache which is used to store the polynomial coefficients. Higher coefficients + (one for each degree of freedom) are saved along the first axis, values for different + images along the second axis. + n : int + Size of the input space. Coincides with the numer of entries in `param_cache` and + calculates to ``deg_of_freedom * data_set.count``. + + Notes + ----- + After a reconstruction the relevant polynomial ramp information is stored in the + `param_cache`. If a phasemap with index `i` in the DataSet should be corrected use: + + .. code-block:: python + + phasemap -= ramp(i=0, dof_list=[0, 1, 2]) + + + The optional parameter `dof_list` can be used to specify a list of degrees of freedom which + should be used for the ramp (e.g. `[0]` will just apply the offset, `[0, 1, 2]` will apply + the offset and linear ramps in both directions). + + Fitting polynoms of higher orders than `order = 1` is possible but not recommended, because + features which stem from the magnetization could be covered by the polynom, decreasing the + phase contribution of the magnetization distribution, leading to a false retrieval. + + """ + + def __init__(self, data_set, order=None): + assert order is None or (isinstance(order, int) and order >= 0), \ + 'Order has to be None or a positive integer!' + self.order = order + self.a = data_set.a + self.count = data_set.count + self.dimensions = [projector.dim_uv for projector in data_set.projectors] + self.hook_points = data_set.hook_points + self.deg_of_freedom = (1 + 2 * self.order) if self.order is not None else 0 + self.param_cache = np.zeros((self.deg_of_freedom, self.count)) + self.n = self.deg_of_freedom * self.count # 0 if order is None + + def __call__(self, index, dof_list=None): + if self.order is None: # Do nothing if order is None! + return 0 + else: + if dof_list is None: # if no specific list is supplied! + dof_list = range(self.deg_of_freedom) # use all available degrees of freedom + dim_uv = self.dimensions[index] + phase_ramp = np.zeros(dim_uv) + # Iterate over all degrees of freedom: + for dof in dof_list: + # Add the contribution of the current degree of freedom: + phase_ramp += (self.param_cache[dof][index] * + self.create_poly_mesh(self.a, dof, dim_uv)) + return PhaseMap(self.a, phase_ramp, mask=np.zeros(dim_uv, dtype=np.bool)) + + def jac_dot(self, index): + """Calculate the product of the Jacobi matrix . + + Parameters + ---------- + index : int + Index of the phasemap from the `dataset` for which the phase ramp is calculated. + + Returns + ------- + result_vector : :class:`~numpy.ndarray` (N=1) + Product of the Jacobi matrix (which is not explicitely calculated) with the input + `vector`. Just the ramp contribution is calculated! + + """ + if self.order is None: # Do nothing if order is None! + return 0 + else: + dim_uv = self.dimensions[index] + phase_ramp = np.zeros(dim_uv) + # Iterate over all degrees of freedom: + for dof in range(self.deg_of_freedom): + # Add the contribution of the current degree of freedom: + phase_ramp += (self.param_cache[dof][index] * + self.create_poly_mesh(self.a, dof, dim_uv)) + return np.ravel(phase_ramp) + + def jac_T_dot(self, vector): + """'Calculate the transposed ramp parameters from a given `vector`. + + Parameters + ---------- + vector : :class:`~numpy.ndarray` (N=1) + Vectorized form of all 2D phase maps one after another in one vector. + + Returns + ------- + result_vector : :class:`~numpy.ndarray` (N=1) + Transposed ramp parameters. + + """ + result = [] + hp = self.hook_points + # Iterate over all degrees of freedom: + for dof in range(self.deg_of_freedom): + # Iterate over all projectors: + for i, dim_uv in enumerate(self.dimensions): + sub_vec = vector[hp[i]:hp[i + 1]] + poly_mesh = self.create_poly_mesh(self.a, dof, dim_uv) + # Transposed ramp parameters: summed product of the vector with the poly-mesh: + result.append(np.sum(sub_vec * np.ravel(poly_mesh))) + return result + + def extract_ramp_params(self, x): + """Extract the ramp parameters of an input vector and return the rest. + + Parameters + ---------- + x : :class:`~numpy.ndarray` (N=1) + Input vector which consists of the vectorised magnetization distribution and the ramp + parameters at the end which will be extracted. + + Returns + ------- + result_vector : :class:`~numpy.ndarray` (N=1) + Inpput vector without the extracted ramp parameters. + + Notes + ----- + This method should always be used before a vector `x` is processed if it is known that + ramp parameters are present so that other functions do not have to bother with them + and the :class:`.~ramp` already knows all important parameters for its own functions. + + """ + if self.order is not None: # Do nothing if order is None! + # Split off ramp parameters and fill cache: + x, ramp_params = np.split(x, [-self.n]) + self.param_cache = ramp_params.reshape((self.deg_of_freedom, self.count)) + return x + + @classmethod + def create_poly_mesh(cls, a, deg_of_freedom, dim_uv): + """Create a polynomial mesh for the ramp calculation for a specific degree of freedom. + + Parameters + ---------- + a : float + Grid spacing which should be used for the ramp. + deg_of_freedom : int + Current degree of freedom for which the mesh should be created. 0 corresponds to a + simple offset, 1 corresponds to a linear ramp in u-direction, 2 to a linear ramp in + v-direction and so on. + dim_uv : tuple (N=2) + Dimensions of the 2D mesh that should be created. + + Returns + ------- + result_mesh : :class:`~numpy.ndarray` (N=2) + Polynomial mesh that was created and can be used for further calculations. + + """ + # Determine if u-direction (u_or_v == 1) or v-direction (u_or_v == 0)! + u_or_v = (deg_of_freedom - 1) % 2 + # Determine polynomial order: + order = (deg_of_freedom + 1) // 2 + # Return polynomial mesh: + return (np.indices(dim_uv)[u_or_v] * a) ** order + + @classmethod + def create_ramp(cls, a, dim_uv, params): + """Class method to create an arbitrary polynomial ramp. + + Parameters + ---------- + a : float + Grid spacing which should be used for the ramp. + dim_uv : tuple (N=2) + Dimensions of the 2D mesh that should be created. + params : list + List of ramp parameters. The first entry corresponds to a simple offset, the second + and third correspond to a linear ramp in u- and v-direction, respectively and so on. + + Returns + ------- + phase_ramp : :class:`~pyramid.phasemap.PhaseMap` + The phase ramp as a :class:`~pyramid.phasemap.PhaseMap` object. + + """ + phase_ramp = np.zeros(dim_uv) + dof_list = range(len(params)) + for dof in dof_list: + phase_ramp += params[dof] * cls.create_poly_mesh(a, dof, dim_uv) + # Return the phase ramp as a PhaseMap with empty (!) mask: + return PhaseMap(a, phase_ramp, mask=np.zeros(dim_uv, dtype=np.bool)) diff --git a/pyramid/reconstruction.py b/pyramid/reconstruction.py index 572f07af89f9fcc75b91a2b5bde04f3c0136773c..cc313727c1bbfa1bf0c6fe69de7a7ddbdec23d2d 100644 --- a/pyramid/reconstruction.py +++ b/pyramid/reconstruction.py @@ -1,184 +1,184 @@ -# -*- coding: utf-8 -*- -# Copyright 2014 by Forschungszentrum Juelich GmbH -# Author: J. Caron -# -"""Reconstruct magnetic distributions from given phasemaps. - -This module reconstructs 3-dimensional magnetic distributions (as -:class:`~pyramid.magdata.VectorData` objects) from a given set of phase maps (represented by -:class:`~pyramid.phasemap.PhaseMap` objects) by using several model based reconstruction algorithms - which use the forward model provided by :mod:`~pyramid.projector` and :mod:`~pyramid.phasemapper` - and a priori knowledge of the distribution. - -""" - -import logging - -import numpy as np - -from pyramid.fielddata import VectorData - -__all__ = ['optimize_linear', 'optimize_nonlin', 'optimize_splitbregman'] -_log = logging.getLogger(__name__) - - -def optimize_linear(costfunction, mag_0=None, ramp_0=None, max_iter=None, verbose=False): - """Reconstruct a three-dimensional magnetic distribution from given phase maps via the - conjugate gradient optimizaion method :func:`~.scipy.sparse.linalg.cg`. - Blazingly fast for l2-based cost functions. - - Parameters - ---------- - costfunction : :class:`~.Costfunction` - A :class:`~.Costfunction` object which implements a specified forward model and - regularisator which is minimized in the optimization process. - mag_0: :class:`~.VectorData` - The starting magnetisation distribution used for the reconstruction. A zero vector will be - used if no VectorData object is specified. - mag_0: :class:`~.Ramp` - The starting ramp for the reconstruction. A zero vector will be - used if no Ramp object is specified. - max_iter : int, optional - The maximum number of iterations for the opimization. - verbose: bool, optional - If set to True, information like a progressbar is displayed during reconstruction. - The default is False. - - Returns - ------- - magdata : :class:`~pyramid.fielddata.VectorData` - The reconstructed magnetic distribution as a :class:`~.VectorData` object. - - """ - import jutil.cg as jcg - from jutil.taketime import TakeTime - _log.debug('Calling optimize_linear') - _log.info('Cost before optimization: {:.3e}'.format(costfunction(np.zeros(costfunction.n)))) - data_set = costfunction.fwd_model.data_set - # Get starting distribution vector x_0: - x_0 = np.empty(costfunction.n) - if mag_0 is not None: - costfunction.fwd_model.magdata = mag_0 - x_0[:data_set.n] = costfunction.fwd_model.magdata.get_vector(mask=data_set.mask) - if ramp_0 is not None: - ramp_vec = ramp_0.param_cache.ravel() - else: - ramp_vec = np.zeros_like(costfunction.fwd_model.ramp.n) - x_0[data_set.n:] = ramp_vec - # Minimize: - with TakeTime('reconstruction time'): - x_opt = jcg.conj_grad_minimize(costfunction, x_0=x_0, max_iter=max_iter, verbose=verbose).x - _log.info('Cost after optimization: {:.3e}'.format(costfunction(x_opt))) - # Cut ramp parameters if necessary (this also saves the final parameters in the ramp class!): - x_opt = costfunction.fwd_model.ramp.extract_ramp_params(x_opt) - # Create and return fitting VectorData object: - mag_opt = VectorData(data_set.a, np.zeros((3,) + data_set.dim)) - mag_opt.set_vector(x_opt, data_set.mask) - return mag_opt - - -def optimize_nonlin(costfunction, first_guess=None): - """Reconstruct a three-dimensional magnetic distribution from given phase maps via - steepest descent method. This is slow, but works best for non l2-regularisators. - - - Parameters - ---------- - costfunction : :class:`~.Costfunction` - A :class:`~.Costfunction` object which implements a specified forward model and - regularisator which is minimized in the optimization process. - first_guess : :class:`~pyramid.fielddata.VectorData` - magnetization to start the non-linear iteration with. - - Returns - ------- - magdata : :class:`~pyramid.fielddata.VectorData` - The reconstructed magnetic distribution as a :class:`~.VectorData` object. - - """ - import jutil.minimizer as jmin - import jutil.norms as jnorms - _log.debug('Calling optimize_nonlin') - data_set = costfunction.fwd_model.data_set - if first_guess is None: - first_guess = VectorData(data_set.a, np.zeros((3,) + data_set.dim)) - - x_0 = first_guess.get_vector(data_set.mask) - assert len(x_0) == costfunction.n, (len(x_0), costfunction.m, costfunction.n) - - p = costfunction.regularisator.p - q = 1. / (1. - (1. / p)) - lq = jnorms.LPPow(q, 1e-20) - - def _preconditioner(_, direc): - direc_p = direc / abs(direc).max() - direc_p = 10 * (1. / q) * lq.jac(direc_p) - return direc_p - - # This Method is semi-best for Lp type problems. Takes forever, though - _log.info('Cost before optimization: {}'.format(costfunction(np.zeros(costfunction.n)))) - result = jmin.minimize( - costfunction, x_0, - method="SteepestDescent", - options={"preconditioner": _preconditioner}, - tol={"max_iteration": 10000}) - x_opt = result.x - _log.info('Cost after optimization: {}'.format(costfunction(x_opt))) - mag_opt = VectorData(data_set.a, np.zeros((3,) + data_set.dim)) - mag_opt.set_vector(x_opt, data_set.mask) - return mag_opt - - -def optimize_splitbregman(costfunction, weight, lam, mu): - """ - Reconstructs magnet distribution from phase image measurements using a split bregman - algorithm with a dedicated TV-l1 norm. Very dedicated, frickle, brittle, and difficult - to get to work, but fastest option available if it works. - - Seems to work for some 2D examples with weight=lam=1 and mu in [1, .., 1e4]. - - Parameters - ---------- - costfunction : :class:`~.Costfunction` - A :class:`~.Costfunction` object which implements a specified forward model and - regularisator which is minimized in the optimization process. - weight : float - Obscure split bregman parameter - lam : float - Cryptic split bregman parameter - mu : float - flabberghasting split bregman paramter - - Returns - ------- - magdata : :class:`~pyramid.fielddata.VectorData` - The reconstructed magnetic distribution as a :class:`~.VectorData` object. - - """ - import jutil.splitbregman as jsb - import jutil.operator as joperator - import jutil.diff as jdiff - _log.debug('Calling optimize_splitbregman') - - # regularisator is actually not necessary, but this makes the cost - # function to that which is supposedly optimized by split bregman. - # Thus cost can be used to verify convergence - fwd_model = costfunction.fwd_model - data_set = fwd_model.data_set - - A = joperator.Function( - (costfunction.m, costfunction.n), - lambda x: fwd_model.jac_dot(None, x), - FT=lambda x: fwd_model.jac_T_dot(None, x)) - D = joperator.VStack([ - jdiff.get_diff_operator(data_set.mask, 0, 3), - jdiff.get_diff_operator(data_set.mask, 1, 3)]) - y = np.asarray(costfunction.y, dtype=np.double) - - x_opt = jsb.split_bregman_2d( - A, D, y, - weight=weight, mu=mu, lambd=lam, max_iter=1000) - - mag_opt = VectorData(data_set.a, np.zeros((3,) + data_set.dim)) - mag_opt.set_vector(x_opt, data_set.mask) - return mag_opt +# -*- coding: utf-8 -*- +# Copyright 2014 by Forschungszentrum Juelich GmbH +# Author: J. Caron +# +"""Reconstruct magnetic distributions from given phasemaps. + +This module reconstructs 3-dimensional magnetic distributions (as +:class:`~pyramid.magdata.VectorData` objects) from a given set of phase maps (represented by +:class:`~pyramid.phasemap.PhaseMap` objects) by using several model based reconstruction algorithms + which use the forward model provided by :mod:`~pyramid.projector` and :mod:`~pyramid.phasemapper` + and a priori knowledge of the distribution. + +""" + +import logging + +import numpy as np + +from pyramid.fielddata import VectorData + +__all__ = ['optimize_linear', 'optimize_nonlin', 'optimize_splitbregman'] +_log = logging.getLogger(__name__) + + +def optimize_linear(costfunction, mag_0=None, ramp_0=None, max_iter=None, verbose=False): + """Reconstruct a three-dimensional magnetic distribution from given phase maps via the + conjugate gradient optimizaion method :func:`~.scipy.sparse.linalg.cg`. + Blazingly fast for l2-based cost functions. + + Parameters + ---------- + costfunction : :class:`~.Costfunction` + A :class:`~.Costfunction` object which implements a specified forward model and + regularisator which is minimized in the optimization process. + mag_0: :class:`~.VectorData` + The starting magnetisation distribution used for the reconstruction. A zero vector will be + used if no VectorData object is specified. + mag_0: :class:`~.Ramp` + The starting ramp for the reconstruction. A zero vector will be + used if no Ramp object is specified. + max_iter : int, optional + The maximum number of iterations for the opimization. + verbose: bool, optional + If set to True, information like a progressbar is displayed during reconstruction. + The default is False. + + Returns + ------- + magdata : :class:`~pyramid.fielddata.VectorData` + The reconstructed magnetic distribution as a :class:`~.VectorData` object. + + """ + import jutil.cg as jcg + from jutil.taketime import TakeTime + _log.debug('Calling optimize_linear') + _log.info('Cost before optimization: {:.3e}'.format(costfunction(np.zeros(costfunction.n)))) + data_set = costfunction.fwd_model.data_set + # Get starting distribution vector x_0: + x_0 = np.empty(costfunction.n) + if mag_0 is not None: + costfunction.fwd_model.magdata = mag_0 + x_0[:data_set.n] = costfunction.fwd_model.magdata.get_vector(mask=data_set.mask) + if ramp_0 is not None: + ramp_vec = ramp_0.param_cache.ravel() + else: + ramp_vec = np.zeros_like(costfunction.fwd_model.ramp.n) + x_0[data_set.n:] = ramp_vec + # Minimize: + with TakeTime('reconstruction time'): + x_opt = jcg.conj_grad_minimize(costfunction, x_0=x_0, max_iter=max_iter, verbose=verbose).x + _log.info('Cost after optimization: {:.3e}'.format(costfunction(x_opt))) + # Cut ramp parameters if necessary (this also saves the final parameters in the ramp class!): + x_opt = costfunction.fwd_model.ramp.extract_ramp_params(x_opt) + # Create and return fitting VectorData object: + mag_opt = VectorData(data_set.a, np.zeros((3,) + data_set.dim)) + mag_opt.set_vector(x_opt, data_set.mask) + return mag_opt + + +def optimize_nonlin(costfunction, first_guess=None): + """Reconstruct a three-dimensional magnetic distribution from given phase maps via + steepest descent method. This is slow, but works best for non l2-regularisators. + + + Parameters + ---------- + costfunction : :class:`~.Costfunction` + A :class:`~.Costfunction` object which implements a specified forward model and + regularisator which is minimized in the optimization process. + first_guess : :class:`~pyramid.fielddata.VectorData` + magnetization to start the non-linear iteration with. + + Returns + ------- + magdata : :class:`~pyramid.fielddata.VectorData` + The reconstructed magnetic distribution as a :class:`~.VectorData` object. + + """ + import jutil.minimizer as jmin + import jutil.norms as jnorms + _log.debug('Calling optimize_nonlin') + data_set = costfunction.fwd_model.data_set + if first_guess is None: + first_guess = VectorData(data_set.a, np.zeros((3,) + data_set.dim)) + + x_0 = first_guess.get_vector(data_set.mask) + assert len(x_0) == costfunction.n, (len(x_0), costfunction.m, costfunction.n) + + p = costfunction.regularisator.p + q = 1. / (1. - (1. / p)) + lq = jnorms.LPPow(q, 1e-20) + + def _preconditioner(_, direc): + direc_p = direc / abs(direc).max() + direc_p = 10 * (1. / q) * lq.jac(direc_p) + return direc_p + + # This Method is semi-best for Lp type problems. Takes forever, though + _log.info('Cost before optimization: {}'.format(costfunction(np.zeros(costfunction.n)))) + result = jmin.minimize( + costfunction, x_0, + method="SteepestDescent", + options={"preconditioner": _preconditioner}, + tol={"max_iteration": 10000}) + x_opt = result.x + _log.info('Cost after optimization: {}'.format(costfunction(x_opt))) + mag_opt = VectorData(data_set.a, np.zeros((3,) + data_set.dim)) + mag_opt.set_vector(x_opt, data_set.mask) + return mag_opt + + +def optimize_splitbregman(costfunction, weight, lam, mu): + """ + Reconstructs magnet distribution from phase image measurements using a split bregman + algorithm with a dedicated TV-l1 norm. Very dedicated, frickle, brittle, and difficult + to get to work, but fastest option available if it works. + + Seems to work for some 2D examples with weight=lam=1 and mu in [1, .., 1e4]. + + Parameters + ---------- + costfunction : :class:`~.Costfunction` + A :class:`~.Costfunction` object which implements a specified forward model and + regularisator which is minimized in the optimization process. + weight : float + Obscure split bregman parameter + lam : float + Cryptic split bregman parameter + mu : float + flabberghasting split bregman paramter + + Returns + ------- + magdata : :class:`~pyramid.fielddata.VectorData` + The reconstructed magnetic distribution as a :class:`~.VectorData` object. + + """ + import jutil.splitbregman as jsb + import jutil.operator as joperator + import jutil.diff as jdiff + _log.debug('Calling optimize_splitbregman') + + # regularisator is actually not necessary, but this makes the cost + # function to that which is supposedly optimized by split bregman. + # Thus cost can be used to verify convergence + fwd_model = costfunction.fwd_model + data_set = fwd_model.data_set + + A = joperator.Function( + (costfunction.m, costfunction.n), + lambda x: fwd_model.jac_dot(None, x), + FT=lambda x: fwd_model.jac_T_dot(None, x)) + D = joperator.VStack([ + jdiff.get_diff_operator(data_set.mask, 0, 3), + jdiff.get_diff_operator(data_set.mask, 1, 3)]) + y = np.asarray(costfunction.y, dtype=np.double) + + x_opt = jsb.split_bregman_2d( + A, D, y, + weight=weight, mu=mu, lambd=lam, max_iter=1000) + + mag_opt = VectorData(data_set.a, np.zeros((3,) + data_set.dim)) + mag_opt.set_vector(x_opt, data_set.mask) + return mag_opt diff --git a/pyramid/regularisator.py b/pyramid/regularisator.py index 335cd36659c902769d4166f28f369606b13d05fc..dc351a138d38af0b45baa5ce4f52ac811f1ade7d 100644 --- a/pyramid/regularisator.py +++ b/pyramid/regularisator.py @@ -1,378 +1,378 @@ -# -*- coding: utf-8 -*- -# Copyright 2014 by Forschungszentrum Juelich GmbH -# Author: J. Caron -# -"""This module provides the :class:`~.Regularisator` class which represents a regularisation term -which adds additional constraints to a costfunction to minimize.""" - -import abc -import logging - -import numpy as np -from scipy import sparse - -import jutil.diff as jdiff -import jutil.norms as jnorm - -__all__ = ['NoneRegularisator', 'ZeroOrderRegularisator', 'FirstOrderRegularisator', - 'ComboRegularisator'] - - -class Regularisator(object, metaclass=abc.ABCMeta): - """Class for providing a regularisation term which implements additional constraints. - - Represents a certain constraint for the 3D magnetization distribution whose cost is to minimize - in addition to the derivation from the 2D phase maps. Important is the used `norm` and the - regularisation parameter `lam` (lambda) which determines the weighting between the two cost - parts (measurements and regularisation). Additional parameters at the end of the input - vector, which are not relevant for the regularisation can be discarded by specifying the - number in `add_params`. - - Attributes - ---------- - norm : :class:`~jutil.norm.WeightedNorm` - Norm, which is used to determine the cost of the regularisation term. - lam : float - Regularisation parameter determining the weighting between measurements and regularisation. - add_params : int - Number of additional parameters which are not used in the regularisation. Used to cut - the input vector into the appropriate size. - - """ - - _log = logging.getLogger(__name__ + '.Regularisator') - - @abc.abstractmethod - def __init__(self, norm, lam, add_params=0): - self._log.debug('Calling __init__') - self.norm = norm - self.lam = lam - self.add_params = add_params - if self.add_params > 0: - self.slice = slice(-add_params) - else: - self.slice = slice(None) - self._log.debug('Created ' + str(self)) - - def __call__(self, x): - self._log.debug('Calling __call__') - return self.lam * self.norm(x[self.slice]) - - def __repr__(self): - self._log.debug('Calling __repr__') - return '%s(norm=%r, lam=%r, add_params=%r)' % (self.__class__, self.norm, self.lam, - self.add_params) - - def __str__(self): - self._log.debug('Calling __str__') - return 'Regularisator(norm=%s, lam=%s, add_params=%s)' % (self.norm, self.lam, - self.add_params) - - def jac(self, x): - """Calculate the derivative of the regularisation term for a given magnetic distribution. - - Parameters - ---------- - x: :class:`~numpy.ndarray` (N=1) - Vectorized magnetization distribution, for which the Jacobi vector is calculated. - - Returns - ------- - result : :class:`~numpy.ndarray` (N=1) - Jacobi vector which represents the cost derivative of all voxels of the magnetization. - - """ - result = np.zeros_like(x) - result[self.slice] = self.lam * self.norm.jac(x[self.slice]) - return result - - def hess_dot(self, x, vector): - """Calculate the product of a `vector` with the Hessian matrix of the regularisation term. - - Parameters - ---------- - x : :class:`~numpy.ndarray` (N=1) - Vectorized magnetization distribution at which the Hessian is calculated. The Hessian - is constant in this case, thus `x` can be set to None (it is not used int the - computation). It is implemented for the case that in the future nonlinear problems - have to be solved. - vector : :class:`~numpy.ndarray` (N=1) - Vectorized magnetization distribution which is multiplied by the Hessian. - - Returns - ------- - result : :class:`~numpy.ndarray` (N=1) - Product of the input `vector` with the Hessian matrix. - - """ - result = np.zeros_like(vector) - result[self.slice] = self.lam * self.norm.hess_dot(x, vector[self.slice]) - return result - - def hess_diag(self, x): - """ Return the diagonal of the Hessian. - - Parameters - ---------- - x : :class:`~numpy.ndarray` (N=1) - Vectorized magnetization distribution at which the Hessian is calculated. The Hessian - is constant in this case, thus `x` can be set to None (it is not used in the - computation). It is implemented for the case that in the future nonlinear problems - have to be solved. - - Returns - ------- - result : :class:`~numpy.ndarray` (N=1) - Diagonal of the Hessian matrix. - - """ - self._log.debug('Calling hess_diag') - result = np.zeros_like(x) - result[self.slice] = self.lam * self.norm.hess_diag(x[self.slice]) - return result - - -class ComboRegularisator(Regularisator): - """Class for providing a regularisation term which combines several regularisators. - - If more than one regularisation should be utilized, this class can be use. It is given a list - of :class:`~.Regularisator` objects. The input will be forwarded to each of them and the - results are summed up and returned. - - Attributes - ---------- - reg_list: :class:`~.Regularisator` - A list of regularisator objects to whom the input is passed on. - - """ - - def __init__(self, reg_list): - self._log.debug('Calling __init__') - self.reg_list = reg_list - super().__init__(norm=None, lam=None) - self._log.debug('Created ' + str(self)) - - def __call__(self, x): - self._log.debug('Calling __call__') - return np.sum([self.reg_list[i](x) for i in range(len(self.reg_list))], axis=0) - - def __repr__(self): - self._log.debug('Calling __repr__') - return '%s(reg_list=%r)' % (self.__class__, self.reg_list) - - def __str__(self): - self._log.debug('Calling __str__') - return 'ComboRegularisator(reg_list=%s)' % self.reg_list - - def jac(self, x): - """Calculate the derivative of the regularisation term for a given magnetic distribution. - - Parameters - ---------- - x: :class:`~numpy.ndarray` (N=1) - Vectorized magnetization distribution, for which the Jacobi vector is calculated. - - Returns - ------- - result : :class:`~numpy.ndarray` (N=1) - Jacobi vector which represents the cost derivative of all voxels of the magnetization. - - """ - return np.sum([self.reg_list[i].jac(x) for i in range(len(self.reg_list))], axis=0) - - def hess_dot(self, x, vector): - """Calculate the product of a `vector` with the Hessian matrix of the regularisation term. - - Parameters - ---------- - x : :class:`~numpy.ndarray` (N=1) - Vectorized magnetization distribution at which the Hessian is calculated. The Hessian - is constant in this case, thus `x` can be set to None (it is not used int the - computation). It is implemented for the case that in the future nonlinear problems - have to be solved. - vector : :class:`~numpy.ndarray` (N=1) - Vectorized magnetization distribution which is multiplied by the Hessian. - - Returns - ------- - result : :class:`~numpy.ndarray` (N=1) - Product of the input `vector` with the Hessian matrix. - - """ - return np.sum([self.reg_list[i].hess_dot(x, vector) for i in range(len(self.reg_list))], - axis=0) - - def hess_diag(self, x): - """ Return the diagonal of the Hessian. - - Parameters - ---------- - x : :class:`~numpy.ndarray` (N=1) - Vectorized magnetization distribution at which the Hessian is calculated. The Hessian - is constant in this case, thus `x` can be set to None (it is not used in the - computation). It is implemented for the case that in the future nonlinear problems - have to be solved. - - Returns - ------- - result : :class:`~numpy.ndarray` (N=1) - Diagonal of the Hessian matrix. - - """ - self._log.debug('Calling hess_diag') - return np.sum([self.reg_list[i].hess_diag(x) for i in range(len(self.reg_list))], axis=0) - - -class NoneRegularisator(Regularisator): - """Placeholder class if no regularization is used. - - This class is instantiated in the :class:`~pyramid.costfunction.Costfunction`, which means - no regularisation is used. All associated functions return appropriate zero-values. - - Attributes - ---------- - norm: None - No regularization is used, thus also no norm. - lam: 0 - Not used. - - """ - - _log = logging.getLogger(__name__ + '.NoneRegularisator') - - def __init__(self): - self._log.debug('Calling __init__') - self.norm = None - self.lam = 0 - self.add_params = None - super().__init__(norm=None, lam=None) - self._log.debug('Created ' + str(self)) - - def __call__(self, x): - self._log.debug('Calling __call__') - return 0 - - def jac(self, x): - """Calculate the derivative of the regularisation term for a given magnetic distribution. - - Parameters - ---------- - x: :class:`~numpy.ndarray` (N=1) - Vectorized magnetization distribution, for which the Jacobi vector is calculated. - - Returns - ------- - result : :class:`~numpy.ndarray` (N=1) - Jacobi vector which represents the cost derivative of all voxels of the magnetization. - - """ - return np.zeros_like(x) - - def hess_dot(self, x, vector): - """Calculate the product of a `vector` with the Hessian matrix of the regularisation term. - - Parameters - ---------- - x : :class:`~numpy.ndarray` (N=1) - Vectorized magnetization distribution at which the Hessian is calculated. The Hessian - is constant in this case, thus `x` can be set to None (it is not used in the - computation). It is implemented for the case that in the future nonlinear problems - have to be solved. - vector : a :class:`~numpy.ndarray` (N=1) - Vectorized magnetization distribution which is multiplied by the Hessian. - - Returns - ------- - result : :class:`~numpy.ndarray` (N=1) - Product of the input `vector` with the Hessian matrix of the costfunction. - - """ - return np.zeros_like(vector) - - def hess_diag(self, x): - """ Return the diagonal of the Hessian. - - Parameters - ---------- - x : :class:`~numpy.ndarray` (N=1) - Vectorized magnetization distribution at which the Hessian is calculated. The Hessian - is constant in this case, thus `x` can be set to None (it is not used int the - computation). It is implemented for the case that in the future nonlinear problems - have to be solved. - - Returns - ------- - result : :class:`~numpy.ndarray` (N=1) - Diagonal of the Hessian matrix. - - """ - self._log.debug('Calling hess_diag') - return np.zeros_like(x) - - -class ZeroOrderRegularisator(Regularisator): - """Class for providing a regularisation term which implements Lp norm minimization. - - The constraint this class represents is the minimization of the Lp norm for the 3D - magnetization distribution. Important is the regularisation parameter `lam` (lambda) which - determines the weighting between the two cost parts (measurements and regularisation). - - Attributes - ---------- - lam: float - Regularisation parameter determining the weighting between measurements and regularisation. - p: int, optional - Order of the norm (default: 2, which means a standard L2-norm). - add_params : int - Number of additional parameters which are not used in the regularisation. Used to cut - the input vector into the appropriate size. - - """ - - _log = logging.getLogger(__name__ + '.ZeroOrderRegularisator') - - def __init__(self, _=None, lam=1E-4, p=2, add_params=0): - self._log.debug('Calling __init__') - self.p = p - if p == 2: - norm = jnorm.L2Square() - else: - norm = jnorm.LPPow(p, 1e-12) - super().__init__(norm, lam, add_params) - self._log.debug('Created ' + str(self)) - - -class FirstOrderRegularisator(Regularisator): - """Class for providing a regularisation term which implements derivation minimization. - - The constraint this class represents is the minimization of the first order derivative of the - 3D magnetization distribution using a Lp norm. Important is the regularisation parameter `lam` - (lambda) which determines the weighting between the two cost parts (measurements and - regularisation). - - Attributes - ---------- - mask: :class:`~numpy.ndarray` (N=3) - A boolean mask which defines the magnetized volume in 3D. - lam: float - Regularisation parameter determining the weighting between measurements and regularisation. - p: int, optional - Order of the norm (default: 2, which means a standard L2-norm). - add_params : int - Number of additional parameters which are not used in the regularisation. Used to cut - the input vector into the appropriate size. - - """ - - def __init__(self, mask, lam=1E-4, p=2, add_params=0): - self.p = p - D0 = jdiff.get_diff_operator(mask, 0, 3) - D1 = jdiff.get_diff_operator(mask, 1, 3) - D2 = jdiff.get_diff_operator(mask, 2, 3) - D = sparse.vstack([D0, D1, D2]) - if p == 2: - norm = jnorm.WeightedL2Square(D) - else: - norm = jnorm.WeightedTV(jnorm.LPPow(p, 1e-12), D, [D0.shape[0], D.shape[0]]) - super().__init__(norm, lam, add_params) - self._log.debug('Created ' + str(self)) +# -*- coding: utf-8 -*- +# Copyright 2014 by Forschungszentrum Juelich GmbH +# Author: J. Caron +# +"""This module provides the :class:`~.Regularisator` class which represents a regularisation term +which adds additional constraints to a costfunction to minimize.""" + +import abc +import logging + +import numpy as np +from scipy import sparse + +import jutil.diff as jdiff +import jutil.norms as jnorm + +__all__ = ['NoneRegularisator', 'ZeroOrderRegularisator', 'FirstOrderRegularisator', + 'ComboRegularisator'] + + +class Regularisator(object, metaclass=abc.ABCMeta): + """Class for providing a regularisation term which implements additional constraints. + + Represents a certain constraint for the 3D magnetization distribution whose cost is to minimize + in addition to the derivation from the 2D phase maps. Important is the used `norm` and the + regularisation parameter `lam` (lambda) which determines the weighting between the two cost + parts (measurements and regularisation). Additional parameters at the end of the input + vector, which are not relevant for the regularisation can be discarded by specifying the + number in `add_params`. + + Attributes + ---------- + norm : :class:`~jutil.norm.WeightedNorm` + Norm, which is used to determine the cost of the regularisation term. + lam : float + Regularisation parameter determining the weighting between measurements and regularisation. + add_params : int + Number of additional parameters which are not used in the regularisation. Used to cut + the input vector into the appropriate size. + + """ + + _log = logging.getLogger(__name__ + '.Regularisator') + + @abc.abstractmethod + def __init__(self, norm, lam, add_params=0): + self._log.debug('Calling __init__') + self.norm = norm + self.lam = lam + self.add_params = add_params + if self.add_params > 0: + self.slice = slice(-add_params) + else: + self.slice = slice(None) + self._log.debug('Created ' + str(self)) + + def __call__(self, x): + self._log.debug('Calling __call__') + return self.lam * self.norm(x[self.slice]) + + def __repr__(self): + self._log.debug('Calling __repr__') + return '%s(norm=%r, lam=%r, add_params=%r)' % (self.__class__, self.norm, self.lam, + self.add_params) + + def __str__(self): + self._log.debug('Calling __str__') + return 'Regularisator(norm=%s, lam=%s, add_params=%s)' % (self.norm, self.lam, + self.add_params) + + def jac(self, x): + """Calculate the derivative of the regularisation term for a given magnetic distribution. + + Parameters + ---------- + x: :class:`~numpy.ndarray` (N=1) + Vectorized magnetization distribution, for which the Jacobi vector is calculated. + + Returns + ------- + result : :class:`~numpy.ndarray` (N=1) + Jacobi vector which represents the cost derivative of all voxels of the magnetization. + + """ + result = np.zeros_like(x) + result[self.slice] = self.lam * self.norm.jac(x[self.slice]) + return result + + def hess_dot(self, x, vector): + """Calculate the product of a `vector` with the Hessian matrix of the regularisation term. + + Parameters + ---------- + x : :class:`~numpy.ndarray` (N=1) + Vectorized magnetization distribution at which the Hessian is calculated. The Hessian + is constant in this case, thus `x` can be set to None (it is not used int the + computation). It is implemented for the case that in the future nonlinear problems + have to be solved. + vector : :class:`~numpy.ndarray` (N=1) + Vectorized magnetization distribution which is multiplied by the Hessian. + + Returns + ------- + result : :class:`~numpy.ndarray` (N=1) + Product of the input `vector` with the Hessian matrix. + + """ + result = np.zeros_like(vector) + result[self.slice] = self.lam * self.norm.hess_dot(x, vector[self.slice]) + return result + + def hess_diag(self, x): + """ Return the diagonal of the Hessian. + + Parameters + ---------- + x : :class:`~numpy.ndarray` (N=1) + Vectorized magnetization distribution at which the Hessian is calculated. The Hessian + is constant in this case, thus `x` can be set to None (it is not used in the + computation). It is implemented for the case that in the future nonlinear problems + have to be solved. + + Returns + ------- + result : :class:`~numpy.ndarray` (N=1) + Diagonal of the Hessian matrix. + + """ + self._log.debug('Calling hess_diag') + result = np.zeros_like(x) + result[self.slice] = self.lam * self.norm.hess_diag(x[self.slice]) + return result + + +class ComboRegularisator(Regularisator): + """Class for providing a regularisation term which combines several regularisators. + + If more than one regularisation should be utilized, this class can be use. It is given a list + of :class:`~.Regularisator` objects. The input will be forwarded to each of them and the + results are summed up and returned. + + Attributes + ---------- + reg_list: :class:`~.Regularisator` + A list of regularisator objects to whom the input is passed on. + + """ + + def __init__(self, reg_list): + self._log.debug('Calling __init__') + self.reg_list = reg_list + super().__init__(norm=None, lam=None) + self._log.debug('Created ' + str(self)) + + def __call__(self, x): + self._log.debug('Calling __call__') + return np.sum([self.reg_list[i](x) for i in range(len(self.reg_list))], axis=0) + + def __repr__(self): + self._log.debug('Calling __repr__') + return '%s(reg_list=%r)' % (self.__class__, self.reg_list) + + def __str__(self): + self._log.debug('Calling __str__') + return 'ComboRegularisator(reg_list=%s)' % self.reg_list + + def jac(self, x): + """Calculate the derivative of the regularisation term for a given magnetic distribution. + + Parameters + ---------- + x: :class:`~numpy.ndarray` (N=1) + Vectorized magnetization distribution, for which the Jacobi vector is calculated. + + Returns + ------- + result : :class:`~numpy.ndarray` (N=1) + Jacobi vector which represents the cost derivative of all voxels of the magnetization. + + """ + return np.sum([self.reg_list[i].jac(x) for i in range(len(self.reg_list))], axis=0) + + def hess_dot(self, x, vector): + """Calculate the product of a `vector` with the Hessian matrix of the regularisation term. + + Parameters + ---------- + x : :class:`~numpy.ndarray` (N=1) + Vectorized magnetization distribution at which the Hessian is calculated. The Hessian + is constant in this case, thus `x` can be set to None (it is not used int the + computation). It is implemented for the case that in the future nonlinear problems + have to be solved. + vector : :class:`~numpy.ndarray` (N=1) + Vectorized magnetization distribution which is multiplied by the Hessian. + + Returns + ------- + result : :class:`~numpy.ndarray` (N=1) + Product of the input `vector` with the Hessian matrix. + + """ + return np.sum([self.reg_list[i].hess_dot(x, vector) for i in range(len(self.reg_list))], + axis=0) + + def hess_diag(self, x): + """ Return the diagonal of the Hessian. + + Parameters + ---------- + x : :class:`~numpy.ndarray` (N=1) + Vectorized magnetization distribution at which the Hessian is calculated. The Hessian + is constant in this case, thus `x` can be set to None (it is not used in the + computation). It is implemented for the case that in the future nonlinear problems + have to be solved. + + Returns + ------- + result : :class:`~numpy.ndarray` (N=1) + Diagonal of the Hessian matrix. + + """ + self._log.debug('Calling hess_diag') + return np.sum([self.reg_list[i].hess_diag(x) for i in range(len(self.reg_list))], axis=0) + + +class NoneRegularisator(Regularisator): + """Placeholder class if no regularization is used. + + This class is instantiated in the :class:`~pyramid.costfunction.Costfunction`, which means + no regularisation is used. All associated functions return appropriate zero-values. + + Attributes + ---------- + norm: None + No regularization is used, thus also no norm. + lam: 0 + Not used. + + """ + + _log = logging.getLogger(__name__ + '.NoneRegularisator') + + def __init__(self): + self._log.debug('Calling __init__') + self.norm = None + self.lam = 0 + self.add_params = None + super().__init__(norm=None, lam=None) + self._log.debug('Created ' + str(self)) + + def __call__(self, x): + self._log.debug('Calling __call__') + return 0 + + def jac(self, x): + """Calculate the derivative of the regularisation term for a given magnetic distribution. + + Parameters + ---------- + x: :class:`~numpy.ndarray` (N=1) + Vectorized magnetization distribution, for which the Jacobi vector is calculated. + + Returns + ------- + result : :class:`~numpy.ndarray` (N=1) + Jacobi vector which represents the cost derivative of all voxels of the magnetization. + + """ + return np.zeros_like(x) + + def hess_dot(self, x, vector): + """Calculate the product of a `vector` with the Hessian matrix of the regularisation term. + + Parameters + ---------- + x : :class:`~numpy.ndarray` (N=1) + Vectorized magnetization distribution at which the Hessian is calculated. The Hessian + is constant in this case, thus `x` can be set to None (it is not used in the + computation). It is implemented for the case that in the future nonlinear problems + have to be solved. + vector : a :class:`~numpy.ndarray` (N=1) + Vectorized magnetization distribution which is multiplied by the Hessian. + + Returns + ------- + result : :class:`~numpy.ndarray` (N=1) + Product of the input `vector` with the Hessian matrix of the costfunction. + + """ + return np.zeros_like(vector) + + def hess_diag(self, x): + """ Return the diagonal of the Hessian. + + Parameters + ---------- + x : :class:`~numpy.ndarray` (N=1) + Vectorized magnetization distribution at which the Hessian is calculated. The Hessian + is constant in this case, thus `x` can be set to None (it is not used int the + computation). It is implemented for the case that in the future nonlinear problems + have to be solved. + + Returns + ------- + result : :class:`~numpy.ndarray` (N=1) + Diagonal of the Hessian matrix. + + """ + self._log.debug('Calling hess_diag') + return np.zeros_like(x) + + +class ZeroOrderRegularisator(Regularisator): + """Class for providing a regularisation term which implements Lp norm minimization. + + The constraint this class represents is the minimization of the Lp norm for the 3D + magnetization distribution. Important is the regularisation parameter `lam` (lambda) which + determines the weighting between the two cost parts (measurements and regularisation). + + Attributes + ---------- + lam: float + Regularisation parameter determining the weighting between measurements and regularisation. + p: int, optional + Order of the norm (default: 2, which means a standard L2-norm). + add_params : int + Number of additional parameters which are not used in the regularisation. Used to cut + the input vector into the appropriate size. + + """ + + _log = logging.getLogger(__name__ + '.ZeroOrderRegularisator') + + def __init__(self, _=None, lam=1E-4, p=2, add_params=0): + self._log.debug('Calling __init__') + self.p = p + if p == 2: + norm = jnorm.L2Square() + else: + norm = jnorm.LPPow(p, 1e-12) + super().__init__(norm, lam, add_params) + self._log.debug('Created ' + str(self)) + + +class FirstOrderRegularisator(Regularisator): + """Class for providing a regularisation term which implements derivation minimization. + + The constraint this class represents is the minimization of the first order derivative of the + 3D magnetization distribution using a Lp norm. Important is the regularisation parameter `lam` + (lambda) which determines the weighting between the two cost parts (measurements and + regularisation). + + Attributes + ---------- + mask: :class:`~numpy.ndarray` (N=3) + A boolean mask which defines the magnetized volume in 3D. + lam: float + Regularisation parameter determining the weighting between measurements and regularisation. + p: int, optional + Order of the norm (default: 2, which means a standard L2-norm). + add_params : int + Number of additional parameters which are not used in the regularisation. Used to cut + the input vector into the appropriate size. + + """ + + def __init__(self, mask, lam=1E-4, p=2, add_params=0): + self.p = p + D0 = jdiff.get_diff_operator(mask, 0, 3) + D1 = jdiff.get_diff_operator(mask, 1, 3) + D2 = jdiff.get_diff_operator(mask, 2, 3) + D = sparse.vstack([D0, D1, D2]) + if p == 2: + norm = jnorm.WeightedL2Square(D) + else: + norm = jnorm.WeightedTV(jnorm.LPPow(p, 1e-12), D, [D0.shape[0], D.shape[0]]) + super().__init__(norm, lam, add_params) + self._log.debug('Created ' + str(self)) diff --git a/pyramid/tests/test_analytic.py b/pyramid/tests/test_analytic.py index ae2f79fec525a9db481d97fcabba91b81b966124..c5b54859860605e5763065599b90831b3fd69906 100644 --- a/pyramid/tests/test_analytic.py +++ b/pyramid/tests/test_analytic.py @@ -1,61 +1,61 @@ -# -*- coding: utf-8 -*- -"""Testcase for the analytic module.""" - -import os -import unittest - -import numpy as np -from numpy import pi -from numpy.testing import assert_allclose - -import pyramid.analytic as an - - -class TestCaseAnalytic(unittest.TestCase): - """TestCase for the analytic module.""" - - path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'test_analytic/') - dim = (4, 4, 4) - a = 10.0 - phi = pi / 4 - center = (dim[0] / 2, dim[1] / 2, dim[2] / 2) - radius = dim[2] / 4 - - def test_phase_mag_slab(self): - """Test of the phase_mag_slab method.""" - width = (self.dim[0] / 2, self.dim[1] / 2, self.dim[2] / 2) - phase = an.phase_mag_slab(self.dim, self.a, self.phi, self.center, width).phase - reference = np.load(os.path.join(self.path, 'ref_phase_slab.npy')) - assert_allclose(phase, reference, atol=1E-10, - err_msg='Unexpected behavior in phase_mag_slab()') - - def test_phase_mag_disc(self): - """Test of the phase_mag_disc method.""" - radius = self.dim[2] / 4 - height = self.dim[2] / 2 - phase = an.phase_mag_disc(self.dim, self.a, self.phi, self.center, radius, height).phase - reference = np.load(os.path.join(self.path, 'ref_phase_disc.npy')) - assert_allclose(phase, reference, atol=1E-10, - err_msg='Unexpected behavior in phase_mag_disc()') - - def test_phase_mag_sphere(self): - """Test of the phase_mag_sphere method.""" - radius = self.dim[2] / 4 - phase = an.phase_mag_sphere(self.dim, self.a, self.phi, self.center, radius).phase - reference = np.load(os.path.join(self.path, 'ref_phase_sphere.npy')) - assert_allclose(phase, reference, atol=1E-10, - err_msg='Unexpected behavior in phase_mag_sphere()') - - def test_phase_mag_vortex(self): - """Test of the phase_mag_vortex method.""" - radius = self.dim[2] / 4 - height = self.dim[2] / 2 - phase = an.phase_mag_vortex(self.dim, self.a, self.center, radius, height).phase - reference = np.load(os.path.join(self.path, 'ref_phase_vort.npy')) - assert_allclose(phase, reference, atol=1E-10, - err_msg='Unexpected behavior in phase_mag_vortex()') - - -if __name__ == '__main__': - import nose - nose.run(defaultTest=__name__) +# -*- coding: utf-8 -*- +"""Testcase for the analytic module.""" + +import os +import unittest + +import numpy as np +from numpy import pi +from numpy.testing import assert_allclose + +import pyramid.analytic as an + + +class TestCaseAnalytic(unittest.TestCase): + """TestCase for the analytic module.""" + + path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'test_analytic/') + dim = (4, 4, 4) + a = 10.0 + phi = pi / 4 + center = (dim[0] / 2, dim[1] / 2, dim[2] / 2) + radius = dim[2] / 4 + + def test_phase_mag_slab(self): + """Test of the phase_mag_slab method.""" + width = (self.dim[0] / 2, self.dim[1] / 2, self.dim[2] / 2) + phase = an.phase_mag_slab(self.dim, self.a, self.phi, self.center, width).phase + reference = np.load(os.path.join(self.path, 'ref_phase_slab.npy')) + assert_allclose(phase, reference, atol=1E-10, + err_msg='Unexpected behavior in phase_mag_slab()') + + def test_phase_mag_disc(self): + """Test of the phase_mag_disc method.""" + radius = self.dim[2] / 4 + height = self.dim[2] / 2 + phase = an.phase_mag_disc(self.dim, self.a, self.phi, self.center, radius, height).phase + reference = np.load(os.path.join(self.path, 'ref_phase_disc.npy')) + assert_allclose(phase, reference, atol=1E-10, + err_msg='Unexpected behavior in phase_mag_disc()') + + def test_phase_mag_sphere(self): + """Test of the phase_mag_sphere method.""" + radius = self.dim[2] / 4 + phase = an.phase_mag_sphere(self.dim, self.a, self.phi, self.center, radius).phase + reference = np.load(os.path.join(self.path, 'ref_phase_sphere.npy')) + assert_allclose(phase, reference, atol=1E-10, + err_msg='Unexpected behavior in phase_mag_sphere()') + + def test_phase_mag_vortex(self): + """Test of the phase_mag_vortex method.""" + radius = self.dim[2] / 4 + height = self.dim[2] / 2 + phase = an.phase_mag_vortex(self.dim, self.a, self.center, radius, height).phase + reference = np.load(os.path.join(self.path, 'ref_phase_vort.npy')) + assert_allclose(phase, reference, atol=1E-10, + err_msg='Unexpected behavior in phase_mag_vortex()') + + +if __name__ == '__main__': + import nose + nose.run(defaultTest=__name__) diff --git a/pyramid/tests/test_costfunction.py b/pyramid/tests/test_costfunction.py index 9bbc82b501d116b6795dd767d2a719c338ec7c71..44d44f77a1e2d49bd50459d3529989ce8c6fa0fc 100644 --- a/pyramid/tests/test_costfunction.py +++ b/pyramid/tests/test_costfunction.py @@ -1,81 +1,81 @@ -# -*- coding: utf-8 -*- -"""Testcase for the costfunction module""" - -import os -import unittest - -import numpy as np -from numpy.testing import assert_allclose - -from pyramid.costfunction import Costfunction -from pyramid.dataset import DataSet -from pyramid.forwardmodel import ForwardModel -from pyramid.projector import SimpleProjector -from pyramid.regularisator import FirstOrderRegularisator -from pyramid import load_phasemap - - -class TestCaseCostfunction(unittest.TestCase): - def setUp(self): - self.path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'test_costfunction') - self.a = 10. - self.dim = (4, 5, 6) - self.mask = np.zeros(self.dim, dtype=bool) - self.mask[1:-1, 1:-1, 1:-1] = True - self.data = DataSet(self.a, self.dim, mask=self.mask) - self.projector = SimpleProjector(self.dim) - self.phasemap = load_phasemap(os.path.join(self.path, 'phasemap_ref.hdf5')) - self.data.append(self.phasemap, self.projector) - self.data.append(self.phasemap, self.projector) - self.reg = FirstOrderRegularisator(self.mask, lam=1E-4) - self.cost = Costfunction(ForwardModel(self.data), self.reg) - - def tearDown(self): - self.path = None - self.a = None - self.dim = None - self.mask = None - self.data = None - self.projector = None - self.phasemap = None - self.reg = None - self.cost = None - - def test_call(self): - assert_allclose(self.cost(np.ones(self.cost.n)), 0., atol=1E-7, - err_msg='Unexpected behaviour in __call__()!') - zero_vec_cost = np.load(os.path.join(self.path, 'zero_vec_cost.npy')) - assert_allclose(self.cost(np.zeros(self.cost.n)), zero_vec_cost, - err_msg='Unexpected behaviour in __call__()!') - - def test_jac(self): - assert_allclose(self.cost.jac(np.ones(self.cost.n)), np.zeros(self.cost.n), atol=1E-7, - err_msg='Unexpected behaviour in jac()!') - jac_vec_ref = np.load(os.path.join(self.path, 'jac_vec_ref.npy')) - assert_allclose(self.cost.jac(np.zeros(self.cost.n)), jac_vec_ref, atol=1E-7, - err_msg='Unexpected behaviour in jac()!') - jac = np.array([self.cost.jac(np.eye(self.cost.n)[:, i]) for i in range(self.cost.n)]).T - jac_ref = np.load(os.path.join(self.path, 'jac_ref.npy')) - assert_allclose(jac, jac_ref, atol=1E-7, - err_msg='Unexpected behaviour in jac()!') - - def test_hess_dot(self): - assert_allclose(self.cost.hess_dot(None, np.zeros(self.cost.n)), np.zeros(self.cost.n), - atol=1E-7, err_msg='Unexpected behaviour in jac()!') - hess_vec_ref = -np.load(os.path.join(self.path, 'jac_vec_ref.npy')) - assert_allclose(self.cost.hess_dot(None, np.ones(self.cost.n)), hess_vec_ref, atol=1E-7, - err_msg='Unexpected behaviour in jac()!') - hess = np.array([self.cost.hess_dot(None, np.eye(self.cost.n)[:, i]) - for i in range(self.cost.n)]).T - hess_ref = np.load(os.path.join(self.path, 'hess_ref.npy')) - assert_allclose(hess, hess_ref, atol=1E-7, - err_msg='Unexpected behaviour in hess_dot()!') - - def test_hess_diag(self): - assert_allclose(self.cost.hess_diag(None), np.ones(self.cost.n), - err_msg='Unexpected behaviour in hess_diag()!') - - -if __name__ == '__main__': - import nose - nose.run(defaultTest=__name__) +# -*- coding: utf-8 -*- +"""Testcase for the costfunction module""" + +import os +import unittest + +import numpy as np +from numpy.testing import assert_allclose + +from pyramid.costfunction import Costfunction +from pyramid.dataset import DataSet +from pyramid.forwardmodel import ForwardModel +from pyramid.projector import SimpleProjector +from pyramid.regularisator import FirstOrderRegularisator +from pyramid import load_phasemap + + +class TestCaseCostfunction(unittest.TestCase): + def setUp(self): + self.path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'test_costfunction') + self.a = 10. + self.dim = (4, 5, 6) + self.mask = np.zeros(self.dim, dtype=bool) + self.mask[1:-1, 1:-1, 1:-1] = True + self.data = DataSet(self.a, self.dim, mask=self.mask) + self.projector = SimpleProjector(self.dim) + self.phasemap = load_phasemap(os.path.join(self.path, 'phasemap_ref.hdf5')) + self.data.append(self.phasemap, self.projector) + self.data.append(self.phasemap, self.projector) + self.reg = FirstOrderRegularisator(self.mask, lam=1E-4) + self.cost = Costfunction(ForwardModel(self.data), self.reg) + + def tearDown(self): + self.path = None + self.a = None + self.dim = None + self.mask = None + self.data = None + self.projector = None + self.phasemap = None + self.reg = None + self.cost = None + + def test_call(self): + assert_allclose(self.cost(np.ones(self.cost.n)), 0., atol=1E-7, + err_msg='Unexpected behaviour in __call__()!') + zero_vec_cost = np.load(os.path.join(self.path, 'zero_vec_cost.npy')) + assert_allclose(self.cost(np.zeros(self.cost.n)), zero_vec_cost, + err_msg='Unexpected behaviour in __call__()!') + + def test_jac(self): + assert_allclose(self.cost.jac(np.ones(self.cost.n)), np.zeros(self.cost.n), atol=1E-7, + err_msg='Unexpected behaviour in jac()!') + jac_vec_ref = np.load(os.path.join(self.path, 'jac_vec_ref.npy')) + assert_allclose(self.cost.jac(np.zeros(self.cost.n)), jac_vec_ref, atol=1E-7, + err_msg='Unexpected behaviour in jac()!') + jac = np.array([self.cost.jac(np.eye(self.cost.n)[:, i]) for i in range(self.cost.n)]).T + jac_ref = np.load(os.path.join(self.path, 'jac_ref.npy')) + assert_allclose(jac, jac_ref, atol=1E-7, + err_msg='Unexpected behaviour in jac()!') + + def test_hess_dot(self): + assert_allclose(self.cost.hess_dot(None, np.zeros(self.cost.n)), np.zeros(self.cost.n), + atol=1E-7, err_msg='Unexpected behaviour in jac()!') + hess_vec_ref = -np.load(os.path.join(self.path, 'jac_vec_ref.npy')) + assert_allclose(self.cost.hess_dot(None, np.ones(self.cost.n)), hess_vec_ref, atol=1E-7, + err_msg='Unexpected behaviour in jac()!') + hess = np.array([self.cost.hess_dot(None, np.eye(self.cost.n)[:, i]) + for i in range(self.cost.n)]).T + hess_ref = np.load(os.path.join(self.path, 'hess_ref.npy')) + assert_allclose(hess, hess_ref, atol=1E-7, + err_msg='Unexpected behaviour in hess_dot()!') + + def test_hess_diag(self): + assert_allclose(self.cost.hess_diag(None), np.ones(self.cost.n), + err_msg='Unexpected behaviour in hess_diag()!') + + +if __name__ == '__main__': + import nose + nose.run(defaultTest=__name__) diff --git a/pyramid/tests/test_dataset.py b/pyramid/tests/test_dataset.py index fbeb8eb95c753aad8b4a697e57501436a5fcca03..cbc5e4716be7ffde8d0b08ecfa083b5456091357 100644 --- a/pyramid/tests/test_dataset.py +++ b/pyramid/tests/test_dataset.py @@ -1,95 +1,95 @@ -# -*- coding: utf-8 -*- -"""Testcase for the dataset module""" - -import os -import unittest - -import numpy as np -from numpy.testing import assert_allclose - -from pyramid.dataset import DataSet -from pyramid.fielddata import VectorData -from pyramid.phasemap import PhaseMap -from pyramid.projector import SimpleProjector - - -class TestCaseDataSet(unittest.TestCase): - def setUp(self): - self.path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'test_dataset') - self.a = 10. - self.dim = (4, 5, 6) - self.mask = np.zeros(self.dim, dtype=bool) - self.mask[1:-1, 1:-1, 1:-1] = True - self.data = DataSet(self.a, self.dim, mask=self.mask) - self.projector = SimpleProjector(self.dim) - self.phasemap = PhaseMap(self.a, np.ones(self.dim[1:3])) - - def tearDown(self): - self.path = None - self.a = None - self.dim = None - self.mask = None - self.data = None - self.projector = None - self.phasemap = None - - def test_append(self): - self.data.append(self.phasemap, self.projector) - assert self.data.phasemaps[0] == self.phasemap, 'Phase map not correctly assigned!' - assert self.data.projectors[0] == self.projector, 'Projector not correctly assigned!' - - def test_create_phasemaps(self): - self.data.append(PhaseMap(self.a, np.zeros(self.projector.dim_uv)), self.projector) - magdata = VectorData(self.a, np.ones((3,) + self.dim)) - phasemaps = self.data.create_phasemaps(magdata) - phase_vec = phasemaps[0].phase_vec - phase_vec_ref = np.load(os.path.join(self.path, 'phase_vec_ref.npy')) - assert_allclose(phase_vec, phase_vec_ref, atol=1E-6, - err_msg='Unexpected behaviour in create_phasemaps()!') - - def test_set_Se_inv_block_diag(self): - self.data.append(self.phasemap, self.projector) - self.data.append(self.phasemap, self.projector) - cov = np.diag(np.ones(np.prod(self.phasemap.dim_uv))) - self.data.set_Se_inv_block_diag([cov, cov]) - assert self.data.Se_inv.shape == (self.data.m, self.data.m), \ - 'Unexpected behaviour in set_Se_inv_block_diag()!' - assert self.data.Se_inv.diagonal().sum() == self.data.m, \ - 'Unexpected behaviour in set_Se_inv_block_diag()!' - - def test_set_Se_inv_diag_with_conf(self): - self.data.append(self.phasemap, self.projector) - self.data.append(self.phasemap, self.projector) - confidence = self.mask[1, ...] - self.data.set_Se_inv_diag_with_conf([confidence, confidence]) - assert self.data.Se_inv.shape == (self.data.m, self.data.m), \ - 'Unexpected behaviour in set_Se_inv_diag_with_masks()!' - assert self.data.Se_inv.diagonal().sum() == 2 * confidence.sum(), \ - 'Unexpected behaviour in set_Se_inv_diag_with_masks()!' - - def test_set_3d_mask(self): - projector_z = SimpleProjector(self.dim, axis='z') - projector_y = SimpleProjector(self.dim, axis='y') - projector_x = SimpleProjector(self.dim, axis='x') - mask_z = np.zeros(projector_z.dim_uv, dtype=bool) - mask_y = np.zeros(projector_y.dim_uv, dtype=bool) - mask_x = np.zeros(projector_x.dim_uv, dtype=bool) - mask_z[1:-1, 1:-1] = True - mask_y[1:-1, 1:-1] = True - mask_x[1:-1, 1:-1] = True - phasemap_z = PhaseMap(self.a, np.zeros(projector_z.dim_uv), mask_z) - phasemap_y = PhaseMap(self.a, np.zeros(projector_y.dim_uv), mask_y) - phasemap_x = PhaseMap(self.a, np.zeros(projector_x.dim_uv), mask_x) - self.data.append(phasemap_z, projector_z) - self.data.append(phasemap_y, projector_y) - self.data.append(phasemap_x, projector_x) - self.data.set_3d_mask() - mask_ref = np.zeros(self.dim, dtype=bool) - mask_ref[1:-1, 1:-1, 1:-1] = True - np.testing.assert_equal(self.data.mask, mask_ref, - err_msg='Unexpected behaviour in set_3d_mask') - - -if __name__ == '__main__': - import nose - nose.run(defaultTest=__name__) +# -*- coding: utf-8 -*- +"""Testcase for the dataset module""" + +import os +import unittest + +import numpy as np +from numpy.testing import assert_allclose + +from pyramid.dataset import DataSet +from pyramid.fielddata import VectorData +from pyramid.phasemap import PhaseMap +from pyramid.projector import SimpleProjector + + +class TestCaseDataSet(unittest.TestCase): + def setUp(self): + self.path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'test_dataset') + self.a = 10. + self.dim = (4, 5, 6) + self.mask = np.zeros(self.dim, dtype=bool) + self.mask[1:-1, 1:-1, 1:-1] = True + self.data = DataSet(self.a, self.dim, mask=self.mask) + self.projector = SimpleProjector(self.dim) + self.phasemap = PhaseMap(self.a, np.ones(self.dim[1:3])) + + def tearDown(self): + self.path = None + self.a = None + self.dim = None + self.mask = None + self.data = None + self.projector = None + self.phasemap = None + + def test_append(self): + self.data.append(self.phasemap, self.projector) + assert self.data.phasemaps[0] == self.phasemap, 'Phase map not correctly assigned!' + assert self.data.projectors[0] == self.projector, 'Projector not correctly assigned!' + + def test_create_phasemaps(self): + self.data.append(PhaseMap(self.a, np.zeros(self.projector.dim_uv)), self.projector) + magdata = VectorData(self.a, np.ones((3,) + self.dim)) + phasemaps = self.data.create_phasemaps(magdata) + phase_vec = phasemaps[0].phase_vec + phase_vec_ref = np.load(os.path.join(self.path, 'phase_vec_ref.npy')) + assert_allclose(phase_vec, phase_vec_ref, atol=1E-6, + err_msg='Unexpected behaviour in create_phasemaps()!') + + def test_set_Se_inv_block_diag(self): + self.data.append(self.phasemap, self.projector) + self.data.append(self.phasemap, self.projector) + cov = np.diag(np.ones(np.prod(self.phasemap.dim_uv))) + self.data.set_Se_inv_block_diag([cov, cov]) + assert self.data.Se_inv.shape == (self.data.m, self.data.m), \ + 'Unexpected behaviour in set_Se_inv_block_diag()!' + assert self.data.Se_inv.diagonal().sum() == self.data.m, \ + 'Unexpected behaviour in set_Se_inv_block_diag()!' + + def test_set_Se_inv_diag_with_conf(self): + self.data.append(self.phasemap, self.projector) + self.data.append(self.phasemap, self.projector) + confidence = self.mask[1, ...] + self.data.set_Se_inv_diag_with_conf([confidence, confidence]) + assert self.data.Se_inv.shape == (self.data.m, self.data.m), \ + 'Unexpected behaviour in set_Se_inv_diag_with_masks()!' + assert self.data.Se_inv.diagonal().sum() == 2 * confidence.sum(), \ + 'Unexpected behaviour in set_Se_inv_diag_with_masks()!' + + def test_set_3d_mask(self): + projector_z = SimpleProjector(self.dim, axis='z') + projector_y = SimpleProjector(self.dim, axis='y') + projector_x = SimpleProjector(self.dim, axis='x') + mask_z = np.zeros(projector_z.dim_uv, dtype=bool) + mask_y = np.zeros(projector_y.dim_uv, dtype=bool) + mask_x = np.zeros(projector_x.dim_uv, dtype=bool) + mask_z[1:-1, 1:-1] = True + mask_y[1:-1, 1:-1] = True + mask_x[1:-1, 1:-1] = True + phasemap_z = PhaseMap(self.a, np.zeros(projector_z.dim_uv), mask_z) + phasemap_y = PhaseMap(self.a, np.zeros(projector_y.dim_uv), mask_y) + phasemap_x = PhaseMap(self.a, np.zeros(projector_x.dim_uv), mask_x) + self.data.append(phasemap_z, projector_z) + self.data.append(phasemap_y, projector_y) + self.data.append(phasemap_x, projector_x) + self.data.set_3d_mask() + mask_ref = np.zeros(self.dim, dtype=bool) + mask_ref[1:-1, 1:-1, 1:-1] = True + np.testing.assert_equal(self.data.mask, mask_ref, + err_msg='Unexpected behaviour in set_3d_mask') + + +if __name__ == '__main__': + import nose + nose.run(defaultTest=__name__) diff --git a/pyramid/tests/test_fielddata.py b/pyramid/tests/test_fielddata.py index 172e5da566a5e7a55b5785c6f8ac083cea0d0018..39cc04f77f713122377c71521930fa92e725fcee 100644 --- a/pyramid/tests/test_fielddata.py +++ b/pyramid/tests/test_fielddata.py @@ -1,123 +1,123 @@ -# -*- coding: utf-8 -*- -"""Testcase for the magdata module.""" - -import os -import unittest - -import numpy as np -from numpy.testing import assert_allclose - -from pyramid.fielddata import VectorData -from pyramid import load_vectordata - - -class TestCaseVectorData(unittest.TestCase): - def setUp(self): - self.path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'test_fielddata') - magnitude = np.zeros((3, 4, 4, 4)) - magnitude[:, 1:-1, 1:-1, 1:-1] = 1 - self.magdata = VectorData(10.0, magnitude) - - def tearDown(self): - self.path = None - self.magdata = None - - def test_copy(self): - magdata = self.magdata - magdata_copy = self.magdata.copy() - assert magdata == self.magdata, 'Unexpected behaviour in copy()!' - assert magdata_copy != self.magdata, 'Unexpected behaviour in copy()!' - - def test_scale_down(self): - self.magdata.scale_down() - reference = 1 / 8. * np.ones((3, 2, 2, 2)) - assert_allclose(self.magdata.field, reference, - err_msg='Unexpected behavior in scale_down()!') - assert_allclose(self.magdata.a, 20, - err_msg='Unexpected behavior in scale_down()!') - - def test_scale_up(self): - self.magdata.scale_up() - reference = np.zeros((3, 8, 8, 8)) - reference[:, 2:6, 2:6, 2:6] = 1 - assert_allclose(self.magdata.field, reference, - err_msg='Unexpected behavior in scale_down()!') - assert_allclose(self.magdata.a, 5, - err_msg='Unexpected behavior in scale_down()!') - - def test_pad(self): - reference = self.magdata.field.copy() - self.magdata.pad((1, 1, 1)) - reference = np.pad(reference, ((0, 0), (1, 1), (1, 1), (1, 1)), mode='constant') - assert_allclose(self.magdata.field, reference, - err_msg='Unexpected behavior in scale_down()!') - self.magdata.pad(((1, 1), (1, 1), (1, 1))) - reference = np.pad(reference, ((0, 0), (1, 1), (1, 1), (1, 1)), mode='constant') - assert_allclose(self.magdata.field, reference, - err_msg='Unexpected behavior in scale_down()!') - - def test_get_mask(self): - mask = self.magdata.get_mask() - reference = np.zeros((4, 4, 4)) - reference[1:-1, 1:-1, 1:-1] = True - assert_allclose(mask, reference, - err_msg='Unexpected behavior in get_mask()!') - - def test_get_vector(self): - mask = self.magdata.get_mask() - vector = self.magdata.get_vector(mask) - reference = np.ones(np.sum(mask) * 3) - assert_allclose(vector, reference, - err_msg='Unexpected behavior in get_vector()!') - - def test_set_vector(self): - mask = self.magdata.get_mask() - vector = 2 * np.ones(np.sum(mask) * 3) - self.magdata.set_vector(vector, mask) - reference = np.zeros((3, 4, 4, 4)) - reference[:, 1:-1, 1:-1, 1:-1] = 2 - assert_allclose(self.magdata.field, reference, - err_msg='Unexpected behavior in set_vector()!') - - def test_flip(self): - magdata = load_vectordata(os.path.join(self.path, 'magdata_orig.hdf5')) - magdata_flipx = load_vectordata(os.path.join(self.path, 'magdata_flipx.hdf5')) - magdata_flipy = load_vectordata(os.path.join(self.path, 'magdata_flipy.hdf5')) - magdata_flipz = load_vectordata(os.path.join(self.path, 'magdata_flipz.hdf5')) - assert_allclose(magdata.flip('x').field, magdata_flipx.field, - err_msg='Unexpected behavior in flip()! (x)') - assert_allclose(magdata.flip('y').field, magdata_flipy.field, - err_msg='Unexpected behavior in flip()! (y)') - assert_allclose(magdata.flip('z').field, magdata_flipz.field, - err_msg='Unexpected behavior in flip()! (z)') - - def test_rot(self): - magdata = load_vectordata(os.path.join(self.path, 'magdata_orig.hdf5')) - magdata_rotx = load_vectordata(os.path.join(self.path, 'magdata_rotx.hdf5')) - magdata_roty = load_vectordata(os.path.join(self.path, 'magdata_roty.hdf5')) - magdata_rotz = load_vectordata(os.path.join(self.path, 'magdata_rotz.hdf5')) - assert_allclose(magdata.rot90('x').field, magdata_rotx.field, - err_msg='Unexpected behavior in rot()! (x)') - assert_allclose(magdata.rot90('y').field, magdata_roty.field, - err_msg='Unexpected behavior in rot()! (y)') - assert_allclose(magdata.rot90('z').field, magdata_rotz.field, - err_msg='Unexpected behavior in rot()! (z)') - - def test_load_from_llg(self): - magdata = load_vectordata(os.path.join(self.path, 'magdata_ref_load.txt')) - assert_allclose(magdata.field, self.magdata.field, - err_msg='Unexpected behavior in load_from_llg()!') - assert_allclose(magdata.a, self.magdata.a, - err_msg='Unexpected behavior in load_from_llg()!') - - def test_load_from_hdf5(self): - magdata = load_vectordata(os.path.join(self.path, 'magdata_ref_load.hdf5')) - assert_allclose(magdata.field, self.magdata.field, - err_msg='Unexpected behavior in load_from_hdf5()!') - assert_allclose(magdata.a, self.magdata.a, - err_msg='Unexpected behavior in load_from_hdf5()!') - - -if __name__ == '__main__': - import nose - nose.run(defaultTest=__name__) +# -*- coding: utf-8 -*- +"""Testcase for the magdata module.""" + +import os +import unittest + +import numpy as np +from numpy.testing import assert_allclose + +from pyramid.fielddata import VectorData +from pyramid import load_vectordata + + +class TestCaseVectorData(unittest.TestCase): + def setUp(self): + self.path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'test_fielddata') + magnitude = np.zeros((3, 4, 4, 4)) + magnitude[:, 1:-1, 1:-1, 1:-1] = 1 + self.magdata = VectorData(10.0, magnitude) + + def tearDown(self): + self.path = None + self.magdata = None + + def test_copy(self): + magdata = self.magdata + magdata_copy = self.magdata.copy() + assert magdata == self.magdata, 'Unexpected behaviour in copy()!' + assert magdata_copy != self.magdata, 'Unexpected behaviour in copy()!' + + def test_scale_down(self): + self.magdata.scale_down() + reference = 1 / 8. * np.ones((3, 2, 2, 2)) + assert_allclose(self.magdata.field, reference, + err_msg='Unexpected behavior in scale_down()!') + assert_allclose(self.magdata.a, 20, + err_msg='Unexpected behavior in scale_down()!') + + def test_scale_up(self): + self.magdata.scale_up() + reference = np.zeros((3, 8, 8, 8)) + reference[:, 2:6, 2:6, 2:6] = 1 + assert_allclose(self.magdata.field, reference, + err_msg='Unexpected behavior in scale_down()!') + assert_allclose(self.magdata.a, 5, + err_msg='Unexpected behavior in scale_down()!') + + def test_pad(self): + reference = self.magdata.field.copy() + self.magdata.pad((1, 1, 1)) + reference = np.pad(reference, ((0, 0), (1, 1), (1, 1), (1, 1)), mode='constant') + assert_allclose(self.magdata.field, reference, + err_msg='Unexpected behavior in scale_down()!') + self.magdata.pad(((1, 1), (1, 1), (1, 1))) + reference = np.pad(reference, ((0, 0), (1, 1), (1, 1), (1, 1)), mode='constant') + assert_allclose(self.magdata.field, reference, + err_msg='Unexpected behavior in scale_down()!') + + def test_get_mask(self): + mask = self.magdata.get_mask() + reference = np.zeros((4, 4, 4)) + reference[1:-1, 1:-1, 1:-1] = True + assert_allclose(mask, reference, + err_msg='Unexpected behavior in get_mask()!') + + def test_get_vector(self): + mask = self.magdata.get_mask() + vector = self.magdata.get_vector(mask) + reference = np.ones(np.sum(mask) * 3) + assert_allclose(vector, reference, + err_msg='Unexpected behavior in get_vector()!') + + def test_set_vector(self): + mask = self.magdata.get_mask() + vector = 2 * np.ones(np.sum(mask) * 3) + self.magdata.set_vector(vector, mask) + reference = np.zeros((3, 4, 4, 4)) + reference[:, 1:-1, 1:-1, 1:-1] = 2 + assert_allclose(self.magdata.field, reference, + err_msg='Unexpected behavior in set_vector()!') + + def test_flip(self): + magdata = load_vectordata(os.path.join(self.path, 'magdata_orig.hdf5')) + magdata_flipx = load_vectordata(os.path.join(self.path, 'magdata_flipx.hdf5')) + magdata_flipy = load_vectordata(os.path.join(self.path, 'magdata_flipy.hdf5')) + magdata_flipz = load_vectordata(os.path.join(self.path, 'magdata_flipz.hdf5')) + assert_allclose(magdata.flip('x').field, magdata_flipx.field, + err_msg='Unexpected behavior in flip()! (x)') + assert_allclose(magdata.flip('y').field, magdata_flipy.field, + err_msg='Unexpected behavior in flip()! (y)') + assert_allclose(magdata.flip('z').field, magdata_flipz.field, + err_msg='Unexpected behavior in flip()! (z)') + + def test_rot(self): + magdata = load_vectordata(os.path.join(self.path, 'magdata_orig.hdf5')) + magdata_rotx = load_vectordata(os.path.join(self.path, 'magdata_rotx.hdf5')) + magdata_roty = load_vectordata(os.path.join(self.path, 'magdata_roty.hdf5')) + magdata_rotz = load_vectordata(os.path.join(self.path, 'magdata_rotz.hdf5')) + assert_allclose(magdata.rot90('x').field, magdata_rotx.field, + err_msg='Unexpected behavior in rot()! (x)') + assert_allclose(magdata.rot90('y').field, magdata_roty.field, + err_msg='Unexpected behavior in rot()! (y)') + assert_allclose(magdata.rot90('z').field, magdata_rotz.field, + err_msg='Unexpected behavior in rot()! (z)') + + def test_load_from_llg(self): + magdata = load_vectordata(os.path.join(self.path, 'magdata_ref_load.txt')) + assert_allclose(magdata.field, self.magdata.field, + err_msg='Unexpected behavior in load_from_llg()!') + assert_allclose(magdata.a, self.magdata.a, + err_msg='Unexpected behavior in load_from_llg()!') + + def test_load_from_hdf5(self): + magdata = load_vectordata(os.path.join(self.path, 'magdata_ref_load.hdf5')) + assert_allclose(magdata.field, self.magdata.field, + err_msg='Unexpected behavior in load_from_hdf5()!') + assert_allclose(magdata.a, self.magdata.a, + err_msg='Unexpected behavior in load_from_hdf5()!') + + +if __name__ == '__main__': + import nose + nose.run(defaultTest=__name__) diff --git a/pyramid/tests/test_fielddata/magdata_ref_load.txt b/pyramid/tests/test_fielddata/magdata_ref_load.txt index 293c185a73fce0669baba800745f16fcc48d7a2e..2534662a3e443634de59702735544e8e0a0894a1 100644 --- a/pyramid/tests/test_fielddata/magdata_ref_load.txt +++ b/pyramid/tests/test_fielddata/magdata_ref_load.txt @@ -1,66 +1,66 @@ -LLGFileCreator: test_magdata/ref_mag_data - 4 4 4 -5.000000e-07 5.000000e-07 5.000000e-07 0.000000e+00 0.000000e+00 0.000000e+00 -1.500000e-06 5.000000e-07 5.000000e-07 0.000000e+00 0.000000e+00 0.000000e+00 -2.500000e-06 5.000000e-07 5.000000e-07 0.000000e+00 0.000000e+00 0.000000e+00 -3.500000e-06 5.000000e-07 5.000000e-07 0.000000e+00 0.000000e+00 0.000000e+00 -5.000000e-07 1.500000e-06 5.000000e-07 0.000000e+00 0.000000e+00 0.000000e+00 -1.500000e-06 1.500000e-06 5.000000e-07 0.000000e+00 0.000000e+00 0.000000e+00 -2.500000e-06 1.500000e-06 5.000000e-07 0.000000e+00 0.000000e+00 0.000000e+00 -3.500000e-06 1.500000e-06 5.000000e-07 0.000000e+00 0.000000e+00 0.000000e+00 -5.000000e-07 2.500000e-06 5.000000e-07 0.000000e+00 0.000000e+00 0.000000e+00 -1.500000e-06 2.500000e-06 5.000000e-07 0.000000e+00 0.000000e+00 0.000000e+00 -2.500000e-06 2.500000e-06 5.000000e-07 0.000000e+00 0.000000e+00 0.000000e+00 -3.500000e-06 2.500000e-06 5.000000e-07 0.000000e+00 0.000000e+00 0.000000e+00 -5.000000e-07 3.500000e-06 5.000000e-07 0.000000e+00 0.000000e+00 0.000000e+00 -1.500000e-06 3.500000e-06 5.000000e-07 0.000000e+00 0.000000e+00 0.000000e+00 -2.500000e-06 3.500000e-06 5.000000e-07 0.000000e+00 0.000000e+00 0.000000e+00 -3.500000e-06 3.500000e-06 5.000000e-07 0.000000e+00 0.000000e+00 0.000000e+00 -5.000000e-07 5.000000e-07 1.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 -1.500000e-06 5.000000e-07 1.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 -2.500000e-06 5.000000e-07 1.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 -3.500000e-06 5.000000e-07 1.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 -5.000000e-07 1.500000e-06 1.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 -1.500000e-06 1.500000e-06 1.500000e-06 1.000000e+00 1.000000e+00 1.000000e+00 -2.500000e-06 1.500000e-06 1.500000e-06 1.000000e+00 1.000000e+00 1.000000e+00 -3.500000e-06 1.500000e-06 1.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 -5.000000e-07 2.500000e-06 1.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 -1.500000e-06 2.500000e-06 1.500000e-06 1.000000e+00 1.000000e+00 1.000000e+00 -2.500000e-06 2.500000e-06 1.500000e-06 1.000000e+00 1.000000e+00 1.000000e+00 -3.500000e-06 2.500000e-06 1.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 -5.000000e-07 3.500000e-06 1.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 -1.500000e-06 3.500000e-06 1.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 -2.500000e-06 3.500000e-06 1.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 -3.500000e-06 3.500000e-06 1.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 -5.000000e-07 5.000000e-07 2.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 -1.500000e-06 5.000000e-07 2.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 -2.500000e-06 5.000000e-07 2.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 -3.500000e-06 5.000000e-07 2.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 -5.000000e-07 1.500000e-06 2.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 -1.500000e-06 1.500000e-06 2.500000e-06 1.000000e+00 1.000000e+00 1.000000e+00 -2.500000e-06 1.500000e-06 2.500000e-06 1.000000e+00 1.000000e+00 1.000000e+00 -3.500000e-06 1.500000e-06 2.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 -5.000000e-07 2.500000e-06 2.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 -1.500000e-06 2.500000e-06 2.500000e-06 1.000000e+00 1.000000e+00 1.000000e+00 -2.500000e-06 2.500000e-06 2.500000e-06 1.000000e+00 1.000000e+00 1.000000e+00 -3.500000e-06 2.500000e-06 2.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 -5.000000e-07 3.500000e-06 2.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 -1.500000e-06 3.500000e-06 2.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 -2.500000e-06 3.500000e-06 2.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 -3.500000e-06 3.500000e-06 2.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 -5.000000e-07 5.000000e-07 3.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 -1.500000e-06 5.000000e-07 3.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 -2.500000e-06 5.000000e-07 3.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 -3.500000e-06 5.000000e-07 3.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 -5.000000e-07 1.500000e-06 3.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 -1.500000e-06 1.500000e-06 3.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 -2.500000e-06 1.500000e-06 3.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 -3.500000e-06 1.500000e-06 3.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 -5.000000e-07 2.500000e-06 3.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 -1.500000e-06 2.500000e-06 3.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 -2.500000e-06 2.500000e-06 3.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 -3.500000e-06 2.500000e-06 3.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 -5.000000e-07 3.500000e-06 3.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 -1.500000e-06 3.500000e-06 3.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 -2.500000e-06 3.500000e-06 3.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 +LLGFileCreator: test_magdata/ref_mag_data + 4 4 4 +5.000000e-07 5.000000e-07 5.000000e-07 0.000000e+00 0.000000e+00 0.000000e+00 +1.500000e-06 5.000000e-07 5.000000e-07 0.000000e+00 0.000000e+00 0.000000e+00 +2.500000e-06 5.000000e-07 5.000000e-07 0.000000e+00 0.000000e+00 0.000000e+00 +3.500000e-06 5.000000e-07 5.000000e-07 0.000000e+00 0.000000e+00 0.000000e+00 +5.000000e-07 1.500000e-06 5.000000e-07 0.000000e+00 0.000000e+00 0.000000e+00 +1.500000e-06 1.500000e-06 5.000000e-07 0.000000e+00 0.000000e+00 0.000000e+00 +2.500000e-06 1.500000e-06 5.000000e-07 0.000000e+00 0.000000e+00 0.000000e+00 +3.500000e-06 1.500000e-06 5.000000e-07 0.000000e+00 0.000000e+00 0.000000e+00 +5.000000e-07 2.500000e-06 5.000000e-07 0.000000e+00 0.000000e+00 0.000000e+00 +1.500000e-06 2.500000e-06 5.000000e-07 0.000000e+00 0.000000e+00 0.000000e+00 +2.500000e-06 2.500000e-06 5.000000e-07 0.000000e+00 0.000000e+00 0.000000e+00 +3.500000e-06 2.500000e-06 5.000000e-07 0.000000e+00 0.000000e+00 0.000000e+00 +5.000000e-07 3.500000e-06 5.000000e-07 0.000000e+00 0.000000e+00 0.000000e+00 +1.500000e-06 3.500000e-06 5.000000e-07 0.000000e+00 0.000000e+00 0.000000e+00 +2.500000e-06 3.500000e-06 5.000000e-07 0.000000e+00 0.000000e+00 0.000000e+00 +3.500000e-06 3.500000e-06 5.000000e-07 0.000000e+00 0.000000e+00 0.000000e+00 +5.000000e-07 5.000000e-07 1.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 +1.500000e-06 5.000000e-07 1.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 +2.500000e-06 5.000000e-07 1.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 +3.500000e-06 5.000000e-07 1.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 +5.000000e-07 1.500000e-06 1.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 +1.500000e-06 1.500000e-06 1.500000e-06 1.000000e+00 1.000000e+00 1.000000e+00 +2.500000e-06 1.500000e-06 1.500000e-06 1.000000e+00 1.000000e+00 1.000000e+00 +3.500000e-06 1.500000e-06 1.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 +5.000000e-07 2.500000e-06 1.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 +1.500000e-06 2.500000e-06 1.500000e-06 1.000000e+00 1.000000e+00 1.000000e+00 +2.500000e-06 2.500000e-06 1.500000e-06 1.000000e+00 1.000000e+00 1.000000e+00 +3.500000e-06 2.500000e-06 1.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 +5.000000e-07 3.500000e-06 1.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 +1.500000e-06 3.500000e-06 1.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 +2.500000e-06 3.500000e-06 1.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 +3.500000e-06 3.500000e-06 1.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 +5.000000e-07 5.000000e-07 2.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 +1.500000e-06 5.000000e-07 2.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 +2.500000e-06 5.000000e-07 2.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 +3.500000e-06 5.000000e-07 2.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 +5.000000e-07 1.500000e-06 2.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 +1.500000e-06 1.500000e-06 2.500000e-06 1.000000e+00 1.000000e+00 1.000000e+00 +2.500000e-06 1.500000e-06 2.500000e-06 1.000000e+00 1.000000e+00 1.000000e+00 +3.500000e-06 1.500000e-06 2.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 +5.000000e-07 2.500000e-06 2.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 +1.500000e-06 2.500000e-06 2.500000e-06 1.000000e+00 1.000000e+00 1.000000e+00 +2.500000e-06 2.500000e-06 2.500000e-06 1.000000e+00 1.000000e+00 1.000000e+00 +3.500000e-06 2.500000e-06 2.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 +5.000000e-07 3.500000e-06 2.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 +1.500000e-06 3.500000e-06 2.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 +2.500000e-06 3.500000e-06 2.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 +3.500000e-06 3.500000e-06 2.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 +5.000000e-07 5.000000e-07 3.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 +1.500000e-06 5.000000e-07 3.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 +2.500000e-06 5.000000e-07 3.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 +3.500000e-06 5.000000e-07 3.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 +5.000000e-07 1.500000e-06 3.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 +1.500000e-06 1.500000e-06 3.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 +2.500000e-06 1.500000e-06 3.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 +3.500000e-06 1.500000e-06 3.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 +5.000000e-07 2.500000e-06 3.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 +1.500000e-06 2.500000e-06 3.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 +2.500000e-06 2.500000e-06 3.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 +3.500000e-06 2.500000e-06 3.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 +5.000000e-07 3.500000e-06 3.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 +1.500000e-06 3.500000e-06 3.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 +2.500000e-06 3.500000e-06 3.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 3.500000e-06 3.500000e-06 3.500000e-06 0.000000e+00 0.000000e+00 0.000000e+00 \ No newline at end of file diff --git a/pyramid/tests/test_forwardmodel.py b/pyramid/tests/test_forwardmodel.py index 401a36aa1ffe5204bab79c2cc4285f6bcdfae040..1ca4fc4e148dd12319600b1dde16a53afabd20ee 100644 --- a/pyramid/tests/test_forwardmodel.py +++ b/pyramid/tests/test_forwardmodel.py @@ -1,74 +1,74 @@ -# -*- coding: utf-8 -*- -"""Testcase for the forwardmodel module""" - -import os -import unittest - -import numpy as np -from numpy.testing import assert_allclose - -from pyramid.dataset import DataSet -from pyramid.forwardmodel import ForwardModel -from pyramid.projector import SimpleProjector -from pyramid import load_phasemap - - -class TestCaseForwardModel(unittest.TestCase): - def setUp(self): - self.path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'test_forwardmodel') - self.a = 10. - self.dim = (4, 5, 6) - self.mask = np.zeros(self.dim, dtype=bool) - self.mask[1:-1, 1:-1, 1:-1] = True - self.data = DataSet(self.a, self.dim, mask=self.mask) - self.projector = SimpleProjector(self.dim) - self.phasemap = load_phasemap(os.path.join(self.path, 'phasemap_ref.hdf5')) - self.data.append(self.phasemap, self.projector) - self.data.append(self.phasemap, self.projector) - self.fwd_model = ForwardModel(self.data) - - def tearDown(self): - self.path = None - self.a = None - self.dim = None - self.mask = None - self.data = None - self.projector = None - self.phasemap = None - self.fwdmodel = None - - def test_call(self): - n = self.fwd_model.n - result = self.fwd_model(np.ones(n)) - hp = self.data.hook_points - assert_allclose(result[hp[0]:hp[1]], self.phasemap.phase.ravel(), atol=1E-7, - err_msg='Unexpected behavior in __call__()!') - assert_allclose(result[hp[1]:hp[2]], self.phasemap.phase.ravel(), atol=1E-7, - err_msg='Unexpected behavior in __call__()!') - - def test_jac_dot(self): - n = self.fwd_model.n - vector = np.ones(n) - result = self.fwd_model(vector) - result_jac = self.fwd_model.jac_dot(None, vector) - assert_allclose(result, result_jac, atol=1E-7, - err_msg='Inconsistency between __call__() and jac_dot()!') - jac = np.array([self.fwd_model.jac_dot(None, np.eye(n)[:, i]) for i in range(n)]).T - hp = self.data.hook_points - assert_allclose(jac[hp[0]:hp[1], :], jac[hp[1]:hp[2], :], atol=1E-7, - err_msg='Unexpected behaviour in the the jacobi matrix!') - jac_ref = np.load(os.path.join(self.path, 'jac.npy')) - assert_allclose(jac, jac_ref, atol=1E-7, - err_msg='Unexpected behaviour in the the jacobi matrix!') - - def test_jac_T_dot(self): - m = self.fwd_model.m - jac_T = np.array([self.fwd_model.jac_T_dot(None, np.eye(m)[:, i]) for i in range(m)]).T - jac_T_ref = np.load(os.path.join(self.path, 'jac.npy')).T - assert_allclose(jac_T, jac_T_ref, atol=1E-7, - err_msg='Unexpected behaviour in the the transposed jacobi matrix!') - - -if __name__ == '__main__': - import nose - nose.run(defaultTest=__name__) +# -*- coding: utf-8 -*- +"""Testcase for the forwardmodel module""" + +import os +import unittest + +import numpy as np +from numpy.testing import assert_allclose + +from pyramid.dataset import DataSet +from pyramid.forwardmodel import ForwardModel +from pyramid.projector import SimpleProjector +from pyramid import load_phasemap + + +class TestCaseForwardModel(unittest.TestCase): + def setUp(self): + self.path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'test_forwardmodel') + self.a = 10. + self.dim = (4, 5, 6) + self.mask = np.zeros(self.dim, dtype=bool) + self.mask[1:-1, 1:-1, 1:-1] = True + self.data = DataSet(self.a, self.dim, mask=self.mask) + self.projector = SimpleProjector(self.dim) + self.phasemap = load_phasemap(os.path.join(self.path, 'phasemap_ref.hdf5')) + self.data.append(self.phasemap, self.projector) + self.data.append(self.phasemap, self.projector) + self.fwd_model = ForwardModel(self.data) + + def tearDown(self): + self.path = None + self.a = None + self.dim = None + self.mask = None + self.data = None + self.projector = None + self.phasemap = None + self.fwdmodel = None + + def test_call(self): + n = self.fwd_model.n + result = self.fwd_model(np.ones(n)) + hp = self.data.hook_points + assert_allclose(result[hp[0]:hp[1]], self.phasemap.phase.ravel(), atol=1E-7, + err_msg='Unexpected behavior in __call__()!') + assert_allclose(result[hp[1]:hp[2]], self.phasemap.phase.ravel(), atol=1E-7, + err_msg='Unexpected behavior in __call__()!') + + def test_jac_dot(self): + n = self.fwd_model.n + vector = np.ones(n) + result = self.fwd_model(vector) + result_jac = self.fwd_model.jac_dot(None, vector) + assert_allclose(result, result_jac, atol=1E-7, + err_msg='Inconsistency between __call__() and jac_dot()!') + jac = np.array([self.fwd_model.jac_dot(None, np.eye(n)[:, i]) for i in range(n)]).T + hp = self.data.hook_points + assert_allclose(jac[hp[0]:hp[1], :], jac[hp[1]:hp[2], :], atol=1E-7, + err_msg='Unexpected behaviour in the the jacobi matrix!') + jac_ref = np.load(os.path.join(self.path, 'jac.npy')) + assert_allclose(jac, jac_ref, atol=1E-7, + err_msg='Unexpected behaviour in the the jacobi matrix!') + + def test_jac_T_dot(self): + m = self.fwd_model.m + jac_T = np.array([self.fwd_model.jac_T_dot(None, np.eye(m)[:, i]) for i in range(m)]).T + jac_T_ref = np.load(os.path.join(self.path, 'jac.npy')).T + assert_allclose(jac_T, jac_T_ref, atol=1E-7, + err_msg='Unexpected behaviour in the the transposed jacobi matrix!') + + +if __name__ == '__main__': + import nose + nose.run(defaultTest=__name__) diff --git a/pyramid/tests/test_kernel.py b/pyramid/tests/test_kernel.py index 46d19c398ec18836a958a7ebedaa2fbd37533b08..c16e603f1314b3a6ea98eca95f34b616a51a6d14 100644 --- a/pyramid/tests/test_kernel.py +++ b/pyramid/tests/test_kernel.py @@ -1,37 +1,37 @@ -# -*- coding: utf-8 -*- -"""Testcase for the magdata module.""" - -import os -import unittest - -import numpy as np -from numpy.testing import assert_allclose - -from pyramid.kernel import Kernel - - -class TestCaseKernel(unittest.TestCase): - def setUp(self): - self.path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'test_kernel') - self.kernel = Kernel(1., dim_uv=(4, 4), b_0=1., geometry='disc') - - def tearDown(self): - self.path = None - self.kernel = None - - def test_kernel(self): - ref_u = np.load(os.path.join(self.path, 'ref_u.npy')) - ref_v = np.load(os.path.join(self.path, 'ref_v.npy')) - ref_u_fft = np.load(os.path.join(self.path, 'ref_u_fft.npy')) - ref_v_fft = np.load(os.path.join(self.path, 'ref_v_fft.npy')) - assert_allclose(self.kernel.u, ref_u, err_msg='Unexpected behavior in u') - assert_allclose(self.kernel.v, ref_v, err_msg='Unexpected behavior in v') - assert_allclose(self.kernel.u_fft, ref_u_fft, atol=1E-7, - err_msg='Unexpected behavior in u_fft') - assert_allclose(self.kernel.v_fft, ref_v_fft, atol=1E-7, - err_msg='Unexpected behavior in v_fft') - - -if __name__ == '__main__': - import nose - nose.run(defaultTest=__name__) +# -*- coding: utf-8 -*- +"""Testcase for the magdata module.""" + +import os +import unittest + +import numpy as np +from numpy.testing import assert_allclose + +from pyramid.kernel import Kernel + + +class TestCaseKernel(unittest.TestCase): + def setUp(self): + self.path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'test_kernel') + self.kernel = Kernel(1., dim_uv=(4, 4), b_0=1., geometry='disc') + + def tearDown(self): + self.path = None + self.kernel = None + + def test_kernel(self): + ref_u = np.load(os.path.join(self.path, 'ref_u.npy')) + ref_v = np.load(os.path.join(self.path, 'ref_v.npy')) + ref_u_fft = np.load(os.path.join(self.path, 'ref_u_fft.npy')) + ref_v_fft = np.load(os.path.join(self.path, 'ref_v_fft.npy')) + assert_allclose(self.kernel.u, ref_u, err_msg='Unexpected behavior in u') + assert_allclose(self.kernel.v, ref_v, err_msg='Unexpected behavior in v') + assert_allclose(self.kernel.u_fft, ref_u_fft, atol=1E-7, + err_msg='Unexpected behavior in u_fft') + assert_allclose(self.kernel.v_fft, ref_v_fft, atol=1E-7, + err_msg='Unexpected behavior in v_fft') + + +if __name__ == '__main__': + import nose + nose.run(defaultTest=__name__) diff --git a/pyramid/tests/test_magcreator.py b/pyramid/tests/test_magcreator.py index 68d4b3b62d3ed94449dc6174c88583e46fbf10d2..75686bd20ba9887bd23aee2358eb7ef88e466131 100644 --- a/pyramid/tests/test_magcreator.py +++ b/pyramid/tests/test_magcreator.py @@ -1,85 +1,85 @@ -# -*- coding: utf-8 -*- -"""Testcase for the magcreator module.""" - -import os -import unittest - -import numpy as np -from numpy import pi -from numpy.testing import assert_allclose - -import pyramid.magcreator as mc - - -class TestCaseMagCreator(unittest.TestCase): - path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'test_magcreator') - - def test_shape_slab(self): - test_slab = mc.shapes.slab((5, 6, 7), (2.5, 3.5, 4.5), (1, 3, 5)) - assert_allclose(test_slab, np.load(os.path.join(self.path, 'ref_slab.npy')), - err_msg='Created slab does not match expectation!') - - def test_shape_disc(self): - test_disc_z = mc.shapes.disc((5, 6, 7), (2.5, 3.5, 4.5), 2, 3, 'z') - test_disc_y = mc.shapes.disc((5, 6, 7), (2.5, 3.5, 4.5), 2, 3, 'y') - test_disc_x = mc.shapes.disc((5, 6, 7), (2.5, 3.5, 4.5), 2, 3, 'x') - assert_allclose(test_disc_z, np.load(os.path.join(self.path, 'ref_disc_z.npy')), - err_msg='Created disc in z-direction does not match expectation!') - assert_allclose(test_disc_y, np.load(os.path.join(self.path, 'ref_disc_y.npy')), - err_msg='Created disc in y-direction does not match expectation!') - assert_allclose(test_disc_x, np.load(os.path.join(self.path, 'ref_disc_x.npy')), - err_msg='Created disc in x-direction does not match expectation!') - - def test_shape_ellipse(self): - test_ellipse_z = mc.shapes.ellipse((7, 8, 9), (3.5, 4.5, 5.5), (3, 5), 1, 'z') - test_ellipse_y = mc.shapes.ellipse((7, 8, 9), (3.5, 4.5, 5.5), (3, 5), 1, 'y') - test_ellipse_x = mc.shapes.ellipse((7, 8, 9), (3.5, 4.5, 5.5), (3, 5), 1, 'x') - assert_allclose(test_ellipse_z, np.load(os.path.join(self.path, 'ref_ellipse_z.npy')), - err_msg='Created ellipse does not match expectation (z)!') - assert_allclose(test_ellipse_y, np.load(os.path.join(self.path, 'ref_ellipse_y.npy')), - err_msg='Created ellipse does not match expectation (y)!') - assert_allclose(test_ellipse_x, np.load(os.path.join(self.path, 'ref_ellipse_x.npy')), - err_msg='Created ellipse does not match expectation (x)!') - - def test_shape_sphere(self): - test_sphere = mc.shapes.sphere((5, 6, 7), (2.5, 3.5, 4.5), 2) - assert_allclose(test_sphere, np.load(os.path.join(self.path, 'ref_sphere.npy')), - err_msg='Created sphere does not match expectation!') - - def test_shape_ellipsoid(self): - test_ellipsoid = mc.shapes.ellipsoid((7, 8, 9), (3.5, 4.5, 4.5), (3, 5, 7)) - assert_allclose(test_ellipsoid, np.load(os.path.join(self.path, 'ref_ellipsoid.npy')), - err_msg='Created ellipsoid does not match expectation!') - - def test_shape_filament(self): - test_filament_z = mc.shapes.filament((5, 6, 7), (2, 3), 'z') - test_filament_y = mc.shapes.filament((5, 6, 7), (2, 3), 'y') - test_filament_x = mc.shapes.filament((5, 6, 7), (2, 3), 'x') - assert_allclose(test_filament_z, np.load(os.path.join(self.path, 'ref_fil_z.npy')), - err_msg='Created filament in z-direction does not match expectation!') - assert_allclose(test_filament_y, np.load(os.path.join(self.path, 'ref_fil_y.npy')), - err_msg='Created filament in y-direction does not match expectation!') - assert_allclose(test_filament_x, np.load(os.path.join(self.path, 'ref_fil_x.npy')), - err_msg='Created filament in x-direction does not match expectation!') - - def test_shape_pixel(self): - test_pixel = mc.shapes.pixel((5, 6, 7), (2, 3, 4)) - assert_allclose(test_pixel, np.load(os.path.join(self.path, 'ref_pixel.npy')), - err_msg='Created pixel does not match expectation!') - - def test_create_mag_dist_homog(self): - mag_shape = mc.shapes.disc((1, 10, 10), (0, 5, 5), 3, 1) - magnitude = mc.create_mag_dist_homog(mag_shape, pi / 4) - assert_allclose(magnitude, np.load(os.path.join(self.path, 'ref_mag_disc.npy')), - err_msg='Created homog. magnetic distribution does not match expectation') - - def test_create_mag_dist_vortex(self): - mag_shape = mc.shapes.disc((1, 10, 10), (0, 5, 5), 3, 1) - magnitude = mc.create_mag_dist_vortex(mag_shape) - assert_allclose(magnitude, np.load(os.path.join(self.path, 'ref_mag_vort.npy')), - err_msg='Created vortex magnetic distribution does not match expectation') - - -if __name__ == '__main__': - import nose - nose.run(defaultTest=__name__) +# -*- coding: utf-8 -*- +"""Testcase for the magcreator module.""" + +import os +import unittest + +import numpy as np +from numpy import pi +from numpy.testing import assert_allclose + +import pyramid.magcreator as mc + + +class TestCaseMagCreator(unittest.TestCase): + path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'test_magcreator') + + def test_shape_slab(self): + test_slab = mc.shapes.slab((5, 6, 7), (2.5, 3.5, 4.5), (1, 3, 5)) + assert_allclose(test_slab, np.load(os.path.join(self.path, 'ref_slab.npy')), + err_msg='Created slab does not match expectation!') + + def test_shape_disc(self): + test_disc_z = mc.shapes.disc((5, 6, 7), (2.5, 3.5, 4.5), 2, 3, 'z') + test_disc_y = mc.shapes.disc((5, 6, 7), (2.5, 3.5, 4.5), 2, 3, 'y') + test_disc_x = mc.shapes.disc((5, 6, 7), (2.5, 3.5, 4.5), 2, 3, 'x') + assert_allclose(test_disc_z, np.load(os.path.join(self.path, 'ref_disc_z.npy')), + err_msg='Created disc in z-direction does not match expectation!') + assert_allclose(test_disc_y, np.load(os.path.join(self.path, 'ref_disc_y.npy')), + err_msg='Created disc in y-direction does not match expectation!') + assert_allclose(test_disc_x, np.load(os.path.join(self.path, 'ref_disc_x.npy')), + err_msg='Created disc in x-direction does not match expectation!') + + def test_shape_ellipse(self): + test_ellipse_z = mc.shapes.ellipse((7, 8, 9), (3.5, 4.5, 5.5), (3, 5), 1, 'z') + test_ellipse_y = mc.shapes.ellipse((7, 8, 9), (3.5, 4.5, 5.5), (3, 5), 1, 'y') + test_ellipse_x = mc.shapes.ellipse((7, 8, 9), (3.5, 4.5, 5.5), (3, 5), 1, 'x') + assert_allclose(test_ellipse_z, np.load(os.path.join(self.path, 'ref_ellipse_z.npy')), + err_msg='Created ellipse does not match expectation (z)!') + assert_allclose(test_ellipse_y, np.load(os.path.join(self.path, 'ref_ellipse_y.npy')), + err_msg='Created ellipse does not match expectation (y)!') + assert_allclose(test_ellipse_x, np.load(os.path.join(self.path, 'ref_ellipse_x.npy')), + err_msg='Created ellipse does not match expectation (x)!') + + def test_shape_sphere(self): + test_sphere = mc.shapes.sphere((5, 6, 7), (2.5, 3.5, 4.5), 2) + assert_allclose(test_sphere, np.load(os.path.join(self.path, 'ref_sphere.npy')), + err_msg='Created sphere does not match expectation!') + + def test_shape_ellipsoid(self): + test_ellipsoid = mc.shapes.ellipsoid((7, 8, 9), (3.5, 4.5, 4.5), (3, 5, 7)) + assert_allclose(test_ellipsoid, np.load(os.path.join(self.path, 'ref_ellipsoid.npy')), + err_msg='Created ellipsoid does not match expectation!') + + def test_shape_filament(self): + test_filament_z = mc.shapes.filament((5, 6, 7), (2, 3), 'z') + test_filament_y = mc.shapes.filament((5, 6, 7), (2, 3), 'y') + test_filament_x = mc.shapes.filament((5, 6, 7), (2, 3), 'x') + assert_allclose(test_filament_z, np.load(os.path.join(self.path, 'ref_fil_z.npy')), + err_msg='Created filament in z-direction does not match expectation!') + assert_allclose(test_filament_y, np.load(os.path.join(self.path, 'ref_fil_y.npy')), + err_msg='Created filament in y-direction does not match expectation!') + assert_allclose(test_filament_x, np.load(os.path.join(self.path, 'ref_fil_x.npy')), + err_msg='Created filament in x-direction does not match expectation!') + + def test_shape_pixel(self): + test_pixel = mc.shapes.pixel((5, 6, 7), (2, 3, 4)) + assert_allclose(test_pixel, np.load(os.path.join(self.path, 'ref_pixel.npy')), + err_msg='Created pixel does not match expectation!') + + def test_create_mag_dist_homog(self): + mag_shape = mc.shapes.disc((1, 10, 10), (0, 5, 5), 3, 1) + magnitude = mc.create_mag_dist_homog(mag_shape, pi / 4) + assert_allclose(magnitude, np.load(os.path.join(self.path, 'ref_mag_disc.npy')), + err_msg='Created homog. magnetic distribution does not match expectation') + + def test_create_mag_dist_vortex(self): + mag_shape = mc.shapes.disc((1, 10, 10), (0, 5, 5), 3, 1) + magnitude = mc.create_mag_dist_vortex(mag_shape) + assert_allclose(magnitude, np.load(os.path.join(self.path, 'ref_mag_vort.npy')), + err_msg='Created vortex magnetic distribution does not match expectation') + + +if __name__ == '__main__': + import nose + nose.run(defaultTest=__name__) diff --git a/pyramid/tests/test_phasemap.py b/pyramid/tests/test_phasemap.py index 5bc585f2f64ff78f7577613e18da5a46e807fe4c..2af8a33373ee3c3b7114475408e3bbf9c6881e62 100644 --- a/pyramid/tests/test_phasemap.py +++ b/pyramid/tests/test_phasemap.py @@ -1,79 +1,79 @@ -# -*- coding: utf-8 -*- -"""Testcase for the phasemap module.""" - -import os -import unittest - -import numpy as np -from numpy.testing import assert_allclose - -from pyramid.phasemap import PhaseMap -from pyramid import load_phasemap - - -class TestCasePhaseMap(unittest.TestCase): - def setUp(self): - self.path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'test_phasemap') - phase = np.zeros((4, 4)) - phase[1:-1, 1:-1] = 1 - mask = phase.astype(dtype=np.bool) - confidence = np.ones((4, 4)) - self.phasemap = PhaseMap(10.0, phase, mask, confidence) - - def tearDown(self): - self.path = None - self.phasemap = None - - def test_copy(self): - phasemap = self.phasemap - phasemap_copy = self.phasemap.copy() - assert phasemap == self.phasemap, 'Unexpected behaviour in copy()!' - assert phasemap_copy != self.phasemap, 'Unexpected behaviour in copy()!' - - def test_scale_down(self): - self.phasemap.scale_down() - reference = 1 / 4. * np.ones((2, 2)) - assert_allclose(self.phasemap.phase, reference, - err_msg='Unexpected behavior in scale_down()!') - assert_allclose(self.phasemap.mask, np.zeros((2, 2), dtype=np.bool), - err_msg='Unexpected behavior in scale_down()!') - assert_allclose(self.phasemap.confidence, np.ones((2, 2)), - err_msg='Unexpected behavior in scale_down()!') - assert_allclose(self.phasemap.a, 20, - err_msg='Unexpected behavior in scale_down()!') - - def test_scale_up(self): - self.phasemap.scale_up() - reference = np.zeros((8, 8)) - reference[2:-2, 2:-2] = 1 - assert_allclose(self.phasemap.phase, reference, - err_msg='Unexpected behavior in scale_up()!') - assert_allclose(self.phasemap.mask, reference.astype(dtype=np.bool), - err_msg='Unexpected behavior in scale_up()!') - assert_allclose(self.phasemap.confidence, np.ones((8, 8)), - err_msg='Unexpected behavior in scale_up()!') - assert_allclose(self.phasemap.a, 5, - err_msg='Unexpected behavior in scale_up()!') - - def test_load_from_txt(self): - phasemap = load_phasemap(os.path.join(self.path, 'ref_phasemap.txt')) - assert_allclose(self.phasemap.phase, phasemap.phase, - err_msg='Unexpected behavior in load_from_txt()!') - assert_allclose(phasemap.a, self.phasemap.a, - err_msg='Unexpected behavior in load_from_txt()!') - - def test_load_from_hdf5(self): - phasemap = load_phasemap(os.path.join(self.path, 'ref_phasemap.hdf5')) - assert_allclose(self.phasemap.phase, phasemap.phase, - err_msg='Unexpected behavior in load_from_netcdf4()!') - assert_allclose(self.phasemap.mask, phasemap.mask, - err_msg='Unexpected behavior in load_from_netcdf4()!') - assert_allclose(self.phasemap.confidence, phasemap.confidence, - err_msg='Unexpected behavior in load_from_netcdf4()!') - assert_allclose(phasemap.a, self.phasemap.a, - err_msg='Unexpected behavior in load_from_netcdf4()!') - - -if __name__ == '__main__': - import nose - nose.run(defaultTest=__name__) +# -*- coding: utf-8 -*- +"""Testcase for the phasemap module.""" + +import os +import unittest + +import numpy as np +from numpy.testing import assert_allclose + +from pyramid.phasemap import PhaseMap +from pyramid import load_phasemap + + +class TestCasePhaseMap(unittest.TestCase): + def setUp(self): + self.path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'test_phasemap') + phase = np.zeros((4, 4)) + phase[1:-1, 1:-1] = 1 + mask = phase.astype(dtype=np.bool) + confidence = np.ones((4, 4)) + self.phasemap = PhaseMap(10.0, phase, mask, confidence) + + def tearDown(self): + self.path = None + self.phasemap = None + + def test_copy(self): + phasemap = self.phasemap + phasemap_copy = self.phasemap.copy() + assert phasemap == self.phasemap, 'Unexpected behaviour in copy()!' + assert phasemap_copy != self.phasemap, 'Unexpected behaviour in copy()!' + + def test_scale_down(self): + self.phasemap.scale_down() + reference = 1 / 4. * np.ones((2, 2)) + assert_allclose(self.phasemap.phase, reference, + err_msg='Unexpected behavior in scale_down()!') + assert_allclose(self.phasemap.mask, np.zeros((2, 2), dtype=np.bool), + err_msg='Unexpected behavior in scale_down()!') + assert_allclose(self.phasemap.confidence, np.ones((2, 2)), + err_msg='Unexpected behavior in scale_down()!') + assert_allclose(self.phasemap.a, 20, + err_msg='Unexpected behavior in scale_down()!') + + def test_scale_up(self): + self.phasemap.scale_up() + reference = np.zeros((8, 8)) + reference[2:-2, 2:-2] = 1 + assert_allclose(self.phasemap.phase, reference, + err_msg='Unexpected behavior in scale_up()!') + assert_allclose(self.phasemap.mask, reference.astype(dtype=np.bool), + err_msg='Unexpected behavior in scale_up()!') + assert_allclose(self.phasemap.confidence, np.ones((8, 8)), + err_msg='Unexpected behavior in scale_up()!') + assert_allclose(self.phasemap.a, 5, + err_msg='Unexpected behavior in scale_up()!') + + def test_load_from_txt(self): + phasemap = load_phasemap(os.path.join(self.path, 'ref_phasemap.txt')) + assert_allclose(self.phasemap.phase, phasemap.phase, + err_msg='Unexpected behavior in load_from_txt()!') + assert_allclose(phasemap.a, self.phasemap.a, + err_msg='Unexpected behavior in load_from_txt()!') + + def test_load_from_hdf5(self): + phasemap = load_phasemap(os.path.join(self.path, 'ref_phasemap.hdf5')) + assert_allclose(self.phasemap.phase, phasemap.phase, + err_msg='Unexpected behavior in load_from_netcdf4()!') + assert_allclose(self.phasemap.mask, phasemap.mask, + err_msg='Unexpected behavior in load_from_netcdf4()!') + assert_allclose(self.phasemap.confidence, phasemap.confidence, + err_msg='Unexpected behavior in load_from_netcdf4()!') + assert_allclose(phasemap.a, self.phasemap.a, + err_msg='Unexpected behavior in load_from_netcdf4()!') + + +if __name__ == '__main__': + import nose + nose.run(defaultTest=__name__) diff --git a/pyramid/tests/test_phasemap/ref_phasemap.txt b/pyramid/tests/test_phasemap/ref_phasemap.txt index 006147419362bac346528f8cf9594882ef3a341a..f8dbb9a4621fdac678f9a0d0ae14f773b2a378a2 100644 --- a/pyramid/tests/test_phasemap/ref_phasemap.txt +++ b/pyramid/tests/test_phasemap/ref_phasemap.txt @@ -1,6 +1,6 @@ -PYRAMID-PHASEMAP: ref_phase_map -grid spacing = 10.0 nm -0.000000e+00 0.000000e+00 0.000000e+00 0.000000e+00 -0.000000e+00 1.000000e+00 1.000000e+00 0.000000e+00 -0.000000e+00 1.000000e+00 1.000000e+00 0.000000e+00 -0.000000e+00 0.000000e+00 0.000000e+00 0.000000e+00 +PYRAMID-PHASEMAP: ref_phase_map +grid spacing = 10.0 nm +0.000000e+00 0.000000e+00 0.000000e+00 0.000000e+00 +0.000000e+00 1.000000e+00 1.000000e+00 0.000000e+00 +0.000000e+00 1.000000e+00 1.000000e+00 0.000000e+00 +0.000000e+00 0.000000e+00 0.000000e+00 0.000000e+00 diff --git a/pyramid/tests/test_phasemapper.py b/pyramid/tests/test_phasemapper.py index 69a4ffeef03ce56c9aa6c9b0395e0a7634073a8d..0f86250724c7aaf82f526a97bf441312faa5e3e9 100644 --- a/pyramid/tests/test_phasemapper.py +++ b/pyramid/tests/test_phasemapper.py @@ -1,182 +1,182 @@ -# -*- coding: utf-8 -*- -"""Testcase for the phasemapper module.""" - -import os -import unittest - -import numpy as np -from numpy.testing import assert_allclose - -from pyramid.kernel import Kernel -from pyramid.phasemapper import PhaseMapperRDFC, PhaseMapperFDFC, PhaseMapperMIP -from pyramid import load_phasemap, load_vectordata, load_scalardata - - -class TestCasePhaseMapperRDFC(unittest.TestCase): - def setUp(self): - self.path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'test_phasemapper') - self.mag_proj = load_vectordata(os.path.join(self.path, 'mag_proj.hdf5')) - self.mapper = PhaseMapperRDFC(Kernel(self.mag_proj.a, self.mag_proj.dim[1:])) - - def tearDown(self): - self.path = None - self.mag_proj = None - self.mapper = None - - def test_PhaseMapperRDFC_call(self): - phase_ref = load_phasemap(os.path.join(self.path, 'phasemap.hdf5')) - phasemap = self.mapper(self.mag_proj) - assert_allclose(phasemap.phase, phase_ref.phase, atol=1E-7, - err_msg='Unexpected behavior in __call__()!') - assert_allclose(phasemap.a, phase_ref.a, err_msg='Unexpected behavior in __call__()!') - - def test_PhaseMapperRDFC_jac_dot(self): - phase = self.mapper(self.mag_proj).phase - mag_proj_vec = self.mag_proj.field[:2, ...].ravel() - phase_jac = self.mapper.jac_dot(mag_proj_vec).reshape(self.mapper.kernel.dim_uv) - assert_allclose(phase, phase_jac, atol=1E-7, - err_msg='Inconsistency between __call__() and jac_dot()!') - n = self.mapper.n - jac = np.array([self.mapper.jac_dot(np.eye(n)[:, i]) for i in range(n)]).T - jac_ref = np.load(os.path.join(self.path, 'jac.npy')) - assert_allclose(jac, jac_ref, atol=1E-7, - err_msg='Unexpected behaviour in the the jacobi matrix!') - - def test_PhaseMapperRDFC_jac_T_dot(self): - m = self.mapper.m - jac_T = np.array([self.mapper.jac_T_dot(np.eye(m)[:, i]) for i in range(m)]).T - jac_T_ref = np.load(os.path.join(self.path, 'jac.npy')).T - assert_allclose(jac_T, jac_T_ref, atol=1E-7, - err_msg='Unexpected behaviour in the the transposed jacobi matrix!') - - -class TestCasePhaseMapperFDFCpad0(unittest.TestCase): - def setUp(self): - self.path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'test_phasemapper') - self.mag_proj = load_vectordata(os.path.join(self.path, 'mag_proj.hdf5')) - self.mapper = PhaseMapperFDFC(self.mag_proj.a, self.mag_proj.dim[1:], padding=0) - - def tearDown(self): - self.path = None - self.mag_proj = None - self.mapper = None - - def test_PhaseMapperFDFC_call(self): - phase_ref = load_phasemap(os.path.join(self.path, 'phasemap_fc.hdf5')) - phasemap = self.mapper(self.mag_proj) - assert_allclose(phasemap.phase, phase_ref.phase, atol=1E-7, - err_msg='Unexpected behavior in __call__()!') - assert_allclose(phasemap.a, phase_ref.a, err_msg='Unexpected behavior in __call__()!') - - def test_PhaseMapperFDFC_jac_dot(self): - phase = self.mapper(self.mag_proj).phase - mag_proj_vec = self.mag_proj.field[:2, ...].ravel() - phase_jac = self.mapper.jac_dot(mag_proj_vec).reshape(self.mapper.dim_uv) - assert_allclose(phase, phase_jac, atol=1E-7, - err_msg='Inconsistency between __call__() and jac_dot()!') - n = self.mapper.n - jac = np.array([self.mapper.jac_dot(np.eye(n)[:, i]) for i in range(n)]).T - jac_ref = np.load(os.path.join(self.path, 'jac_fc.npy')) - assert_allclose(jac, jac_ref, atol=1E-7, - err_msg='Unexpected behaviour in the the jacobi matrix!') - - def test_PhaseMapperFDFC_jac_T_dot(self): - self.assertRaises(NotImplementedError, self.mapper.jac_T_dot, None) - - -class TestCasePhaseMapperFDFCpad1(unittest.TestCase): - def setUp(self): - self.path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'test_phasemapper') - self.mag_proj = load_vectordata(os.path.join(self.path, 'mag_proj.hdf5')) - self.mapper = PhaseMapperFDFC(self.mag_proj.a, self.mag_proj.dim[1:], padding=1) - - def tearDown(self): - self.path = None - self.mag_proj = None - self.mapper = None - - def test_PhaseMapperFDFC_call(self): - phase_ref = load_phasemap(os.path.join(self.path, 'phasemap_fc_pad1.hdf5')) - phasemap = self.mapper(self.mag_proj) - assert_allclose(phasemap.phase, phase_ref.phase, atol=1E-7, - err_msg='Unexpected behavior in __call__()!') - assert_allclose(phasemap.a, phase_ref.a, err_msg='Unexpected behavior in __call__()!') - - def test_PhaseMapperFDFC_jac_dot(self): - phase = self.mapper(self.mag_proj).phase - mag_proj_vec = self.mag_proj.field[:2, ...].ravel() - phase_jac = self.mapper.jac_dot(mag_proj_vec).reshape(self.mapper.dim_uv) - assert_allclose(phase, phase_jac, atol=1E-7, - err_msg='Inconsistency between __call__() and jac_dot()!') - n = self.mapper.n - jac = np.array([self.mapper.jac_dot(np.eye(n)[:, i]) for i in range(n)]).T - jac_ref = np.load(os.path.join(self.path, 'jac_fc_pad1.npy')) - assert_allclose(jac, jac_ref, atol=1E-7, - err_msg='Unexpected behaviour in the the jacobi matrix!') - - def test_PhaseMapperFDFC_jac_T_dot(self): - self.assertRaises(NotImplementedError, self.mapper.jac_T_dot, None) - - -class TestCasePhaseMapperFDFCpad10(unittest.TestCase): - def setUp(self): - self.path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'test_phasemapper') - self.mag_proj = load_vectordata(os.path.join(self.path, 'mag_proj.hdf5')) - self.mapper = PhaseMapperFDFC(self.mag_proj.a, self.mag_proj.dim[1:], padding=200) - - def tearDown(self): - self.path = None - self.mag_proj = None - self.mapper = None - - def test_PhaseMapperFDFC_call(self): - phase_ref = load_phasemap(os.path.join(self.path, 'phasemap_fc_pad10.hdf5')) - phasemap = self.mapper(self.mag_proj) - assert_allclose(phasemap.phase, phase_ref.phase, atol=1E-7, - err_msg='Unexpected behavior in __call__()!') - assert_allclose(phasemap.a, phase_ref.a, err_msg='Unexpected behavior in __call__()!') - - def test_PhaseMapperFDFC_jac_dot(self): - phase = self.mapper(self.mag_proj).phase - mag_proj_vec = self.mag_proj.field[:2, ...].ravel() - phase_jac = self.mapper.jac_dot(mag_proj_vec).reshape(self.mapper.dim_uv) - assert_allclose(phase, phase_jac, atol=1E-7, - err_msg='Inconsistency between __call__() and jac_dot()!') - n = self.mapper.n - jac = np.array([self.mapper.jac_dot(np.eye(n)[:, i]) for i in range(n)]).T - jac_ref = np.load(os.path.join(self.path, 'jac_fc_pad10.npy')) - assert_allclose(jac, jac_ref, atol=1E-7, - err_msg='Unexpected behaviour in the the jacobi matrix!') - - def test_PhaseMapperFDFC_jac_T_dot(self): - self.assertRaises(NotImplementedError, self.mapper.jac_T_dot, None) - - -class TestCasePhaseMapperMIP(unittest.TestCase): - def setUp(self): - self.path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'test_phasemapper') - self.elec_proj = load_scalardata(os.path.join(self.path, 'elec_proj.hdf5')) - self.mapper = PhaseMapperMIP(self.elec_proj.a, self.elec_proj.dim[1:]) - - def tearDown(self): - self.path = None - self.elec_proj = None - self.mapper = None - - def test_call(self): - phase_ref = load_phasemap(os.path.join(self.path, 'phasemap_elec.hdf5')) - phasemap = self.mapper(self.elec_proj) - assert_allclose(phasemap.phase, phase_ref.phase, atol=1E-7, - err_msg='Unexpected behavior in __call__()!') - assert_allclose(phasemap.a, phase_ref.a, err_msg='Unexpected behavior in __call__()!') - - def test_jac_dot(self): - self.assertRaises(NotImplementedError, self.mapper.jac_dot, None) - - def test_jac_T_dot(self): - self.assertRaises(NotImplementedError, self.mapper.jac_T_dot, None) - - -if __name__ == '__main__': - import nose - nose.run(defaultTest=__name__) +# -*- coding: utf-8 -*- +"""Testcase for the phasemapper module.""" + +import os +import unittest + +import numpy as np +from numpy.testing import assert_allclose + +from pyramid.kernel import Kernel +from pyramid.phasemapper import PhaseMapperRDFC, PhaseMapperFDFC, PhaseMapperMIP +from pyramid import load_phasemap, load_vectordata, load_scalardata + + +class TestCasePhaseMapperRDFC(unittest.TestCase): + def setUp(self): + self.path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'test_phasemapper') + self.mag_proj = load_vectordata(os.path.join(self.path, 'mag_proj.hdf5')) + self.mapper = PhaseMapperRDFC(Kernel(self.mag_proj.a, self.mag_proj.dim[1:])) + + def tearDown(self): + self.path = None + self.mag_proj = None + self.mapper = None + + def test_PhaseMapperRDFC_call(self): + phase_ref = load_phasemap(os.path.join(self.path, 'phasemap.hdf5')) + phasemap = self.mapper(self.mag_proj) + assert_allclose(phasemap.phase, phase_ref.phase, atol=1E-7, + err_msg='Unexpected behavior in __call__()!') + assert_allclose(phasemap.a, phase_ref.a, err_msg='Unexpected behavior in __call__()!') + + def test_PhaseMapperRDFC_jac_dot(self): + phase = self.mapper(self.mag_proj).phase + mag_proj_vec = self.mag_proj.field[:2, ...].ravel() + phase_jac = self.mapper.jac_dot(mag_proj_vec).reshape(self.mapper.kernel.dim_uv) + assert_allclose(phase, phase_jac, atol=1E-7, + err_msg='Inconsistency between __call__() and jac_dot()!') + n = self.mapper.n + jac = np.array([self.mapper.jac_dot(np.eye(n)[:, i]) for i in range(n)]).T + jac_ref = np.load(os.path.join(self.path, 'jac.npy')) + assert_allclose(jac, jac_ref, atol=1E-7, + err_msg='Unexpected behaviour in the the jacobi matrix!') + + def test_PhaseMapperRDFC_jac_T_dot(self): + m = self.mapper.m + jac_T = np.array([self.mapper.jac_T_dot(np.eye(m)[:, i]) for i in range(m)]).T + jac_T_ref = np.load(os.path.join(self.path, 'jac.npy')).T + assert_allclose(jac_T, jac_T_ref, atol=1E-7, + err_msg='Unexpected behaviour in the the transposed jacobi matrix!') + + +class TestCasePhaseMapperFDFCpad0(unittest.TestCase): + def setUp(self): + self.path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'test_phasemapper') + self.mag_proj = load_vectordata(os.path.join(self.path, 'mag_proj.hdf5')) + self.mapper = PhaseMapperFDFC(self.mag_proj.a, self.mag_proj.dim[1:], padding=0) + + def tearDown(self): + self.path = None + self.mag_proj = None + self.mapper = None + + def test_PhaseMapperFDFC_call(self): + phase_ref = load_phasemap(os.path.join(self.path, 'phasemap_fc.hdf5')) + phasemap = self.mapper(self.mag_proj) + assert_allclose(phasemap.phase, phase_ref.phase, atol=1E-7, + err_msg='Unexpected behavior in __call__()!') + assert_allclose(phasemap.a, phase_ref.a, err_msg='Unexpected behavior in __call__()!') + + def test_PhaseMapperFDFC_jac_dot(self): + phase = self.mapper(self.mag_proj).phase + mag_proj_vec = self.mag_proj.field[:2, ...].ravel() + phase_jac = self.mapper.jac_dot(mag_proj_vec).reshape(self.mapper.dim_uv) + assert_allclose(phase, phase_jac, atol=1E-7, + err_msg='Inconsistency between __call__() and jac_dot()!') + n = self.mapper.n + jac = np.array([self.mapper.jac_dot(np.eye(n)[:, i]) for i in range(n)]).T + jac_ref = np.load(os.path.join(self.path, 'jac_fc.npy')) + assert_allclose(jac, jac_ref, atol=1E-7, + err_msg='Unexpected behaviour in the the jacobi matrix!') + + def test_PhaseMapperFDFC_jac_T_dot(self): + self.assertRaises(NotImplementedError, self.mapper.jac_T_dot, None) + + +class TestCasePhaseMapperFDFCpad1(unittest.TestCase): + def setUp(self): + self.path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'test_phasemapper') + self.mag_proj = load_vectordata(os.path.join(self.path, 'mag_proj.hdf5')) + self.mapper = PhaseMapperFDFC(self.mag_proj.a, self.mag_proj.dim[1:], padding=1) + + def tearDown(self): + self.path = None + self.mag_proj = None + self.mapper = None + + def test_PhaseMapperFDFC_call(self): + phase_ref = load_phasemap(os.path.join(self.path, 'phasemap_fc_pad1.hdf5')) + phasemap = self.mapper(self.mag_proj) + assert_allclose(phasemap.phase, phase_ref.phase, atol=1E-7, + err_msg='Unexpected behavior in __call__()!') + assert_allclose(phasemap.a, phase_ref.a, err_msg='Unexpected behavior in __call__()!') + + def test_PhaseMapperFDFC_jac_dot(self): + phase = self.mapper(self.mag_proj).phase + mag_proj_vec = self.mag_proj.field[:2, ...].ravel() + phase_jac = self.mapper.jac_dot(mag_proj_vec).reshape(self.mapper.dim_uv) + assert_allclose(phase, phase_jac, atol=1E-7, + err_msg='Inconsistency between __call__() and jac_dot()!') + n = self.mapper.n + jac = np.array([self.mapper.jac_dot(np.eye(n)[:, i]) for i in range(n)]).T + jac_ref = np.load(os.path.join(self.path, 'jac_fc_pad1.npy')) + assert_allclose(jac, jac_ref, atol=1E-7, + err_msg='Unexpected behaviour in the the jacobi matrix!') + + def test_PhaseMapperFDFC_jac_T_dot(self): + self.assertRaises(NotImplementedError, self.mapper.jac_T_dot, None) + + +class TestCasePhaseMapperFDFCpad10(unittest.TestCase): + def setUp(self): + self.path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'test_phasemapper') + self.mag_proj = load_vectordata(os.path.join(self.path, 'mag_proj.hdf5')) + self.mapper = PhaseMapperFDFC(self.mag_proj.a, self.mag_proj.dim[1:], padding=200) + + def tearDown(self): + self.path = None + self.mag_proj = None + self.mapper = None + + def test_PhaseMapperFDFC_call(self): + phase_ref = load_phasemap(os.path.join(self.path, 'phasemap_fc_pad10.hdf5')) + phasemap = self.mapper(self.mag_proj) + assert_allclose(phasemap.phase, phase_ref.phase, atol=1E-7, + err_msg='Unexpected behavior in __call__()!') + assert_allclose(phasemap.a, phase_ref.a, err_msg='Unexpected behavior in __call__()!') + + def test_PhaseMapperFDFC_jac_dot(self): + phase = self.mapper(self.mag_proj).phase + mag_proj_vec = self.mag_proj.field[:2, ...].ravel() + phase_jac = self.mapper.jac_dot(mag_proj_vec).reshape(self.mapper.dim_uv) + assert_allclose(phase, phase_jac, atol=1E-7, + err_msg='Inconsistency between __call__() and jac_dot()!') + n = self.mapper.n + jac = np.array([self.mapper.jac_dot(np.eye(n)[:, i]) for i in range(n)]).T + jac_ref = np.load(os.path.join(self.path, 'jac_fc_pad10.npy')) + assert_allclose(jac, jac_ref, atol=1E-7, + err_msg='Unexpected behaviour in the the jacobi matrix!') + + def test_PhaseMapperFDFC_jac_T_dot(self): + self.assertRaises(NotImplementedError, self.mapper.jac_T_dot, None) + + +class TestCasePhaseMapperMIP(unittest.TestCase): + def setUp(self): + self.path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'test_phasemapper') + self.elec_proj = load_scalardata(os.path.join(self.path, 'elec_proj.hdf5')) + self.mapper = PhaseMapperMIP(self.elec_proj.a, self.elec_proj.dim[1:]) + + def tearDown(self): + self.path = None + self.elec_proj = None + self.mapper = None + + def test_call(self): + phase_ref = load_phasemap(os.path.join(self.path, 'phasemap_elec.hdf5')) + phasemap = self.mapper(self.elec_proj) + assert_allclose(phasemap.phase, phase_ref.phase, atol=1E-7, + err_msg='Unexpected behavior in __call__()!') + assert_allclose(phasemap.a, phase_ref.a, err_msg='Unexpected behavior in __call__()!') + + def test_jac_dot(self): + self.assertRaises(NotImplementedError, self.mapper.jac_dot, None) + + def test_jac_T_dot(self): + self.assertRaises(NotImplementedError, self.mapper.jac_T_dot, None) + + +if __name__ == '__main__': + import nose + nose.run(defaultTest=__name__) diff --git a/pyramid/tests/test_pm.py b/pyramid/tests/test_pm.py index 2bb8ec2a271a177e9559744aa163c40becc81034..061d5c54230b48a23f81fa75e10820efeaeccadb 100644 --- a/pyramid/tests/test_pm.py +++ b/pyramid/tests/test_pm.py @@ -1,33 +1,33 @@ -# -*- coding: utf-8 -*- -"""Testcase for the pm function.""" - -import os -import unittest - -from numpy.testing import assert_allclose - -from pyramid.utils import pm -from pyramid import load_phasemap, load_vectordata - - -class TestCasePM(unittest.TestCase): - def setUp(self): - self.path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'test_phasemapper') - self.mag_proj = load_vectordata(os.path.join(self.path, 'mag_proj.hdf5')) - - def tearDown(self): - self.path = None - self.mag_proj = None - self.mapper = None - - def test_pm(self): - phase_ref = load_phasemap(os.path.join(self.path, 'phasemap.hdf5')) - phasemap = pm(self.mag_proj) - assert_allclose(phasemap.phase, phase_ref.phase, atol=1E-7, - err_msg='Unexpected behavior in pm()!') - assert_allclose(phasemap.a, phase_ref.a, err_msg='Unexpected behavior in pm()!') - - -if __name__ == '__main__': - import nose - nose.run(defaultTest=__name__) +# -*- coding: utf-8 -*- +"""Testcase for the pm function.""" + +import os +import unittest + +from numpy.testing import assert_allclose + +from pyramid.utils import pm +from pyramid import load_phasemap, load_vectordata + + +class TestCasePM(unittest.TestCase): + def setUp(self): + self.path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'test_phasemapper') + self.mag_proj = load_vectordata(os.path.join(self.path, 'mag_proj.hdf5')) + + def tearDown(self): + self.path = None + self.mag_proj = None + self.mapper = None + + def test_pm(self): + phase_ref = load_phasemap(os.path.join(self.path, 'phasemap.hdf5')) + phasemap = pm(self.mag_proj) + assert_allclose(phasemap.phase, phase_ref.phase, atol=1E-7, + err_msg='Unexpected behavior in pm()!') + assert_allclose(phasemap.a, phase_ref.a, err_msg='Unexpected behavior in pm()!') + + +if __name__ == '__main__': + import nose + nose.run(defaultTest=__name__) diff --git a/pyramid/tests/test_projector.py b/pyramid/tests/test_projector.py index c0dec264ad7a5f13dfd40a3a12045bb87a3edc6d..b87d447deeaecc832fc876cc767176a17c126ecd 100644 --- a/pyramid/tests/test_projector.py +++ b/pyramid/tests/test_projector.py @@ -1,262 +1,262 @@ -# -*- coding: utf-8 -*- -"""Testcase for the projector module.""" - -import os -import unittest - -import numpy as np -from numpy import pi -from numpy.testing import assert_allclose - -from pyramid.projector import XTiltProjector, YTiltProjector, SimpleProjector -from pyramid import load_vectordata - - -class TestCaseSimpleProjector(unittest.TestCase): - def setUp(self): - self.path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'test_projector') - self.magdata = load_vectordata(os.path.join(self.path, 'ref_magdata.hdf5')) - self.proj_z = SimpleProjector(self.magdata.dim, axis='z') - self.proj_y = SimpleProjector(self.magdata.dim, axis='y') - self.proj_x = SimpleProjector(self.magdata.dim, axis='x') - - def tearDown(self): - self.path = None - self.magdata = None - self.proj_z = None - self.proj_y = None - self.proj_x = None - - def test_SimpleProjector_call(self): - mag_proj_z = self.proj_z(self.magdata) - mag_proj_y = self.proj_y(self.magdata) - mag_proj_x = self.proj_x(self.magdata) - mag_proj_z_ref = load_vectordata(os.path.join(self.path, 'ref_mag_proj_z.hdf5')) - mag_proj_y_ref = load_vectordata(os.path.join(self.path, 'ref_mag_proj_y.hdf5')) - mag_proj_x_ref = load_vectordata(os.path.join(self.path, 'ref_mag_proj_x.hdf5')) - assert_allclose(mag_proj_z.field, mag_proj_z_ref.field, - err_msg='Unexpected behaviour in SimpleProjector (z-axis)') - assert_allclose(mag_proj_y.field, mag_proj_y_ref.field, - err_msg='Unexpected behaviour in SimpleProjector (y-axis)') - assert_allclose(mag_proj_x.field, mag_proj_x_ref.field, - err_msg='Unexpected behaviour in SimpleProjector (x-axis)') - - def test_SimpleProjector_jac_dot(self): - mag_vec = self.magdata.field_vec - mag_proj_z = self.proj_z.jac_dot(mag_vec).reshape((2,) + self.proj_z.dim_uv) - mag_proj_y = self.proj_y.jac_dot(mag_vec).reshape((2,) + self.proj_y.dim_uv) - mag_proj_x = self.proj_x.jac_dot(mag_vec).reshape((2,) + self.proj_x.dim_uv) - mag_proj_z_ref = load_vectordata(os.path.join(self.path, 'ref_mag_proj_z.hdf5')) - mag_proj_y_ref = load_vectordata(os.path.join(self.path, 'ref_mag_proj_y.hdf5')) - mag_proj_x_ref = load_vectordata(os.path.join(self.path, 'ref_mag_proj_x.hdf5')) - assert_allclose(mag_proj_z, mag_proj_z_ref.field[:2, 0, ...], - err_msg='Inconsistency between __call__() and jac_dot()! (z-axis)') - assert_allclose(mag_proj_y, mag_proj_y_ref.field[:2, 0, ...], - err_msg='Inconsistency between __call__() and jac_dot()! (y-axis)') - assert_allclose(mag_proj_x, mag_proj_x_ref.field[:2, 0, ...], - err_msg='Inconsistency between __call__() and jac_dot()! (x-axis)') - nz = self.proj_z.n - ny = self.proj_y.n - nx = self.proj_x.n - jac_z = np.array([self.proj_z.jac_dot(np.eye(nz)[:, i]) for i in range(nz)]).T - jac_y = np.array([self.proj_y.jac_dot(np.eye(ny)[:, i]) for i in range(ny)]).T - jac_x = np.array([self.proj_x.jac_dot(np.eye(nx)[:, i]) for i in range(nx)]).T - jac_z_ref = np.load(os.path.join(self.path, 'jac_z.npy')) - jac_y_ref = np.load(os.path.join(self.path, 'jac_y.npy')) - jac_x_ref = np.load(os.path.join(self.path, 'jac_x.npy')) - assert_allclose(jac_z, jac_z_ref, - err_msg='Unexpected behaviour in the the jacobi matrix! (z-axis)') - assert_allclose(jac_y, jac_y_ref, - err_msg='Unexpected behaviour in the the jacobi matrix! (y-axis)') - assert_allclose(jac_x, jac_x_ref, - err_msg='Unexpected behaviour in the the jacobi matrix! (x-axis)') - - def test_SimpleProjector_jac_T_dot(self): - mz = self.proj_z.m - my = self.proj_y.m - mx = self.proj_x.m - jac_T_z = np.array([self.proj_z.jac_T_dot(np.eye(mz)[:, i]) for i in range(mz)]).T - jac_T_y = np.array([self.proj_y.jac_T_dot(np.eye(my)[:, i]) for i in range(my)]).T - jac_T_x = np.array([self.proj_x.jac_T_dot(np.eye(mx)[:, i]) for i in range(mx)]).T - jac_T_z_ref = np.load(os.path.join(self.path, 'jac_z.npy')).T - jac_T_y_ref = np.load(os.path.join(self.path, 'jac_y.npy')).T - jac_T_x_ref = np.load(os.path.join(self.path, 'jac_x.npy')).T - assert_allclose(jac_T_z, jac_T_z_ref, - err_msg='Unexpected behaviour in the the transp. jacobi matrix! (z-axis)') - assert_allclose(jac_T_y, jac_T_y_ref, - err_msg='Unexpected behaviour in the the transp. jacobi matrix! (y-axis)') - assert_allclose(jac_T_x, jac_T_x_ref, - err_msg='Unexpected behaviour in the the transp. jacobi matrix! (x-axis)') - - -class TestCaseXTiltProjector(unittest.TestCase): - def setUp(self): - self.path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'test_projector') - self.magdata = load_vectordata(os.path.join(self.path, 'ref_magdata.hdf5')) - self.proj_00 = XTiltProjector(self.magdata.dim, tilt=0) - self.proj_45 = XTiltProjector(self.magdata.dim, tilt=pi / 4) - self.proj_90 = XTiltProjector(self.magdata.dim, tilt=pi / 2) - - def tearDown(self): - self.path = None - self.magdata = None - self.proj_00 = None - self.proj_45 = None - self.proj_90 = None - - def test_XTiltProjector_call(self): - mag_proj_00 = self.proj_00(self.magdata) - mag_proj_45 = self.proj_45(self.magdata) - mag_proj_90 = self.proj_90(self.magdata) - mag_proj_00_ref = load_vectordata( - os.path.join(self.path, 'ref_mag_proj_x00.hdf5')) - mag_proj_45_ref = load_vectordata( - os.path.join(self.path, 'ref_mag_proj_x45.hdf5')) - mag_proj_90_ref = load_vectordata( - os.path.join(self.path, 'ref_mag_proj_x90.hdf5')) - assert_allclose(mag_proj_00.field, mag_proj_00_ref.field, - err_msg='Unexpected behaviour in XTiltProjector (0°)') - assert_allclose(mag_proj_45.field, mag_proj_45_ref.field, - err_msg='Unexpected behaviour in XTiltProjector (45°)') - assert_allclose(mag_proj_90.field, mag_proj_90_ref.field, - err_msg='Unexpected behaviour in XTiltProjector (90°)') - - def test_XTiltProjector_jac_dot(self): - mag_vec = self.magdata.field_vec - mag_proj_00 = self.proj_00.jac_dot(mag_vec).reshape((2,) + self.proj_00.dim_uv) - mag_proj_45 = self.proj_45.jac_dot(mag_vec).reshape((2,) + self.proj_45.dim_uv) - mag_proj_90 = self.proj_90.jac_dot(mag_vec).reshape((2,) + self.proj_90.dim_uv) - mag_proj_00_ref = load_vectordata( - os.path.join(self.path, 'ref_mag_proj_x00.hdf5')) - mag_proj_45_ref = load_vectordata( - os.path.join(self.path, 'ref_mag_proj_x45.hdf5')) - mag_proj_90_ref = load_vectordata( - os.path.join(self.path, 'ref_mag_proj_x90.hdf5')) - assert_allclose(mag_proj_00, mag_proj_00_ref.field[:2, 0, ...], - err_msg='Inconsistency between __call__() and jac_dot()! (0°)') - assert_allclose(mag_proj_45, mag_proj_45_ref.field[:2, 0, ...], - err_msg='Inconsistency between __call__() and jac_dot()! (45°)') - assert_allclose(mag_proj_90, mag_proj_90_ref.field[:2, 0, ...], - err_msg='Inconsistency between __call__() and jac_dot()! (90°)') - n00 = self.proj_00.n - n45 = self.proj_45.n - n90 = self.proj_90.n - jac_00 = np.array([self.proj_00.jac_dot(np.eye(n00)[:, i]) for i in range(n00)]).T - jac_45 = np.array([self.proj_45.jac_dot(np.eye(n45)[:, i]) for i in range(n45)]).T - jac_90 = np.array([self.proj_90.jac_dot(np.eye(n90)[:, i]) for i in range(n90)]).T - jac_00_ref = np.load(os.path.join(self.path, 'jac_x00.npy')) - jac_45_ref = np.load(os.path.join(self.path, 'jac_x45.npy')) - jac_90_ref = np.load(os.path.join(self.path, 'jac_x90.npy')) - assert_allclose(jac_00, jac_00_ref, - err_msg='Unexpected behaviour in the the jacobi matrix! (0°)') - assert_allclose(jac_45, jac_45_ref, - err_msg='Unexpected behaviour in the the jacobi matrix! (45°)') - assert_allclose(jac_90, jac_90_ref, - err_msg='Unexpected behaviour in the the jacobi matrix! (90°)') - - def test_XTiltProjector_jac_T_dot(self): - m00 = self.proj_00.m - m45 = self.proj_45.m - m90 = self.proj_90.m - jac_T_00 = np.array([self.proj_00.jac_T_dot(np.eye(m00)[:, i]) for i in range(m00)]).T - jac_T_45 = np.array([self.proj_45.jac_T_dot(np.eye(m45)[:, i]) for i in range(m45)]).T - jac_T_90 = np.array([self.proj_90.jac_T_dot(np.eye(m90)[:, i]) for i in range(m90)]).T - jac_T_00_ref = np.load(os.path.join(self.path, 'jac_x00.npy')).T - jac_T_45_ref = np.load(os.path.join(self.path, 'jac_x45.npy')).T - jac_T_90_ref = np.load(os.path.join(self.path, 'jac_x90.npy')).T - assert_allclose(jac_T_00, jac_T_00_ref, - err_msg='Unexpected behaviour in the the transp. jacobi matrix! (0°)') - assert_allclose(jac_T_45, jac_T_45_ref, - err_msg='Unexpected behaviour in the the transp. jacobi matrix! (45°)') - assert_allclose(jac_T_90, jac_T_90_ref, - err_msg='Unexpected behaviour in the the transp. jacobi matrix! (90°)') - - -class TestCaseYTiltProjector(unittest.TestCase): - def setUp(self): - self.path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'test_projector') - self.magdata = load_vectordata(os.path.join(self.path, 'ref_magdata.hdf5')) - self.proj_00 = YTiltProjector(self.magdata.dim, tilt=0) - self.proj_45 = YTiltProjector(self.magdata.dim, tilt=pi / 4) - self.proj_90 = YTiltProjector(self.magdata.dim, tilt=pi / 2) - - def tearDown(self): - self.path = None - self.magdata = None - self.proj_00 = None - self.proj_45 = None - self.proj_90 = None - - def test_XTiltProjector_call(self): - mag_proj_00 = self.proj_00(self.magdata) - mag_proj_45 = self.proj_45(self.magdata) - mag_proj_90 = self.proj_90(self.magdata) - mag_proj_00_ref = load_vectordata( - os.path.join(self.path, 'ref_mag_proj_y00.hdf5')) - mag_proj_45_ref = load_vectordata( - os.path.join(self.path, 'ref_mag_proj_y45.hdf5')) - mag_proj_90_ref = load_vectordata( - os.path.join(self.path, 'ref_mag_proj_y90.hdf5')) - assert_allclose(mag_proj_00.field, mag_proj_00_ref.field, - err_msg='Unexpected behaviour in XTiltProjector (0°)') - assert_allclose(mag_proj_45.field, mag_proj_45_ref.field, - err_msg='Unexpected behaviour in XTiltProjector (45°)') - assert_allclose(mag_proj_90.field, mag_proj_90_ref.field, - err_msg='Unexpected behaviour in XTiltProjector (90°)') - - def test_XTiltProjector_jac_dot(self): - mag_vec = self.magdata.field_vec - mag_proj_00 = self.proj_00.jac_dot(mag_vec).reshape((2,) + self.proj_00.dim_uv) - mag_proj_45 = self.proj_45.jac_dot(mag_vec).reshape((2,) + self.proj_45.dim_uv) - mag_proj_90 = self.proj_90.jac_dot(mag_vec).reshape((2,) + self.proj_90.dim_uv) - mag_proj_00_ref = load_vectordata( - os.path.join(self.path, 'ref_mag_proj_y00.hdf5')) - mag_proj_45_ref = load_vectordata( - os.path.join(self.path, 'ref_mag_proj_y45.hdf5')) - mag_proj_90_ref = load_vectordata( - os.path.join(self.path, 'ref_mag_proj_y90.hdf5')) - assert_allclose(mag_proj_00, mag_proj_00_ref.field[:2, 0, ...], - err_msg='Inconsistency between __call__() and jac_dot()! (0°)') - assert_allclose(mag_proj_45, mag_proj_45_ref.field[:2, 0, ...], - err_msg='Inconsistency between __call__() and jac_dot()! (45°)') - assert_allclose(mag_proj_90, mag_proj_90_ref.field[:2, 0, ...], - err_msg='Inconsistency between __call__() and jac_dot()! (90°)') - n00 = self.proj_00.n - n45 = self.proj_45.n - n90 = self.proj_90.n - jac_00 = np.array([self.proj_00.jac_dot(np.eye(n00)[:, i]) for i in range(n00)]).T - jac_45 = np.array([self.proj_45.jac_dot(np.eye(n45)[:, i]) for i in range(n45)]).T - jac_90 = np.array([self.proj_90.jac_dot(np.eye(n90)[:, i]) for i in range(n90)]).T - jac_00_ref = np.load(os.path.join(self.path, 'jac_y00.npy')) - jac_45_ref = np.load(os.path.join(self.path, 'jac_y45.npy')) - jac_90_ref = np.load(os.path.join(self.path, 'jac_y90.npy')) - assert_allclose(jac_00, jac_00_ref, - err_msg='Unexpected behaviour in the the jacobi matrix! (0°)') - assert_allclose(jac_45, jac_45_ref, - err_msg='Unexpected behaviour in the the jacobi matrix! (45°)') - assert_allclose(jac_90, jac_90_ref, - err_msg='Unexpected behaviour in the the jacobi matrix! (90°)') - - def test_YTiltProjector_jac_T_dot(self): - m00 = self.proj_00.m - m45 = self.proj_45.m - m90 = self.proj_90.m - jac_T_00 = np.array([self.proj_00.jac_T_dot(np.eye(m00)[:, i]) for i in range(m00)]).T - jac_T_45 = np.array([self.proj_45.jac_T_dot(np.eye(m45)[:, i]) for i in range(m45)]).T - jac_T_90 = np.array([self.proj_90.jac_T_dot(np.eye(m90)[:, i]) for i in range(m90)]).T - jac_T_00_ref = np.load(os.path.join(self.path, 'jac_y00.npy')).T - jac_T_45_ref = np.load(os.path.join(self.path, 'jac_y45.npy')).T - jac_T_90_ref = np.load(os.path.join(self.path, 'jac_y90.npy')).T - assert_allclose(jac_T_00, jac_T_00_ref, - err_msg='Unexpected behaviour in the the transp. jacobi matrix! (0°)') - assert_allclose(jac_T_45, jac_T_45_ref, - err_msg='Unexpected behaviour in the the transp. jacobi matrix! (45°)') - assert_allclose(jac_T_90, jac_T_90_ref, - err_msg='Unexpected behaviour in the the transp. jacobi matrix! (90°)') - - -# TODO: Test RotTiltProjector!!! - -if __name__ == '__main__': - import nose - nose.run(defaultTest=__name__) +# -*- coding: utf-8 -*- +"""Testcase for the projector module.""" + +import os +import unittest + +import numpy as np +from numpy import pi +from numpy.testing import assert_allclose + +from pyramid.projector import XTiltProjector, YTiltProjector, SimpleProjector +from pyramid import load_vectordata + + +class TestCaseSimpleProjector(unittest.TestCase): + def setUp(self): + self.path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'test_projector') + self.magdata = load_vectordata(os.path.join(self.path, 'ref_magdata.hdf5')) + self.proj_z = SimpleProjector(self.magdata.dim, axis='z') + self.proj_y = SimpleProjector(self.magdata.dim, axis='y') + self.proj_x = SimpleProjector(self.magdata.dim, axis='x') + + def tearDown(self): + self.path = None + self.magdata = None + self.proj_z = None + self.proj_y = None + self.proj_x = None + + def test_SimpleProjector_call(self): + mag_proj_z = self.proj_z(self.magdata) + mag_proj_y = self.proj_y(self.magdata) + mag_proj_x = self.proj_x(self.magdata) + mag_proj_z_ref = load_vectordata(os.path.join(self.path, 'ref_mag_proj_z.hdf5')) + mag_proj_y_ref = load_vectordata(os.path.join(self.path, 'ref_mag_proj_y.hdf5')) + mag_proj_x_ref = load_vectordata(os.path.join(self.path, 'ref_mag_proj_x.hdf5')) + assert_allclose(mag_proj_z.field, mag_proj_z_ref.field, + err_msg='Unexpected behaviour in SimpleProjector (z-axis)') + assert_allclose(mag_proj_y.field, mag_proj_y_ref.field, + err_msg='Unexpected behaviour in SimpleProjector (y-axis)') + assert_allclose(mag_proj_x.field, mag_proj_x_ref.field, + err_msg='Unexpected behaviour in SimpleProjector (x-axis)') + + def test_SimpleProjector_jac_dot(self): + mag_vec = self.magdata.field_vec + mag_proj_z = self.proj_z.jac_dot(mag_vec).reshape((2,) + self.proj_z.dim_uv) + mag_proj_y = self.proj_y.jac_dot(mag_vec).reshape((2,) + self.proj_y.dim_uv) + mag_proj_x = self.proj_x.jac_dot(mag_vec).reshape((2,) + self.proj_x.dim_uv) + mag_proj_z_ref = load_vectordata(os.path.join(self.path, 'ref_mag_proj_z.hdf5')) + mag_proj_y_ref = load_vectordata(os.path.join(self.path, 'ref_mag_proj_y.hdf5')) + mag_proj_x_ref = load_vectordata(os.path.join(self.path, 'ref_mag_proj_x.hdf5')) + assert_allclose(mag_proj_z, mag_proj_z_ref.field[:2, 0, ...], + err_msg='Inconsistency between __call__() and jac_dot()! (z-axis)') + assert_allclose(mag_proj_y, mag_proj_y_ref.field[:2, 0, ...], + err_msg='Inconsistency between __call__() and jac_dot()! (y-axis)') + assert_allclose(mag_proj_x, mag_proj_x_ref.field[:2, 0, ...], + err_msg='Inconsistency between __call__() and jac_dot()! (x-axis)') + nz = self.proj_z.n + ny = self.proj_y.n + nx = self.proj_x.n + jac_z = np.array([self.proj_z.jac_dot(np.eye(nz)[:, i]) for i in range(nz)]).T + jac_y = np.array([self.proj_y.jac_dot(np.eye(ny)[:, i]) for i in range(ny)]).T + jac_x = np.array([self.proj_x.jac_dot(np.eye(nx)[:, i]) for i in range(nx)]).T + jac_z_ref = np.load(os.path.join(self.path, 'jac_z.npy')) + jac_y_ref = np.load(os.path.join(self.path, 'jac_y.npy')) + jac_x_ref = np.load(os.path.join(self.path, 'jac_x.npy')) + assert_allclose(jac_z, jac_z_ref, + err_msg='Unexpected behaviour in the the jacobi matrix! (z-axis)') + assert_allclose(jac_y, jac_y_ref, + err_msg='Unexpected behaviour in the the jacobi matrix! (y-axis)') + assert_allclose(jac_x, jac_x_ref, + err_msg='Unexpected behaviour in the the jacobi matrix! (x-axis)') + + def test_SimpleProjector_jac_T_dot(self): + mz = self.proj_z.m + my = self.proj_y.m + mx = self.proj_x.m + jac_T_z = np.array([self.proj_z.jac_T_dot(np.eye(mz)[:, i]) for i in range(mz)]).T + jac_T_y = np.array([self.proj_y.jac_T_dot(np.eye(my)[:, i]) for i in range(my)]).T + jac_T_x = np.array([self.proj_x.jac_T_dot(np.eye(mx)[:, i]) for i in range(mx)]).T + jac_T_z_ref = np.load(os.path.join(self.path, 'jac_z.npy')).T + jac_T_y_ref = np.load(os.path.join(self.path, 'jac_y.npy')).T + jac_T_x_ref = np.load(os.path.join(self.path, 'jac_x.npy')).T + assert_allclose(jac_T_z, jac_T_z_ref, + err_msg='Unexpected behaviour in the the transp. jacobi matrix! (z-axis)') + assert_allclose(jac_T_y, jac_T_y_ref, + err_msg='Unexpected behaviour in the the transp. jacobi matrix! (y-axis)') + assert_allclose(jac_T_x, jac_T_x_ref, + err_msg='Unexpected behaviour in the the transp. jacobi matrix! (x-axis)') + + +class TestCaseXTiltProjector(unittest.TestCase): + def setUp(self): + self.path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'test_projector') + self.magdata = load_vectordata(os.path.join(self.path, 'ref_magdata.hdf5')) + self.proj_00 = XTiltProjector(self.magdata.dim, tilt=0) + self.proj_45 = XTiltProjector(self.magdata.dim, tilt=pi / 4) + self.proj_90 = XTiltProjector(self.magdata.dim, tilt=pi / 2) + + def tearDown(self): + self.path = None + self.magdata = None + self.proj_00 = None + self.proj_45 = None + self.proj_90 = None + + def test_XTiltProjector_call(self): + mag_proj_00 = self.proj_00(self.magdata) + mag_proj_45 = self.proj_45(self.magdata) + mag_proj_90 = self.proj_90(self.magdata) + mag_proj_00_ref = load_vectordata( + os.path.join(self.path, 'ref_mag_proj_x00.hdf5')) + mag_proj_45_ref = load_vectordata( + os.path.join(self.path, 'ref_mag_proj_x45.hdf5')) + mag_proj_90_ref = load_vectordata( + os.path.join(self.path, 'ref_mag_proj_x90.hdf5')) + assert_allclose(mag_proj_00.field, mag_proj_00_ref.field, + err_msg='Unexpected behaviour in XTiltProjector (0°)') + assert_allclose(mag_proj_45.field, mag_proj_45_ref.field, + err_msg='Unexpected behaviour in XTiltProjector (45°)') + assert_allclose(mag_proj_90.field, mag_proj_90_ref.field, + err_msg='Unexpected behaviour in XTiltProjector (90°)') + + def test_XTiltProjector_jac_dot(self): + mag_vec = self.magdata.field_vec + mag_proj_00 = self.proj_00.jac_dot(mag_vec).reshape((2,) + self.proj_00.dim_uv) + mag_proj_45 = self.proj_45.jac_dot(mag_vec).reshape((2,) + self.proj_45.dim_uv) + mag_proj_90 = self.proj_90.jac_dot(mag_vec).reshape((2,) + self.proj_90.dim_uv) + mag_proj_00_ref = load_vectordata( + os.path.join(self.path, 'ref_mag_proj_x00.hdf5')) + mag_proj_45_ref = load_vectordata( + os.path.join(self.path, 'ref_mag_proj_x45.hdf5')) + mag_proj_90_ref = load_vectordata( + os.path.join(self.path, 'ref_mag_proj_x90.hdf5')) + assert_allclose(mag_proj_00, mag_proj_00_ref.field[:2, 0, ...], + err_msg='Inconsistency between __call__() and jac_dot()! (0°)') + assert_allclose(mag_proj_45, mag_proj_45_ref.field[:2, 0, ...], + err_msg='Inconsistency between __call__() and jac_dot()! (45°)') + assert_allclose(mag_proj_90, mag_proj_90_ref.field[:2, 0, ...], + err_msg='Inconsistency between __call__() and jac_dot()! (90°)') + n00 = self.proj_00.n + n45 = self.proj_45.n + n90 = self.proj_90.n + jac_00 = np.array([self.proj_00.jac_dot(np.eye(n00)[:, i]) for i in range(n00)]).T + jac_45 = np.array([self.proj_45.jac_dot(np.eye(n45)[:, i]) for i in range(n45)]).T + jac_90 = np.array([self.proj_90.jac_dot(np.eye(n90)[:, i]) for i in range(n90)]).T + jac_00_ref = np.load(os.path.join(self.path, 'jac_x00.npy')) + jac_45_ref = np.load(os.path.join(self.path, 'jac_x45.npy')) + jac_90_ref = np.load(os.path.join(self.path, 'jac_x90.npy')) + assert_allclose(jac_00, jac_00_ref, + err_msg='Unexpected behaviour in the the jacobi matrix! (0°)') + assert_allclose(jac_45, jac_45_ref, + err_msg='Unexpected behaviour in the the jacobi matrix! (45°)') + assert_allclose(jac_90, jac_90_ref, + err_msg='Unexpected behaviour in the the jacobi matrix! (90°)') + + def test_XTiltProjector_jac_T_dot(self): + m00 = self.proj_00.m + m45 = self.proj_45.m + m90 = self.proj_90.m + jac_T_00 = np.array([self.proj_00.jac_T_dot(np.eye(m00)[:, i]) for i in range(m00)]).T + jac_T_45 = np.array([self.proj_45.jac_T_dot(np.eye(m45)[:, i]) for i in range(m45)]).T + jac_T_90 = np.array([self.proj_90.jac_T_dot(np.eye(m90)[:, i]) for i in range(m90)]).T + jac_T_00_ref = np.load(os.path.join(self.path, 'jac_x00.npy')).T + jac_T_45_ref = np.load(os.path.join(self.path, 'jac_x45.npy')).T + jac_T_90_ref = np.load(os.path.join(self.path, 'jac_x90.npy')).T + assert_allclose(jac_T_00, jac_T_00_ref, + err_msg='Unexpected behaviour in the the transp. jacobi matrix! (0°)') + assert_allclose(jac_T_45, jac_T_45_ref, + err_msg='Unexpected behaviour in the the transp. jacobi matrix! (45°)') + assert_allclose(jac_T_90, jac_T_90_ref, + err_msg='Unexpected behaviour in the the transp. jacobi matrix! (90°)') + + +class TestCaseYTiltProjector(unittest.TestCase): + def setUp(self): + self.path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'test_projector') + self.magdata = load_vectordata(os.path.join(self.path, 'ref_magdata.hdf5')) + self.proj_00 = YTiltProjector(self.magdata.dim, tilt=0) + self.proj_45 = YTiltProjector(self.magdata.dim, tilt=pi / 4) + self.proj_90 = YTiltProjector(self.magdata.dim, tilt=pi / 2) + + def tearDown(self): + self.path = None + self.magdata = None + self.proj_00 = None + self.proj_45 = None + self.proj_90 = None + + def test_XTiltProjector_call(self): + mag_proj_00 = self.proj_00(self.magdata) + mag_proj_45 = self.proj_45(self.magdata) + mag_proj_90 = self.proj_90(self.magdata) + mag_proj_00_ref = load_vectordata( + os.path.join(self.path, 'ref_mag_proj_y00.hdf5')) + mag_proj_45_ref = load_vectordata( + os.path.join(self.path, 'ref_mag_proj_y45.hdf5')) + mag_proj_90_ref = load_vectordata( + os.path.join(self.path, 'ref_mag_proj_y90.hdf5')) + assert_allclose(mag_proj_00.field, mag_proj_00_ref.field, + err_msg='Unexpected behaviour in XTiltProjector (0°)') + assert_allclose(mag_proj_45.field, mag_proj_45_ref.field, + err_msg='Unexpected behaviour in XTiltProjector (45°)') + assert_allclose(mag_proj_90.field, mag_proj_90_ref.field, + err_msg='Unexpected behaviour in XTiltProjector (90°)') + + def test_XTiltProjector_jac_dot(self): + mag_vec = self.magdata.field_vec + mag_proj_00 = self.proj_00.jac_dot(mag_vec).reshape((2,) + self.proj_00.dim_uv) + mag_proj_45 = self.proj_45.jac_dot(mag_vec).reshape((2,) + self.proj_45.dim_uv) + mag_proj_90 = self.proj_90.jac_dot(mag_vec).reshape((2,) + self.proj_90.dim_uv) + mag_proj_00_ref = load_vectordata( + os.path.join(self.path, 'ref_mag_proj_y00.hdf5')) + mag_proj_45_ref = load_vectordata( + os.path.join(self.path, 'ref_mag_proj_y45.hdf5')) + mag_proj_90_ref = load_vectordata( + os.path.join(self.path, 'ref_mag_proj_y90.hdf5')) + assert_allclose(mag_proj_00, mag_proj_00_ref.field[:2, 0, ...], + err_msg='Inconsistency between __call__() and jac_dot()! (0°)') + assert_allclose(mag_proj_45, mag_proj_45_ref.field[:2, 0, ...], + err_msg='Inconsistency between __call__() and jac_dot()! (45°)') + assert_allclose(mag_proj_90, mag_proj_90_ref.field[:2, 0, ...], + err_msg='Inconsistency between __call__() and jac_dot()! (90°)') + n00 = self.proj_00.n + n45 = self.proj_45.n + n90 = self.proj_90.n + jac_00 = np.array([self.proj_00.jac_dot(np.eye(n00)[:, i]) for i in range(n00)]).T + jac_45 = np.array([self.proj_45.jac_dot(np.eye(n45)[:, i]) for i in range(n45)]).T + jac_90 = np.array([self.proj_90.jac_dot(np.eye(n90)[:, i]) for i in range(n90)]).T + jac_00_ref = np.load(os.path.join(self.path, 'jac_y00.npy')) + jac_45_ref = np.load(os.path.join(self.path, 'jac_y45.npy')) + jac_90_ref = np.load(os.path.join(self.path, 'jac_y90.npy')) + assert_allclose(jac_00, jac_00_ref, + err_msg='Unexpected behaviour in the the jacobi matrix! (0°)') + assert_allclose(jac_45, jac_45_ref, + err_msg='Unexpected behaviour in the the jacobi matrix! (45°)') + assert_allclose(jac_90, jac_90_ref, + err_msg='Unexpected behaviour in the the jacobi matrix! (90°)') + + def test_YTiltProjector_jac_T_dot(self): + m00 = self.proj_00.m + m45 = self.proj_45.m + m90 = self.proj_90.m + jac_T_00 = np.array([self.proj_00.jac_T_dot(np.eye(m00)[:, i]) for i in range(m00)]).T + jac_T_45 = np.array([self.proj_45.jac_T_dot(np.eye(m45)[:, i]) for i in range(m45)]).T + jac_T_90 = np.array([self.proj_90.jac_T_dot(np.eye(m90)[:, i]) for i in range(m90)]).T + jac_T_00_ref = np.load(os.path.join(self.path, 'jac_y00.npy')).T + jac_T_45_ref = np.load(os.path.join(self.path, 'jac_y45.npy')).T + jac_T_90_ref = np.load(os.path.join(self.path, 'jac_y90.npy')).T + assert_allclose(jac_T_00, jac_T_00_ref, + err_msg='Unexpected behaviour in the the transp. jacobi matrix! (0°)') + assert_allclose(jac_T_45, jac_T_45_ref, + err_msg='Unexpected behaviour in the the transp. jacobi matrix! (45°)') + assert_allclose(jac_T_90, jac_T_90_ref, + err_msg='Unexpected behaviour in the the transp. jacobi matrix! (90°)') + + +# TODO: Test RotTiltProjector!!! + +if __name__ == '__main__': + import nose + nose.run(defaultTest=__name__) diff --git a/pyramid/tests/test_regularisator.py b/pyramid/tests/test_regularisator.py index 443bbcbf91d7c8d79fae0e644c5f283edc81429f..3ed52ea50dd467f75ab2e0f1322299bca6fdbc22 100644 --- a/pyramid/tests/test_regularisator.py +++ b/pyramid/tests/test_regularisator.py @@ -1,135 +1,135 @@ -# -*- coding: utf-8 -*- -"""Testcase for the regularisator module""" - -import os -import unittest - -import numpy as np -from numpy.testing import assert_allclose - -from pyramid.regularisator import FirstOrderRegularisator -from pyramid.regularisator import NoneRegularisator -from pyramid.regularisator import ZeroOrderRegularisator - - -class TestCaseNoneRegularisator(unittest.TestCase): - def setUp(self): - self.n = 9 - self.reg = NoneRegularisator() - - def tearDown(self): - self.n = None - self.reg = None - - def test_call(self): - assert_allclose(self.reg(np.arange(self.n)), 0, - err_msg='Unexpected behaviour in __call__()!') - - def test_jac(self): - assert_allclose(self.reg.jac(np.arange(self.n)), np.zeros(self.n), - err_msg='Unexpected behaviour in jac()!') - - def test_hess_dot(self): - assert_allclose(self.reg.hess_dot(None, np.arange(self.n)), np.zeros(self.n), - err_msg='Unexpected behaviour in jac()!') - - def test_hess_diag(self): - assert_allclose(self.reg.hess_diag(np.arange(self.n)), np.zeros(self.n), - err_msg='Unexpected behaviour in hess_diag()!') - - -class TestCaseZeroOrderRegularisator(unittest.TestCase): - def setUp(self): - self.n = 9 - self.lam = 1 - self.reg = ZeroOrderRegularisator(lam=self.lam) - - def tearDown(self): - self.n = None - self.lam = None - self.reg = None - - def test_call(self): - assert_allclose(self.reg(np.arange(self.n)), np.sum(np.arange(self.n) ** 2), - err_msg='Unexpected behaviour in __call__()!') - - def test_jac(self): - assert_allclose(self.reg.jac(np.arange(self.n)), 2 * np.arange(self.n), - err_msg='Unexpected behaviour in jac()!') - jac = np.array([self.reg.jac(np.eye(self.n)[:, i]) for i in range(self.n)]).T - assert_allclose(jac, 2 * np.eye(self.n), err_msg='Unexpected behaviour in jac()!') - - def test_hess_dot(self): - assert_allclose(self.reg.hess_dot(None, np.arange(self.n)), 2 * np.arange(self.n), - err_msg='Unexpected behaviour in jac()!') - hess = np.array([self.reg.hess_dot(None, np.eye(self.n)[:, i]) for i in range(self.n)]).T - assert_allclose(hess, 2 * np.eye(self.n), err_msg='Unexpected behaviour in hess_dot()!') - - def test_hess_diag(self): - assert_allclose(self.reg.hess_diag(np.arange(self.n)), 2 * np.ones(self.n), - err_msg='Unexpected behaviour in hess_diag()!') - - -class TestCaseFirstOrderRegularisator(unittest.TestCase): - def setUp(self): - self.path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'test_regularisator') - self.dim = (4, 5, 6) - self.mask = np.zeros(self.dim, dtype=bool) - self.mask[1:-1, 1:-1, 1:-1] = True - self.n = 3 * self.mask.sum() - self.lam = 1. - self.reg = FirstOrderRegularisator(self.mask, lam=self.lam) - - def tearDown(self): - self.path = None - self.dim = None - self.mask = None - self.n = None - self.lam = None - self.reg = None - - def test_call(self): - assert_allclose(self.reg(np.ones(self.n)), 0., - err_msg='Unexpected behaviour in __call__()!') - assert_allclose(self.reg(np.zeros(self.n)), 0., - err_msg='Unexpected behaviour in __call__()!') - cost_ref = np.load(os.path.join(self.path, 'first_order_cost_ref.npy')) - assert_allclose(self.reg(np.arange(self.n)), cost_ref, - err_msg='Unexpected behaviour in __call__()!') - - def test_jac(self): - assert_allclose(self.reg.jac(np.ones(self.n)), np.zeros(self.n), - err_msg='Unexpected behaviour in jac()!') - assert_allclose(self.reg.jac(np.zeros(self.n)), np.zeros(self.n), - err_msg='Unexpected behaviour in jac()!') - jac_vec_ref = np.load(os.path.join(self.path, 'first_order_jac_vec_ref.npy')) - assert_allclose(self.reg.jac(np.arange(self.n)), jac_vec_ref, atol=1E-7, - err_msg='Unexpected behaviour in jac()!') - jac = np.array([self.reg.jac(np.eye(self.n)[:, i]) for i in range(self.n)]).T - jac_ref = np.load(os.path.join(self.path, 'first_order_jac_ref.npy')) - assert_allclose(jac, jac_ref, atol=1E-7, - err_msg='Unexpected behaviour in jac()!') - - def test_hess_dot(self): - assert_allclose(self.reg.hess_dot(None, np.ones(self.n)), np.zeros(self.n), - err_msg='Unexpected behaviour in hess_dot()!') - assert_allclose(self.reg.hess_dot(None, np.zeros(self.n)), np.zeros(self.n), - err_msg='Unexpected behaviour in hess_dot()!') - hess_vec_ref = np.load(os.path.join(self.path, 'first_order_jac_vec_ref.npy')) - assert_allclose(self.reg.hess_dot(None, np.arange(self.n)), hess_vec_ref, atol=1E-7, - err_msg='Unexpected behaviour in hess_dot()!') - hess = np.array([self.reg.hess_dot(None, np.eye(self.n)[:, i]) for i in range(self.n)]).T - hess_ref = np.load(os.path.join(self.path, 'first_order_jac_ref.npy')) - assert_allclose(hess, hess_ref, atol=1E-7, - err_msg='Unexpected behaviour in hess_dot()!') - - def test_hess_diag(self): - hess_diag = self.reg.hess_diag(np.ones(self.n)) - hess_diag_ref = np.diag(np.load(os.path.join(self.path, 'first_order_jac_ref.npy'))) - assert_allclose(hess_diag, hess_diag_ref, atol=1E-7, - err_msg='Unexpected behaviour in hess_diag()!') - - -if __name__ == '__main__': - import nose - nose.run(defaultTest=__name__) +# -*- coding: utf-8 -*- +"""Testcase for the regularisator module""" + +import os +import unittest + +import numpy as np +from numpy.testing import assert_allclose + +from pyramid.regularisator import FirstOrderRegularisator +from pyramid.regularisator import NoneRegularisator +from pyramid.regularisator import ZeroOrderRegularisator + + +class TestCaseNoneRegularisator(unittest.TestCase): + def setUp(self): + self.n = 9 + self.reg = NoneRegularisator() + + def tearDown(self): + self.n = None + self.reg = None + + def test_call(self): + assert_allclose(self.reg(np.arange(self.n)), 0, + err_msg='Unexpected behaviour in __call__()!') + + def test_jac(self): + assert_allclose(self.reg.jac(np.arange(self.n)), np.zeros(self.n), + err_msg='Unexpected behaviour in jac()!') + + def test_hess_dot(self): + assert_allclose(self.reg.hess_dot(None, np.arange(self.n)), np.zeros(self.n), + err_msg='Unexpected behaviour in jac()!') + + def test_hess_diag(self): + assert_allclose(self.reg.hess_diag(np.arange(self.n)), np.zeros(self.n), + err_msg='Unexpected behaviour in hess_diag()!') + + +class TestCaseZeroOrderRegularisator(unittest.TestCase): + def setUp(self): + self.n = 9 + self.lam = 1 + self.reg = ZeroOrderRegularisator(lam=self.lam) + + def tearDown(self): + self.n = None + self.lam = None + self.reg = None + + def test_call(self): + assert_allclose(self.reg(np.arange(self.n)), np.sum(np.arange(self.n) ** 2), + err_msg='Unexpected behaviour in __call__()!') + + def test_jac(self): + assert_allclose(self.reg.jac(np.arange(self.n)), 2 * np.arange(self.n), + err_msg='Unexpected behaviour in jac()!') + jac = np.array([self.reg.jac(np.eye(self.n)[:, i]) for i in range(self.n)]).T + assert_allclose(jac, 2 * np.eye(self.n), err_msg='Unexpected behaviour in jac()!') + + def test_hess_dot(self): + assert_allclose(self.reg.hess_dot(None, np.arange(self.n)), 2 * np.arange(self.n), + err_msg='Unexpected behaviour in jac()!') + hess = np.array([self.reg.hess_dot(None, np.eye(self.n)[:, i]) for i in range(self.n)]).T + assert_allclose(hess, 2 * np.eye(self.n), err_msg='Unexpected behaviour in hess_dot()!') + + def test_hess_diag(self): + assert_allclose(self.reg.hess_diag(np.arange(self.n)), 2 * np.ones(self.n), + err_msg='Unexpected behaviour in hess_diag()!') + + +class TestCaseFirstOrderRegularisator(unittest.TestCase): + def setUp(self): + self.path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'test_regularisator') + self.dim = (4, 5, 6) + self.mask = np.zeros(self.dim, dtype=bool) + self.mask[1:-1, 1:-1, 1:-1] = True + self.n = 3 * self.mask.sum() + self.lam = 1. + self.reg = FirstOrderRegularisator(self.mask, lam=self.lam) + + def tearDown(self): + self.path = None + self.dim = None + self.mask = None + self.n = None + self.lam = None + self.reg = None + + def test_call(self): + assert_allclose(self.reg(np.ones(self.n)), 0., + err_msg='Unexpected behaviour in __call__()!') + assert_allclose(self.reg(np.zeros(self.n)), 0., + err_msg='Unexpected behaviour in __call__()!') + cost_ref = np.load(os.path.join(self.path, 'first_order_cost_ref.npy')) + assert_allclose(self.reg(np.arange(self.n)), cost_ref, + err_msg='Unexpected behaviour in __call__()!') + + def test_jac(self): + assert_allclose(self.reg.jac(np.ones(self.n)), np.zeros(self.n), + err_msg='Unexpected behaviour in jac()!') + assert_allclose(self.reg.jac(np.zeros(self.n)), np.zeros(self.n), + err_msg='Unexpected behaviour in jac()!') + jac_vec_ref = np.load(os.path.join(self.path, 'first_order_jac_vec_ref.npy')) + assert_allclose(self.reg.jac(np.arange(self.n)), jac_vec_ref, atol=1E-7, + err_msg='Unexpected behaviour in jac()!') + jac = np.array([self.reg.jac(np.eye(self.n)[:, i]) for i in range(self.n)]).T + jac_ref = np.load(os.path.join(self.path, 'first_order_jac_ref.npy')) + assert_allclose(jac, jac_ref, atol=1E-7, + err_msg='Unexpected behaviour in jac()!') + + def test_hess_dot(self): + assert_allclose(self.reg.hess_dot(None, np.ones(self.n)), np.zeros(self.n), + err_msg='Unexpected behaviour in hess_dot()!') + assert_allclose(self.reg.hess_dot(None, np.zeros(self.n)), np.zeros(self.n), + err_msg='Unexpected behaviour in hess_dot()!') + hess_vec_ref = np.load(os.path.join(self.path, 'first_order_jac_vec_ref.npy')) + assert_allclose(self.reg.hess_dot(None, np.arange(self.n)), hess_vec_ref, atol=1E-7, + err_msg='Unexpected behaviour in hess_dot()!') + hess = np.array([self.reg.hess_dot(None, np.eye(self.n)[:, i]) for i in range(self.n)]).T + hess_ref = np.load(os.path.join(self.path, 'first_order_jac_ref.npy')) + assert_allclose(hess, hess_ref, atol=1E-7, + err_msg='Unexpected behaviour in hess_dot()!') + + def test_hess_diag(self): + hess_diag = self.reg.hess_diag(np.ones(self.n)) + hess_diag_ref = np.diag(np.load(os.path.join(self.path, 'first_order_jac_ref.npy'))) + assert_allclose(hess_diag, hess_diag_ref, atol=1E-7, + err_msg='Unexpected behaviour in hess_diag()!') + + +if __name__ == '__main__': + import nose + nose.run(defaultTest=__name__) diff --git a/pyramid/utils/__init__.py b/pyramid/utils/__init__.py index c475bec84bb990761e812551b7e217dcbf1dffdd..1c550a0fc3284040769e6291f7c496e41382355c 100644 --- a/pyramid/utils/__init__.py +++ b/pyramid/utils/__init__.py @@ -1,14 +1,14 @@ -# -*- coding: utf-8 -*- -# Copyright 2016 by Forschungszentrum Juelich GmbH -# Author: J. Caron -# -"""Subpackage containing Pyramid utility functions.""" - -from .pm import pm -from .reconstruction_2d_from_phasemap import reconstruction_2d_from_phasemap -from .reconstruction_3d_from_magdata import reconstruction_3d_from_magdata -from .phasemap_creator import gui_phasemap_creator -from .mag_slicer import gui_mag_slicer - -__all__ = ['pm', 'reconstruction_2d_from_phasemap', 'reconstruction_3d_from_magdata', - 'gui_phasemap_creator', 'gui_mag_slicer'] +# -*- coding: utf-8 -*- +# Copyright 2016 by Forschungszentrum Juelich GmbH +# Author: J. Caron +# +"""Subpackage containing Pyramid utility functions.""" + +from .pm import pm +from .reconstruction_2d_from_phasemap import reconstruction_2d_from_phasemap +from .reconstruction_3d_from_magdata import reconstruction_3d_from_magdata +from .phasemap_creator import gui_phasemap_creator +from .mag_slicer import gui_mag_slicer + +__all__ = ['pm', 'reconstruction_2d_from_phasemap', 'reconstruction_3d_from_magdata', + 'gui_phasemap_creator', 'gui_mag_slicer'] diff --git a/pyramid/utils/mag_slicer.py b/pyramid/utils/mag_slicer.py index fee9925614f23047ee301a750dff28021bf1478a..52d480b9e82a3b8cc536612bd8bf557881f718d0 100644 --- a/pyramid/utils/mag_slicer.py +++ b/pyramid/utils/mag_slicer.py @@ -1,150 +1,150 @@ -# -*- coding: utf-8 -*- -# Form implementation generated from reading ui file 'mag_slicer.ui' -# -# Created: Sun Aug 31 20:39:52 2014 -# by: PyQt4 UI code generator 4.9.6 -# -# WARNING! All changes made in this file will be lost! -"""GUI for slicing 3D magnetization distributions.""" - -import logging - -import os -import sys - -from PyQt4 import QtGui, QtCore -from PyQt4.uic import loadUiType - -from matplotlib.figure import Figure -from matplotlib.backends.backend_qt4agg import FigureCanvasQTAgg as FigureCanvas -from matplotlib.backends.backend_qt4agg import NavigationToolbar2QT as NavigationToolbar - -from ..projector import SimpleProjector -from ..kernel import Kernel -from ..phasemapper import PhaseMapperRDFC -from ..file_io.io_vectordata import load_vectordata - -__all__ = ['gui_mag_slicer'] -_log = logging.getLogger(__name__) - - -ui_location = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'mag_slicer.ui') -UI_MainWindow, QMainWindow = loadUiType(ui_location) - - -class Main(QMainWindow, UI_MainWindow): - - def __init__(self): - super().__init__() - self.setupUi(self) - self.connect(self.checkBoxLog, QtCore.SIGNAL('clicked()'), - self.update_slice) - self.connect(self.checkBoxScale, QtCore.SIGNAL('clicked()'), - self.update_slice) - self.connect(self.spinBoxSlice, QtCore.SIGNAL('valueChanged(int)'), - self.update_slice) - self.connect(self.comboBoxSlice, QtCore.SIGNAL('currentIndexChanged(int)'), - self.update_phase) - self.connect(self.spinBoxGain, QtCore.SIGNAL('valueChanged(double)'), - self.update_phase) - self.connect(self.checkBoxAuto, QtCore.SIGNAL('toggled(bool)'), - self.update_phase) - self.connect(self.checkBoxSmooth, QtCore.SIGNAL('toggled(bool)'), - self.update_phase) - self.connect(self.pushButtonLoad, QtCore.SIGNAL('clicked()'), - self.load) - self.is_magdata_loaded = False - self.magdata = None - - def addmpl(self): - fig = Figure() - fig.add_subplot(111, aspect='equal') - self.canvasMag = FigureCanvas(fig) - self.layoutMag.addWidget(self.canvasMag) - self.canvasMag.draw() - self.toolbarMag = NavigationToolbar(self.canvasMag, self, coordinates=True) - self.layoutMag.addWidget(self.toolbarMag) - fig = Figure() - fig.add_subplot(111, aspect='equal') - self.canvasPhase = FigureCanvas(fig) - self.layoutPhase.addWidget(self.canvasPhase) - self.canvasPhase.draw() - self.toolbarPhase = NavigationToolbar(self.canvasPhase, self, coordinates=True) - self.layoutPhase.addWidget(self.toolbarPhase) - fig = Figure() - fig.add_subplot(111, aspect='equal') - self.canvasHolo = FigureCanvas(fig) - self.layoutHolo.addWidget(self.canvasHolo) - self.canvasHolo.draw() - self.toolbarHolo = NavigationToolbar(self.canvasHolo, self, coordinates=True) - self.layoutHolo.addWidget(self.toolbarHolo) - - def update_phase(self): - if self.is_magdata_loaded: - mode_ind = self.comboBoxSlice.currentIndex() - if mode_ind == 0: - self.mode = 'z' - length = self.magdata.dim[0] - 1 - elif mode_ind == 1: - self.mode = 'y' - length = self.magdata.dim[1] - 1 - else: - self.mode = 'x' - length = self.magdata.dim[2] - 1 - if self.checkBoxAuto.isChecked(): - gain = 'auto' - else: - gain = self.spinBoxGain.value() - self.projector = SimpleProjector(self.magdata.dim, axis=self.mode) - self.spinBoxSlice.setMaximum(length) - self.scrollBarSlice.setMaximum(length) - self.spinBoxSlice.setValue(int(length / 2.)) - self.update_slice() - kernel = Kernel(self.magdata.a, self.projector.dim_uv) - self.phasemapper = PhaseMapperRDFC(kernel) - self.phasemap = self.phasemapper(self.projector(self.magdata)) - self.canvasPhase.figure.axes[0].clear() - self.phasemap.plot_phase(axis=self.canvasPhase.figure.axes[0], cbar=False) - if self.checkBoxSmooth.isChecked(): - interpolation = 'bilinear' - else: - interpolation = 'none' - self.canvasHolo.figure.axes[0].clear() - self.phasemap.plot_holo(axis=self.canvasHolo.figure.axes[0], gain=gain, - interpolation=interpolation) - self.canvasPhase.draw() - self.canvasHolo.draw() - - def update_slice(self): - if self.is_magdata_loaded: - self.canvasMag.figure.axes[0].clear() - self.magdata.plot_quiver(axis=self.canvasMag.figure.axes[0], proj_axis=self.mode, - ax_slice=self.spinBoxSlice.value(), - log=self.checkBoxLog.isChecked(), - scaled=self.checkBoxScale.isChecked()) - self.canvasMag.draw() - - def load(self): - try: - mag_file = QtGui.QFileDialog.getOpenFileName(self, 'Open Data File', '', - 'HDF5 files (*.hdf5)') - except ValueError: - return # Abort if no conf_path is selected! - import hyperspy.api as hs - print(hs.load(mag_file)) - self.magdata = load_vectordata(mag_file) - if not self.is_magdata_loaded: - self.addmpl() - self.is_magdata_loaded = True - self.comboBoxSlice.setCurrentIndex(0) - self.update_phase() - - -def gui_mag_slicer(): - """Call the GUI for viewing magnetic distributions.""" - _log.debug('Calling gui_mag_slicer') - app = QtGui.QApplication(sys.argv) - main = Main() - main.show() - app.exec() - return main.magdata +# -*- coding: utf-8 -*- +# Form implementation generated from reading ui file 'mag_slicer.ui' +# +# Created: Sun Aug 31 20:39:52 2014 +# by: PyQt4 UI code generator 4.9.6 +# +# WARNING! All changes made in this file will be lost! +"""GUI for slicing 3D magnetization distributions.""" + +import logging + +import os +import sys + +from PyQt4 import QtGui, QtCore +from PyQt4.uic import loadUiType + +from matplotlib.figure import Figure +from matplotlib.backends.backend_qt4agg import FigureCanvasQTAgg as FigureCanvas +from matplotlib.backends.backend_qt4agg import NavigationToolbar2QT as NavigationToolbar + +from ..projector import SimpleProjector +from ..kernel import Kernel +from ..phasemapper import PhaseMapperRDFC +from ..file_io.io_vectordata import load_vectordata + +__all__ = ['gui_mag_slicer'] +_log = logging.getLogger(__name__) + + +ui_location = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'mag_slicer.ui') +UI_MainWindow, QMainWindow = loadUiType(ui_location) + + +class Main(QMainWindow, UI_MainWindow): + + def __init__(self): + super().__init__() + self.setupUi(self) + self.connect(self.checkBoxLog, QtCore.SIGNAL('clicked()'), + self.update_slice) + self.connect(self.checkBoxScale, QtCore.SIGNAL('clicked()'), + self.update_slice) + self.connect(self.spinBoxSlice, QtCore.SIGNAL('valueChanged(int)'), + self.update_slice) + self.connect(self.comboBoxSlice, QtCore.SIGNAL('currentIndexChanged(int)'), + self.update_phase) + self.connect(self.spinBoxGain, QtCore.SIGNAL('valueChanged(double)'), + self.update_phase) + self.connect(self.checkBoxAuto, QtCore.SIGNAL('toggled(bool)'), + self.update_phase) + self.connect(self.checkBoxSmooth, QtCore.SIGNAL('toggled(bool)'), + self.update_phase) + self.connect(self.pushButtonLoad, QtCore.SIGNAL('clicked()'), + self.load) + self.is_magdata_loaded = False + self.magdata = None + + def addmpl(self): + fig = Figure() + fig.add_subplot(111, aspect='equal') + self.canvasMag = FigureCanvas(fig) + self.layoutMag.addWidget(self.canvasMag) + self.canvasMag.draw() + self.toolbarMag = NavigationToolbar(self.canvasMag, self, coordinates=True) + self.layoutMag.addWidget(self.toolbarMag) + fig = Figure() + fig.add_subplot(111, aspect='equal') + self.canvasPhase = FigureCanvas(fig) + self.layoutPhase.addWidget(self.canvasPhase) + self.canvasPhase.draw() + self.toolbarPhase = NavigationToolbar(self.canvasPhase, self, coordinates=True) + self.layoutPhase.addWidget(self.toolbarPhase) + fig = Figure() + fig.add_subplot(111, aspect='equal') + self.canvasHolo = FigureCanvas(fig) + self.layoutHolo.addWidget(self.canvasHolo) + self.canvasHolo.draw() + self.toolbarHolo = NavigationToolbar(self.canvasHolo, self, coordinates=True) + self.layoutHolo.addWidget(self.toolbarHolo) + + def update_phase(self): + if self.is_magdata_loaded: + mode_ind = self.comboBoxSlice.currentIndex() + if mode_ind == 0: + self.mode = 'z' + length = self.magdata.dim[0] - 1 + elif mode_ind == 1: + self.mode = 'y' + length = self.magdata.dim[1] - 1 + else: + self.mode = 'x' + length = self.magdata.dim[2] - 1 + if self.checkBoxAuto.isChecked(): + gain = 'auto' + else: + gain = self.spinBoxGain.value() + self.projector = SimpleProjector(self.magdata.dim, axis=self.mode) + self.spinBoxSlice.setMaximum(length) + self.scrollBarSlice.setMaximum(length) + self.spinBoxSlice.setValue(int(length / 2.)) + self.update_slice() + kernel = Kernel(self.magdata.a, self.projector.dim_uv) + self.phasemapper = PhaseMapperRDFC(kernel) + self.phasemap = self.phasemapper(self.projector(self.magdata)) + self.canvasPhase.figure.axes[0].clear() + self.phasemap.plot_phase(axis=self.canvasPhase.figure.axes[0], cbar=False) + if self.checkBoxSmooth.isChecked(): + interpolation = 'bilinear' + else: + interpolation = 'none' + self.canvasHolo.figure.axes[0].clear() + self.phasemap.plot_holo(axis=self.canvasHolo.figure.axes[0], gain=gain, + interpolation=interpolation) + self.canvasPhase.draw() + self.canvasHolo.draw() + + def update_slice(self): + if self.is_magdata_loaded: + self.canvasMag.figure.axes[0].clear() + self.magdata.plot_quiver(axis=self.canvasMag.figure.axes[0], proj_axis=self.mode, + ax_slice=self.spinBoxSlice.value(), + log=self.checkBoxLog.isChecked(), + scaled=self.checkBoxScale.isChecked()) + self.canvasMag.draw() + + def load(self): + try: + mag_file = QtGui.QFileDialog.getOpenFileName(self, 'Open Data File', '', + 'HDF5 files (*.hdf5)') + except ValueError: + return # Abort if no conf_path is selected! + import hyperspy.api as hs + print(hs.load(mag_file)) + self.magdata = load_vectordata(mag_file) + if not self.is_magdata_loaded: + self.addmpl() + self.is_magdata_loaded = True + self.comboBoxSlice.setCurrentIndex(0) + self.update_phase() + + +def gui_mag_slicer(): + """Call the GUI for viewing magnetic distributions.""" + _log.debug('Calling gui_mag_slicer') + app = QtGui.QApplication(sys.argv) + main = Main() + main.show() + app.exec() + return main.magdata diff --git a/pyramid/utils/mag_slicer.ui b/pyramid/utils/mag_slicer.ui index bd983878330be31ef9f6bacfed4d4b2f537e7259..8971a07ddb583305d0bfd66cf7b7439deb278dcf 100644 --- a/pyramid/utils/mag_slicer.ui +++ b/pyramid/utils/mag_slicer.ui @@ -1,300 +1,300 @@ -<?xml version="1.0" encoding="UTF-8"?> -<ui version="4.0"> - <class>MainWindow</class> - <widget class="QMainWindow" name="MainWindow"> - <property name="geometry"> - <rect> - <x>0</x> - <y>0</y> - <width>1222</width> - <height>480</height> - </rect> - </property> - <property name="windowTitle"> - <string>MagSlicer</string> - </property> - <widget class="QWidget" name="centralwidget"> - <property name="sizePolicy"> - <sizepolicy hsizetype="Expanding" vsizetype="Expanding"> - <horstretch>0</horstretch> - <verstretch>0</verstretch> - </sizepolicy> - </property> - <property name="windowTitle"> - <string>Mag Slicer</string> - </property> - <layout class="QVBoxLayout" name="verticalLayout"> - <item> - <layout class="QHBoxLayout" name="horizontalLayout"> - <item> - <widget class="QPushButton" name="pushButtonLoad"> - <property name="sizePolicy"> - <sizepolicy hsizetype="Minimum" vsizetype="Fixed"> - <horstretch>0</horstretch> - <verstretch>0</verstretch> - </sizepolicy> - </property> - <property name="text"> - <string>Laden</string> - </property> - </widget> - </item> - <item> - <widget class="QCheckBox" name="checkBoxScale"> - <property name="sizePolicy"> - <sizepolicy hsizetype="Minimum" vsizetype="Fixed"> - <horstretch>0</horstretch> - <verstretch>0</verstretch> - </sizepolicy> - </property> - <property name="text"> - <string>Scaled</string> - </property> - <property name="checked"> - <bool>true</bool> - </property> - <property name="tristate"> - <bool>false</bool> - </property> - </widget> - </item> - <item> - <widget class="QCheckBox" name="checkBoxLog"> - <property name="sizePolicy"> - <sizepolicy hsizetype="Minimum" vsizetype="Fixed"> - <horstretch>0</horstretch> - <verstretch>0</verstretch> - </sizepolicy> - </property> - <property name="text"> - <string>Log</string> - </property> - </widget> - </item> - <item> - <widget class="QComboBox" name="comboBoxSlice"> - <property name="sizePolicy"> - <sizepolicy hsizetype="Preferred" vsizetype="Fixed"> - <horstretch>0</horstretch> - <verstretch>0</verstretch> - </sizepolicy> - </property> - <item> - <property name="text"> - <string>xy-plane</string> - </property> - </item> - <item> - <property name="text"> - <string>xz-plane</string> - </property> - </item> - <item> - <property name="text"> - <string>zy-plane</string> - </property> - </item> - </widget> - </item> - <item> - <widget class="QLabel" name="labelSlice"> - <property name="sizePolicy"> - <sizepolicy hsizetype="Minimum" vsizetype="Fixed"> - <horstretch>0</horstretch> - <verstretch>0</verstretch> - </sizepolicy> - </property> - <property name="text"> - <string>Slice:</string> - </property> - <property name="alignment"> - <set>Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter</set> - </property> - </widget> - </item> - <item> - <widget class="QSpinBox" name="spinBoxSlice"> - <property name="sizePolicy"> - <sizepolicy hsizetype="Minimum" vsizetype="Fixed"> - <horstretch>0</horstretch> - <verstretch>0</verstretch> - </sizepolicy> - </property> - <property name="alignment"> - <set>Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter</set> - </property> - <property name="maximum"> - <number>0</number> - </property> - <property name="value"> - <number>0</number> - </property> - </widget> - </item> - <item> - <widget class="QScrollBar" name="scrollBarSlice"> - <property name="sizePolicy"> - <sizepolicy hsizetype="Expanding" vsizetype="Fixed"> - <horstretch>0</horstretch> - <verstretch>0</verstretch> - </sizepolicy> - </property> - <property name="minimumSize"> - <size> - <width>400</width> - <height>0</height> - </size> - </property> - <property name="maximum"> - <number>0</number> - </property> - <property name="orientation"> - <enum>Qt::Horizontal</enum> - </property> - </widget> - </item> - <item> - <widget class="QCheckBox" name="checkBoxSmooth"> - <property name="sizePolicy"> - <sizepolicy hsizetype="Minimum" vsizetype="Fixed"> - <horstretch>0</horstretch> - <verstretch>0</verstretch> - </sizepolicy> - </property> - <property name="text"> - <string>Smooth</string> - </property> - <property name="checked"> - <bool>true</bool> - </property> - </widget> - </item> - <item> - <widget class="QCheckBox" name="checkBoxAuto"> - <property name="sizePolicy"> - <sizepolicy hsizetype="Minimum" vsizetype="Fixed"> - <horstretch>0</horstretch> - <verstretch>0</verstretch> - </sizepolicy> - </property> - <property name="text"> - <string>Auto</string> - </property> - <property name="checked"> - <bool>true</bool> - </property> - </widget> - </item> - <item> - <widget class="QLabel" name="labelGain"> - <property name="sizePolicy"> - <sizepolicy hsizetype="Minimum" vsizetype="Fixed"> - <horstretch>0</horstretch> - <verstretch>0</verstretch> - </sizepolicy> - </property> - <property name="text"> - <string>Gain:</string> - </property> - </widget> - </item> - <item> - <widget class="QDoubleSpinBox" name="spinBoxGain"> - <property name="enabled"> - <bool>false</bool> - </property> - <property name="sizePolicy"> - <sizepolicy hsizetype="Minimum" vsizetype="Fixed"> - <horstretch>0</horstretch> - <verstretch>0</verstretch> - </sizepolicy> - </property> - <property name="maximum"> - <double>1000000.000000000000000</double> - </property> - <property name="singleStep"> - <double>0.100000000000000</double> - </property> - <property name="value"> - <double>1.000000000000000</double> - </property> - </widget> - </item> - </layout> - </item> - <item> - <widget class="QWidget" name="layoutPlots" native="true"> - <property name="sizePolicy"> - <sizepolicy hsizetype="Preferred" vsizetype="Expanding"> - <horstretch>0</horstretch> - <verstretch>0</verstretch> - </sizepolicy> - </property> - <layout class="QHBoxLayout" name="horizontalLayoutPlots"> - <item> - <layout class="QVBoxLayout" name="layoutMag"/> - </item> - <item> - <layout class="QVBoxLayout" name="layoutPhase"/> - </item> - <item> - <layout class="QVBoxLayout" name="layoutHolo"/> - </item> - </layout> - </widget> - </item> - </layout> - </widget> - </widget> - <resources/> - <connections> - <connection> - <sender>spinBoxSlice</sender> - <signal>valueChanged(int)</signal> - <receiver>scrollBarSlice</receiver> - <slot>setValue(int)</slot> - <hints> - <hint type="sourcelabel"> - <x>322</x> - <y>19</y> - </hint> - <hint type="destinationlabel"> - <x>386</x> - <y>20</y> - </hint> - </hints> - </connection> - <connection> - <sender>scrollBarSlice</sender> - <signal>sliderMoved(int)</signal> - <receiver>spinBoxSlice</receiver> - <slot>setValue(int)</slot> - <hints> - <hint type="sourcelabel"> - <x>431</x> - <y>24</y> - </hint> - <hint type="destinationlabel"> - <x>321</x> - <y>20</y> - </hint> - </hints> - </connection> - <connection> - <sender>checkBoxAuto</sender> - <signal>toggled(bool)</signal> - <receiver>spinBoxGain</receiver> - <slot>setDisabled(bool)</slot> - <hints> - <hint type="sourcelabel"> - <x>1065</x> - <y>22</y> - </hint> - <hint type="destinationlabel"> - <x>1166</x> - <y>21</y> - </hint> - </hints> - </connection> - </connections> -</ui> +<?xml version="1.0" encoding="UTF-8"?> +<ui version="4.0"> + <class>MainWindow</class> + <widget class="QMainWindow" name="MainWindow"> + <property name="geometry"> + <rect> + <x>0</x> + <y>0</y> + <width>1222</width> + <height>480</height> + </rect> + </property> + <property name="windowTitle"> + <string>MagSlicer</string> + </property> + <widget class="QWidget" name="centralwidget"> + <property name="sizePolicy"> + <sizepolicy hsizetype="Expanding" vsizetype="Expanding"> + <horstretch>0</horstretch> + <verstretch>0</verstretch> + </sizepolicy> + </property> + <property name="windowTitle"> + <string>Mag Slicer</string> + </property> + <layout class="QVBoxLayout" name="verticalLayout"> + <item> + <layout class="QHBoxLayout" name="horizontalLayout"> + <item> + <widget class="QPushButton" name="pushButtonLoad"> + <property name="sizePolicy"> + <sizepolicy hsizetype="Minimum" vsizetype="Fixed"> + <horstretch>0</horstretch> + <verstretch>0</verstretch> + </sizepolicy> + </property> + <property name="text"> + <string>Laden</string> + </property> + </widget> + </item> + <item> + <widget class="QCheckBox" name="checkBoxScale"> + <property name="sizePolicy"> + <sizepolicy hsizetype="Minimum" vsizetype="Fixed"> + <horstretch>0</horstretch> + <verstretch>0</verstretch> + </sizepolicy> + </property> + <property name="text"> + <string>Scaled</string> + </property> + <property name="checked"> + <bool>true</bool> + </property> + <property name="tristate"> + <bool>false</bool> + </property> + </widget> + </item> + <item> + <widget class="QCheckBox" name="checkBoxLog"> + <property name="sizePolicy"> + <sizepolicy hsizetype="Minimum" vsizetype="Fixed"> + <horstretch>0</horstretch> + <verstretch>0</verstretch> + </sizepolicy> + </property> + <property name="text"> + <string>Log</string> + </property> + </widget> + </item> + <item> + <widget class="QComboBox" name="comboBoxSlice"> + <property name="sizePolicy"> + <sizepolicy hsizetype="Preferred" vsizetype="Fixed"> + <horstretch>0</horstretch> + <verstretch>0</verstretch> + </sizepolicy> + </property> + <item> + <property name="text"> + <string>xy-plane</string> + </property> + </item> + <item> + <property name="text"> + <string>xz-plane</string> + </property> + </item> + <item> + <property name="text"> + <string>zy-plane</string> + </property> + </item> + </widget> + </item> + <item> + <widget class="QLabel" name="labelSlice"> + <property name="sizePolicy"> + <sizepolicy hsizetype="Minimum" vsizetype="Fixed"> + <horstretch>0</horstretch> + <verstretch>0</verstretch> + </sizepolicy> + </property> + <property name="text"> + <string>Slice:</string> + </property> + <property name="alignment"> + <set>Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter</set> + </property> + </widget> + </item> + <item> + <widget class="QSpinBox" name="spinBoxSlice"> + <property name="sizePolicy"> + <sizepolicy hsizetype="Minimum" vsizetype="Fixed"> + <horstretch>0</horstretch> + <verstretch>0</verstretch> + </sizepolicy> + </property> + <property name="alignment"> + <set>Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter</set> + </property> + <property name="maximum"> + <number>0</number> + </property> + <property name="value"> + <number>0</number> + </property> + </widget> + </item> + <item> + <widget class="QScrollBar" name="scrollBarSlice"> + <property name="sizePolicy"> + <sizepolicy hsizetype="Expanding" vsizetype="Fixed"> + <horstretch>0</horstretch> + <verstretch>0</verstretch> + </sizepolicy> + </property> + <property name="minimumSize"> + <size> + <width>400</width> + <height>0</height> + </size> + </property> + <property name="maximum"> + <number>0</number> + </property> + <property name="orientation"> + <enum>Qt::Horizontal</enum> + </property> + </widget> + </item> + <item> + <widget class="QCheckBox" name="checkBoxSmooth"> + <property name="sizePolicy"> + <sizepolicy hsizetype="Minimum" vsizetype="Fixed"> + <horstretch>0</horstretch> + <verstretch>0</verstretch> + </sizepolicy> + </property> + <property name="text"> + <string>Smooth</string> + </property> + <property name="checked"> + <bool>true</bool> + </property> + </widget> + </item> + <item> + <widget class="QCheckBox" name="checkBoxAuto"> + <property name="sizePolicy"> + <sizepolicy hsizetype="Minimum" vsizetype="Fixed"> + <horstretch>0</horstretch> + <verstretch>0</verstretch> + </sizepolicy> + </property> + <property name="text"> + <string>Auto</string> + </property> + <property name="checked"> + <bool>true</bool> + </property> + </widget> + </item> + <item> + <widget class="QLabel" name="labelGain"> + <property name="sizePolicy"> + <sizepolicy hsizetype="Minimum" vsizetype="Fixed"> + <horstretch>0</horstretch> + <verstretch>0</verstretch> + </sizepolicy> + </property> + <property name="text"> + <string>Gain:</string> + </property> + </widget> + </item> + <item> + <widget class="QDoubleSpinBox" name="spinBoxGain"> + <property name="enabled"> + <bool>false</bool> + </property> + <property name="sizePolicy"> + <sizepolicy hsizetype="Minimum" vsizetype="Fixed"> + <horstretch>0</horstretch> + <verstretch>0</verstretch> + </sizepolicy> + </property> + <property name="maximum"> + <double>1000000.000000000000000</double> + </property> + <property name="singleStep"> + <double>0.100000000000000</double> + </property> + <property name="value"> + <double>1.000000000000000</double> + </property> + </widget> + </item> + </layout> + </item> + <item> + <widget class="QWidget" name="layoutPlots" native="true"> + <property name="sizePolicy"> + <sizepolicy hsizetype="Preferred" vsizetype="Expanding"> + <horstretch>0</horstretch> + <verstretch>0</verstretch> + </sizepolicy> + </property> + <layout class="QHBoxLayout" name="horizontalLayoutPlots"> + <item> + <layout class="QVBoxLayout" name="layoutMag"/> + </item> + <item> + <layout class="QVBoxLayout" name="layoutPhase"/> + </item> + <item> + <layout class="QVBoxLayout" name="layoutHolo"/> + </item> + </layout> + </widget> + </item> + </layout> + </widget> + </widget> + <resources/> + <connections> + <connection> + <sender>spinBoxSlice</sender> + <signal>valueChanged(int)</signal> + <receiver>scrollBarSlice</receiver> + <slot>setValue(int)</slot> + <hints> + <hint type="sourcelabel"> + <x>322</x> + <y>19</y> + </hint> + <hint type="destinationlabel"> + <x>386</x> + <y>20</y> + </hint> + </hints> + </connection> + <connection> + <sender>scrollBarSlice</sender> + <signal>sliderMoved(int)</signal> + <receiver>spinBoxSlice</receiver> + <slot>setValue(int)</slot> + <hints> + <hint type="sourcelabel"> + <x>431</x> + <y>24</y> + </hint> + <hint type="destinationlabel"> + <x>321</x> + <y>20</y> + </hint> + </hints> + </connection> + <connection> + <sender>checkBoxAuto</sender> + <signal>toggled(bool)</signal> + <receiver>spinBoxGain</receiver> + <slot>setDisabled(bool)</slot> + <hints> + <hint type="sourcelabel"> + <x>1065</x> + <y>22</y> + </hint> + <hint type="destinationlabel"> + <x>1166</x> + <y>21</y> + </hint> + </hints> + </connection> + </connections> +</ui> diff --git a/pyramid/utils/phasemap_creator.py b/pyramid/utils/phasemap_creator.py index 7f40c8903682d1c69aa9248044925e68d1f1a6d8..ac49c08698d6b204db656b538f2d516c2349bb9d 100644 --- a/pyramid/utils/phasemap_creator.py +++ b/pyramid/utils/phasemap_creator.py @@ -1,170 +1,170 @@ -# -*- coding: utf-8 -*- - -# Form implementation generated from reading ui file 'phasemap_creator.ui' -# -# Created: Thu Sep 24 11:42:11 2015 -# by: PyQt4 UI code generator 4.9.6 -# -# WARNING! All changes made in this file will be lost! -"""GUI for setting up PhasMaps from existing data in different formats.""" - -import logging - -import os -import sys - -from PyQt4 import QtGui, QtCore -from PyQt4.uic import loadUiType - -from matplotlib.figure import Figure -from matplotlib.backends.backend_qt4agg import FigureCanvasQTAgg as FigureCanvas -from matplotlib.backends.backend_qt4agg import NavigationToolbar2QT as NavigationToolbar - -from PIL import Image - -import numpy as np - -import hyperspy.api as hs - -import pyramid as pr - -__all__ = ['gui_phasemap_creator'] -_log = logging.getLogger(__name__) - - -ui_location = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'phasemap_creator.ui') -UI_MainWindow, QMainWindow = loadUiType(ui_location) - - -class Main(QMainWindow, UI_MainWindow): - - def __init__(self): - super().__init__() - self.setupUi(self) - self.connect(self.pushButton_phase, QtCore.SIGNAL('clicked()'), - self.load_phase) - self.connect(self.pushButton_mask, QtCore.SIGNAL('clicked()'), - self.load_mask) - self.connect(self.pushButton_conf, QtCore.SIGNAL('clicked()'), - self.load_conf) - self.connect(self.pushButton_export, QtCore.SIGNAL('clicked()'), - self.export) - self.connect(self.horizontalScrollBar, QtCore.SIGNAL('valueChanged(int)'), - self.doubleSpinBox_thres.setValue) - self.connect(self.doubleSpinBox_thres, QtCore.SIGNAL('valueChanged(double)'), - self.horizontalScrollBar.setValue) - self.connect(self.checkBox_mask, QtCore.SIGNAL('clicked()'), - self.update_phasemap) - self.connect(self.checkBox_conf, QtCore.SIGNAL('clicked()'), - self.update_phasemap) - self.connect(self.doubleSpinBox_a, QtCore.SIGNAL('editingFinished()'), - self.update_phasemap) - self.connect(self.doubleSpinBox_thres, QtCore.SIGNAL('valueChanged(double)'), - self.update_mask) - self.phase_loaded = False - self.mask_loaded = False - self.dir = '' - self.phasemap = None - - def addmpl(self): - fig = Figure() - fig.add_subplot(111, aspect='equal') - self.canvas = FigureCanvas(fig) - self.mplLayout.addWidget(self.canvas) - self.canvas.draw() - self.toolbar = NavigationToolbar(self.canvas, self, coordinates=True) - self.mplLayout.addWidget(self.toolbar) - - def update_phasemap(self): - if self.phase_loaded: - self.phasemap.a = self.doubleSpinBox_a.value() - show_mask = self.checkBox_mask.isChecked() - show_conf = self.checkBox_conf.isChecked() - self.canvas.figure.axes[0].clear() - self.canvas.figure.axes[0].hold(True) - self.phasemap.plot_phase('PhaseMap', axis=self.canvas.figure.axes[0], - show_mask=show_mask, show_conf=show_conf, cbar=False) - self.canvas.draw() - - def update_mask(self): - if self.mask_loaded: - threshold = self.doubleSpinBox_thres.value() - mask_img = Image.fromarray(self.raw_mask) - mask = np.asarray(mask_img.resize(list(reversed(self.phasemap.dim_uv)))) - self.phasemap.mask = np.where(mask >= threshold, True, False) - self.update_phasemap() - - def load_phase(self): - try: - self.phase_path = QtGui.QFileDialog.getOpenFileName(self, 'Load Phase', self.dir) - self.phasemap = pr.file_io.io_phasemap._load(self.phase_path, as_phasemap=True) - except ValueError: - return # Abort if no phase_path is selected! - self.doubleSpinBox_a.setValue(self.phasemap.a) - self.dir = os.path.join(os.path.dirname(self.phase_path)) - if not self.phase_loaded: - self.addmpl() - self.pushButton_mask.setEnabled(True) - self.pushButton_conf.setEnabled(True) - self.pushButton_export.setEnabled(True) - self.phase_loaded = True - self.horizontalScrollBar.setMinimum(0) - self.horizontalScrollBar.setMaximum(0) - self.horizontalScrollBar.setEnabled(False) - self.doubleSpinBox_thres.setMinimum(0) - self.doubleSpinBox_thres.setMaximum(0) - self.doubleSpinBox_thres.setValue(0) - self.doubleSpinBox_thres.setEnabled(False) - self.mask_loaded = False - self.update_phasemap() - - def load_mask(self): - try: - mask_path = QtGui.QFileDialog.getOpenFileName(self, 'Load Mask', self.dir) - self.raw_mask = pr.file_io.io_phasemap._load(mask_path) - except ValueError: - return # Abort if no mask_path is selected! - mask_min = self.raw_mask.min() - mask_max = self.raw_mask.max() - self.horizontalScrollBar.setEnabled(True) - self.horizontalScrollBar.setMinimum(mask_min) - self.horizontalScrollBar.setMaximum(mask_max) - self.horizontalScrollBar.setSingleStep((mask_max - mask_min) / 255.) - self.horizontalScrollBar.setValue((mask_max - mask_min) / 2.) - self.doubleSpinBox_thres.setEnabled(True) - self.doubleSpinBox_thres.setMinimum(mask_min) - self.doubleSpinBox_thres.setMaximum(mask_max) - self.doubleSpinBox_thres.setSingleStep((mask_max - mask_min) / 255.) - self.doubleSpinBox_thres.setValue((mask_max - mask_min) / 2.) - self.mask_loaded = True - self.update_mask() - - def load_conf(self): - try: - conf_path = QtGui.QFileDialog.getOpenFileName(self, 'Load Confidence', self.dir) - confidence = pr.file_io.io_phasemap._load(conf_path) - except ValueError: - return # Abort if no conf_path is selected! - confidence = confidence.astype(float) / confidence.max() + 1e-30 - self.phasemap.confidence = confidence - self.update_phasemap() - - def export(self): - try: - export_name = os.path.splitext(os.path.basename(self.phase_path))[0] - export_default = os.path.join(self.dir, 'phasemap_gui_{}.hdf5'.format(export_name)) - export_path = QtGui.QFileDialog.getSaveFileName(self, 'Export PhaseMap', - export_default, 'HDF5 (*.hdf5)') - self.phasemap.to_signal().save(export_path, overwrite=True) - except (ValueError, AttributeError): - return # Abort if no export_path is selected or self.phasemap doesn't exist yet! - - -def gui_phasemap_creator(): - """Call the GUI for phasemap creation.""" - _log.debug('Calling gui_phasemap_creator') - app = QtGui.QApplication(sys.argv) - main = Main() - main.show() - app.exec() - return main.phasemap +# -*- coding: utf-8 -*- + +# Form implementation generated from reading ui file 'phasemap_creator.ui' +# +# Created: Thu Sep 24 11:42:11 2015 +# by: PyQt4 UI code generator 4.9.6 +# +# WARNING! All changes made in this file will be lost! +"""GUI for setting up PhasMaps from existing data in different formats.""" + +import logging + +import os +import sys + +from PyQt4 import QtGui, QtCore +from PyQt4.uic import loadUiType + +from matplotlib.figure import Figure +from matplotlib.backends.backend_qt4agg import FigureCanvasQTAgg as FigureCanvas +from matplotlib.backends.backend_qt4agg import NavigationToolbar2QT as NavigationToolbar + +from PIL import Image + +import numpy as np + +import hyperspy.api as hs + +import pyramid as pr + +__all__ = ['gui_phasemap_creator'] +_log = logging.getLogger(__name__) + + +ui_location = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'phasemap_creator.ui') +UI_MainWindow, QMainWindow = loadUiType(ui_location) + + +class Main(QMainWindow, UI_MainWindow): + + def __init__(self): + super().__init__() + self.setupUi(self) + self.connect(self.pushButton_phase, QtCore.SIGNAL('clicked()'), + self.load_phase) + self.connect(self.pushButton_mask, QtCore.SIGNAL('clicked()'), + self.load_mask) + self.connect(self.pushButton_conf, QtCore.SIGNAL('clicked()'), + self.load_conf) + self.connect(self.pushButton_export, QtCore.SIGNAL('clicked()'), + self.export) + self.connect(self.horizontalScrollBar, QtCore.SIGNAL('valueChanged(int)'), + self.doubleSpinBox_thres.setValue) + self.connect(self.doubleSpinBox_thres, QtCore.SIGNAL('valueChanged(double)'), + self.horizontalScrollBar.setValue) + self.connect(self.checkBox_mask, QtCore.SIGNAL('clicked()'), + self.update_phasemap) + self.connect(self.checkBox_conf, QtCore.SIGNAL('clicked()'), + self.update_phasemap) + self.connect(self.doubleSpinBox_a, QtCore.SIGNAL('editingFinished()'), + self.update_phasemap) + self.connect(self.doubleSpinBox_thres, QtCore.SIGNAL('valueChanged(double)'), + self.update_mask) + self.phase_loaded = False + self.mask_loaded = False + self.dir = '' + self.phasemap = None + + def addmpl(self): + fig = Figure() + fig.add_subplot(111, aspect='equal') + self.canvas = FigureCanvas(fig) + self.mplLayout.addWidget(self.canvas) + self.canvas.draw() + self.toolbar = NavigationToolbar(self.canvas, self, coordinates=True) + self.mplLayout.addWidget(self.toolbar) + + def update_phasemap(self): + if self.phase_loaded: + self.phasemap.a = self.doubleSpinBox_a.value() + show_mask = self.checkBox_mask.isChecked() + show_conf = self.checkBox_conf.isChecked() + self.canvas.figure.axes[0].clear() + self.canvas.figure.axes[0].hold(True) + self.phasemap.plot_phase('PhaseMap', axis=self.canvas.figure.axes[0], + show_mask=show_mask, show_conf=show_conf, cbar=False) + self.canvas.draw() + + def update_mask(self): + if self.mask_loaded: + threshold = self.doubleSpinBox_thres.value() + mask_img = Image.fromarray(self.raw_mask) + mask = np.asarray(mask_img.resize(list(reversed(self.phasemap.dim_uv)))) + self.phasemap.mask = np.where(mask >= threshold, True, False) + self.update_phasemap() + + def load_phase(self): + try: + self.phase_path = QtGui.QFileDialog.getOpenFileName(self, 'Load Phase', self.dir) + self.phasemap = pr.file_io.io_phasemap._load(self.phase_path, as_phasemap=True) + except ValueError: + return # Abort if no phase_path is selected! + self.doubleSpinBox_a.setValue(self.phasemap.a) + self.dir = os.path.join(os.path.dirname(self.phase_path)) + if not self.phase_loaded: + self.addmpl() + self.pushButton_mask.setEnabled(True) + self.pushButton_conf.setEnabled(True) + self.pushButton_export.setEnabled(True) + self.phase_loaded = True + self.horizontalScrollBar.setMinimum(0) + self.horizontalScrollBar.setMaximum(0) + self.horizontalScrollBar.setEnabled(False) + self.doubleSpinBox_thres.setMinimum(0) + self.doubleSpinBox_thres.setMaximum(0) + self.doubleSpinBox_thres.setValue(0) + self.doubleSpinBox_thres.setEnabled(False) + self.mask_loaded = False + self.update_phasemap() + + def load_mask(self): + try: + mask_path = QtGui.QFileDialog.getOpenFileName(self, 'Load Mask', self.dir) + self.raw_mask = pr.file_io.io_phasemap._load(mask_path) + except ValueError: + return # Abort if no mask_path is selected! + mask_min = self.raw_mask.min() + mask_max = self.raw_mask.max() + self.horizontalScrollBar.setEnabled(True) + self.horizontalScrollBar.setMinimum(mask_min) + self.horizontalScrollBar.setMaximum(mask_max) + self.horizontalScrollBar.setSingleStep((mask_max - mask_min) / 255.) + self.horizontalScrollBar.setValue((mask_max - mask_min) / 2.) + self.doubleSpinBox_thres.setEnabled(True) + self.doubleSpinBox_thres.setMinimum(mask_min) + self.doubleSpinBox_thres.setMaximum(mask_max) + self.doubleSpinBox_thres.setSingleStep((mask_max - mask_min) / 255.) + self.doubleSpinBox_thres.setValue((mask_max - mask_min) / 2.) + self.mask_loaded = True + self.update_mask() + + def load_conf(self): + try: + conf_path = QtGui.QFileDialog.getOpenFileName(self, 'Load Confidence', self.dir) + confidence = pr.file_io.io_phasemap._load(conf_path) + except ValueError: + return # Abort if no conf_path is selected! + confidence = confidence.astype(float) / confidence.max() + 1e-30 + self.phasemap.confidence = confidence + self.update_phasemap() + + def export(self): + try: + export_name = os.path.splitext(os.path.basename(self.phase_path))[0] + export_default = os.path.join(self.dir, 'phasemap_gui_{}.hdf5'.format(export_name)) + export_path = QtGui.QFileDialog.getSaveFileName(self, 'Export PhaseMap', + export_default, 'HDF5 (*.hdf5)') + self.phasemap.to_signal().save(export_path, overwrite=True) + except (ValueError, AttributeError): + return # Abort if no export_path is selected or self.phasemap doesn't exist yet! + + +def gui_phasemap_creator(): + """Call the GUI for phasemap creation.""" + _log.debug('Calling gui_phasemap_creator') + app = QtGui.QApplication(sys.argv) + main = Main() + main.show() + app.exec() + return main.phasemap diff --git a/pyramid/utils/phasemap_creator.ui b/pyramid/utils/phasemap_creator.ui index caa64fc2d2c44c5fc59bbdf24ab48fbfe612a6ea..12db72f2b520d5f823c0667b72ba38533a076371 100644 --- a/pyramid/utils/phasemap_creator.ui +++ b/pyramid/utils/phasemap_creator.ui @@ -1,222 +1,222 @@ -<?xml version="1.0" encoding="UTF-8"?> -<ui version="4.0"> - <class>MainWindow</class> - <widget class="QMainWindow" name="MainWindow"> - <property name="geometry"> - <rect> - <x>0</x> - <y>0</y> - <width>726</width> - <height>632</height> - </rect> - </property> - <property name="windowTitle"> - <string>MainWindow</string> - </property> - <widget class="QWidget" name="centralwidget"> - <property name="windowTitle"> - <string>Form</string> - </property> - <layout class="QVBoxLayout" name="verticalLayout"> - <item> - <widget class="QWidget" name="mplwidget" native="true"> - <property name="sizePolicy"> - <sizepolicy hsizetype="Expanding" vsizetype="Expanding"> - <horstretch>0</horstretch> - <verstretch>0</verstretch> - </sizepolicy> - </property> - <layout class="QVBoxLayout" name="verticalLayout_2"> - <item> - <layout class="QVBoxLayout" name="mplLayout"/> - </item> - </layout> - </widget> - </item> - <item> - <layout class="QHBoxLayout" name="horizontalLayout"> - <item> - <widget class="QPushButton" name="pushButton_phase"> - <property name="text"> - <string>Load Phase</string> - </property> - </widget> - </item> - <item> - <widget class="QPushButton" name="pushButton_mask"> - <property name="text"> - <string>Load Mask</string> - </property> - </widget> - </item> - <item> - <widget class="QPushButton" name="pushButton_conf"> - <property name="text"> - <string>Load Confidence</string> - </property> - <property name="checkable"> - <bool>false</bool> - </property> - </widget> - </item> - <item> - <widget class="QPushButton" name="pushButton_export"> - <property name="text"> - <string>Export Phasemap</string> - </property> - </widget> - </item> - </layout> - </item> - <item> - <layout class="QHBoxLayout" name="horizontalLayout_2"> - <item> - <widget class="QLabel" name="label_2"> - <property name="text"> - <string>Grid spacing [nm]:</string> - </property> - </widget> - </item> - <item> - <widget class="QDoubleSpinBox" name="doubleSpinBox_a"> - <property name="sizePolicy"> - <sizepolicy hsizetype="Minimum" vsizetype="Fixed"> - <horstretch>0</horstretch> - <verstretch>0</verstretch> - </sizepolicy> - </property> - <property name="minimumSize"> - <size> - <width>60</width> - <height>0</height> - </size> - </property> - <property name="baseSize"> - <size> - <width>0</width> - <height>0</height> - </size> - </property> - <property name="alignment"> - <set>Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter</set> - </property> - <property name="maximum"> - <double>1000.000000000000000</double> - </property> - <property name="value"> - <double>1.000000000000000</double> - </property> - </widget> - </item> - <item> - <widget class="QLabel" name="label"> - <property name="sizePolicy"> - <sizepolicy hsizetype="Fixed" vsizetype="Preferred"> - <horstretch>0</horstretch> - <verstretch>0</verstretch> - </sizepolicy> - </property> - <property name="text"> - <string>Mask Threshold:</string> - </property> - </widget> - </item> - <item> - <widget class="QDoubleSpinBox" name="doubleSpinBox_thres"> - <property name="sizePolicy"> - <sizepolicy hsizetype="Minimum" vsizetype="Fixed"> - <horstretch>0</horstretch> - <verstretch>0</verstretch> - </sizepolicy> - </property> - <property name="minimumSize"> - <size> - <width>60</width> - <height>0</height> - </size> - </property> - <property name="maximumSize"> - <size> - <width>16777215</width> - <height>16777215</height> - </size> - </property> - <property name="baseSize"> - <size> - <width>0</width> - <height>0</height> - </size> - </property> - <property name="alignment"> - <set>Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter</set> - </property> - <property name="prefix"> - <string/> - </property> - <property name="decimals"> - <number>2</number> - </property> - <property name="maximum"> - <double>0.000000000000000</double> - </property> - </widget> - </item> - <item> - <widget class="QScrollBar" name="horizontalScrollBar"> - <property name="sizePolicy"> - <sizepolicy hsizetype="MinimumExpanding" vsizetype="Fixed"> - <horstretch>0</horstretch> - <verstretch>0</verstretch> - </sizepolicy> - </property> - <property name="maximum"> - <number>0</number> - </property> - <property name="singleStep"> - <number>1</number> - </property> - <property name="orientation"> - <enum>Qt::Horizontal</enum> - </property> - </widget> - </item> - <item> - <widget class="QCheckBox" name="checkBox_mask"> - <property name="sizePolicy"> - <sizepolicy hsizetype="Fixed" vsizetype="Fixed"> - <horstretch>0</horstretch> - <verstretch>0</verstretch> - </sizepolicy> - </property> - <property name="text"> - <string>show mask</string> - </property> - <property name="checked"> - <bool>true</bool> - </property> - </widget> - </item> - <item> - <widget class="QCheckBox" name="checkBox_conf"> - <property name="sizePolicy"> - <sizepolicy hsizetype="Fixed" vsizetype="Fixed"> - <horstretch>0</horstretch> - <verstretch>0</verstretch> - </sizepolicy> - </property> - <property name="text"> - <string>show confidence</string> - </property> - <property name="checked"> - <bool>true</bool> - </property> - </widget> - </item> - </layout> - </item> - </layout> - </widget> - </widget> - <resources/> - <connections/> -</ui> +<?xml version="1.0" encoding="UTF-8"?> +<ui version="4.0"> + <class>MainWindow</class> + <widget class="QMainWindow" name="MainWindow"> + <property name="geometry"> + <rect> + <x>0</x> + <y>0</y> + <width>726</width> + <height>632</height> + </rect> + </property> + <property name="windowTitle"> + <string>MainWindow</string> + </property> + <widget class="QWidget" name="centralwidget"> + <property name="windowTitle"> + <string>Form</string> + </property> + <layout class="QVBoxLayout" name="verticalLayout"> + <item> + <widget class="QWidget" name="mplwidget" native="true"> + <property name="sizePolicy"> + <sizepolicy hsizetype="Expanding" vsizetype="Expanding"> + <horstretch>0</horstretch> + <verstretch>0</verstretch> + </sizepolicy> + </property> + <layout class="QVBoxLayout" name="verticalLayout_2"> + <item> + <layout class="QVBoxLayout" name="mplLayout"/> + </item> + </layout> + </widget> + </item> + <item> + <layout class="QHBoxLayout" name="horizontalLayout"> + <item> + <widget class="QPushButton" name="pushButton_phase"> + <property name="text"> + <string>Load Phase</string> + </property> + </widget> + </item> + <item> + <widget class="QPushButton" name="pushButton_mask"> + <property name="text"> + <string>Load Mask</string> + </property> + </widget> + </item> + <item> + <widget class="QPushButton" name="pushButton_conf"> + <property name="text"> + <string>Load Confidence</string> + </property> + <property name="checkable"> + <bool>false</bool> + </property> + </widget> + </item> + <item> + <widget class="QPushButton" name="pushButton_export"> + <property name="text"> + <string>Export Phasemap</string> + </property> + </widget> + </item> + </layout> + </item> + <item> + <layout class="QHBoxLayout" name="horizontalLayout_2"> + <item> + <widget class="QLabel" name="label_2"> + <property name="text"> + <string>Grid spacing [nm]:</string> + </property> + </widget> + </item> + <item> + <widget class="QDoubleSpinBox" name="doubleSpinBox_a"> + <property name="sizePolicy"> + <sizepolicy hsizetype="Minimum" vsizetype="Fixed"> + <horstretch>0</horstretch> + <verstretch>0</verstretch> + </sizepolicy> + </property> + <property name="minimumSize"> + <size> + <width>60</width> + <height>0</height> + </size> + </property> + <property name="baseSize"> + <size> + <width>0</width> + <height>0</height> + </size> + </property> + <property name="alignment"> + <set>Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter</set> + </property> + <property name="maximum"> + <double>1000.000000000000000</double> + </property> + <property name="value"> + <double>1.000000000000000</double> + </property> + </widget> + </item> + <item> + <widget class="QLabel" name="label"> + <property name="sizePolicy"> + <sizepolicy hsizetype="Fixed" vsizetype="Preferred"> + <horstretch>0</horstretch> + <verstretch>0</verstretch> + </sizepolicy> + </property> + <property name="text"> + <string>Mask Threshold:</string> + </property> + </widget> + </item> + <item> + <widget class="QDoubleSpinBox" name="doubleSpinBox_thres"> + <property name="sizePolicy"> + <sizepolicy hsizetype="Minimum" vsizetype="Fixed"> + <horstretch>0</horstretch> + <verstretch>0</verstretch> + </sizepolicy> + </property> + <property name="minimumSize"> + <size> + <width>60</width> + <height>0</height> + </size> + </property> + <property name="maximumSize"> + <size> + <width>16777215</width> + <height>16777215</height> + </size> + </property> + <property name="baseSize"> + <size> + <width>0</width> + <height>0</height> + </size> + </property> + <property name="alignment"> + <set>Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter</set> + </property> + <property name="prefix"> + <string/> + </property> + <property name="decimals"> + <number>2</number> + </property> + <property name="maximum"> + <double>0.000000000000000</double> + </property> + </widget> + </item> + <item> + <widget class="QScrollBar" name="horizontalScrollBar"> + <property name="sizePolicy"> + <sizepolicy hsizetype="MinimumExpanding" vsizetype="Fixed"> + <horstretch>0</horstretch> + <verstretch>0</verstretch> + </sizepolicy> + </property> + <property name="maximum"> + <number>0</number> + </property> + <property name="singleStep"> + <number>1</number> + </property> + <property name="orientation"> + <enum>Qt::Horizontal</enum> + </property> + </widget> + </item> + <item> + <widget class="QCheckBox" name="checkBox_mask"> + <property name="sizePolicy"> + <sizepolicy hsizetype="Fixed" vsizetype="Fixed"> + <horstretch>0</horstretch> + <verstretch>0</verstretch> + </sizepolicy> + </property> + <property name="text"> + <string>show mask</string> + </property> + <property name="checked"> + <bool>true</bool> + </property> + </widget> + </item> + <item> + <widget class="QCheckBox" name="checkBox_conf"> + <property name="sizePolicy"> + <sizepolicy hsizetype="Fixed" vsizetype="Fixed"> + <horstretch>0</horstretch> + <verstretch>0</verstretch> + </sizepolicy> + </property> + <property name="text"> + <string>show confidence</string> + </property> + <property name="checked"> + <bool>true</bool> + </property> + </widget> + </item> + </layout> + </item> + </layout> + </widget> + </widget> + <resources/> + <connections/> +</ui> diff --git a/pyramid/utils/pm.py b/pyramid/utils/pm.py index ee203c06f3d7adcbe35539b68aad8ca7de59e0e3..5c4232b25e6030953bfdd541ea07964ced64fbfb 100644 --- a/pyramid/utils/pm.py +++ b/pyramid/utils/pm.py @@ -1,65 +1,65 @@ -# -*- coding: utf-8 -*- -# Copyright 2016 by Forschungszentrum Juelich GmbH -# Author: J. Caron -# -"""Convenience function for phase mapping magnetic distributions.""" - -import logging - -from ..kernel import Kernel -from ..phasemapper import PhaseMapperRDFC, PhaseMapperFDFC -from ..projector import RotTiltProjector, XTiltProjector, YTiltProjector, SimpleProjector - -__all__ = ['pm'] -_log = logging.getLogger(__name__) - - -def pm(magdata, mode='z', b_0=1, mapper='RDFC', **kwargs): - """Convenience function for fast magnetic phase mapping. - - Parameters - ---------- - magdata : :class:`~.VectorData` - A :class:`~.VectorData` object, from which the projected phase map should be calculated. - mode: {'z', 'y', 'x', 'x-tilt', 'y-tilt', 'rot-tilt'}, optional - Projection mode which determines the :class:`~.pyramid.projector.Projector` subclass, which - is used for the projection. Default is a simple projection along the `z`-direction. - b_0 : float, optional - Saturation magnetization in Tesla, which is used for the phase calculation. Default is 1. - **kwargs : additional arguments - Additional arguments like `dim_uv`, 'tilt' or 'rotation', which are passed to the - projector-constructor, defined by the `mode`. - - Returns - ------- - phasemap : :class:`~pyramid.phasemap.PhaseMap` - The calculated phase map as a :class:`~.PhaseMap` object. - - """ - _log.debug('Calling pm') - # Determine projection mode: - if mode == 'rot-tilt': - projector = RotTiltProjector(magdata.dim, **kwargs) - elif mode == 'x-tilt': - projector = XTiltProjector(magdata.dim, **kwargs) - elif mode == 'y-tilt': - projector = YTiltProjector(magdata.dim, **kwargs) - elif mode in ['x', 'y', 'z']: - projector = SimpleProjector(magdata.dim, axis=mode, **kwargs) - else: - raise ValueError("Invalid mode (use 'x', 'y', 'z', 'x-tilt', 'y-tilt' or 'rot-tilt')") - # Project: - mag_proj = projector(magdata) - # Set up phasemapper and map phase: - if mapper == 'RDFC': - phasemapper = PhaseMapperRDFC(Kernel(magdata.a, projector.dim_uv, b_0=b_0)) - elif mapper == 'FDFC': - padding = kwargs.get('padding', 0) - phasemapper = PhaseMapperFDFC(magdata.a, projector.dim_uv, b_0=b_0, padding=padding) - else: - raise ValueError("Invalid mapper (use 'RDFC' or 'FDFC'") - phasemap = phasemapper(mag_proj) - # Get mask from magdata: - phasemap.mask = mag_proj.get_mask()[0, ...] - # Return phase: - return phasemap +# -*- coding: utf-8 -*- +# Copyright 2016 by Forschungszentrum Juelich GmbH +# Author: J. Caron +# +"""Convenience function for phase mapping magnetic distributions.""" + +import logging + +from ..kernel import Kernel +from ..phasemapper import PhaseMapperRDFC, PhaseMapperFDFC +from ..projector import RotTiltProjector, XTiltProjector, YTiltProjector, SimpleProjector + +__all__ = ['pm'] +_log = logging.getLogger(__name__) + + +def pm(magdata, mode='z', b_0=1, mapper='RDFC', **kwargs): + """Convenience function for fast magnetic phase mapping. + + Parameters + ---------- + magdata : :class:`~.VectorData` + A :class:`~.VectorData` object, from which the projected phase map should be calculated. + mode: {'z', 'y', 'x', 'x-tilt', 'y-tilt', 'rot-tilt'}, optional + Projection mode which determines the :class:`~.pyramid.projector.Projector` subclass, which + is used for the projection. Default is a simple projection along the `z`-direction. + b_0 : float, optional + Saturation magnetization in Tesla, which is used for the phase calculation. Default is 1. + **kwargs : additional arguments + Additional arguments like `dim_uv`, 'tilt' or 'rotation', which are passed to the + projector-constructor, defined by the `mode`. + + Returns + ------- + phasemap : :class:`~pyramid.phasemap.PhaseMap` + The calculated phase map as a :class:`~.PhaseMap` object. + + """ + _log.debug('Calling pm') + # Determine projection mode: + if mode == 'rot-tilt': + projector = RotTiltProjector(magdata.dim, **kwargs) + elif mode == 'x-tilt': + projector = XTiltProjector(magdata.dim, **kwargs) + elif mode == 'y-tilt': + projector = YTiltProjector(magdata.dim, **kwargs) + elif mode in ['x', 'y', 'z']: + projector = SimpleProjector(magdata.dim, axis=mode, **kwargs) + else: + raise ValueError("Invalid mode (use 'x', 'y', 'z', 'x-tilt', 'y-tilt' or 'rot-tilt')") + # Project: + mag_proj = projector(magdata) + # Set up phasemapper and map phase: + if mapper == 'RDFC': + phasemapper = PhaseMapperRDFC(Kernel(magdata.a, projector.dim_uv, b_0=b_0)) + elif mapper == 'FDFC': + padding = kwargs.get('padding', 0) + phasemapper = PhaseMapperFDFC(magdata.a, projector.dim_uv, b_0=b_0, padding=padding) + else: + raise ValueError("Invalid mapper (use 'RDFC' or 'FDFC'") + phasemap = phasemapper(mag_proj) + # Get mask from magdata: + phasemap.mask = mag_proj.get_mask()[0, ...] + # Return phase: + return phasemap diff --git a/pyramid/utils/reconstruction_2d_from_phasemap.py b/pyramid/utils/reconstruction_2d_from_phasemap.py index 9af92f8643b7ee60106a661af597d4215f22c100..0272303b0a4e84a77eb62f2db2e3bc22f8ff5800 100644 --- a/pyramid/utils/reconstruction_2d_from_phasemap.py +++ b/pyramid/utils/reconstruction_2d_from_phasemap.py @@ -1,107 +1,107 @@ -# -*- coding: utf-8 -*- -# Copyright 2016 by Forschungszentrum Juelich GmbH -# Author: J. Caron -# -"""Reconstruct a magnetization distributions from a single phase map.""" - -import logging - -import numpy as np - -from .. import reconstruction -from ..dataset import DataSet -from ..projector import SimpleProjector -from ..regularisator import FirstOrderRegularisator -from ..forwardmodel import ForwardModel -from ..costfunction import Costfunction -from .pm import pm - -__all__ = ['reconstruction_2d_from_phasemap'] -_log = logging.getLogger(__name__) - - -def reconstruction_2d_from_phasemap(phasemap, b_0=1, lam=1E-3, max_iter=100, ramp_order=1, - plot_results=False, ar_dens=None, verbose=True): - """Convenience function for reconstructing a projected distribution from a single phasemap. - - Parameters - ---------- - phasemap: :class:`~PhaseMap` - The phasemap which is used for the reconstruction. - b_0 : float, optional - The magnetic induction corresponding to a magnetization `M`\ :sub:`0` in T. - The default is 1. - lam : float - Regularisation parameter determining the weighting between measurements and regularisation. - max_iter : int, optional - The maximum number of iterations for the opimization. - ramp_order : int or None (default) - Polynomial order of the additional phase ramp which will be added to the phase maps. - All ramp parameters have to be at the end of the input vector and are split automatically. - Default is None (no ramps are added). - plot_results: boolean, optional - If True, the results are plotted after reconstruction. - ar_dens: int, optional - Number defining the arrow density which is plotted. A higher ar_dens number skips more - arrows (a number of 2 plots every second arrow). Default is 1. - verbose: bool, optional - If set to True, information like a progressbar is displayed during reconstruction. - The default is False. - - Returns - ------- - magdata_rec, cost: :class:`~.VectorData`, :class:`~.Costfunction` - The reconstructed magnetisation distribution and the used costfunction. - - """ - _log.debug('Calling reconstruction_2d_from_phasemap') - # Construct DataSet, Regularisator, ForwardModel and Costfunction: - dim = (1,) + phasemap.dim_uv - data = DataSet(phasemap.a, dim, b_0) - data.append(phasemap, SimpleProjector(dim)) - data.set_3d_mask() - fwd_model = ForwardModel(data, ramp_order) - reg = FirstOrderRegularisator(data.mask, lam, add_params=fwd_model.ramp.n) - cost = Costfunction(fwd_model, reg) - # Reconstruct: - magdata_rec = reconstruction.optimize_linear(cost, max_iter=max_iter, verbose=verbose) - param_cache = cost.fwd_model.ramp.param_cache - if ramp_order is None: - offset, ramp = 0, (0, 0) - elif ramp_order >= 1: - offset, ramp = param_cache[0][0], (param_cache[1][0], param_cache[2][0]) - elif ramp_order == 0: - offset, ramp = param_cache[0][0], (0, 0) - else: - raise ValueError('ramp_order has to be a positive integer or None!') - # Plot stuff: - if plot_results: - if ar_dens is None: - ar_dens = np.max([1, np.max(dim) // 64]) - magdata_rec.plot_quiver_field(note='Reconstructed Distribution', - ar_dens=ar_dens, figsize=(16, 16)) - phasemap_rec = pm(magdata_rec) - gain = 4 * 2 * np.pi / (np.abs(phasemap_rec.phase).max() + 1E-30) - gain = round(gain, -int(np.floor(np.log10(abs(gain))))) - vmin = phasemap_rec.phase.min() - vmax = phasemap_rec.phase.max() - phasemap.plot_combined(note='Input Phase', gain=gain) - phasemap -= fwd_model.ramp(index=0) - phasemap.plot_combined(note='Input Phase (ramp corrected)', gain=gain, vmin=vmin, vmax=vmax) - title = 'Reconstructed Phase' - if ramp_order is not None: - if ramp_order >= 0: - print('offset:', offset) - title += ', fitted Offset: {:.2g} [rad]'.format(offset) - if ramp_order >= 1: - print('ramp:', ramp) - title += ', (Fitted Ramp: (u:{:.2g}, v:{:.2g}) [rad/nm]'.format(*ramp) - phasemap_rec.plot_combined(note=title, gain=gain, vmin=vmin, vmax=vmax) - diff = (phasemap_rec - phasemap) - diff_name = 'Difference (RMS: {:.2g} rad)'.format(np.sqrt(np.mean(diff.phase) ** 2)) - diff.plot_phase_with_hist(note=diff_name, sigma_clip=3) - if ramp_order is not None: - ramp = fwd_model.ramp(0) - ramp.plot_phase(note='Fitted Ramp') - # Return reconstructed magnetisation distribution and cost function: - return magdata_rec, cost +# -*- coding: utf-8 -*- +# Copyright 2016 by Forschungszentrum Juelich GmbH +# Author: J. Caron +# +"""Reconstruct a magnetization distributions from a single phase map.""" + +import logging + +import numpy as np + +from .. import reconstruction +from ..dataset import DataSet +from ..projector import SimpleProjector +from ..regularisator import FirstOrderRegularisator +from ..forwardmodel import ForwardModel +from ..costfunction import Costfunction +from .pm import pm + +__all__ = ['reconstruction_2d_from_phasemap'] +_log = logging.getLogger(__name__) + + +def reconstruction_2d_from_phasemap(phasemap, b_0=1, lam=1E-3, max_iter=100, ramp_order=1, + plot_results=False, ar_dens=None, verbose=True): + """Convenience function for reconstructing a projected distribution from a single phasemap. + + Parameters + ---------- + phasemap: :class:`~PhaseMap` + The phasemap which is used for the reconstruction. + b_0 : float, optional + The magnetic induction corresponding to a magnetization `M`\ :sub:`0` in T. + The default is 1. + lam : float + Regularisation parameter determining the weighting between measurements and regularisation. + max_iter : int, optional + The maximum number of iterations for the opimization. + ramp_order : int or None (default) + Polynomial order of the additional phase ramp which will be added to the phase maps. + All ramp parameters have to be at the end of the input vector and are split automatically. + Default is None (no ramps are added). + plot_results: boolean, optional + If True, the results are plotted after reconstruction. + ar_dens: int, optional + Number defining the arrow density which is plotted. A higher ar_dens number skips more + arrows (a number of 2 plots every second arrow). Default is 1. + verbose: bool, optional + If set to True, information like a progressbar is displayed during reconstruction. + The default is False. + + Returns + ------- + magdata_rec, cost: :class:`~.VectorData`, :class:`~.Costfunction` + The reconstructed magnetisation distribution and the used costfunction. + + """ + _log.debug('Calling reconstruction_2d_from_phasemap') + # Construct DataSet, Regularisator, ForwardModel and Costfunction: + dim = (1,) + phasemap.dim_uv + data = DataSet(phasemap.a, dim, b_0) + data.append(phasemap, SimpleProjector(dim)) + data.set_3d_mask() + fwd_model = ForwardModel(data, ramp_order) + reg = FirstOrderRegularisator(data.mask, lam, add_params=fwd_model.ramp.n) + cost = Costfunction(fwd_model, reg) + # Reconstruct: + magdata_rec = reconstruction.optimize_linear(cost, max_iter=max_iter, verbose=verbose) + param_cache = cost.fwd_model.ramp.param_cache + if ramp_order is None: + offset, ramp = 0, (0, 0) + elif ramp_order >= 1: + offset, ramp = param_cache[0][0], (param_cache[1][0], param_cache[2][0]) + elif ramp_order == 0: + offset, ramp = param_cache[0][0], (0, 0) + else: + raise ValueError('ramp_order has to be a positive integer or None!') + # Plot stuff: + if plot_results: + if ar_dens is None: + ar_dens = np.max([1, np.max(dim) // 64]) + magdata_rec.plot_quiver_field(note='Reconstructed Distribution', + ar_dens=ar_dens, figsize=(16, 16)) + phasemap_rec = pm(magdata_rec) + gain = 4 * 2 * np.pi / (np.abs(phasemap_rec.phase).max() + 1E-30) + gain = round(gain, -int(np.floor(np.log10(abs(gain))))) + vmin = phasemap_rec.phase.min() + vmax = phasemap_rec.phase.max() + phasemap.plot_combined(note='Input Phase', gain=gain) + phasemap -= fwd_model.ramp(index=0) + phasemap.plot_combined(note='Input Phase (ramp corrected)', gain=gain, vmin=vmin, vmax=vmax) + title = 'Reconstructed Phase' + if ramp_order is not None: + if ramp_order >= 0: + print('offset:', offset) + title += ', fitted Offset: {:.2g} [rad]'.format(offset) + if ramp_order >= 1: + print('ramp:', ramp) + title += ', (Fitted Ramp: (u:{:.2g}, v:{:.2g}) [rad/nm]'.format(*ramp) + phasemap_rec.plot_combined(note=title, gain=gain, vmin=vmin, vmax=vmax) + diff = (phasemap_rec - phasemap) + diff_name = 'Difference (RMS: {:.2g} rad)'.format(np.sqrt(np.mean(diff.phase) ** 2)) + diff.plot_phase_with_hist(note=diff_name, sigma_clip=3) + if ramp_order is not None: + ramp = fwd_model.ramp(0) + ramp.plot_phase(note='Fitted Ramp') + # Return reconstructed magnetisation distribution and cost function: + return magdata_rec, cost diff --git a/pyramid/utils/reconstruction_3d_from_magdata.py b/pyramid/utils/reconstruction_3d_from_magdata.py index b0d1bd3082f6450ed9f10bcfc53c102917f3a507..db10f9253e55a9cdb93b7907fbc3b2da2394f52c 100644 --- a/pyramid/utils/reconstruction_3d_from_magdata.py +++ b/pyramid/utils/reconstruction_3d_from_magdata.py @@ -1,151 +1,151 @@ -# -*- coding: utf-8 -*- -# Copyright 2016 by Forschungszentrum Juelich GmbH -# Author: J. Caron -# -"""Reconstruct a magnetization distributions from phase maps created from it.""" - -import logging - -import numpy as np - -import multiprocessing as mp - -from .. import reconstruction -from ..dataset import DataSet -from ..projector import XTiltProjector, YTiltProjector -from ..ramp import Ramp -from ..regularisator import FirstOrderRegularisator -from ..forwardmodel import ForwardModel, DistributedForwardModel -from ..costfunction import Costfunction -from ..phasemapper import PhaseMapperRDFC -from ..kernel import Kernel - -__all__ = ['reconstruction_3d_from_magdata'] -_log = logging.getLogger(__name__) - - -def reconstruction_3d_from_magdata(magdata, b_0=1, lam=1E-3, max_iter=100, ramp_order=1, - angles=np.linspace(-90, 90, num=19), dim_uv=None, - axes=(True, True), noise=0, offset_max=0, ramp_max=0, - use_internal_mask=True, plot_results=False, plot_input=False, - ar_dens=None, multicore=False, verbose=True): - """Convenience function for reconstructing a projected distribution from a single phasemap. - - Parameters - ---------- - magdata: :class:`~.VectorData` - The magnetisation distribution which should be used for the reconstruction. - b_0 : float, optional - The magnetic induction corresponding to a magnetization `M`\ :sub:`0` in T. - The default is 1. - lam : float - Regularisation parameter determining the weighting between measurements and regularisation. - max_iter : int, optional - The maximum number of iterations for the opimization. - ramp_order : int or None (default) - Polynomial order of the additional phase ramp which will be added to the phase maps. - All ramp parameters have to be at the end of the input vector and are split automatically. - Default is None (no ramps are added). - angles: :class:`~numpy.ndarray` (N=1), optional - Numpy array determining the angles which should be used for the projectors in x- and - y-direction. This implicitly sets the number of images per rotation axis. Defaults to a - range from -90° to 90° degrees, in 10° steps. - dim_uv: int or None (default) - Determines if the phasemaps should be padded to a certain size while calculating. - axes: tuple of booleans (N=2), optional - Determines if both tilt axes should be calculated. The order is (x, y), both are True by - default. - noise: float, optional - If this is not zero, random gaussian noise with this as a maximum value will be applied - to all calculated phasemaps. The default is 0. - offset_max: float, optional - if this is not zero, a random offset with this as a maximum value will be applied to all - calculated phasemaps. The default is 0. - ramp_max: float, optional - if this is not zero, a random linear ramp with this as a maximum value will be applied - to both axes of all calculated phasemaps. The default is 0. - use_internal_mask: boolean, optional - If True, the mask from the input magnetization distribution is taken for the - reconstruction. If False, the mask is calculated via logic backprojection from the 2D-masks - of the input phasemaps. - plot_results: boolean, optional - If True, the results are plotted after reconstruction. - plot_input: - If True, the input phasemaps are plotted after reconstruction. - ar_dens: int, optional - Number defining the arrow density which is plotted. A higher ar_dens number skips more - arrows (a number of 2 plots every second arrow). Default is 1. - multicore: boolean, optional - Determines if multiprocessing should be used. Default is True. Phasemap calculations - will be divided onto the separate cores. - verbose: bool, optional - If set to True, information like a progressbar is displayed during reconstruction. - The default is False. - - Returns - ------- - magdata_rec, cost: :class:`~.VectorData`, :class:`~.Costfunction` - The reconstructed magnetisation distribution and the used costfunction. - - """ - _log.debug('Calling reconstruction_3d_from_magdata') - # Construct DataSet: - dim = magdata.dim - if ar_dens is None: - ar_dens = np.max([1, np.max(dim) // 128]) - data = DataSet(magdata.a, magdata.dim, b_0) - # Construct projectors: - projectors = [] - # Construct data set and regularisator: - for angle in angles: - angle_rad = angle * np.pi / 180 - if axes[0]: - projectors.append(XTiltProjector(magdata.dim, angle_rad, dim_uv)) - if axes[1]: - projectors.append(YTiltProjector(magdata.dim, angle_rad, dim_uv)) - # Add pairs of projectors and according phasemaps to the DataSet: - for projector in projectors: - mag_proj = projector(magdata) - phasemap = PhaseMapperRDFC(Kernel(magdata.a, projector.dim_uv, b_0))(mag_proj) - phasemap.mask = mag_proj.get_mask()[0, ...] - data.append(phasemap, projector) - # Add offset and ramp if necessary: - for i, phasemap in enumerate(data.phasemaps): - offset = np.random.uniform(-offset_max, offset_max) - ramp_u = np.random.uniform(-ramp_max, ramp_max) - ramp_v = np.random.uniform(-ramp_max, ramp_max) - phasemap += Ramp.create_ramp(phasemap.a, phasemap.dim_uv, (offset, ramp_u, ramp_v)) - # Add noise if necessary: - if noise != 0: - for i, phasemap in enumerate(data.phasemaps): - phasemap.phase += np.random.normal(0, noise, phasemap.dim_uv) - data.phasemaps[i] = phasemap - # Construct mask: - if use_internal_mask: - data.mask = magdata.get_mask() # Use perfect mask from magdata! - else: - data.set_3d_mask() # Construct mask from 2D phase masks! - # Construct regularisator, forward model and costfunction: - if multicore: - mp.freeze_support() - fwd_model = DistributedForwardModel(data, ramp_order=ramp_order, nprocs=mp.cpu_count()) - else: - fwd_model = ForwardModel(data, ramp_order=ramp_order) - reg = FirstOrderRegularisator(data.mask, lam, add_params=fwd_model.ramp.n) - cost = Costfunction(fwd_model, reg) - # Reconstruct and save: - magdata_rec = reconstruction.optimize_linear(cost, max_iter=max_iter, verbose=verbose) - # Finalize ForwardModel (returns workers if multicore): - fwd_model.finalize() - # Plot input: - if plot_input: - data.plot_phasemaps() - # Plot results: - if plot_results: - data.plot_mask(ar_dens=ar_dens) - magdata.plot_quiver3d('Original Distribution', ar_dens=ar_dens) - magdata_rec.plot_quiver3d('Reconstructed Distribution (angle)', ar_dens=ar_dens) - magdata_rec.plot_quiver3d('Reconstructed Distribution (amplitude)', - ar_dens=ar_dens, coloring='amplitude') - # Return reconstructed magnetisation distribution and cost function: - return magdata_rec, cost +# -*- coding: utf-8 -*- +# Copyright 2016 by Forschungszentrum Juelich GmbH +# Author: J. Caron +# +"""Reconstruct a magnetization distributions from phase maps created from it.""" + +import logging + +import numpy as np + +import multiprocessing as mp + +from .. import reconstruction +from ..dataset import DataSet +from ..projector import XTiltProjector, YTiltProjector +from ..ramp import Ramp +from ..regularisator import FirstOrderRegularisator +from ..forwardmodel import ForwardModel, DistributedForwardModel +from ..costfunction import Costfunction +from ..phasemapper import PhaseMapperRDFC +from ..kernel import Kernel + +__all__ = ['reconstruction_3d_from_magdata'] +_log = logging.getLogger(__name__) + + +def reconstruction_3d_from_magdata(magdata, b_0=1, lam=1E-3, max_iter=100, ramp_order=1, + angles=np.linspace(-90, 90, num=19), dim_uv=None, + axes=(True, True), noise=0, offset_max=0, ramp_max=0, + use_internal_mask=True, plot_results=False, plot_input=False, + ar_dens=None, multicore=False, verbose=True): + """Convenience function for reconstructing a projected distribution from a single phasemap. + + Parameters + ---------- + magdata: :class:`~.VectorData` + The magnetisation distribution which should be used for the reconstruction. + b_0 : float, optional + The magnetic induction corresponding to a magnetization `M`\ :sub:`0` in T. + The default is 1. + lam : float + Regularisation parameter determining the weighting between measurements and regularisation. + max_iter : int, optional + The maximum number of iterations for the opimization. + ramp_order : int or None (default) + Polynomial order of the additional phase ramp which will be added to the phase maps. + All ramp parameters have to be at the end of the input vector and are split automatically. + Default is None (no ramps are added). + angles: :class:`~numpy.ndarray` (N=1), optional + Numpy array determining the angles which should be used for the projectors in x- and + y-direction. This implicitly sets the number of images per rotation axis. Defaults to a + range from -90° to 90° degrees, in 10° steps. + dim_uv: int or None (default) + Determines if the phasemaps should be padded to a certain size while calculating. + axes: tuple of booleans (N=2), optional + Determines if both tilt axes should be calculated. The order is (x, y), both are True by + default. + noise: float, optional + If this is not zero, random gaussian noise with this as a maximum value will be applied + to all calculated phasemaps. The default is 0. + offset_max: float, optional + if this is not zero, a random offset with this as a maximum value will be applied to all + calculated phasemaps. The default is 0. + ramp_max: float, optional + if this is not zero, a random linear ramp with this as a maximum value will be applied + to both axes of all calculated phasemaps. The default is 0. + use_internal_mask: boolean, optional + If True, the mask from the input magnetization distribution is taken for the + reconstruction. If False, the mask is calculated via logic backprojection from the 2D-masks + of the input phasemaps. + plot_results: boolean, optional + If True, the results are plotted after reconstruction. + plot_input: + If True, the input phasemaps are plotted after reconstruction. + ar_dens: int, optional + Number defining the arrow density which is plotted. A higher ar_dens number skips more + arrows (a number of 2 plots every second arrow). Default is 1. + multicore: boolean, optional + Determines if multiprocessing should be used. Default is True. Phasemap calculations + will be divided onto the separate cores. + verbose: bool, optional + If set to True, information like a progressbar is displayed during reconstruction. + The default is False. + + Returns + ------- + magdata_rec, cost: :class:`~.VectorData`, :class:`~.Costfunction` + The reconstructed magnetisation distribution and the used costfunction. + + """ + _log.debug('Calling reconstruction_3d_from_magdata') + # Construct DataSet: + dim = magdata.dim + if ar_dens is None: + ar_dens = np.max([1, np.max(dim) // 128]) + data = DataSet(magdata.a, magdata.dim, b_0) + # Construct projectors: + projectors = [] + # Construct data set and regularisator: + for angle in angles: + angle_rad = angle * np.pi / 180 + if axes[0]: + projectors.append(XTiltProjector(magdata.dim, angle_rad, dim_uv)) + if axes[1]: + projectors.append(YTiltProjector(magdata.dim, angle_rad, dim_uv)) + # Add pairs of projectors and according phasemaps to the DataSet: + for projector in projectors: + mag_proj = projector(magdata) + phasemap = PhaseMapperRDFC(Kernel(magdata.a, projector.dim_uv, b_0))(mag_proj) + phasemap.mask = mag_proj.get_mask()[0, ...] + data.append(phasemap, projector) + # Add offset and ramp if necessary: + for i, phasemap in enumerate(data.phasemaps): + offset = np.random.uniform(-offset_max, offset_max) + ramp_u = np.random.uniform(-ramp_max, ramp_max) + ramp_v = np.random.uniform(-ramp_max, ramp_max) + phasemap += Ramp.create_ramp(phasemap.a, phasemap.dim_uv, (offset, ramp_u, ramp_v)) + # Add noise if necessary: + if noise != 0: + for i, phasemap in enumerate(data.phasemaps): + phasemap.phase += np.random.normal(0, noise, phasemap.dim_uv) + data.phasemaps[i] = phasemap + # Construct mask: + if use_internal_mask: + data.mask = magdata.get_mask() # Use perfect mask from magdata! + else: + data.set_3d_mask() # Construct mask from 2D phase masks! + # Construct regularisator, forward model and costfunction: + if multicore: + mp.freeze_support() + fwd_model = DistributedForwardModel(data, ramp_order=ramp_order, nprocs=mp.cpu_count()) + else: + fwd_model = ForwardModel(data, ramp_order=ramp_order) + reg = FirstOrderRegularisator(data.mask, lam, add_params=fwd_model.ramp.n) + cost = Costfunction(fwd_model, reg) + # Reconstruct and save: + magdata_rec = reconstruction.optimize_linear(cost, max_iter=max_iter, verbose=verbose) + # Finalize ForwardModel (returns workers if multicore): + fwd_model.finalize() + # Plot input: + if plot_input: + data.plot_phasemaps() + # Plot results: + if plot_results: + data.plot_mask(ar_dens=ar_dens) + magdata.plot_quiver3d('Original Distribution', ar_dens=ar_dens) + magdata_rec.plot_quiver3d('Reconstructed Distribution (angle)', ar_dens=ar_dens) + magdata_rec.plot_quiver3d('Reconstructed Distribution (amplitude)', + ar_dens=ar_dens, coloring='amplitude') + # Return reconstructed magnetisation distribution and cost function: + return magdata_rec, cost diff --git a/pyramid/version.py b/pyramid/version.py new file mode 100644 index 0000000000000000000000000000000000000000..873cdeeadad16271c5e3ff7ee1499dd70a6d640d --- /dev/null +++ b/pyramid/version.py @@ -0,0 +1,4 @@ +# -*- coding: utf-8 -*- +""""This file is generated automatically by the Pyramid `setup.py`""" +version = "0.1.0-dev" +hg_revision = "???" # TODO: Now uses git! Maybe delete alltogether? See setup.py!