Skip to content
Snippets Groups Projects
Commit bda13a16 authored by Jan Caron's avatar Jan Caron
Browse files

fielddata: Lots of stuff not in place anymore! Added rotate function!

projector: RotTiltProjector now ends in the correct coord system!
diagnostics: Refined the LCurve class (now saves stuff!)
lorentz: Added preliminary Lorentz imaging functions!
parent a6ac38e2
No related branches found
No related tags found
No related merge requests found
Pipeline #21851 failed
...@@ -9,6 +9,8 @@ import os ...@@ -9,6 +9,8 @@ import os
import logging import logging
import pickle
from pyramid.forwardmodel import ForwardModel from pyramid.forwardmodel import ForwardModel
from pyramid.costfunction import Costfunction from pyramid.costfunction import Costfunction
from pyramid.regularisator import FirstOrderRegularisator from pyramid.regularisator import FirstOrderRegularisator
...@@ -26,6 +28,14 @@ import numpy as np ...@@ -26,6 +28,14 @@ import numpy as np
import jutil import jutil
try:
if type(get_ipython()).__name__ == 'ZMQInteractiveShell': # IPython Notebook!
from tqdm import tqdm_notebook as tqdm
else: # IPython, but not a Notebook (e.g. terminal)
from tqdm import tqdm
except NameError:
from tqdm import tqdm
__all__ = ['Diagnostics', 'LCurve','get_vector_field_errors'] __all__ = ['Diagnostics', 'LCurve','get_vector_field_errors']
# TODO: should be subpackage, distribute methods and classes to separate modules! # TODO: should be subpackage, distribute methods and classes to separate modules!
...@@ -143,7 +153,7 @@ class Diagnostics(object): ...@@ -143,7 +153,7 @@ class Diagnostics(object):
self._updated_avrg_kern_row = False self._updated_avrg_kern_row = False
self._updated_measure_contribution = False self._updated_measure_contribution = False
def __init__(self, magdata, cost, max_iter=1000, verbose=False): def __init__(self, magdata, cost, max_iter=1000, verbose=False): # TODO: verbose True default
self._log.debug('Calling __init__') self._log.debug('Calling __init__')
self.magdata = magdata self.magdata = magdata
self.cost = cost self.cost = cost
...@@ -390,7 +400,7 @@ class Diagnostics(object): ...@@ -390,7 +400,7 @@ class Diagnostics(object):
artist = axis.add_patch(patches.Ellipse(xy, width, height, fill=False, edgecolor='w', artist = axis.add_patch(patches.Ellipse(xy, width, height, fill=False, edgecolor='w',
linewidth=2, alpha=0.5)) linewidth=2, alpha=0.5))
artist.set_path_effects([patheffects.withStroke(linewidth=4, foreground='k', alpha=0.5)]) artist.set_path_effects([patheffects.withStroke(linewidth=4, foreground='k', alpha=0.5)])
# TODO: Return axis on every plot?
def plot_avrg_kern_field3d(self, pos=None, mask=True, ellipsoid=True, **kwargs): def plot_avrg_kern_field3d(self, pos=None, mask=True, ellipsoid=True, **kwargs):
avrg_kern_field = self.get_avrg_kern_field(pos) avrg_kern_field = self.get_avrg_kern_field(pos)
...@@ -481,84 +491,101 @@ class LCurve(object): ...@@ -481,84 +491,101 @@ class LCurve(object):
_log = logging.getLogger(__name__ + '.FieldData') _log = logging.getLogger(__name__ + '.FieldData')
def __init__(self, fwd_model, max_iter=0, verbose=False, save_dir=None): def __init__(self, fwd_model, max_iter=0, verbose=True, save_dir='lcurve'):
self._log.debug('Calling __init__') self._log.debug('Calling __init__')
assert isinstance(fwd_model, ForwardModel), 'Input has to be a costfunction' assert isinstance(fwd_model, ForwardModel), 'Input has to be a costfunction'
self.fwd_model = fwd_model self.fwd_model = fwd_model
self.max_iter = max_iter self.max_iter = max_iter
self.verbose = verbose self.verbose = verbose
self.lams = [] self.l_dict = {}
self.chisq_a = [] self.save_dir = save_dir
self.chisq_m = [] if self.save_dir is not None:
if save_dir is not None: if not os.path.isdir(self.save_dir): # Create directory if it does not exist:
assert os.path.isdir(save_dir), 'save_dir has to be None or a valid directory!' os.makedirs(self.save_dir)
self.save_dir = save_dir # TODO: Use save_dir!!! if os.path.isfile('{}/lcurve.pkl'.format(self.save_dir)): # Load file if it exists:
self._load()
else: # Create file:
self._save()
self._log.debug('Created ' + str(self)) self._log.debug('Created ' + str(self))
def calculate(self, lam): # TODO: Methods for saving and loading l_dict's!!!
def _save(self):
with open('{}/lcurve.pkl'.format(self.save_dir), 'wb') as f:
pickle.dump(self.l_dict, f, pickle.HIGHEST_PROTOCOL)
def _load(self):
with open('{}/lcurve.pkl'.format(self.save_dir), 'rb') as f:
self.l_dict = pickle.load(f)
def calculate(self, lambdas, overwrite=False):
# TODO: Docstring! # TODO: Docstring!
if lam not in self.lams: lams = np.atleast_1d(lambdas)
# Create new regularisator: # TODO: Not hardcoding FirstOrderRegularisator! for lam in tqdm(lams, disable=not self.verbose):
reg = FirstOrderRegularisator(self.fwd_model.data_set.mask, lam, if lam not in self.l_dict.keys() or overwrite:
add_params=self.fwd_model.ramp.n) # Create new regularisator and costfunction: # TODO: Not hardcoding FirstOrder!
cost = Costfunction(fwd_model=self.fwd_model, regularisator=reg) # TODO: Not necessary if lambda can be extracted from regularisator? self.cost?
# Reconstruct: reg = FirstOrderRegularisator(self.fwd_model.data_set.mask, lam,
magdata_rec = reconstruction.optimize_linear(cost, max_iter=self.max_iter, add_params=self.fwd_model.ramp.n)
verbose=self.verbose) cost = Costfunction(fwd_model=self.fwd_model, regularisator=reg)
# Save magdata_rec if necessary: # Reconstruct:
if self.save_dir is not None: magdata_rec = reconstruction.optimize_linear(cost, max_iter=self.max_iter,
filename = 'magdata_rec_lam{:.0e}.hdf5'.format(lam) verbose=self.verbose)
magdata_rec.save(os.path.join(self.save_dir, filename), overwrite=True) # Add new values to dictionary:
# Append new values: chisq_m, chisq_a = cost.chisq_m[-1], cost.chisq_a[-1] # TODO: chisq_m list or not?
self.lams.append(lam) self.l_dict[lam] = (chisq_m, chisq_a)
chisq_m, chisq_a = cost.chisq_m[-1], cost.chisq_a[-1] # TODO: is chisq_m list or not? self._log.info(lam, ' --> m:', chisq_m, ' a:', chisq_a)
self.chisq_m.append(chisq_m) # Save magdata_rec and dictionary if necessary:
self.chisq_a.append(chisq_a / lam) # TODO: lambda out of regularisator? if self.save_dir is not None:
# Sort lists according to lambdas: filename = 'magdata_rec_lam{:.0e}.hdf5'.format(lam)
tuples = zip(*sorted(zip(self.lams, self.chisq_m, self.chisq_a))) magdata_rec.save(os.path.join(self.save_dir, filename), overwrite=True)
self.lams, self.chisq_m, self.chisq_a = (list(l) for l in tuples) self._save()
self._log.info(lam, ' --> m:', chisq_m, ' a:', chisq_a / lam)
return chisq_m, chisq_a / lam def calculate_auto(self, lam_start=1E-18, lam_end=1E5, online_axis=False):
raise NotImplementedError()
# TODO: Docstring!
# TODO: IMPLEMENT!!!
# # Calculate new cost terms:
# log_m_s, log_a_s = np.log10(self.calculate(lam_start))
# log_m_e, log_a_e = np.log10(self.calculate(lam_end))
# # Calculate new lambda:
# log_lam_s, log_lam_e = np.log10(lam_start), np.log10(lam_end)
# log_lam_new = np.mean((log_lam_s, log_lam_e)) # logarithmic mean to find middle on L!
# sign_exp = np.floor(log_lam_new)
# last_sign_digit = np.round(10 ** (log_lam_new - sign_exp))
# lam_new = last_sign_digit * 10 ** sign_exp
# # Calculate cost terms for new lambda:
# log_m_new, log_a_new = np.log10(self.calculate(lam_new))
# if online_axis: # Update plot if necessary:
# self.plot(axis=online_axis)
# from IPython import display
# display.clear_output(wait=True)
# display.display(plt.gcf())
# # Calculate distances from origin and find new interval:
# dist_s, dist_e = np.hypot(log_m_s, log_a_s), np.hypot(log_m_e, log_a_e)
# dist_new = np.hypot(log_m_new, log_a_new)
# print(lam_start, lam_end, lam_new)
# print(dist_s, dist_e, dist_new)
# # if dist_new
# TODO: slope has to be normalised, scale of axes is not equal!!!
# TODO: get rid of right flank (do Not use right points with slope steeper than -45°
# TODO: Implement else, return saved values! # TODO: Implement else, return saved values!
# TODO: Make this work with batch, sort lambdas at the end! # TODO: Make this work with batch, sort lambdas at the end!
# TODO: After sorting, calculate the CURVATURE for each lambda! (needs 3 points?) # TODO: After sorting, calculate the CURVATURE for each lambda! (needs 3 points?)
# TODO: Use finite difference methods (forward/backward/central, depending on location)! # TODO: Use finite difference methods (forward/backward/central, depends on location)!
# TODO: Investigate points around highest curvature further. # TODO: Investigate points around highest curvature further.
# TODO: Make sure to update ALL curvatures and search for new best EVERYWHERE! # TODO: Make sure to update ALL curvatures and search for new best EVERYWHERE!
# TODO: Distinguish regions of the L-Curve. # TODO: Distinguish regions of the L-Curve.
def calculate_auto(self, lam_start=1E-18, lam_end=1E5, online_axis=False):
# TODO: Docstring!
# Calculate new cost terms:
log_m_s, log_a_s = np.log10(self.calculate(lam_start))
log_m_e, log_a_e = np.log10(self.calculate(lam_end))
# Calculate new lambda:
log_lam_s, log_lam_e = np.log10(lam_start), np.log10(lam_end)
log_lam_new = np.mean((log_lam_s, log_lam_e)) # logarithmic mean to find middle on L!
sign_exp = np.floor(log_lam_new)
last_sign_digit = np.round(10 ** (log_lam_new - sign_exp))
lam_new = last_sign_digit * 10 ** sign_exp
# Calculate cost terms for new lambda:
log_m_new, log_a_new = np.log10(self.calculate(lam_new))
if online_axis: # Update plot if necessary:
self.plot(axis=online_axis)
from IPython import display
display.clear_output(wait=True)
display.display(plt.gcf())
# TODO: slope has to be normalised, scale of axes is not equal!!!
# TODO: get rid of right flank (do Not use right points with slope steeper than -45°
# Calculate distances from origin and find new interval:
dist_s, dist_e = np.hypot(log_m_s, log_a_s), np.hypot(log_m_e, log_a_e)
dist_new = np.hypot(log_m_new, log_a_new)
print(lam_start, lam_end, lam_new)
print(dist_s, dist_e, dist_new)
#if dist_new
def plot(self, axis=None, figsize=None): def plot(self, lambdas=None, axis=None, figsize=None):
# TODO: Docstring! # TODO: Docstring!
# Sort lists according to lambdas:
if lambdas is None:
lambdas = sorted(self.l_dict.keys())
x, y = [], []
for lam in lambdas:
x.append(self.l_dict[lam][0])
y.append(self.l_dict[lam][1] / lam)
if figsize is None: if figsize is None:
figsize = plottools.FIGSIZE_DEFAULT figsize = plottools.FIGSIZE_DEFAULT
if axis is None: if axis is None:
...@@ -567,23 +594,20 @@ class LCurve(object): ...@@ -567,23 +594,20 @@ class LCurve(object):
axis = fig.add_subplot(1, 1, 1) axis = fig.add_subplot(1, 1, 1)
axis.set_yscale("log", nonposx='clip') axis.set_yscale("log", nonposx='clip')
axis.set_xscale("log", nonposx='clip') axis.set_xscale("log", nonposx='clip')
axis.plot(self.chisq_m, self.chisq_a, 'k-', linewidth=3, zorder=1) axis.plot(x, y, 'k-', linewidth=3, zorder=1)
sc = axis.scatter(x=self.chisq_m, y=self.chisq_a, c=self.lams, marker='o', s=100, sc = axis.scatter(x, y, c=lambdas, marker='o', s=100, zorder=2,
zorder=2, cmap='nipy_spectral', norm=LogNorm()) cmap='nipy_spectral', norm=LogNorm())
plt.colorbar(mappable=sc, label='regularisation parameter $\lambda$') plt.colorbar(mappable=sc, label='regularisation parameter $\lambda$')
#plottools.add_cbar(axis, mappable=sc, label='regularisation parameter $\lambda$')
#axis.get_xaxis().get_major_formatter().labelOnlyBase = False
axis.set_xlabel( axis.set_xlabel(
r'$\Vert\mathbf{F}(\mathbf{x})-\mathbf{y}\Vert_{\mathbf{S}_{\epsilon}^{-1}}^{2}$', r'$\Vert\mathbf{F}(\mathbf{x})-\mathbf{y}\Vert_{\mathbf{S}_{\epsilon}^{-1}}^{2}$',
fontsize=22, labelpad=-5) fontsize=22, labelpad=-5)
axis.set_ylabel(r'$\frac{1}{\lambda}\Vert\mathbf{x}\Vert_{\mathbf{S}_{a}^{-1}}^{2}$', axis.set_ylabel(r'$\frac{1}{\lambda}\Vert\mathbf{x}\Vert_{\mathbf{S}_{a}^{-1}}^{2}$',
fontsize=22) fontsize=22)
#axis.set_xlim(3E3, 2E5)
#axis.set_ylim(1E2, 1E9)
axis.xaxis.label.set_color('firebrick') axis.xaxis.label.set_color('firebrick')
axis.yaxis.label.set_color('seagreen') axis.yaxis.label.set_color('seagreen')
axis.tick_params(axis='both', which='major') axis.tick_params(axis='both', which='major')
axis.grid() axis.grid()
return axis
# TODO: Don't plot the steep part on the right... # TODO: Don't plot the steep part on the right...
......
...@@ -8,6 +8,7 @@ import abc ...@@ -8,6 +8,7 @@ import abc
import logging import logging
import os import os
import tempfile import tempfile
from scipy.ndimage.interpolation import rotate
from numbers import Number from numbers import Number
import numpy as np import numpy as np
...@@ -21,6 +22,7 @@ import cmocean ...@@ -21,6 +22,7 @@ import cmocean
from . import colors from . import colors
from . import plottools from . import plottools
from .quaternion import Quaternion
__all__ = ['VectorData', 'ScalarData'] __all__ = ['VectorData', 'ScalarData']
...@@ -323,7 +325,8 @@ class FieldData(object, metaclass=abc.ABCMeta): ...@@ -323,7 +325,8 @@ class FieldData(object, metaclass=abc.ABCMeta):
Only possible, if each axis length is a power of 2! Only possible, if each axis length is a power of 2!
""" """
pass pass # TODO: NotImplementedError instead? See that all classes have the same interface!
# TODO: This means that all common functions of Scalar and Vector have to be abstract here!
@abc.abstractmethod @abc.abstractmethod
def scale_up(self, n, order): def scale_up(self, n, order):
...@@ -467,6 +470,56 @@ class VectorData(FieldData): ...@@ -467,6 +470,56 @@ class VectorData(FieldData):
def __getitem__(self, item): def __getitem__(self, item):
return self.__class__(self.a, self.field[item]) return self.__class__(self.a, self.field[item])
def get_vector(self, mask):
"""Returns the vector field components arranged in a vector, specified by a mask.
Parameters
----------
mask : :class:`~numpy.ndarray` (N=3, boolean)
Masks the pixels from which the components should be taken.
Returns
-------
vector : :class:`~numpy.ndarray` (N=1)
The vector containing vector field components of the specified pixels.
Order is: first all `x`-, then all `y`-, then all `z`-components.
"""
self._log.debug('Calling get_vector')
if mask is not None:
return np.reshape([self.field[0][mask],
self.field[1][mask],
self.field[2][mask]], -1)
else:
return self.field_vec
def set_vector(self, vector, mask=None):
"""Set the field components of the masked pixels to the values specified by `vector`.
Parameters
----------
mask : :class:`~numpy.ndarray` (N=3, boolean), optional
Masks the pixels from which the components should be taken.
vector : :class:`~numpy.ndarray` (N=1)
The vector containing vector field components of the specified pixels.
Order is: first all `x`-, then all `y-, then all `z`-components.
Returns
-------
None
"""
self._log.debug('Calling set_vector')
assert np.size(vector) % 3 == 0, 'Vector has to contain all 3 components for every pixel!'
count = np.size(vector) // 3
if mask is not None:
self.field[0][mask] = vector[:count] # x-component
self.field[1][mask] = vector[count:2 * count] # y-component
self.field[2][mask] = vector[2 * count:] # z-component
else:
self.field_vec = vector
# TODO: scale_down and scale_up should not work in place (also in ScalarData)!
def scale_down(self, n=1): def scale_down(self, n=1):
"""Scale down the field distribution by averaging over two pixels along each axis. """Scale down the field distribution by averaging over two pixels along each axis.
...@@ -487,16 +540,18 @@ class VectorData(FieldData): ...@@ -487,16 +540,18 @@ class VectorData(FieldData):
""" """
self._log.debug('Calling scale_down') self._log.debug('Calling scale_down')
assert n > 0 and isinstance(n, int), 'n must be a positive integer!' assert n > 0 and isinstance(n, int), 'n must be a positive integer!'
self.a *= 2 ** n a_new = self.a * 2 ** n
field_new = self.field
for t in range(n): for t in range(n):
# Pad if necessary: # Pad if necessary:
pz, py, px = self.dim[0] % 2, self.dim[1] % 2, self.dim[2] % 2 dim = field_new.shape[1:]
pz, py, px = dim[0] % 2, dim[1] % 2, dim[2] % 2
if pz != 0 or py != 0 or px != 0: if pz != 0 or py != 0 or px != 0:
self.field = np.pad(self.field, ((0, 0), (0, pz), (0, py), (0, px)), field_new = np.pad(field_new, ((0, 0), (0, pz), (0, py), (0, px)), mode='constant')
mode='constant')
# Create coarser grid for the vector field: # Create coarser grid for the vector field:
shape_4d = (3, self.dim[0] // 2, 2, self.dim[1] // 2, 2, self.dim[2] // 2, 2) shape_4d = (3, dim[0] // 2, 2, dim[1] // 2, 2, dim[2] // 2, 2)
self.field = self.field.reshape(shape_4d).mean(axis=(6, 4, 2)) field_new = field_new.reshape(shape_4d).mean(axis=(6, 4, 2))
return VectorData(a_new, field_new)
def scale_up(self, n=1, order=0): def scale_up(self, n=1, order=0):
"""Scale up the field distribution using spline interpolation of the requested order. """Scale up the field distribution using spline interpolation of the requested order.
...@@ -522,10 +577,11 @@ class VectorData(FieldData): ...@@ -522,10 +577,11 @@ class VectorData(FieldData):
assert n > 0 and isinstance(n, int), 'n must be a positive integer!' assert n > 0 and isinstance(n, int), 'n must be a positive integer!'
assert 5 > order >= 0 and isinstance(order, int), \ assert 5 > order >= 0 and isinstance(order, int), \
'order must be a positive integer between 0 and 5!' 'order must be a positive integer between 0 and 5!'
self.a /= 2 ** n a_new = self.a / 2 ** n
self.field = np.array((zoom(self.field[0], zoom=2 ** n, order=order), field_new = np.array((zoom(self.field[0], zoom=2 ** n, order=order),
zoom(self.field[1], zoom=2 ** n, order=order), zoom(self.field[1], zoom=2 ** n, order=order),
zoom(self.field[2], zoom=2 ** n, order=order))) zoom(self.field[2], zoom=2 ** n, order=order)))
return VectorData(a_new, field_new)
def pad(self, pad_values): def pad(self, pad_values):
"""Pad the current field distribution with zeros for each individual axis. """Pad the current field distribution with zeros for each individual axis.
...@@ -551,8 +607,9 @@ class VectorData(FieldData): ...@@ -551,8 +607,9 @@ class VectorData(FieldData):
for i, values in enumerate(pad_values): for i, values in enumerate(pad_values):
assert np.shape(values) in [(), (2,)], 'Only one or two values per axis can be given!' assert np.shape(values) in [(), (2,)], 'Only one or two values per axis can be given!'
pv[2 * i:2 * (i + 1)] = values pv[2 * i:2 * (i + 1)] = values
self.field = np.pad(self.field, ((0, 0), (pv[0], pv[1]), (pv[2], pv[3]), (pv[4], pv[5])), field_pad = np.pad(self.field, ((0, 0), (pv[0], pv[1]), (pv[2], pv[3]), (pv[4], pv[5])),
mode='constant') mode='constant')
return VectorData(self.a, field_pad)
def crop(self, crop_values): # TODO: bad copy&paste from pad? def crop(self, crop_values): # TODO: bad copy&paste from pad?
"""Crop the current field distribution with zeros for each individual axis. """Crop the current field distribution with zeros for each individual axis.
...@@ -580,56 +637,8 @@ class VectorData(FieldData): ...@@ -580,56 +637,8 @@ class VectorData(FieldData):
cv[2 * i:2 * (i + 1)] = values cv[2 * i:2 * (i + 1)] = values
cv *= np.resize([1, -1], len(cv)) cv *= np.resize([1, -1], len(cv))
cv = np.where(cv == 0, None, cv) cv = np.where(cv == 0, None, cv)
self.field = self.field[:, cv[0]:cv[1], cv[2]:cv[3], cv[4]:cv[5]] field_crop = self.field[:, cv[0]:cv[1], cv[2]:cv[3], cv[4]:cv[5]]
return VectorData(self.a, field_crop)
def get_vector(self, mask):
"""Returns the vector field components arranged in a vector, specified by a mask.
Parameters
----------
mask : :class:`~numpy.ndarray` (N=3, boolean)
Masks the pixels from which the components should be taken.
Returns
-------
vector : :class:`~numpy.ndarray` (N=1)
The vector containing vector field components of the specified pixels.
Order is: first all `x`-, then all `y`-, then all `z`-components.
"""
self._log.debug('Calling get_vector')
if mask is not None:
return np.reshape([self.field[0][mask],
self.field[1][mask],
self.field[2][mask]], -1)
else:
return self.field_vec
def set_vector(self, vector, mask=None):
"""Set the field components of the masked pixels to the values specified by `vector`.
Parameters
----------
mask : :class:`~numpy.ndarray` (N=3, boolean), optional
Masks the pixels from which the components should be taken.
vector : :class:`~numpy.ndarray` (N=1)
The vector containing vector field components of the specified pixels.
Order is: first all `x`-, then all `y-, then all `z`-components.
Returns
-------
None
"""
self._log.debug('Calling set_vector')
assert np.size(vector) % 3 == 0, 'Vector has to contain all 3 components for every pixel!'
count = np.size(vector) // 3
if mask is not None:
self.field[0][mask] = vector[:count] # x-component
self.field[1][mask] = vector[count:2 * count] # y-component
self.field[2][mask] = vector[2 * count:] # z-component
else:
self.field_vec = vector
def flip(self, axis='x'): def flip(self, axis='x'):
"""Flip/mirror the vector field around the specified axis. """Flip/mirror the vector field around the specified axis.
...@@ -659,6 +668,17 @@ class VectorData(FieldData): ...@@ -659,6 +668,17 @@ class VectorData(FieldData):
raise ValueError("Wrong input! 'x', 'y', 'z' allowed!") raise ValueError("Wrong input! 'x', 'y', 'z' allowed!")
return VectorData(self.a, field_flip) return VectorData(self.a, field_flip)
def rotate(self, angle, axis='z', reshape=False, **kwargs):
# TODO: Docstring!
# Define axes of rotation plane (axis 0 for vector component!) and rotate coord. system:
axes = {'x': (1, 2), 'y': (1, 3), 'z': (2, 3)}[axis]
field_coord_rot = rotate(self.field, angle, axes=axes, reshape=reshape, **kwargs)
# Rotate vectors inside the voxels (- signs determined by scipy rotate axes in firt step):
vec_dict = {'x': (-1, 0, 0), 'y': (0, 1, 0), 'z': (0, 0, -1)}
quat = Quaternion.from_axisangle(vec_dict[axis], np.deg2rad(angle))
field_rot = quat.matrix.dot(field_coord_rot.reshape(3, -1)).reshape(field_coord_rot.shape)
return VectorData(self.a, field_rot)
def rot90(self, axis='x'): def rot90(self, axis='x'):
"""Rotate the vector field 90° around the specified axis (right hand rotation). """Rotate the vector field 90° around the specified axis (right hand rotation).
...@@ -674,28 +694,65 @@ class VectorData(FieldData): ...@@ -674,28 +694,65 @@ class VectorData(FieldData):
""" """
self._log.debug('Calling rot90') self._log.debug('Calling rot90')
if axis == 'x': if axis == 'x': # RotMatrix for 90°: [[1, 0, 0], [0, 0, -1], [0, 1, 0]]
field_rot = np.zeros((3, self.dim[1], self.dim[0], self.dim[2])) field_rot = np.zeros((3, self.dim[1], self.dim[0], self.dim[2]))
for i in range(self.dim[2]): for i in range(self.dim[2]):
mag_x, mag_y, mag_z = self.field[:, :, :, i] mag_x, mag_y, mag_z = self.field[:, :, :, i]
mag_xrot, mag_yrot, mag_zrot = np.rot90(mag_x), np.rot90(mag_y), np.rot90(mag_z) mag_xrot = np.rot90(mag_x)
field_rot[:, :, :, i] = np.array((mag_xrot, mag_zrot, -mag_yrot)) mag_yrot = np.rot90(mag_z)
elif axis == 'y': mag_zrot = -np.rot90(mag_y)
field_rot[:, :, :, i] = np.array((mag_xrot, mag_yrot, mag_zrot))
elif axis == 'y': # RotMatrix for 90°: [[0, 0, 1], [0, 1, 0], [-1, 0, 0]]
field_rot = np.zeros((3, self.dim[2], self.dim[1], self.dim[0])) field_rot = np.zeros((3, self.dim[2], self.dim[1], self.dim[0]))
for i in range(self.dim[1]): for i in range(self.dim[1]):
mag_x, mag_y, mag_z = self.field[:, :, i, :] mag_x, mag_y, mag_z = self.field[:, :, i, :]
mag_xrot, mag_yrot, mag_zrot = np.rot90(mag_x), np.rot90(mag_y), np.rot90(mag_z) mag_xrot = np.rot90(mag_z)
field_rot[:, :, i, :] = np.array((mag_zrot, mag_yrot, -mag_xrot)) mag_yrot = np.rot90(mag_y)
elif axis == 'z': mag_zrot = -np.rot90(mag_x)
field_rot[:, :, i, :] = np.array((mag_xrot, mag_yrot, mag_zrot))
elif axis == 'z': # RotMatrix for 90°: [[0, -1, 0], [1, 0, 0], [0, 0, 1]]
field_rot = np.zeros((3, self.dim[0], self.dim[2], self.dim[1])) field_rot = np.zeros((3, self.dim[0], self.dim[2], self.dim[1]))
for i in range(self.dim[0]): for i in range(self.dim[0]):
mag_x, mag_y, mag_z = self.field[:, i, :, :] mag_x, mag_y, mag_z = self.field[:, i, :, :]
mag_xrot, mag_yrot, mag_zrot = np.rot90(mag_x), np.rot90(mag_y), np.rot90(mag_z) mag_xrot = np.rot90(mag_y)
field_rot[:, i, :, :] = np.array((mag_yrot, -mag_xrot, mag_zrot)) mag_yrot = -np.rot90(mag_x)
mag_zrot = np.rot90(mag_z)
field_rot[:, i, :, :] = np.array((mag_xrot, mag_yrot, mag_zrot))
else: else:
raise ValueError("Wrong input! 'x', 'y', 'z' allowed!") raise ValueError("Wrong input! 'x', 'y', 'z' allowed!")
return VectorData(self.a, field_rot) return VectorData(self.a, field_rot)
def roll(self, shift, axis='x'): # TODO: Make sure both classes have all manipulator methods!
"""Rotate the scalar field 90° around the specified axis (right hand rotation).
Parameters
----------
axis: {'x', 'y', 'z'}, optional
The axis around which the vector field is rotated.
Returns
-------
scalardata_rot: :class:`~.ScalarData`
A rotated copy of the :class:`~.ScalarData` object.
"""
self._log.debug('Calling roll')
ax = {'x': 2, 'y': 1, 'z': 0}[axis]
mag_x_roll = np.roll(self.field[0, ...], shift, ax)
mag_y_roll = np.roll(self.field[1, ...], shift, ax)
mag_z_roll = np.roll(self.field[2, ...], shift, ax)
return VectorData(self.a, np.asarray((mag_x_roll, mag_y_roll, mag_z_roll)))
def clip_amp(self, threshold):
# TODO: Docstring!
# TODO: 'mag' should be 'vec' everywhere!!!
mag_x, mag_y, mag_z = self.field
scaling = np.where(self.field_amp > threshold, threshold/self.field_amp, 1)
mag_x = mag_x * scaling
mag_y = mag_y * scaling
mag_z = mag_z * scaling
return VectorData(self.a, np.asarray((mag_x, mag_y, mag_z)))
def get_slice(self, ax_slice=None, proj_axis='z'): def get_slice(self, ax_slice=None, proj_axis='z'):
# TODO: return x y z instead of u v w (to color fields consistent with xyz!) # TODO: return x y z instead of u v w (to color fields consistent with xyz!)
"""Extract a slice from the :class:`~.VectorData` object. """Extract a slice from the :class:`~.VectorData` object.
...@@ -1363,6 +1420,47 @@ class ScalarData(FieldData): ...@@ -1363,6 +1420,47 @@ class ScalarData(FieldData):
'Vector has to match field shape! {} {}'.format(c_vec.shape, np.prod(self.shape)) 'Vector has to match field shape! {} {}'.format(c_vec.shape, np.prod(self.shape))
self.field = c_vec.reshape(self.dim) self.field = c_vec.reshape(self.dim)
def get_vector(self, mask):
"""Returns the field as a vector, specified by a mask.
Parameters
----------
mask : :class:`~numpy.ndarray` (N=3, boolean)
Masks the pixels from which the components should be taken.
Returns
-------
vector : :class:`~numpy.ndarray` (N=1)
The vector containing the field of the specified pixels.
"""
self._log.debug('Calling get_vector')
if mask is not None:
return np.reshape(self.field[mask], -1)
else:
return self.field_vec
def set_vector(self, vector, mask=None):
"""Set the field components of the masked pixels to the values specified by `vector`.
Parameters
----------
mask : :class:`~numpy.ndarray` (N=3, boolean), optional
Masks the pixels from which the components should be taken.
vector : :class:`~numpy.ndarray` (N=1)
The vector containing the field of the specified pixels.
Returns
-------
None
"""
self._log.debug('Calling set_vector')
if mask is not None:
self.field[mask] = vector
else:
self.field_vec = vector
def scale_down(self, n=1): def scale_down(self, n=1):
"""Scale down the field distribution by averaging over two pixels along each axis. """Scale down the field distribution by averaging over two pixels along each axis.
...@@ -1383,15 +1481,18 @@ class ScalarData(FieldData): ...@@ -1383,15 +1481,18 @@ class ScalarData(FieldData):
""" """
self._log.debug('Calling scale_down') self._log.debug('Calling scale_down')
assert n > 0 and isinstance(n, int), 'n must be a positive integer!' assert n > 0 and isinstance(n, int), 'n must be a positive integer!'
self.a *= 2 ** n a_new = self.a * 2 ** n
field_new = self.field
for t in range(n): for t in range(n):
# Pad if necessary: # Pad if necessary:
pz, py, px = self.dim[0] % 2, self.dim[1] % 2, self.dim[2] % 2 dim = field_new.shape
pz, py, px = dim[0] % 2, dim[1] % 2, dim[2] % 2
if pz != 0 or py != 0 or px != 0: if pz != 0 or py != 0 or px != 0:
self.field = np.pad(self.field, ((0, pz), (0, py), (0, px)), mode='constant') field_new = np.pad(field_new, ((0, pz), (0, py), (0, px)), mode='constant')
# Create coarser grid for the field: # Create coarser grid for the field:
shape_4d = (self.dim[0] // 2, 2, self.dim[1] // 2, 2, self.dim[2] // 2, 2) shape_3d = (dim[0] // 2, 2, dim[1] // 2, 2, dim[2] // 2, 2)
self.field = self.field.reshape(shape_4d).mean(axis=(5, 3, 1)) field_new = field_new.reshape(shape_3d).mean(axis=(5, 3, 1))
return ScalarData(a_new, field_new)
def scale_up(self, n=1, order=0): def scale_up(self, n=1, order=0):
"""Scale up the field distribution using spline interpolation of the requested order. """Scale up the field distribution using spline interpolation of the requested order.
...@@ -1417,49 +1518,17 @@ class ScalarData(FieldData): ...@@ -1417,49 +1518,17 @@ class ScalarData(FieldData):
assert n > 0 and isinstance(n, int), 'n must be a positive integer!' assert n > 0 and isinstance(n, int), 'n must be a positive integer!'
assert 5 > order >= 0 and isinstance(order, int), \ assert 5 > order >= 0 and isinstance(order, int), \
'order must be a positive integer between 0 and 5!' 'order must be a positive integer between 0 and 5!'
self.a /= 2 ** n a_new = self.a / 2 ** n
self.field = zoom(self.field, zoom=2 ** n, order=order) field_new = zoom(self.field, zoom=2 ** n, order=order)
return ScalarData(a_new, field_new)
def get_vector(self, mask): # TODO: flip!
"""Returns the field as a vector, specified by a mask.
Parameters
----------
mask : :class:`~numpy.ndarray` (N=3, boolean)
Masks the pixels from which the components should be taken.
Returns
-------
vector : :class:`~numpy.ndarray` (N=1)
The vector containing the field of the specified pixels.
"""
self._log.debug('Calling get_vector')
if mask is not None:
return np.reshape(self.field[mask], -1)
else:
return self.field_vec
def set_vector(self, vector, mask=None):
"""Set the field components of the masked pixels to the values specified by `vector`.
Parameters def rotate(self, angle, axis='z', reshape=False, **kwargs):
---------- # TODO: Docstring!
mask : :class:`~numpy.ndarray` (N=3, boolean), optional axes = {'x': (0, 1), 'y': (0, 2), 'z': (1, 2)}[axis] # Defines axes of plane of rotation!
Masks the pixels from which the components should be taken. field_rot = rotate(self.field[...], angle, axes=axes, reshape=reshape, **kwargs)
vector : :class:`~numpy.ndarray` (N=1) return ScalarData(self.a, field_rot)
The vector containing the field of the specified pixels.
Returns
-------
None
"""
self._log.debug('Calling set_vector')
if mask is not None:
self.field[mask] = vector
else:
self.field_vec = vector
def rot90(self, axis='x'): def rot90(self, axis='x'):
"""Rotate the scalar field 90° around the specified axis (right hand rotation). """Rotate the scalar field 90° around the specified axis (right hand rotation).
...@@ -1492,6 +1561,24 @@ class ScalarData(FieldData): ...@@ -1492,6 +1561,24 @@ class ScalarData(FieldData):
raise ValueError("Wrong input! 'x', 'y', 'z' allowed!") raise ValueError("Wrong input! 'x', 'y', 'z' allowed!")
return ScalarData(self.a, field_rot) return ScalarData(self.a, field_rot)
def roll(self, shift, axis='x'): # TODO: Make sure both classes have all manipulator methods!
"""Rotate the scalar field 90° around the specified axis (right hand rotation).
Parameters
----------
axis: {'x', 'y', 'z'}, optional
The axis around which the vector field is rotated.
Returns
-------
scalardata_rot: :class:`~.ScalarData`
A rotated copy of the :class:`~.ScalarData` object.
"""
self._log.debug('Calling roll')
ax = {'x': 2, 'y': 1, 'z': 0}[axis]
return ScalarData(self.a, np.roll(self.field, shift, ax))
def to_signal(self): def to_signal(self):
"""Convert :class:`~.ScalarData` data into a HyperSpy signal. """Convert :class:`~.ScalarData` data into a HyperSpy signal.
......
...@@ -262,6 +262,8 @@ class PhaseMap(object): ...@@ -262,6 +262,8 @@ class PhaseMap(object):
return PhaseMap(self.a, self.phase.copy(), self.mask.copy(), return PhaseMap(self.a, self.phase.copy(), self.mask.copy(),
self.confidence.copy()) self.confidence.copy())
# TODO: ALL NOT IN PLACE!!!
def scale_down(self, n=1): def scale_down(self, n=1):
"""Scale down the phase map by averaging over two pixels along each axis. """Scale down the phase map by averaging over two pixels along each axis.
...@@ -671,10 +673,14 @@ class PhaseMap(object): ...@@ -671,10 +673,14 @@ class PhaseMap(object):
if symmetric: if symmetric:
vmin, vmax = np.min([vmin, -0]), np.max([0, vmax]) # Ensure zero is present! vmin, vmax = np.min([vmin, -0]), np.max([0, vmax]) # Ensure zero is present!
limit = np.max(np.abs([vmin, vmax])) limit = np.max(np.abs([vmin, vmax]))
start = (vmin + limit) / (2 * limit) # A symmetric colormap only has zero at white (the symmetry point) if the values
end = (vmax + limit) / (2 * limit) # of the corresponding mappable go from -limit to +limit (symmetric bounds)!
# Calculate the colors of this symmetric colormap for the range vmin to vmax:
start = 0.5 + vmin/(2*limit) # 0 for symmetric bounds, >0: unused colors at lower end!
end = 0.5 + vmax/(2*limit) # 1 for symmetric bounds, <1: unused colors at upper end!
cmap_colors = cmap(np.linspace(start, end, 256)) cmap_colors = cmap(np.linspace(start, end, 256))
cmap = LinearSegmentedColormap.from_list('Symmetric', cmap_colors) # Use calculated colors to create custom (asymmetric) colormap with white at zero:
cmap = LinearSegmentedColormap.from_list('custom', cmap_colors)
# If no axis is specified, a new figure is created: # If no axis is specified, a new figure is created:
if axis is None: if axis is None:
fig = plt.figure(figsize=figsize) fig = plt.figure(figsize=figsize)
...@@ -878,11 +884,14 @@ class PhaseMap(object): ...@@ -878,11 +884,14 @@ class PhaseMap(object):
# Plot histogram: # Plot histogram:
hist_axis = fig.add_subplot(1, 2, 1) hist_axis = fig.add_subplot(1, 2, 1)
vec = self.phase_vec * self.UNITDICT[unit] # Take units into consideration: vec = self.phase_vec * self.UNITDICT[unit] # Take units into consideration:
# TODO: This is bad! Discard low confidence values completely instead! Otherwise peak at 0!
# TODO: Set to nan and then discard with np.isnan()?
vec *= np.where(self.confidence > 0.5, 1, 0).ravel() # Discard low confidence points! vec *= np.where(self.confidence > 0.5, 1, 0).ravel() # Discard low confidence points!
hist_axis.hist(vec, bins=bins, histtype='stepfilled', color='g') hist_axis.hist(vec, bins=bins, histtype='stepfilled', color='g')
# Format histogram: # Format histogram:
x0, x1 = hist_axis.get_xlim() x0, x1 = hist_axis.get_xlim()
y0, y1 = hist_axis.get_ylim() y0, y1 = hist_axis.get_ylim()
# TODO: Why the next line? Seems bad if you want to change things later´!
hist_axis.set(aspect=np.abs(x1 - x0) / np.abs(y1 - y0) * 0.94) # Last value because cbar! hist_axis.set(aspect=np.abs(x1 - x0) / np.abs(y1 - y0) * 0.94) # Last value because cbar!
fontsize = kwargs.get('fontsize', 16) fontsize = kwargs.get('fontsize', 16)
hist_axis.tick_params(axis='both', which='major', labelsize=fontsize) hist_axis.tick_params(axis='both', which='major', labelsize=fontsize)
......
...@@ -411,6 +411,12 @@ class RotTiltProjector(Projector): ...@@ -411,6 +411,12 @@ class RotTiltProjector(Projector):
self._log.debug('Calling __init__') self._log.debug('Calling __init__')
self.rotation = rotation self.rotation = rotation
self.tilt = tilt self.tilt = tilt
# Create tilt, rotation and combined quaternion, careful: Quaternion(w,x,y,z), not (z,y,x):
quat_z_n = Quaternion.from_axisangle((0, 0, 1), -rotation) # Rotate around z-axis
quat_x = Quaternion.from_axisangle((1, 0, 0), tilt) # Tilt around x-axis
quat_z_p = Quaternion.from_axisangle((0, 0, 1), rotation) # Rotate around z-axis
# Combined quaternion (first rotate around z, then tilt around x, rotate back around z):
quat = quat_z_n * quat_x * quat_z_p # p: positive rotation, m: negative rotation
# Determine dimensions: # Determine dimensions:
dim_z, dim_y, dim_x = dim dim_z, dim_y, dim_x = dim
center = (dim_z / 2., dim_y / 2., dim_x / 2.) center = (dim_z / 2., dim_y / 2., dim_x / 2.)
...@@ -421,18 +427,17 @@ class RotTiltProjector(Projector): ...@@ -421,18 +427,17 @@ class RotTiltProjector(Projector):
dim_v, dim_u = dim_uv dim_v, dim_u = dim_uv
# Creating coordinate list of all voxels: # Creating coordinate list of all voxels:
voxels = list(itertools.product(range(dim_z), range(dim_y), range(dim_x))) voxels = list(itertools.product(range(dim_z), range(dim_y), range(dim_x)))
# Calculate vectors to voxels relative to rotation center: # Calculate vectors to voxels relative to rotation center (each row contains (z, y, x)):
voxel_vecs = (np.asarray(voxels) + 0.5 - np.asarray(center)).T voxel_vecs = np.asarray(voxels) + 0.5 - np.asarray(center)
# Create tilt, rotation and combined quaternion, careful: Quaternion(w,x,y,z), not (z,y,x): # Change to coordinate order of quaternions (x, y, z) instead of (z, y, x):
quat_z_n = Quaternion.from_axisangle((0, 0, 1), -rotation) # Rotate around z-axis voxel_vecs = np.fliplr(voxel_vecs)
quat_x = Quaternion.from_axisangle((1, 0, 0), tilt) # Tilt around x-axis # Calculate impact positions (x, y) on the projected pixel coordinate grid (z is dropped):
quat_z_p = Quaternion.from_axisangle((0, 0, 1), rotation) # Rotate around z-axis impacts = quat.matrix[:2, :].dot(voxel_vecs.T).T # Only x and y row of matrix is used!
# Combined quaternion (first rotate around z, then tilt around x, rotate back around z): # Reverse transpose and change back coordinate order from (x, y) to (y, x)/(v, u):
quat = quat_z_n * quat_x * quat_z_p # p: positive rotation, m: negative rotation impacts = np.fliplr(impacts.T) # Now contains rows with (v, u) entries as desired!
# Calculate impact positions on the projected pixel coordinate grid (flip because quat.): # First index: voxel, second index: 0 -> v, 1 -> u!
impacts = np.flipud(quat.matrix[:2, :].dot(np.flipud(voxel_vecs))) # only care for x/y impacts[:, 1] += dim_u / 2. # Shift back to normal indices
impacts[1, :] += dim_u / 2. # Shift back to normal indices impacts[:, 0] += dim_v / 2. # Shift back to normal indices
impacts[0, :] += dim_v / 2. # Shift back to normal indices
# Calculate equivalence radius: # Calculate equivalence radius:
R = (3 / (4 * np.pi)) ** (1 / 3.) R = (3 / (4 * np.pi)) ** (1 / 3.)
# Prepare weight matrix calculation: # Prepare weight matrix calculation:
...@@ -446,7 +451,7 @@ class RotTiltProjector(Projector): ...@@ -446,7 +451,7 @@ class RotTiltProjector(Projector):
for i, voxel in enumerate(tqdm(voxels, disable=disable, leave=False, for i, voxel in enumerate(tqdm(voxels, disable=disable, leave=False,
desc='Set up projector')): desc='Set up projector')):
column_index = voxel[0] * dim_y * dim_x + voxel[1] * dim_x + voxel[2] column_index = voxel[0] * dim_y * dim_x + voxel[1] * dim_x + voxel[2]
remainder, impact = np.modf(impacts[:, i]) # split index of impact and remainder! remainder, impact = np.modf(impacts[i, :]) # split index of impact and remainder!
sub_pixel = (remainder * subcount).astype(dtype=np.int) # sub_pixel inside impact px. sub_pixel = (remainder * subcount).astype(dtype=np.int) # sub_pixel inside impact px.
# Go over all influenced pixels (impact and neighbours, indices are [0, 1, 2]!): # Go over all influenced pixels (impact and neighbours, indices are [0, 1, 2]!):
for px_ind in list(itertools.product(range(3), range(3))): for px_ind in list(itertools.product(range(3), range(3))):
......
...@@ -64,11 +64,18 @@ class Quaternion(object): ...@@ -64,11 +64,18 @@ class Quaternion(object):
self._log.debug('Calling __mul__') self._log.debug('Calling __mul__')
if isinstance(other, Quaternion): # Quaternion multiplication if isinstance(other, Quaternion): # Quaternion multiplication
return self.dot_quat(self, other) return self.dot_quat(self, other)
elif len(other) == 3: # vector multiplication elif len(other) == 3: # vector multiplication (Caution: normalises!)
q_vec = Quaternion((0,) + tuple(other)) q_vec = Quaternion((0,) + tuple(other))
q = self.dot_quat(self.dot_quat(self, q_vec), self.conj) q = self.dot_quat(self.dot_quat(self, q_vec), self.conj)
return q.values[1:] return q.values[1:]
def _normalize(self):
self._log.debug('Calling _normalize')
mag2 = np.sum(n ** 2 for n in self.values)
if abs(mag2 - 1.0) > self.NORM_TOLERANCE:
mag = np.sqrt(mag2)
self.values = tuple(n / mag for n in self.values)
def dot_quat(self, q1, q2): def dot_quat(self, q1, q2):
"""Multiply two :class:`~.Quaternion` objects to create a new one (always normalized). """Multiply two :class:`~.Quaternion` objects to create a new one (always normalized).
...@@ -92,13 +99,6 @@ class Quaternion(object): ...@@ -92,13 +99,6 @@ class Quaternion(object):
z = w1 * z2 + z1 * w2 + x1 * y2 - y1 * x2 z = w1 * z2 + z1 * w2 + x1 * y2 - y1 * x2
return Quaternion((w, x, y, z)) return Quaternion((w, x, y, z))
def _normalize(self):
self._log.debug('Calling _normalize')
mag2 = np.sum(n ** 2 for n in self.values)
if abs(mag2 - 1.0) > self.NORM_TOLERANCE:
mag = np.sqrt(mag2)
self.values = tuple(n / mag for n in self.values)
@classmethod @classmethod
def from_axisangle(cls, vector, theta): def from_axisangle(cls, vector, theta):
"""Create a quaternion from an axis-angle representation """Create a quaternion from an axis-angle representation
......
from scipy import constants
import numpy as np
def electron_wavelength(ht):
"""
Returns electron wavelenght in nm.
Parameters
----------
ht : float
High tension in kV.
"""
momentum = 2 * constants.m_e * constants.elementary_charge * ht * 1000 * (
1 + constants.elementary_charge * ht * 1000 / (2 * constants.m_e * constants.c ** 2))
wavelength = constants.h / np.sqrt(momentum) * 1e9 # in nm
return wavelength
def aberration_function(w, aber_dict, v_acc):
# TODO: Taken from Florian! Use dictionary!
w_cc = np.conjugate(w)
chi_i = {'C1': aber_dict['C1'] * w * w_cc / 2}
chi_sum = np.zeros_like(w)
for key in aber_dict.keys():
chi_sum += chi_i[key]
return (2 * np.pi / electron_wavelength(v_acc)) * np.real(chi_sum)
def apply_aberrations(phasemap, aber_dict, v_acc):
# Define complex scattering angle w
f_freq_v = np.fft.fftfreq(phasemap.dim_uv[0], phasemap.a)
f_freq_u = np.fft.fftfreq(phasemap.dim_uv[1], phasemap.a)
f_freq_mesh = np.meshgrid(f_freq_u, f_freq_v)
w = f_freq_mesh[0] + 1j * f_freq_mesh[1]
w *= electron_wavelength(v_acc)
chi = aberration_function(w, aber_dict, v_acc)
wave = np.exp(1j * phasemap.phase)
wave_fft = np.fft.fftn(wave) / np.prod(phasemap.dim_uv)
wave_fft *= np.exp(-1j * chi)
wave = np.fft.ifftn(wave_fft) * np.prod(phasemap.dim_uv)
return wave, chi
...@@ -29,20 +29,20 @@ class TestCaseVectorData(unittest.TestCase): ...@@ -29,20 +29,20 @@ class TestCaseVectorData(unittest.TestCase):
assert magdata_copy != self.magdata, 'Unexpected behaviour in copy()!' assert magdata_copy != self.magdata, 'Unexpected behaviour in copy()!'
def test_scale_down(self): def test_scale_down(self):
self.magdata.scale_down() magdata_test = self.magdata.scale_down()
reference = 1 / 8. * np.ones((3, 2, 2, 2)) reference = 1 / 8. * np.ones((3, 2, 2, 2))
assert_allclose(self.magdata.field, reference, assert_allclose(magdata_test.field, reference,
err_msg='Unexpected behavior in scale_down()!') err_msg='Unexpected behavior in scale_down()!')
assert_allclose(self.magdata.a, 20, assert_allclose(magdata_test.a, 20,
err_msg='Unexpected behavior in scale_down()!') err_msg='Unexpected behavior in scale_down()!')
def test_scale_up(self): def test_scale_up(self):
self.magdata.scale_up() magdata_test = self.magdata.scale_up()
reference = np.zeros((3, 8, 8, 8)) reference = np.zeros((3, 8, 8, 8))
reference[:, 2:6, 2:6, 2:6] = 1 reference[:, 2:6, 2:6, 2:6] = 1
assert_allclose(self.magdata.field, reference, assert_allclose(magdata_test.field, reference,
err_msg='Unexpected behavior in scale_down()!') err_msg='Unexpected behavior in scale_down()!')
assert_allclose(self.magdata.a, 5, assert_allclose(magdata_test.a, 5,
err_msg='Unexpected behavior in scale_down()!') err_msg='Unexpected behavior in scale_down()!')
def test_pad(self): def test_pad(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment