Skip to content
Snippets Groups Projects
Commit 549254a6 authored by Jan Caron's avatar Jan Caron
Browse files

Added utils subpackage and much more!

utils: Has quaternion and misc submodules
    misc has levi_civita and interp_to_regular_grid (used later for io)
field: Added HyperSpy support (get/set_signal)
    added copy/rotate/rot90/set_vector/get_vector methods
parent 0db0c54b
No related branches found
No related tags found
No related merge requests found
......@@ -27,7 +27,7 @@ EMPyRe is available on the `Python Package Index <http://pypi.python.org/pypi>`_
Per default, only the strictly required libraries are installed, but there are a few additional dependencies that will unlock additional capabilites of EMPYRE.
* ``io`` will install the `HyperSpy <https://hyperspy.org/>`_ package that is used for loading and saving additional file formats.
* ``io`` will install the `tvtk <http://docs.enthought.com/mayavi/tvtk/README.html>`_ & `HyperSpy <https://hyperspy.org/>`_ packages that are used for loading and saving additional file formats.
.. warning::
Due to this `issue <https://github.com/hyperspy/hyperspy/issues/2315>`_, a pip install of hyperspy is currently not possible. Please use
......@@ -51,17 +51,19 @@ You can choose these settings by using, *e.g.*:
Structure
---------
EMPyRe has several dedicated submodules which are fully documented `here <https://empyre.iffgit.fz-juelich.de/empyre/>`_!
EMPyRe has several dedicated modules which are fully documented `here <https://empyre.iffgit.fz-juelich.de/empyre/>`_!
* The ``fields`` submodule provides the ``Field`` container class for multidimensional scalar or vector fields and is the fundamental data structure used in EMPyRe.
* The ``fields`` module provides the ``Field`` container class for multidimensional scalar or vector fields and is the fundamental data structure used in EMPyRe.
* The ``vis`` submodule enables the plotting of ``Field`` objects, based on and similar in syntax to the commonly known `matplotlib <https://matplotlib.org/>`_ framework.
* The ``vis`` module enables the plotting of ``Field`` objects, based on and similar in syntax to the commonly known `matplotlib <https://matplotlib.org/>`_ framework.
* The ``models`` submodule provides tools for constructing forward models that describe processes in Electron Microscopy.
* The ``models`` module provides tools for constructing forward models that describe processes in Electron Microscopy.
* The ``reconstruct`` submodule is a collection of tools for solving the inverse problems corresponding to the constructed forward models and diagnostic tools for their assessment.
* The ``reconstruct`` module is a collection of tools for solving the inverse problems corresponding to the constructed forward models and diagnostic tools for their assessment.
* The ``io`` submodule is used to load and save ``Field`` objects and the models generated by the ``models`` subpackage.
* The ``io`` module is used to load and save ``Field`` objects and the models generated by the ``models`` subpackage.
* The ``utils`` module, which houses utility functionality used throughout EMPyRe.
......
......@@ -4,24 +4,24 @@ The Field container class
General empyre.fields docu here!
The field submodule
-------------------
The field module
----------------
.. automodule:: empyre.fields.field
:members:
:show-inheritance:
The shapes submodule
--------------------
The shapes module
-----------------
.. automodule:: empyre.fields.shapes
:members:
:show-inheritance:
The vectors submodule
---------------------
The vectors module
------------------
.. automodule:: empyre.fields.vectors
:members:
......
......@@ -4,32 +4,32 @@ The vis visualization submodule
General empyre.vis docu here!
The plot2d submodule
--------------------
The plot2d module
-----------------
.. automodule:: empyre.vis.plot2d
:members:
:show-inheritance:
The decorators submodule
------------------------
The decorators module
---------------------
.. automodule:: empyre.vis.decorators
:members:
:show-inheritance:
The colors submodule
--------------------
The colors module
-----------------
.. automodule:: empyre.vis.colors
:members:
:show-inheritance:
The tools submodule
-------------------
The tools module
----------------
.. automodule:: empyre.vis.tools
:members:
......
......@@ -40,12 +40,13 @@ dependencies:
# Documentation:
- sphinx=2.4
- numpydoc=0.9
- sphinx_rtd_theme=0.4 # TODO: not needed?
- sphinx_rtd_theme=0.4
# IPython and notebooks:
- ipython=7.7
- jupyter=1.0
- nb_conda=2.2
#- ptvsd=4.3 # Cell debugging in VS Code
- rope=0.16 # for refactoring in VS Code
# TODO: - ipywidgets
# TODO: Add back GUI dependencies!
# TODO: Get Jutil from gitlab (currently doesn't work, git and cygwin don't play nice,...
......
......@@ -56,6 +56,7 @@ where = src
[options.extras_require]
io =
hyperspy
tvtk
fftw =
pyfftw
colors =
......
......@@ -4,11 +4,13 @@
#
"""Subpackage containing functionality for visualisation of multidimensional fields."""
from . import fields
from . import io
from . import models
from . import reconstruct
from . import vis
from . import utils
from .version import version as __version__
from .version import git_revision as __git_revision__
......@@ -18,7 +20,7 @@ _log.info(f'Imported EMPyRe V-{__version__} GIT-{__git_revision__}')
del logging
__all__ = ['fields', 'io', 'models', 'reconstruct', 'vis']
__all__ = ['fields', 'io', 'models', 'reconstruct', 'vis', 'utils']
del version
......@@ -4,6 +4,7 @@
#
"""Subpackage containing container classes for multidimensional fields and ways to create them."""
from .field import *
from .shapes import *
from .vectors import *
......
......@@ -13,6 +13,8 @@ import numpy as np
from numpy.core import numeric
from scipy.ndimage import interpolation
from ..utils import Quaternion
__all__ = ['Field']
......@@ -95,7 +97,8 @@ class Field(NDArrayOperatorsMixin):
if isinstance(scale, Number): # Scale is the same for each dimension!
self.__scale = (scale,) * len(self.dim)
elif isinstance(scale, tuple):
assert len(scale) == len(self.dim), f'Each dimension {self.dim} needs a scale, but {scale} was given!'
ndim = len(self.dim)
assert len(scale) == ndim, f'Each of the {ndim} dimensions needs a scale, but {scale} was given!'
self.__scale = scale
else:
raise AssertionError('Scaling has to be a number or a tuple of numbers!')
......@@ -288,6 +291,256 @@ class Field(NDArrayOperatorsMixin):
assert all(item.scale == scalar_list[0].scale for item in scalar_list), 'Scales of fields must match!'
return cls(np.stack(scalar_list, axis=-1), scalar_list[0].scale, vector=True)
@classmethod
def from_signal(cls, signal, scale=None, vector=False):
"""Convert a :class:`~hyperspy.signals.Signal` object to a :class:`~.Field` object.
Parameters
----------
signal: :class:`~hyperspy.signals.Signal`
The :class:`~hyperspy.signals.Signal` object which should be converted to Field.
Returns
-------
field: :class:`~.Field`
A :class:`~.Field` object containing the loaded data.
scale: tuple of float, optional
Scaling along the dimensions of the underlying data. If not provided, will be read from the axes_manager.
vector: boolean, optional
If set to True, forces the signal to be interpreted as a vector field. EMPyRe will check if the first axis
is named 'vector components' (EMPyRe saves vector fields like this). If this is the case, vector will be
automatically set to True and the signal will also be interpreted as a vector field.
Notes
-----
Signals recquire the hyperspy package!
"""
cls._log.debug('Calling from_signal')
data = signal.data
if signal.axes_manager[0].name == 'vector components':
vector = True # Automatic detection!
if scale is None: # If not provided, try to read from axes_manager:
scale = [signal.axes_manager[i].scale for i in range(len(data.shape) - vector)] # One less axis if vector!
return cls(data, scale, vector)
def to_signal(self):
"""Convert :class:`~.Field` data into a HyperSpy signal.
Returns
-------
signal: :class:`~hyperspy.signals.Signal`
Representation of the :class:`~.Field` object as a HyperSpy Signal.
Notes
-----
This method recquires the hyperspy package!
"""
self._log.debug('Calling to_signal')
try: # Try importing HyperSpy:
import hyperspy.api as hs
except ImportError:
self._log.error('This method recquires the hyperspy package!')
return
# Create signal:
signal = hs.signals.BaseSignal(self.data) # All axes are signal axes!
# Set axes:
if self.vector:
signal.axes_manager[0].name = 'vector components'
for i in range(len(self.dim)):
ax = i+1 if self.vector else i # take component axis into account if vector!
num = ['x', 'y', 'z'][i] if len(self.dim) <= 3 else i
signal.axes_manager[ax].name = f'axis {num}'
signal.axes_manager[ax].scale = self.scale[i]
signal.axes_manager[ax].units = 'nm'
signal.metadata.Signal.title = f"EMPyRe {'vector' if self.vector else 'scalar'} Field"
return signal
def copy(self):
"""Returns a copy of the :class:`~.Field` object.
Returns
-------
field: :class:`~.Field`
A copy of the :class:`~.Field`.
"""
self._log.debug('Calling copy')
return Field(self.data.copy(), self.scale, self.vector)
def rotate(self, angle, axis='z', **kwargs):
"""Rotate the :class:`~.Field`, based on :function:`~scipy.ndimage.interpolation.rotate`.
Rotation direction is from the first towards the second axis. Works for 2D and 3D Fields.
Parameters
----------
angle : float
The rotation angle in degrees.
axis: {'x', 'y', 'z'}, optional
The axis around which the vector field is rotated. Default is 'z'. Ignored for 2D Fields.
Returns
-------
field: :class:`~.Field`
The rotated :class:`~.Field`.
Notes
-----
All additional kwargs are passed through to :function:`~scipy.ndimage.interpolation.rotate`.
The `reshape` parameter, controlling if the output shape is adapted so that the input array is contained
completely in the output is False per default, contrary to :function:`~scipy.ndimage.interpolation.rotate`,
where it is True.
"""
self._log.debug('Calling rotate')
assert len(self.dim) in (2, 3), 'rotate is currently only defined for 2D and 3D Fields!'
kwargs.setdefault('reshape', False) # Default here is no reshaping!
if len(self.dim) == 2: # For 2D, there are only 2 axes:
axis = 'z' # Overwrite potential argument if 2D!
axes = (0, 1) # y and x
else: # 3D case:
axes = {'x': (0, 1), 'y': (0, 2), 'z': (1, 2)}[axis]
if axis == 'z': # TODO: Somehow needed... don't know why, maybe because origin='lower' in imshow?
angle *= -1
sc_0, sc_1 = self.scale[axes[0]], self.scale[axes[1]]
assert sc_0 == sc_1, f'rotate needs the scales in the rotation plane to be equal (they are {sc_0} & {sc_1})!'
if not self.vector: # Scalar field:
data_new = interpolation.rotate(self.data, angle, axes=axes, **kwargs)
else: # Vector field:
# Rotate coordinate system:
comps = [np.asarray(comp) for comp in self.comp]
if self.ncomp == 3:
data_new = np.stack([interpolation.rotate(c, angle, axes=axes, **kwargs) for c in comps], axis=-1)
# Up till now, only the coordinates are rotated, now we need to rotate the vectors inside the voxels:
rot_axis = [i for i in (0, 1, 2) if i not in axes][0] # [0] because list!
i, j, k = axes[0], axes[1], rot_axis # next line only works if i != j != k
levi_civita = int((j-i) * (k-i) * (k-j) / (np.abs(j-i) * np.abs(k-i) * np.abs(k-j)))
# Create Quaternion, note that they have (x, y, z) order instead of (z, y, x):
# TODO: Again a minus sign needed before levi_civita, still not sure why (see angle above)!
quat_axis = tuple([-levi_civita if i == rot_axis else 0 for i in (2, 1, 0)])
quat = Quaternion.from_axisangle(quat_axis, np.deg2rad(angle))
data_new = quat.matrix.dot(data_new.reshape((-1, 3)).T).T.reshape(self.shape) # T needed b.c. ordering!
elif self.ncomp == 2:
u_comp, v_comp = comps
u_rot = interpolation.rotate(u_comp, angle, axes=axes, **kwargs)
v_rot = interpolation.rotate(v_comp, angle, axes=axes, **kwargs)
# Up till now, only the coordinates are rotated, now we need to rotate the vectors inside the voxels:
# TODO: Again a minus sign needed for vector rotation, still not sure why (see angle above)!
ang_rad = np.deg2rad(-angle)
u_mix = np.cos(ang_rad)*u_rot - np.sin(ang_rad)*v_rot
v_mix = np.sin(ang_rad)*u_rot + np.cos(ang_rad)*v_rot
data_new = np.stack((u_mix, v_mix), axis=-1)
else:
raise ValueError('rotate currently only works for 2- or 3-component vector fields!')
return Field(data_new, self.scale, self.vector)
def rot90(self, k=1, axis='z'):
"""Rotate the :class:`~.Field` 90° around the specified axis (right hand rotation).
Parameters
----------
k : integer
Number of times the array is rotated by 90 degrees.
axis: {'x', 'y', 'z'}, optional
The axis around which the vector field is rotated. Default is 'z'. Ignored for 2D Fields.
Returns
-------
field: :class:`~.Field`
The rotated :class:`~.Field`.
"""
self._log.debug('Calling rot90')
assert axis in ('x', 'y', 'z'), "Wrong input! 'x', 'y', 'z' allowed!"
assert len(self.dim) in (2, 3), 'rotate is currently only defined for 2D and 3D Fields!'
if len(self.dim) == 2: # For 2D, there are only 2 axes:
axis = 'z' # Overwrite potential argument if 2D!
axes = (0, 1) # y and x
else: # 3D case:
axes = {'x': (0, 1), 'y': (0, 2), 'z': (1, 2)}[axis]
sc_0, sc_1 = self.scale[axes[0]], self.scale[axes[1]]
assert sc_0 == sc_1, f'rot90 needs the scales in the rotation plane to be equal (they are {sc_0} & {sc_1})!'
# TODO: Later on, rotation could also flip the scale (not implemented here, yet)!
if not self.vector: # Scalar Field:
data_new = np.rot90(self.data, k=k, axes=axes)
else: # Vector Field:
if len(self.dim) == 3: # 3D:
assert self.ncomp == 3, 'rot90 currently only works for vector fields with 3 components in 3D!'
comp_x, comp_y, comp_z = self.comp
if axis == 'x': # RotMatrix for 90°: [[1, 0, 0], [0, 0, -1], [0, 1, 0]]
comp_x_rot = np.rot90(comp_x, k=k, axes=axes)
comp_y_rot = np.rot90(comp_z, k=k, axes=axes)
comp_z_rot = -np.rot90(comp_y, k=k, axes=axes)
elif axis == 'y': # RotMatrix for 90°: [[0, 0, 1], [0, 1, 0], [-1, 0, 0]]
comp_x_rot = np.rot90(comp_z, k=k, axes=axes)
comp_y_rot = np.rot90(comp_y, k=k, axes=axes)
comp_z_rot = -np.rot90(comp_x, k=k, axes=axes)
elif axis == 'z': # RotMatrix for 90°: [[0, -1, 0], [1, 0, 0], [0, 0, 1]]
comp_x_rot = np.rot90(comp_y, k=k, axes=axes)
comp_y_rot = np.rot90(comp_x, k=k, axes=axes)
comp_z_rot = np.rot90(comp_z, k=k, axes=axes)
data_new = np.stack((comp_x_rot, comp_y_rot, comp_z_rot), axis=-1)
if len(self.dim) == 2: # 2D:
assert self.ncomp == 2, 'rot90 currently only works for vector fields with 2 components in 2D!'
comp_x, comp_y = self.comp
comp_x_rot = np.rot90(comp_y, k=k, axes=axes)
comp_y_rot = np.rot90(comp_x, k=k, axes=axes)
data_new = np.stack((comp_x_rot, comp_y_rot), axis=-1)
# Return result:
return Field(data_new, self.scale, self.vector)
def get_vector(self, mask=None):
"""Returns the field as a vector, specified by a mask.
Parameters
----------
mask : :class:`~numpy.ndarray` (boolean)
Masks the pixels from which the entries should be taken. Must be a numpy array for indexing to work!
Returns
-------
vector : :class:`~numpy.ndarray` (N=1)
The vector containing the field of the specified pixels. If the field is a vector field, components are
first vectorised, then concatenated!
"""
self._log.debug('Calling get_vector')
if mask is None: # If not given, take full data:
mask = np.full(self.dim, True)
if self.vector: # Vector field:
return np.ravel([comp.data[mask] for comp in self.comp])
else: # Scalar field:
return np.ravel(self.data[mask])
def set_vector(self, vector, mask=None):
"""Set the field of the masked pixels to the values specified by `vector`.
Parameters
----------
mask : :class:`~numpy.ndarray` (boolean), optional
Masks the pixels from which the field should be taken.
vector : :class:`~numpy.ndarray` (N=1)
The vector containing the field of the specified pixels.
Returns
-------
None
"""
self._log.debug('Calling set_vector')
if mask is None: # If not given, set full data:
mask = np.full(self.dim, True)
if self.vector: # Vector field:
assert np.size(vector) % self.ncomp == 0, 'Vector has to contain all components for every pixel!'
count = np.size(vector) // self.ncomp
for i in range(self.ncomp):
sl = slice(i*count, (i+1)*count)
self.data[..., i][mask] = vector[sl]
else: # Scalar field:
self.data[mask] = vector
def squeeze(self):
"""Squeeze the `Field` object to remove single-dimensional entries in the shape.
......
# -*- coding: utf-8 -*-
# Copyright 2020 by Forschungszentrum Juelich GmbH
# Author: J. Caron
#
"""Subpackage containing utility functionality."""
from .quaternion import *
__all__ = []
__all__.extend(quaternion.__all__)
del quaternion
# -*- coding: utf-8 -*-
# Copyright 2020 by Forschungszentrum Juelich GmbH
# Author: J. Caron
#
"""This module provides the miscellaneous helper functions."""
import logging
import itertools
from time import time
import numpy as np
from tqdm import tqdm
from scipy.spatial import cKDTree, qhull
from scipy.interpolate import LinearNDInterpolator
__all__ = ['levi_civita', 'interp_to_regular_grid']
_log = logging.getLogger(__name__)
def levi_civita(i, j, k):
_log.debug('Calling levi_civita')
return (j-i) * (k-i) * (k-j) / (np.abs(j-i) * np.abs(k-i) * np.abs(k-j))
def interp_to_regular_grid(points, values, scale, scale_factor=1, step=1, convex=True):
"""Interpolate values on points to regular grid.
Parameters
----------
points : np.ndarray, (N, 3)
Array of points, describing the location of the values that should be interpolated. Three columns x, y, z!
values : np.ndarray, (N, c)
Array of values that should be interpolated to the new grid. `c` is the number of components (`1` for scalar
fields, `3` for normal 3D vector fields).
scale : tuple of 3 ints
Scale along each of the 3 spatial dimensions. Usually given in nm.
scale_factor : float, optional
Additional scaling factor that should be used if the original points are not described on a nm-scale. Use this
to convert from nm of `scale` to the unit of the points. By default 1.
step : int, optional
If this is bigger than 1 (the default), only every `step` point is taken into account. Can speed up calculation,
but you'll lose accuracy of your interpolation.
convex : bool, optional
By default True. If this is set to False, additional measures are taken to find holes in the point cloud.
WARNING: this is an experimental feature that should be used with caution!
Returns
-------
interpolation: np.ndarray
Interpolated grid with shape `(zdim, ydim, xdim)` for scalar and `(zdim, ydim, xdim, ncomp)` for vector field.
"""
_log.debug('Calling interpolate_to_regular_grid')
z_uniq = np.unique(points[:, 2])
_log.info(f'unique positions along z: {len(z_uniq)}')
# Translate scale in local units (not necessarily nm), taken care of with `scale_factor`:
scale = tuple([s * scale_factor for s in scale])
# Determine the size of the point cloud of irregular coordinates:
x_min, x_max = points[:, 0].min(), points[:, 0].max()
y_min, y_max = points[:, 1].min(), points[:, 1].max()
z_min, z_max = points[:, 2].min(), points[:, 2].max()
x_diff, y_diff, z_diff = np.ptp(points[:, 0]), np.ptp(points[:, 1]), np.ptp(points[:, 2])
_log.info(f'x-range: {x_min:.2g} <-> {x_max:.2g} ({x_diff:.2g})')
_log.info(f'y-range: {y_min:.2g} <-> {y_max:.2g} ({y_diff:.2g})')
_log.info(f'z-range: {z_min:.2g} <-> {z_max:.2g} ({z_diff:.2g})')
# Determine dimensions from given grid spacing a:
dim = tuple(np.round(np.asarray((z_diff/scale, y_diff/scale, x_diff/scale), dtype=int)))
x = x_min + scale * (np.arange(dim[2]) + 0.5) # +0.5: shift to pixel center!
y = y_min + scale * (np.arange(dim[1]) + 0.5) # +0.5: shift to pixel center!
z = z_min + scale * (np.arange(dim[0]) + 0.5) # +0.5: shift to pixel center!
# Create points for new Euclidian grid; fliplr for (x, y, z) order:
points_euc = np.fliplr(np.asarray(list(itertools.product(z, y, x))))
# Make values 2D (if not already); double .T so that a new axis is added at the END (n, 1):
values = np.atleast_2d(values.T).T
# Prepare interpolated grid:
interpolation = np.empty(dim+(values.shape[-1],), dtype=np.float)
_log.info(f'Dimensions of new grid: {interpolation.shape}')
# Calculate the Delaunay triangulation (same for every component of multidim./vector fields):
_log.info('Start Delaunay triangulation...')
tick = time()
triangulation = qhull.Delaunay(points[::step])
tock = time()
_log.info(f'Delaunay triangulation complete (took {tock-tick:.2f} s)!')
# Perform the interpolation for each column of `values`:
for i in tqdm(range(values.shape[-1])):
# Create interpolator for the given triangulation and the values of the current column:
interpolator = LinearNDInterpolator(triangulation, values[::step, i], fill_value=0)
# Interpolate:
interpolation[..., i] = interpolator(points_euc).reshape(dim)
# If NOT convex, we have to check for additional holes in the structure (EXPERIMENTAL):
if not convex: # Only necessary if the user expects holes in the (-> nonconvex) distribution:
# Create k-dimensional tree for queries:
tree = cKDTree(points)
# Query the tree for nearest neighbors, x: points to query, k: number of neighbors, p: norm
# to use (here: 2 - Euclidean), distance_upper_bound: maximum distance that is searched!
data, leafsize = tree.query(x=points_euc, k=1, p=2, distance_upper_bound=2*np.mean(scale))
# Create boolean mask that determines which interpolation points have no neighbor near enough:
mask = np.isinf(data).reshape(dim) # Points further away than upper bound were marked 'inf'!
for i in tqdm(range(values.shape[-1])): # Set these points to zero (NOTE: This can take a looooong time):
interpolation[..., i].ravel()[mask.ravel()] = 0
return np.squeeze(interpolation)
# -*- coding: utf-8 -*-
# Copyright 2020 by Forschungszentrum Juelich GmbH
# Author: J. Caron
#
"""This module provides the :class:`~.Quaternion` class which can be used for rotations."""
import logging
import numpy as np
__all__ = ['Quaternion']
class Quaternion(object):
R"""Class representing a rotation expressed by a quaternion.
A quaternion is a four-dimensional description of a rotation which can also be described by
a rotation vector (`v1`, `v2`, `v3`) and a rotation angle :math:`\theta`. The four components
are calculated to:
.. math::
w = \cos(\theta/2)
x = v_1 \cdot \sin(\theta/2)
y = v_2 \cdot \sin(\theta/2)
z = v_3 \cdot \sin(\theta/2)
Use the :func:`~.from_axisangle` and :func:`~.to_axisangle` to convert to axis-angle
representation and vice versa. Quaternions can be multiplied by other quaternions, which
results in a new rotation or with a vector, which results in a rotated vector.
Attributes
----------
values : float
The four quaternion values `w`, `x`, `y`, `z`.
"""
NORM_TOLERANCE = 1E-6
_log = logging.getLogger(__name__ + '.Quaternion')
@property
def conj(self):
"""The conjugate of the quaternion, representing a tilt in opposite direction."""
w, x, y, z = self.values
return Quaternion((w, -x, -y, -z))
@property
def matrix(self):
"""The rotation matrix representation of the quaternion."""
w, x, y, z = self.values
return np.array([[1 - 2 * (y ** 2 + z ** 2), 2 * (x * y - w * z), 2 * (x * z + w * y)],
[2 * (x * y + w * z), 1 - 2 * (x ** 2 + z ** 2), 2 * (y * z - w * x)],
[2 * (x * z - w * y), 2 * (y * z + w * x), 1 - 2 * (x ** 2 + y ** 2)]])
def __init__(self, values):
self._log.debug('Calling __init__')
self.values = values
self._normalize()
self._log.debug('Created ' + str(self))
def __mul__(self, other): # self * other
self._log.debug('Calling __mul__')
if isinstance(other, Quaternion): # Quaternion multiplication
return self.dot_quat(self, other)
elif len(other) == 3: # vector multiplication (Caution: normalises!)
q_vec = Quaternion((0,) + tuple(other))
q = self.dot_quat(self.dot_quat(self, q_vec), self.conj)
return q.values[1:]
def _normalize(self):
self._log.debug('Calling _normalize')
mag2 = np.sum(n ** 2 for n in self.values)
if abs(mag2 - 1.0) > self.NORM_TOLERANCE:
mag = np.sqrt(mag2)
self.values = tuple(n / mag for n in self.values)
def dot_quat(self, q1, q2):
"""Multiply two :class:`~.Quaternion` objects to create a new one (always normalized).
Parameters
----------
q1, q2 : :class:`~.Quaternion`
The quaternion which should be multiplied.
Returns
-------
quaternion : :class:`~.Quaternion`
The resulting quaternion.
"""
self._log.debug('Calling dot_quat')
w1, x1, y1, z1 = q1.values
w2, x2, y2, z2 = q2.values
w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2
x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2
y = w1 * y2 + y1 * w2 + z1 * x2 - x1 * z2
z = w1 * z2 + z1 * w2 + x1 * y2 - y1 * x2
return Quaternion((w, x, y, z))
@classmethod
def from_axisangle(cls, vector, theta):
"""Create a quaternion from an axis-angle representation
Parameters
----------
vector : :class:`~numpy.ndarray` (N=3)
Vector around which the rotation is executed.
theta : float
Rotation angle.
Returns
-------
quaternion : :class:`~.Quaternion`
The resulting quaternion.
"""
cls._log.debug('Calling from_axisangle')
x, y, z = vector
theta /= 2.
w = np.cos(theta)
x *= np.sin(theta)
y *= np.sin(theta)
z *= np.sin(theta)
return cls((w, x, y, z))
def to_axisangle(self):
"""Convert the quaternion to axis-angle-representation.
Returns
-------
vector, theta : :class:`~numpy.ndarray` (N=3), float
Vector around which the rotation is executed and rotation angle.
"""
self._log.debug('Calling to_axisangle')
w, x, y, z = self.values
theta = 2.0 * np.arccos(w)
return np.array((x, y, z)), theta
......@@ -2,7 +2,7 @@
# Copyright 2019 by Forschungszentrum Juelich GmbH
# Author: J. Caron
#
"""This module provides a functions for 2D plots that often wrap functions from `maptlotlib.pyplot`."""
"""This module provides functions for 2D plots that often wrap functions from `maptlotlib.pyplot`."""
import logging
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment