From 549254a621dd31a6b579f8b9810c9406bccf595b Mon Sep 17 00:00:00 2001 From: Jan Caron <j.caron@fz-juelich.de> Date: Fri, 20 Mar 2020 11:16:12 +0100 Subject: [PATCH] Added utils subpackage and much more! utils: Has quaternion and misc submodules misc has levi_civita and interp_to_regular_grid (used later for io) field: Added HyperSpy support (get/set_signal) added copy/rotate/rot90/set_vector/get_vector methods --- README.rst | 16 ++- docs/fields.rst | 12 +- docs/vis.rst | 16 +-- environment.yml | 3 +- setup.cfg | 1 + src/empyre/__init__.py | 4 +- src/empyre/fields/__init__.py | 1 + src/empyre/fields/field.py | 255 ++++++++++++++++++++++++++++++++- src/empyre/utils/__init__.py | 15 ++ src/empyre/utils/misc.py | 103 +++++++++++++ src/empyre/utils/quaternion.py | 142 ++++++++++++++++++ src/empyre/vis/plot2d.py | 2 +- 12 files changed, 545 insertions(+), 25 deletions(-) create mode 100644 src/empyre/utils/__init__.py create mode 100644 src/empyre/utils/misc.py create mode 100644 src/empyre/utils/quaternion.py diff --git a/README.rst b/README.rst index 9ab3167..66ef32b 100644 --- a/README.rst +++ b/README.rst @@ -27,7 +27,7 @@ EMPyRe is available on the `Python Package Index <http://pypi.python.org/pypi>`_ Per default, only the strictly required libraries are installed, but there are a few additional dependencies that will unlock additional capabilites of EMPYRE. -* ``io`` will install the `HyperSpy <https://hyperspy.org/>`_ package that is used for loading and saving additional file formats. +* ``io`` will install the `tvtk <http://docs.enthought.com/mayavi/tvtk/README.html>`_ & `HyperSpy <https://hyperspy.org/>`_ packages that are used for loading and saving additional file formats. .. warning:: Due to this `issue <https://github.com/hyperspy/hyperspy/issues/2315>`_, a pip install of hyperspy is currently not possible. Please use @@ -51,17 +51,19 @@ You can choose these settings by using, *e.g.*: Structure --------- -EMPyRe has several dedicated submodules which are fully documented `here <https://empyre.iffgit.fz-juelich.de/empyre/>`_! +EMPyRe has several dedicated modules which are fully documented `here <https://empyre.iffgit.fz-juelich.de/empyre/>`_! -* The ``fields`` submodule provides the ``Field`` container class for multidimensional scalar or vector fields and is the fundamental data structure used in EMPyRe. +* The ``fields`` module provides the ``Field`` container class for multidimensional scalar or vector fields and is the fundamental data structure used in EMPyRe. -* The ``vis`` submodule enables the plotting of ``Field`` objects, based on and similar in syntax to the commonly known `matplotlib <https://matplotlib.org/>`_ framework. +* The ``vis`` module enables the plotting of ``Field`` objects, based on and similar in syntax to the commonly known `matplotlib <https://matplotlib.org/>`_ framework. -* The ``models`` submodule provides tools for constructing forward models that describe processes in Electron Microscopy. +* The ``models`` module provides tools for constructing forward models that describe processes in Electron Microscopy. -* The ``reconstruct`` submodule is a collection of tools for solving the inverse problems corresponding to the constructed forward models and diagnostic tools for their assessment. +* The ``reconstruct`` module is a collection of tools for solving the inverse problems corresponding to the constructed forward models and diagnostic tools for their assessment. -* The ``io`` submodule is used to load and save ``Field`` objects and the models generated by the ``models`` subpackage. +* The ``io`` module is used to load and save ``Field`` objects and the models generated by the ``models`` subpackage. + +* The ``utils`` module, which houses utility functionality used throughout EMPyRe. diff --git a/docs/fields.rst b/docs/fields.rst index 4f07378..e5b3e9e 100644 --- a/docs/fields.rst +++ b/docs/fields.rst @@ -4,24 +4,24 @@ The Field container class General empyre.fields docu here! -The field submodule -------------------- +The field module +---------------- .. automodule:: empyre.fields.field :members: :show-inheritance: -The shapes submodule --------------------- +The shapes module +----------------- .. automodule:: empyre.fields.shapes :members: :show-inheritance: -The vectors submodule ---------------------- +The vectors module +------------------ .. automodule:: empyre.fields.vectors :members: diff --git a/docs/vis.rst b/docs/vis.rst index c882267..9d0d311 100644 --- a/docs/vis.rst +++ b/docs/vis.rst @@ -4,32 +4,32 @@ The vis visualization submodule General empyre.vis docu here! -The plot2d submodule --------------------- +The plot2d module +----------------- .. automodule:: empyre.vis.plot2d :members: :show-inheritance: -The decorators submodule ------------------------- +The decorators module +--------------------- .. automodule:: empyre.vis.decorators :members: :show-inheritance: -The colors submodule --------------------- +The colors module +----------------- .. automodule:: empyre.vis.colors :members: :show-inheritance: -The tools submodule -------------------- +The tools module +---------------- .. automodule:: empyre.vis.tools :members: diff --git a/environment.yml b/environment.yml index 7252c57..aa4e226 100644 --- a/environment.yml +++ b/environment.yml @@ -40,12 +40,13 @@ dependencies: # Documentation: - sphinx=2.4 - numpydoc=0.9 - - sphinx_rtd_theme=0.4 # TODO: not needed? + - sphinx_rtd_theme=0.4 # IPython and notebooks: - ipython=7.7 - jupyter=1.0 - nb_conda=2.2 #- ptvsd=4.3 # Cell debugging in VS Code + - rope=0.16 # for refactoring in VS Code # TODO: - ipywidgets # TODO: Add back GUI dependencies! # TODO: Get Jutil from gitlab (currently doesn't work, git and cygwin don't play nice,... diff --git a/setup.cfg b/setup.cfg index b687bae..4b40fcc 100644 --- a/setup.cfg +++ b/setup.cfg @@ -56,6 +56,7 @@ where = src [options.extras_require] io = hyperspy + tvtk fftw = pyfftw colors = diff --git a/src/empyre/__init__.py b/src/empyre/__init__.py index cf7a732..5fd8188 100644 --- a/src/empyre/__init__.py +++ b/src/empyre/__init__.py @@ -4,11 +4,13 @@ # """Subpackage containing functionality for visualisation of multidimensional fields.""" + from . import fields from . import io from . import models from . import reconstruct from . import vis +from . import utils from .version import version as __version__ from .version import git_revision as __git_revision__ @@ -18,7 +20,7 @@ _log.info(f'Imported EMPyRe V-{__version__} GIT-{__git_revision__}') del logging -__all__ = ['fields', 'io', 'models', 'reconstruct', 'vis'] +__all__ = ['fields', 'io', 'models', 'reconstruct', 'vis', 'utils'] del version diff --git a/src/empyre/fields/__init__.py b/src/empyre/fields/__init__.py index 7e7e5cb..7376dcb 100644 --- a/src/empyre/fields/__init__.py +++ b/src/empyre/fields/__init__.py @@ -4,6 +4,7 @@ # """Subpackage containing container classes for multidimensional fields and ways to create them.""" + from .field import * from .shapes import * from .vectors import * diff --git a/src/empyre/fields/field.py b/src/empyre/fields/field.py index b94f32f..fe8ccf6 100644 --- a/src/empyre/fields/field.py +++ b/src/empyre/fields/field.py @@ -13,6 +13,8 @@ import numpy as np from numpy.core import numeric from scipy.ndimage import interpolation +from ..utils import Quaternion + __all__ = ['Field'] @@ -95,7 +97,8 @@ class Field(NDArrayOperatorsMixin): if isinstance(scale, Number): # Scale is the same for each dimension! self.__scale = (scale,) * len(self.dim) elif isinstance(scale, tuple): - assert len(scale) == len(self.dim), f'Each dimension {self.dim} needs a scale, but {scale} was given!' + ndim = len(self.dim) + assert len(scale) == ndim, f'Each of the {ndim} dimensions needs a scale, but {scale} was given!' self.__scale = scale else: raise AssertionError('Scaling has to be a number or a tuple of numbers!') @@ -288,6 +291,256 @@ class Field(NDArrayOperatorsMixin): assert all(item.scale == scalar_list[0].scale for item in scalar_list), 'Scales of fields must match!' return cls(np.stack(scalar_list, axis=-1), scalar_list[0].scale, vector=True) + @classmethod + def from_signal(cls, signal, scale=None, vector=False): + """Convert a :class:`~hyperspy.signals.Signal` object to a :class:`~.Field` object. + + Parameters + ---------- + signal: :class:`~hyperspy.signals.Signal` + The :class:`~hyperspy.signals.Signal` object which should be converted to Field. + + Returns + ------- + field: :class:`~.Field` + A :class:`~.Field` object containing the loaded data. + scale: tuple of float, optional + Scaling along the dimensions of the underlying data. If not provided, will be read from the axes_manager. + vector: boolean, optional + If set to True, forces the signal to be interpreted as a vector field. EMPyRe will check if the first axis + is named 'vector components' (EMPyRe saves vector fields like this). If this is the case, vector will be + automatically set to True and the signal will also be interpreted as a vector field. + + Notes + ----- + Signals recquire the hyperspy package! + + """ + cls._log.debug('Calling from_signal') + data = signal.data + if signal.axes_manager[0].name == 'vector components': + vector = True # Automatic detection! + if scale is None: # If not provided, try to read from axes_manager: + scale = [signal.axes_manager[i].scale for i in range(len(data.shape) - vector)] # One less axis if vector! + return cls(data, scale, vector) + + def to_signal(self): + """Convert :class:`~.Field` data into a HyperSpy signal. + + Returns + ------- + signal: :class:`~hyperspy.signals.Signal` + Representation of the :class:`~.Field` object as a HyperSpy Signal. + + Notes + ----- + This method recquires the hyperspy package! + + """ + self._log.debug('Calling to_signal') + try: # Try importing HyperSpy: + import hyperspy.api as hs + except ImportError: + self._log.error('This method recquires the hyperspy package!') + return + # Create signal: + signal = hs.signals.BaseSignal(self.data) # All axes are signal axes! + # Set axes: + if self.vector: + signal.axes_manager[0].name = 'vector components' + for i in range(len(self.dim)): + ax = i+1 if self.vector else i # take component axis into account if vector! + num = ['x', 'y', 'z'][i] if len(self.dim) <= 3 else i + signal.axes_manager[ax].name = f'axis {num}' + signal.axes_manager[ax].scale = self.scale[i] + signal.axes_manager[ax].units = 'nm' + signal.metadata.Signal.title = f"EMPyRe {'vector' if self.vector else 'scalar'} Field" + return signal + + def copy(self): + """Returns a copy of the :class:`~.Field` object. + + Returns + ------- + field: :class:`~.Field` + A copy of the :class:`~.Field`. + + """ + self._log.debug('Calling copy') + return Field(self.data.copy(), self.scale, self.vector) + + def rotate(self, angle, axis='z', **kwargs): + """Rotate the :class:`~.Field`, based on :function:`~scipy.ndimage.interpolation.rotate`. + + Rotation direction is from the first towards the second axis. Works for 2D and 3D Fields. + + Parameters + ---------- + angle : float + The rotation angle in degrees. + axis: {'x', 'y', 'z'}, optional + The axis around which the vector field is rotated. Default is 'z'. Ignored for 2D Fields. + + Returns + ------- + field: :class:`~.Field` + The rotated :class:`~.Field`. + + Notes + ----- + All additional kwargs are passed through to :function:`~scipy.ndimage.interpolation.rotate`. + The `reshape` parameter, controlling if the output shape is adapted so that the input array is contained + completely in the output is False per default, contrary to :function:`~scipy.ndimage.interpolation.rotate`, + where it is True. + + """ + self._log.debug('Calling rotate') + assert len(self.dim) in (2, 3), 'rotate is currently only defined for 2D and 3D Fields!' + kwargs.setdefault('reshape', False) # Default here is no reshaping! + if len(self.dim) == 2: # For 2D, there are only 2 axes: + axis = 'z' # Overwrite potential argument if 2D! + axes = (0, 1) # y and x + else: # 3D case: + axes = {'x': (0, 1), 'y': (0, 2), 'z': (1, 2)}[axis] + if axis == 'z': # TODO: Somehow needed... don't know why, maybe because origin='lower' in imshow? + angle *= -1 + sc_0, sc_1 = self.scale[axes[0]], self.scale[axes[1]] + assert sc_0 == sc_1, f'rotate needs the scales in the rotation plane to be equal (they are {sc_0} & {sc_1})!' + if not self.vector: # Scalar field: + data_new = interpolation.rotate(self.data, angle, axes=axes, **kwargs) + else: # Vector field: + # Rotate coordinate system: + comps = [np.asarray(comp) for comp in self.comp] + if self.ncomp == 3: + data_new = np.stack([interpolation.rotate(c, angle, axes=axes, **kwargs) for c in comps], axis=-1) + # Up till now, only the coordinates are rotated, now we need to rotate the vectors inside the voxels: + rot_axis = [i for i in (0, 1, 2) if i not in axes][0] # [0] because list! + i, j, k = axes[0], axes[1], rot_axis # next line only works if i != j != k + levi_civita = int((j-i) * (k-i) * (k-j) / (np.abs(j-i) * np.abs(k-i) * np.abs(k-j))) + # Create Quaternion, note that they have (x, y, z) order instead of (z, y, x): + # TODO: Again a minus sign needed before levi_civita, still not sure why (see angle above)! + quat_axis = tuple([-levi_civita if i == rot_axis else 0 for i in (2, 1, 0)]) + quat = Quaternion.from_axisangle(quat_axis, np.deg2rad(angle)) + data_new = quat.matrix.dot(data_new.reshape((-1, 3)).T).T.reshape(self.shape) # T needed b.c. ordering! + elif self.ncomp == 2: + u_comp, v_comp = comps + u_rot = interpolation.rotate(u_comp, angle, axes=axes, **kwargs) + v_rot = interpolation.rotate(v_comp, angle, axes=axes, **kwargs) + # Up till now, only the coordinates are rotated, now we need to rotate the vectors inside the voxels: + # TODO: Again a minus sign needed for vector rotation, still not sure why (see angle above)! + ang_rad = np.deg2rad(-angle) + u_mix = np.cos(ang_rad)*u_rot - np.sin(ang_rad)*v_rot + v_mix = np.sin(ang_rad)*u_rot + np.cos(ang_rad)*v_rot + data_new = np.stack((u_mix, v_mix), axis=-1) + else: + raise ValueError('rotate currently only works for 2- or 3-component vector fields!') + return Field(data_new, self.scale, self.vector) + + def rot90(self, k=1, axis='z'): + """Rotate the :class:`~.Field` 90° around the specified axis (right hand rotation). + + Parameters + ---------- + k : integer + Number of times the array is rotated by 90 degrees. + axis: {'x', 'y', 'z'}, optional + The axis around which the vector field is rotated. Default is 'z'. Ignored for 2D Fields. + + Returns + ------- + field: :class:`~.Field` + The rotated :class:`~.Field`. + + """ + self._log.debug('Calling rot90') + assert axis in ('x', 'y', 'z'), "Wrong input! 'x', 'y', 'z' allowed!" + assert len(self.dim) in (2, 3), 'rotate is currently only defined for 2D and 3D Fields!' + if len(self.dim) == 2: # For 2D, there are only 2 axes: + axis = 'z' # Overwrite potential argument if 2D! + axes = (0, 1) # y and x + else: # 3D case: + axes = {'x': (0, 1), 'y': (0, 2), 'z': (1, 2)}[axis] + sc_0, sc_1 = self.scale[axes[0]], self.scale[axes[1]] + assert sc_0 == sc_1, f'rot90 needs the scales in the rotation plane to be equal (they are {sc_0} & {sc_1})!' + # TODO: Later on, rotation could also flip the scale (not implemented here, yet)! + if not self.vector: # Scalar Field: + data_new = np.rot90(self.data, k=k, axes=axes) + else: # Vector Field: + if len(self.dim) == 3: # 3D: + assert self.ncomp == 3, 'rot90 currently only works for vector fields with 3 components in 3D!' + comp_x, comp_y, comp_z = self.comp + if axis == 'x': # RotMatrix for 90°: [[1, 0, 0], [0, 0, -1], [0, 1, 0]] + comp_x_rot = np.rot90(comp_x, k=k, axes=axes) + comp_y_rot = np.rot90(comp_z, k=k, axes=axes) + comp_z_rot = -np.rot90(comp_y, k=k, axes=axes) + elif axis == 'y': # RotMatrix for 90°: [[0, 0, 1], [0, 1, 0], [-1, 0, 0]] + comp_x_rot = np.rot90(comp_z, k=k, axes=axes) + comp_y_rot = np.rot90(comp_y, k=k, axes=axes) + comp_z_rot = -np.rot90(comp_x, k=k, axes=axes) + elif axis == 'z': # RotMatrix for 90°: [[0, -1, 0], [1, 0, 0], [0, 0, 1]] + comp_x_rot = np.rot90(comp_y, k=k, axes=axes) + comp_y_rot = np.rot90(comp_x, k=k, axes=axes) + comp_z_rot = np.rot90(comp_z, k=k, axes=axes) + data_new = np.stack((comp_x_rot, comp_y_rot, comp_z_rot), axis=-1) + if len(self.dim) == 2: # 2D: + assert self.ncomp == 2, 'rot90 currently only works for vector fields with 2 components in 2D!' + comp_x, comp_y = self.comp + comp_x_rot = np.rot90(comp_y, k=k, axes=axes) + comp_y_rot = np.rot90(comp_x, k=k, axes=axes) + data_new = np.stack((comp_x_rot, comp_y_rot), axis=-1) + # Return result: + return Field(data_new, self.scale, self.vector) + + def get_vector(self, mask=None): + """Returns the field as a vector, specified by a mask. + + Parameters + ---------- + mask : :class:`~numpy.ndarray` (boolean) + Masks the pixels from which the entries should be taken. Must be a numpy array for indexing to work! + + Returns + ------- + vector : :class:`~numpy.ndarray` (N=1) + The vector containing the field of the specified pixels. If the field is a vector field, components are + first vectorised, then concatenated! + + """ + self._log.debug('Calling get_vector') + if mask is None: # If not given, take full data: + mask = np.full(self.dim, True) + if self.vector: # Vector field: + return np.ravel([comp.data[mask] for comp in self.comp]) + else: # Scalar field: + return np.ravel(self.data[mask]) + + def set_vector(self, vector, mask=None): + """Set the field of the masked pixels to the values specified by `vector`. + + Parameters + ---------- + mask : :class:`~numpy.ndarray` (boolean), optional + Masks the pixels from which the field should be taken. + vector : :class:`~numpy.ndarray` (N=1) + The vector containing the field of the specified pixels. + + Returns + ------- + None + + """ + self._log.debug('Calling set_vector') + if mask is None: # If not given, set full data: + mask = np.full(self.dim, True) + if self.vector: # Vector field: + assert np.size(vector) % self.ncomp == 0, 'Vector has to contain all components for every pixel!' + count = np.size(vector) // self.ncomp + for i in range(self.ncomp): + sl = slice(i*count, (i+1)*count) + self.data[..., i][mask] = vector[sl] + else: # Scalar field: + self.data[mask] = vector + def squeeze(self): """Squeeze the `Field` object to remove single-dimensional entries in the shape. diff --git a/src/empyre/utils/__init__.py b/src/empyre/utils/__init__.py new file mode 100644 index 0000000..ce805ff --- /dev/null +++ b/src/empyre/utils/__init__.py @@ -0,0 +1,15 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 by Forschungszentrum Juelich GmbH +# Author: J. Caron +# +"""Subpackage containing utility functionality.""" + + +from .quaternion import * + + +__all__ = [] +__all__.extend(quaternion.__all__) + + +del quaternion diff --git a/src/empyre/utils/misc.py b/src/empyre/utils/misc.py new file mode 100644 index 0000000..138fe63 --- /dev/null +++ b/src/empyre/utils/misc.py @@ -0,0 +1,103 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 by Forschungszentrum Juelich GmbH +# Author: J. Caron +# +"""This module provides the miscellaneous helper functions.""" + + +import logging +import itertools +from time import time + +import numpy as np +from tqdm import tqdm +from scipy.spatial import cKDTree, qhull +from scipy.interpolate import LinearNDInterpolator + + +__all__ = ['levi_civita', 'interp_to_regular_grid'] +_log = logging.getLogger(__name__) + + +def levi_civita(i, j, k): + _log.debug('Calling levi_civita') + return (j-i) * (k-i) * (k-j) / (np.abs(j-i) * np.abs(k-i) * np.abs(k-j)) + + +def interp_to_regular_grid(points, values, scale, scale_factor=1, step=1, convex=True): + """Interpolate values on points to regular grid. + + Parameters + ---------- + points : np.ndarray, (N, 3) + Array of points, describing the location of the values that should be interpolated. Three columns x, y, z! + values : np.ndarray, (N, c) + Array of values that should be interpolated to the new grid. `c` is the number of components (`1` for scalar + fields, `3` for normal 3D vector fields). + scale : tuple of 3 ints + Scale along each of the 3 spatial dimensions. Usually given in nm. + scale_factor : float, optional + Additional scaling factor that should be used if the original points are not described on a nm-scale. Use this + to convert from nm of `scale` to the unit of the points. By default 1. + step : int, optional + If this is bigger than 1 (the default), only every `step` point is taken into account. Can speed up calculation, + but you'll lose accuracy of your interpolation. + convex : bool, optional + By default True. If this is set to False, additional measures are taken to find holes in the point cloud. + WARNING: this is an experimental feature that should be used with caution! + + Returns + ------- + interpolation: np.ndarray + Interpolated grid with shape `(zdim, ydim, xdim)` for scalar and `(zdim, ydim, xdim, ncomp)` for vector field. + + """ + _log.debug('Calling interpolate_to_regular_grid') + z_uniq = np.unique(points[:, 2]) + _log.info(f'unique positions along z: {len(z_uniq)}') + # Translate scale in local units (not necessarily nm), taken care of with `scale_factor`: + scale = tuple([s * scale_factor for s in scale]) + # Determine the size of the point cloud of irregular coordinates: + x_min, x_max = points[:, 0].min(), points[:, 0].max() + y_min, y_max = points[:, 1].min(), points[:, 1].max() + z_min, z_max = points[:, 2].min(), points[:, 2].max() + x_diff, y_diff, z_diff = np.ptp(points[:, 0]), np.ptp(points[:, 1]), np.ptp(points[:, 2]) + _log.info(f'x-range: {x_min:.2g} <-> {x_max:.2g} ({x_diff:.2g})') + _log.info(f'y-range: {y_min:.2g} <-> {y_max:.2g} ({y_diff:.2g})') + _log.info(f'z-range: {z_min:.2g} <-> {z_max:.2g} ({z_diff:.2g})') + # Determine dimensions from given grid spacing a: + dim = tuple(np.round(np.asarray((z_diff/scale, y_diff/scale, x_diff/scale), dtype=int))) + x = x_min + scale * (np.arange(dim[2]) + 0.5) # +0.5: shift to pixel center! + y = y_min + scale * (np.arange(dim[1]) + 0.5) # +0.5: shift to pixel center! + z = z_min + scale * (np.arange(dim[0]) + 0.5) # +0.5: shift to pixel center! + # Create points for new Euclidian grid; fliplr for (x, y, z) order: + points_euc = np.fliplr(np.asarray(list(itertools.product(z, y, x)))) + # Make values 2D (if not already); double .T so that a new axis is added at the END (n, 1): + values = np.atleast_2d(values.T).T + # Prepare interpolated grid: + interpolation = np.empty(dim+(values.shape[-1],), dtype=np.float) + _log.info(f'Dimensions of new grid: {interpolation.shape}') + # Calculate the Delaunay triangulation (same for every component of multidim./vector fields): + _log.info('Start Delaunay triangulation...') + tick = time() + triangulation = qhull.Delaunay(points[::step]) + tock = time() + _log.info(f'Delaunay triangulation complete (took {tock-tick:.2f} s)!') + # Perform the interpolation for each column of `values`: + for i in tqdm(range(values.shape[-1])): + # Create interpolator for the given triangulation and the values of the current column: + interpolator = LinearNDInterpolator(triangulation, values[::step, i], fill_value=0) + # Interpolate: + interpolation[..., i] = interpolator(points_euc).reshape(dim) + # If NOT convex, we have to check for additional holes in the structure (EXPERIMENTAL): + if not convex: # Only necessary if the user expects holes in the (-> nonconvex) distribution: + # Create k-dimensional tree for queries: + tree = cKDTree(points) + # Query the tree for nearest neighbors, x: points to query, k: number of neighbors, p: norm + # to use (here: 2 - Euclidean), distance_upper_bound: maximum distance that is searched! + data, leafsize = tree.query(x=points_euc, k=1, p=2, distance_upper_bound=2*np.mean(scale)) + # Create boolean mask that determines which interpolation points have no neighbor near enough: + mask = np.isinf(data).reshape(dim) # Points further away than upper bound were marked 'inf'! + for i in tqdm(range(values.shape[-1])): # Set these points to zero (NOTE: This can take a looooong time): + interpolation[..., i].ravel()[mask.ravel()] = 0 + return np.squeeze(interpolation) diff --git a/src/empyre/utils/quaternion.py b/src/empyre/utils/quaternion.py new file mode 100644 index 0000000..84ebaca --- /dev/null +++ b/src/empyre/utils/quaternion.py @@ -0,0 +1,142 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 by Forschungszentrum Juelich GmbH +# Author: J. Caron +# +"""This module provides the :class:`~.Quaternion` class which can be used for rotations.""" + + +import logging + +import numpy as np + + +__all__ = ['Quaternion'] + + +class Quaternion(object): + R"""Class representing a rotation expressed by a quaternion. + + A quaternion is a four-dimensional description of a rotation which can also be described by + a rotation vector (`v1`, `v2`, `v3`) and a rotation angle :math:`\theta`. The four components + are calculated to: + + .. math:: + + w = \cos(\theta/2) + x = v_1 \cdot \sin(\theta/2) + y = v_2 \cdot \sin(\theta/2) + z = v_3 \cdot \sin(\theta/2) + + Use the :func:`~.from_axisangle` and :func:`~.to_axisangle` to convert to axis-angle + representation and vice versa. Quaternions can be multiplied by other quaternions, which + results in a new rotation or with a vector, which results in a rotated vector. + + Attributes + ---------- + values : float + The four quaternion values `w`, `x`, `y`, `z`. + + """ + + NORM_TOLERANCE = 1E-6 + + _log = logging.getLogger(__name__ + '.Quaternion') + + @property + def conj(self): + """The conjugate of the quaternion, representing a tilt in opposite direction.""" + w, x, y, z = self.values + return Quaternion((w, -x, -y, -z)) + + @property + def matrix(self): + """The rotation matrix representation of the quaternion.""" + w, x, y, z = self.values + return np.array([[1 - 2 * (y ** 2 + z ** 2), 2 * (x * y - w * z), 2 * (x * z + w * y)], + [2 * (x * y + w * z), 1 - 2 * (x ** 2 + z ** 2), 2 * (y * z - w * x)], + [2 * (x * z - w * y), 2 * (y * z + w * x), 1 - 2 * (x ** 2 + y ** 2)]]) + + def __init__(self, values): + self._log.debug('Calling __init__') + self.values = values + self._normalize() + self._log.debug('Created ' + str(self)) + + def __mul__(self, other): # self * other + self._log.debug('Calling __mul__') + if isinstance(other, Quaternion): # Quaternion multiplication + return self.dot_quat(self, other) + elif len(other) == 3: # vector multiplication (Caution: normalises!) + q_vec = Quaternion((0,) + tuple(other)) + q = self.dot_quat(self.dot_quat(self, q_vec), self.conj) + return q.values[1:] + + def _normalize(self): + self._log.debug('Calling _normalize') + mag2 = np.sum(n ** 2 for n in self.values) + if abs(mag2 - 1.0) > self.NORM_TOLERANCE: + mag = np.sqrt(mag2) + self.values = tuple(n / mag for n in self.values) + + def dot_quat(self, q1, q2): + """Multiply two :class:`~.Quaternion` objects to create a new one (always normalized). + + Parameters + ---------- + q1, q2 : :class:`~.Quaternion` + The quaternion which should be multiplied. + + Returns + ------- + quaternion : :class:`~.Quaternion` + The resulting quaternion. + + """ + self._log.debug('Calling dot_quat') + w1, x1, y1, z1 = q1.values + w2, x2, y2, z2 = q2.values + w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2 + x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2 + y = w1 * y2 + y1 * w2 + z1 * x2 - x1 * z2 + z = w1 * z2 + z1 * w2 + x1 * y2 - y1 * x2 + return Quaternion((w, x, y, z)) + + @classmethod + def from_axisangle(cls, vector, theta): + """Create a quaternion from an axis-angle representation + + Parameters + ---------- + vector : :class:`~numpy.ndarray` (N=3) + Vector around which the rotation is executed. + theta : float + Rotation angle. + + Returns + ------- + quaternion : :class:`~.Quaternion` + The resulting quaternion. + + """ + cls._log.debug('Calling from_axisangle') + x, y, z = vector + theta /= 2. + w = np.cos(theta) + x *= np.sin(theta) + y *= np.sin(theta) + z *= np.sin(theta) + return cls((w, x, y, z)) + + def to_axisangle(self): + """Convert the quaternion to axis-angle-representation. + + Returns + ------- + vector, theta : :class:`~numpy.ndarray` (N=3), float + Vector around which the rotation is executed and rotation angle. + + """ + self._log.debug('Calling to_axisangle') + w, x, y, z = self.values + theta = 2.0 * np.arccos(w) + return np.array((x, y, z)), theta diff --git a/src/empyre/vis/plot2d.py b/src/empyre/vis/plot2d.py index 2cbfa8e..c86ed0f 100644 --- a/src/empyre/vis/plot2d.py +++ b/src/empyre/vis/plot2d.py @@ -2,7 +2,7 @@ # Copyright 2019 by Forschungszentrum Juelich GmbH # Author: J. Caron # -"""This module provides a functions for 2D plots that often wrap functions from `maptlotlib.pyplot`.""" +"""This module provides functions for 2D plots that often wrap functions from `maptlotlib.pyplot`.""" import logging -- GitLab