diff --git a/pyramid/diagnostics.py b/pyramid/diagnostics.py index 8acfcf840fcb6258a9f77f718ea42bb4d019e405..1e84c3f7da5e338a4793dbd8b26a5124831e1387 100644 --- a/pyramid/diagnostics.py +++ b/pyramid/diagnostics.py @@ -9,6 +9,8 @@ import os import logging +import pickle + from pyramid.forwardmodel import ForwardModel from pyramid.costfunction import Costfunction from pyramid.regularisator import FirstOrderRegularisator @@ -26,6 +28,14 @@ import numpy as np import jutil +try: + if type(get_ipython()).__name__ == 'ZMQInteractiveShell': # IPython Notebook! + from tqdm import tqdm_notebook as tqdm + else: # IPython, but not a Notebook (e.g. terminal) + from tqdm import tqdm +except NameError: + from tqdm import tqdm + __all__ = ['Diagnostics', 'LCurve','get_vector_field_errors'] # TODO: should be subpackage, distribute methods and classes to separate modules! @@ -143,7 +153,7 @@ class Diagnostics(object): self._updated_avrg_kern_row = False self._updated_measure_contribution = False - def __init__(self, magdata, cost, max_iter=1000, verbose=False): + def __init__(self, magdata, cost, max_iter=1000, verbose=False): # TODO: verbose True default self._log.debug('Calling __init__') self.magdata = magdata self.cost = cost @@ -390,7 +400,7 @@ class Diagnostics(object): artist = axis.add_patch(patches.Ellipse(xy, width, height, fill=False, edgecolor='w', linewidth=2, alpha=0.5)) artist.set_path_effects([patheffects.withStroke(linewidth=4, foreground='k', alpha=0.5)]) - + # TODO: Return axis on every plot? def plot_avrg_kern_field3d(self, pos=None, mask=True, ellipsoid=True, **kwargs): avrg_kern_field = self.get_avrg_kern_field(pos) @@ -481,84 +491,101 @@ class LCurve(object): _log = logging.getLogger(__name__ + '.FieldData') - def __init__(self, fwd_model, max_iter=0, verbose=False, save_dir=None): + def __init__(self, fwd_model, max_iter=0, verbose=True, save_dir='lcurve'): self._log.debug('Calling __init__') assert isinstance(fwd_model, ForwardModel), 'Input has to be a costfunction' self.fwd_model = fwd_model self.max_iter = max_iter self.verbose = verbose - self.lams = [] - self.chisq_a = [] - self.chisq_m = [] - if save_dir is not None: - assert os.path.isdir(save_dir), 'save_dir has to be None or a valid directory!' - self.save_dir = save_dir # TODO: Use save_dir!!! + self.l_dict = {} + self.save_dir = save_dir + if self.save_dir is not None: + if not os.path.isdir(self.save_dir): # Create directory if it does not exist: + os.makedirs(self.save_dir) + if os.path.isfile('{}/lcurve.pkl'.format(self.save_dir)): # Load file if it exists: + self._load() + else: # Create file: + self._save() self._log.debug('Created ' + str(self)) - def calculate(self, lam): + # TODO: Methods for saving and loading l_dict's!!! + def _save(self): + with open('{}/lcurve.pkl'.format(self.save_dir), 'wb') as f: + pickle.dump(self.l_dict, f, pickle.HIGHEST_PROTOCOL) + + def _load(self): + with open('{}/lcurve.pkl'.format(self.save_dir), 'rb') as f: + self.l_dict = pickle.load(f) + + def calculate(self, lambdas, overwrite=False): # TODO: Docstring! - if lam not in self.lams: - # Create new regularisator: # TODO: Not hardcoding FirstOrderRegularisator! - reg = FirstOrderRegularisator(self.fwd_model.data_set.mask, lam, - add_params=self.fwd_model.ramp.n) - cost = Costfunction(fwd_model=self.fwd_model, regularisator=reg) - # Reconstruct: - magdata_rec = reconstruction.optimize_linear(cost, max_iter=self.max_iter, - verbose=self.verbose) - # Save magdata_rec if necessary: - if self.save_dir is not None: - filename = 'magdata_rec_lam{:.0e}.hdf5'.format(lam) - magdata_rec.save(os.path.join(self.save_dir, filename), overwrite=True) - # Append new values: - self.lams.append(lam) - chisq_m, chisq_a = cost.chisq_m[-1], cost.chisq_a[-1] # TODO: is chisq_m list or not? - self.chisq_m.append(chisq_m) - self.chisq_a.append(chisq_a / lam) # TODO: lambda out of regularisator? - # Sort lists according to lambdas: - tuples = zip(*sorted(zip(self.lams, self.chisq_m, self.chisq_a))) - self.lams, self.chisq_m, self.chisq_a = (list(l) for l in tuples) - self._log.info(lam, ' --> m:', chisq_m, ' a:', chisq_a / lam) - return chisq_m, chisq_a / lam + lams = np.atleast_1d(lambdas) + for lam in tqdm(lams, disable=not self.verbose): + if lam not in self.l_dict.keys() or overwrite: + # Create new regularisator and costfunction: # TODO: Not hardcoding FirstOrder! + # TODO: Not necessary if lambda can be extracted from regularisator? self.cost? + reg = FirstOrderRegularisator(self.fwd_model.data_set.mask, lam, + add_params=self.fwd_model.ramp.n) + cost = Costfunction(fwd_model=self.fwd_model, regularisator=reg) + # Reconstruct: + magdata_rec = reconstruction.optimize_linear(cost, max_iter=self.max_iter, + verbose=self.verbose) + # Add new values to dictionary: + chisq_m, chisq_a = cost.chisq_m[-1], cost.chisq_a[-1] # TODO: chisq_m list or not? + self.l_dict[lam] = (chisq_m, chisq_a) + self._log.info(lam, ' --> m:', chisq_m, ' a:', chisq_a) + # Save magdata_rec and dictionary if necessary: + if self.save_dir is not None: + filename = 'magdata_rec_lam{:.0e}.hdf5'.format(lam) + magdata_rec.save(os.path.join(self.save_dir, filename), overwrite=True) + self._save() + + def calculate_auto(self, lam_start=1E-18, lam_end=1E5, online_axis=False): + raise NotImplementedError() + # TODO: Docstring! + # TODO: IMPLEMENT!!! + # # Calculate new cost terms: + # log_m_s, log_a_s = np.log10(self.calculate(lam_start)) + # log_m_e, log_a_e = np.log10(self.calculate(lam_end)) + # # Calculate new lambda: + # log_lam_s, log_lam_e = np.log10(lam_start), np.log10(lam_end) + # log_lam_new = np.mean((log_lam_s, log_lam_e)) # logarithmic mean to find middle on L! + # sign_exp = np.floor(log_lam_new) + # last_sign_digit = np.round(10 ** (log_lam_new - sign_exp)) + # lam_new = last_sign_digit * 10 ** sign_exp + # # Calculate cost terms for new lambda: + # log_m_new, log_a_new = np.log10(self.calculate(lam_new)) + # if online_axis: # Update plot if necessary: + # self.plot(axis=online_axis) + # from IPython import display + # display.clear_output(wait=True) + # display.display(plt.gcf()) + # # Calculate distances from origin and find new interval: + # dist_s, dist_e = np.hypot(log_m_s, log_a_s), np.hypot(log_m_e, log_a_e) + # dist_new = np.hypot(log_m_new, log_a_new) + # print(lam_start, lam_end, lam_new) + # print(dist_s, dist_e, dist_new) + # # if dist_new + # TODO: slope has to be normalised, scale of axes is not equal!!! + # TODO: get rid of right flank (do Not use right points with slope steeper than -45° # TODO: Implement else, return saved values! # TODO: Make this work with batch, sort lambdas at the end! # TODO: After sorting, calculate the CURVATURE for each lambda! (needs 3 points?) - # TODO: Use finite difference methods (forward/backward/central, depending on location)! + # TODO: Use finite difference methods (forward/backward/central, depends on location)! # TODO: Investigate points around highest curvature further. # TODO: Make sure to update ALL curvatures and search for new best EVERYWHERE! # TODO: Distinguish regions of the L-Curve. - def calculate_auto(self, lam_start=1E-18, lam_end=1E5, online_axis=False): - # TODO: Docstring! - # Calculate new cost terms: - log_m_s, log_a_s = np.log10(self.calculate(lam_start)) - log_m_e, log_a_e = np.log10(self.calculate(lam_end)) - # Calculate new lambda: - log_lam_s, log_lam_e = np.log10(lam_start), np.log10(lam_end) - log_lam_new = np.mean((log_lam_s, log_lam_e)) # logarithmic mean to find middle on L! - sign_exp = np.floor(log_lam_new) - last_sign_digit = np.round(10 ** (log_lam_new - sign_exp)) - lam_new = last_sign_digit * 10 ** sign_exp - # Calculate cost terms for new lambda: - log_m_new, log_a_new = np.log10(self.calculate(lam_new)) - if online_axis: # Update plot if necessary: - self.plot(axis=online_axis) - from IPython import display - display.clear_output(wait=True) - display.display(plt.gcf()) - - # TODO: slope has to be normalised, scale of axes is not equal!!! - # TODO: get rid of right flank (do Not use right points with slope steeper than -45° - # Calculate distances from origin and find new interval: - dist_s, dist_e = np.hypot(log_m_s, log_a_s), np.hypot(log_m_e, log_a_e) - dist_new = np.hypot(log_m_new, log_a_new) - print(lam_start, lam_end, lam_new) - print(dist_s, dist_e, dist_new) - #if dist_new - - - def plot(self, axis=None, figsize=None): + def plot(self, lambdas=None, axis=None, figsize=None): # TODO: Docstring! + # Sort lists according to lambdas: + if lambdas is None: + lambdas = sorted(self.l_dict.keys()) + x, y = [], [] + for lam in lambdas: + x.append(self.l_dict[lam][0]) + y.append(self.l_dict[lam][1] / lam) if figsize is None: figsize = plottools.FIGSIZE_DEFAULT if axis is None: @@ -567,23 +594,20 @@ class LCurve(object): axis = fig.add_subplot(1, 1, 1) axis.set_yscale("log", nonposx='clip') axis.set_xscale("log", nonposx='clip') - axis.plot(self.chisq_m, self.chisq_a, 'k-', linewidth=3, zorder=1) - sc = axis.scatter(x=self.chisq_m, y=self.chisq_a, c=self.lams, marker='o', s=100, - zorder=2, cmap='nipy_spectral', norm=LogNorm()) + axis.plot(x, y, 'k-', linewidth=3, zorder=1) + sc = axis.scatter(x, y, c=lambdas, marker='o', s=100, zorder=2, + cmap='nipy_spectral', norm=LogNorm()) plt.colorbar(mappable=sc, label='regularisation parameter $\lambda$') - #plottools.add_cbar(axis, mappable=sc, label='regularisation parameter $\lambda$') - #axis.get_xaxis().get_major_formatter().labelOnlyBase = False axis.set_xlabel( r'$\Vert\mathbf{F}(\mathbf{x})-\mathbf{y}\Vert_{\mathbf{S}_{\epsilon}^{-1}}^{2}$', fontsize=22, labelpad=-5) axis.set_ylabel(r'$\frac{1}{\lambda}\Vert\mathbf{x}\Vert_{\mathbf{S}_{a}^{-1}}^{2}$', fontsize=22) - #axis.set_xlim(3E3, 2E5) - #axis.set_ylim(1E2, 1E9) axis.xaxis.label.set_color('firebrick') axis.yaxis.label.set_color('seagreen') axis.tick_params(axis='both', which='major') axis.grid() + return axis # TODO: Don't plot the steep part on the right... diff --git a/pyramid/fielddata.py b/pyramid/fielddata.py index 7fa4e1eb9165cfc7ace1e75fdf83d69826316419..22cb378ab91260adb71237153ca98264ab76bbfb 100644 --- a/pyramid/fielddata.py +++ b/pyramid/fielddata.py @@ -8,6 +8,7 @@ import abc import logging import os import tempfile +from scipy.ndimage.interpolation import rotate from numbers import Number import numpy as np @@ -21,6 +22,7 @@ import cmocean from . import colors from . import plottools +from .quaternion import Quaternion __all__ = ['VectorData', 'ScalarData'] @@ -323,7 +325,8 @@ class FieldData(object, metaclass=abc.ABCMeta): Only possible, if each axis length is a power of 2! """ - pass + pass # TODO: NotImplementedError instead? See that all classes have the same interface! + # TODO: This means that all common functions of Scalar and Vector have to be abstract here! @abc.abstractmethod def scale_up(self, n, order): @@ -467,6 +470,56 @@ class VectorData(FieldData): def __getitem__(self, item): return self.__class__(self.a, self.field[item]) + def get_vector(self, mask): + """Returns the vector field components arranged in a vector, specified by a mask. + + Parameters + ---------- + mask : :class:`~numpy.ndarray` (N=3, boolean) + Masks the pixels from which the components should be taken. + + Returns + ------- + vector : :class:`~numpy.ndarray` (N=1) + The vector containing vector field components of the specified pixels. + Order is: first all `x`-, then all `y`-, then all `z`-components. + + """ + self._log.debug('Calling get_vector') + if mask is not None: + return np.reshape([self.field[0][mask], + self.field[1][mask], + self.field[2][mask]], -1) + else: + return self.field_vec + + def set_vector(self, vector, mask=None): + """Set the field components of the masked pixels to the values specified by `vector`. + + Parameters + ---------- + mask : :class:`~numpy.ndarray` (N=3, boolean), optional + Masks the pixels from which the components should be taken. + vector : :class:`~numpy.ndarray` (N=1) + The vector containing vector field components of the specified pixels. + Order is: first all `x`-, then all `y-, then all `z`-components. + + Returns + ------- + None + + """ + self._log.debug('Calling set_vector') + assert np.size(vector) % 3 == 0, 'Vector has to contain all 3 components for every pixel!' + count = np.size(vector) // 3 + if mask is not None: + self.field[0][mask] = vector[:count] # x-component + self.field[1][mask] = vector[count:2 * count] # y-component + self.field[2][mask] = vector[2 * count:] # z-component + else: + self.field_vec = vector + + # TODO: scale_down and scale_up should not work in place (also in ScalarData)! def scale_down(self, n=1): """Scale down the field distribution by averaging over two pixels along each axis. @@ -487,16 +540,18 @@ class VectorData(FieldData): """ self._log.debug('Calling scale_down') assert n > 0 and isinstance(n, int), 'n must be a positive integer!' - self.a *= 2 ** n + a_new = self.a * 2 ** n + field_new = self.field for t in range(n): # Pad if necessary: - pz, py, px = self.dim[0] % 2, self.dim[1] % 2, self.dim[2] % 2 + dim = field_new.shape[1:] + pz, py, px = dim[0] % 2, dim[1] % 2, dim[2] % 2 if pz != 0 or py != 0 or px != 0: - self.field = np.pad(self.field, ((0, 0), (0, pz), (0, py), (0, px)), - mode='constant') + field_new = np.pad(field_new, ((0, 0), (0, pz), (0, py), (0, px)), mode='constant') # Create coarser grid for the vector field: - shape_4d = (3, self.dim[0] // 2, 2, self.dim[1] // 2, 2, self.dim[2] // 2, 2) - self.field = self.field.reshape(shape_4d).mean(axis=(6, 4, 2)) + shape_4d = (3, dim[0] // 2, 2, dim[1] // 2, 2, dim[2] // 2, 2) + field_new = field_new.reshape(shape_4d).mean(axis=(6, 4, 2)) + return VectorData(a_new, field_new) def scale_up(self, n=1, order=0): """Scale up the field distribution using spline interpolation of the requested order. @@ -522,10 +577,11 @@ class VectorData(FieldData): assert n > 0 and isinstance(n, int), 'n must be a positive integer!' assert 5 > order >= 0 and isinstance(order, int), \ 'order must be a positive integer between 0 and 5!' - self.a /= 2 ** n - self.field = np.array((zoom(self.field[0], zoom=2 ** n, order=order), - zoom(self.field[1], zoom=2 ** n, order=order), - zoom(self.field[2], zoom=2 ** n, order=order))) + a_new = self.a / 2 ** n + field_new = np.array((zoom(self.field[0], zoom=2 ** n, order=order), + zoom(self.field[1], zoom=2 ** n, order=order), + zoom(self.field[2], zoom=2 ** n, order=order))) + return VectorData(a_new, field_new) def pad(self, pad_values): """Pad the current field distribution with zeros for each individual axis. @@ -551,8 +607,9 @@ class VectorData(FieldData): for i, values in enumerate(pad_values): assert np.shape(values) in [(), (2,)], 'Only one or two values per axis can be given!' pv[2 * i:2 * (i + 1)] = values - self.field = np.pad(self.field, ((0, 0), (pv[0], pv[1]), (pv[2], pv[3]), (pv[4], pv[5])), - mode='constant') + field_pad = np.pad(self.field, ((0, 0), (pv[0], pv[1]), (pv[2], pv[3]), (pv[4], pv[5])), + mode='constant') + return VectorData(self.a, field_pad) def crop(self, crop_values): # TODO: bad copy&paste from pad? """Crop the current field distribution with zeros for each individual axis. @@ -580,56 +637,8 @@ class VectorData(FieldData): cv[2 * i:2 * (i + 1)] = values cv *= np.resize([1, -1], len(cv)) cv = np.where(cv == 0, None, cv) - self.field = self.field[:, cv[0]:cv[1], cv[2]:cv[3], cv[4]:cv[5]] - - def get_vector(self, mask): - """Returns the vector field components arranged in a vector, specified by a mask. - - Parameters - ---------- - mask : :class:`~numpy.ndarray` (N=3, boolean) - Masks the pixels from which the components should be taken. - - Returns - ------- - vector : :class:`~numpy.ndarray` (N=1) - The vector containing vector field components of the specified pixels. - Order is: first all `x`-, then all `y`-, then all `z`-components. - - """ - self._log.debug('Calling get_vector') - if mask is not None: - return np.reshape([self.field[0][mask], - self.field[1][mask], - self.field[2][mask]], -1) - else: - return self.field_vec - - def set_vector(self, vector, mask=None): - """Set the field components of the masked pixels to the values specified by `vector`. - - Parameters - ---------- - mask : :class:`~numpy.ndarray` (N=3, boolean), optional - Masks the pixels from which the components should be taken. - vector : :class:`~numpy.ndarray` (N=1) - The vector containing vector field components of the specified pixels. - Order is: first all `x`-, then all `y-, then all `z`-components. - - Returns - ------- - None - - """ - self._log.debug('Calling set_vector') - assert np.size(vector) % 3 == 0, 'Vector has to contain all 3 components for every pixel!' - count = np.size(vector) // 3 - if mask is not None: - self.field[0][mask] = vector[:count] # x-component - self.field[1][mask] = vector[count:2 * count] # y-component - self.field[2][mask] = vector[2 * count:] # z-component - else: - self.field_vec = vector + field_crop = self.field[:, cv[0]:cv[1], cv[2]:cv[3], cv[4]:cv[5]] + return VectorData(self.a, field_crop) def flip(self, axis='x'): """Flip/mirror the vector field around the specified axis. @@ -659,6 +668,17 @@ class VectorData(FieldData): raise ValueError("Wrong input! 'x', 'y', 'z' allowed!") return VectorData(self.a, field_flip) + def rotate(self, angle, axis='z', reshape=False, **kwargs): + # TODO: Docstring! + # Define axes of rotation plane (axis 0 for vector component!) and rotate coord. system: + axes = {'x': (1, 2), 'y': (1, 3), 'z': (2, 3)}[axis] + field_coord_rot = rotate(self.field, angle, axes=axes, reshape=reshape, **kwargs) + # Rotate vectors inside the voxels (- signs determined by scipy rotate axes in firt step): + vec_dict = {'x': (-1, 0, 0), 'y': (0, 1, 0), 'z': (0, 0, -1)} + quat = Quaternion.from_axisangle(vec_dict[axis], np.deg2rad(angle)) + field_rot = quat.matrix.dot(field_coord_rot.reshape(3, -1)).reshape(field_coord_rot.shape) + return VectorData(self.a, field_rot) + def rot90(self, axis='x'): """Rotate the vector field 90° around the specified axis (right hand rotation). @@ -674,28 +694,65 @@ class VectorData(FieldData): """ self._log.debug('Calling rot90') - if axis == 'x': + if axis == 'x': # RotMatrix for 90°: [[1, 0, 0], [0, 0, -1], [0, 1, 0]] field_rot = np.zeros((3, self.dim[1], self.dim[0], self.dim[2])) for i in range(self.dim[2]): mag_x, mag_y, mag_z = self.field[:, :, :, i] - mag_xrot, mag_yrot, mag_zrot = np.rot90(mag_x), np.rot90(mag_y), np.rot90(mag_z) - field_rot[:, :, :, i] = np.array((mag_xrot, mag_zrot, -mag_yrot)) - elif axis == 'y': + mag_xrot = np.rot90(mag_x) + mag_yrot = np.rot90(mag_z) + mag_zrot = -np.rot90(mag_y) + field_rot[:, :, :, i] = np.array((mag_xrot, mag_yrot, mag_zrot)) + elif axis == 'y': # RotMatrix for 90°: [[0, 0, 1], [0, 1, 0], [-1, 0, 0]] field_rot = np.zeros((3, self.dim[2], self.dim[1], self.dim[0])) for i in range(self.dim[1]): mag_x, mag_y, mag_z = self.field[:, :, i, :] - mag_xrot, mag_yrot, mag_zrot = np.rot90(mag_x), np.rot90(mag_y), np.rot90(mag_z) - field_rot[:, :, i, :] = np.array((mag_zrot, mag_yrot, -mag_xrot)) - elif axis == 'z': + mag_xrot = np.rot90(mag_z) + mag_yrot = np.rot90(mag_y) + mag_zrot = -np.rot90(mag_x) + field_rot[:, :, i, :] = np.array((mag_xrot, mag_yrot, mag_zrot)) + elif axis == 'z': # RotMatrix for 90°: [[0, -1, 0], [1, 0, 0], [0, 0, 1]] field_rot = np.zeros((3, self.dim[0], self.dim[2], self.dim[1])) for i in range(self.dim[0]): mag_x, mag_y, mag_z = self.field[:, i, :, :] - mag_xrot, mag_yrot, mag_zrot = np.rot90(mag_x), np.rot90(mag_y), np.rot90(mag_z) - field_rot[:, i, :, :] = np.array((mag_yrot, -mag_xrot, mag_zrot)) + mag_xrot = np.rot90(mag_y) + mag_yrot = -np.rot90(mag_x) + mag_zrot = np.rot90(mag_z) + field_rot[:, i, :, :] = np.array((mag_xrot, mag_yrot, mag_zrot)) else: raise ValueError("Wrong input! 'x', 'y', 'z' allowed!") return VectorData(self.a, field_rot) + def roll(self, shift, axis='x'): # TODO: Make sure both classes have all manipulator methods! + """Rotate the scalar field 90° around the specified axis (right hand rotation). + + Parameters + ---------- + axis: {'x', 'y', 'z'}, optional + The axis around which the vector field is rotated. + + Returns + ------- + scalardata_rot: :class:`~.ScalarData` + A rotated copy of the :class:`~.ScalarData` object. + + """ + self._log.debug('Calling roll') + ax = {'x': 2, 'y': 1, 'z': 0}[axis] + mag_x_roll = np.roll(self.field[0, ...], shift, ax) + mag_y_roll = np.roll(self.field[1, ...], shift, ax) + mag_z_roll = np.roll(self.field[2, ...], shift, ax) + return VectorData(self.a, np.asarray((mag_x_roll, mag_y_roll, mag_z_roll))) + + def clip_amp(self, threshold): + # TODO: Docstring! + # TODO: 'mag' should be 'vec' everywhere!!! + mag_x, mag_y, mag_z = self.field + scaling = np.where(self.field_amp > threshold, threshold/self.field_amp, 1) + mag_x = mag_x * scaling + mag_y = mag_y * scaling + mag_z = mag_z * scaling + return VectorData(self.a, np.asarray((mag_x, mag_y, mag_z))) + def get_slice(self, ax_slice=None, proj_axis='z'): # TODO: return x y z instead of u v w (to color fields consistent with xyz!) """Extract a slice from the :class:`~.VectorData` object. @@ -1363,6 +1420,47 @@ class ScalarData(FieldData): 'Vector has to match field shape! {} {}'.format(c_vec.shape, np.prod(self.shape)) self.field = c_vec.reshape(self.dim) + def get_vector(self, mask): + """Returns the field as a vector, specified by a mask. + + Parameters + ---------- + mask : :class:`~numpy.ndarray` (N=3, boolean) + Masks the pixels from which the components should be taken. + + Returns + ------- + vector : :class:`~numpy.ndarray` (N=1) + The vector containing the field of the specified pixels. + + """ + self._log.debug('Calling get_vector') + if mask is not None: + return np.reshape(self.field[mask], -1) + else: + return self.field_vec + + def set_vector(self, vector, mask=None): + """Set the field components of the masked pixels to the values specified by `vector`. + + Parameters + ---------- + mask : :class:`~numpy.ndarray` (N=3, boolean), optional + Masks the pixels from which the components should be taken. + vector : :class:`~numpy.ndarray` (N=1) + The vector containing the field of the specified pixels. + + Returns + ------- + None + + """ + self._log.debug('Calling set_vector') + if mask is not None: + self.field[mask] = vector + else: + self.field_vec = vector + def scale_down(self, n=1): """Scale down the field distribution by averaging over two pixels along each axis. @@ -1383,15 +1481,18 @@ class ScalarData(FieldData): """ self._log.debug('Calling scale_down') assert n > 0 and isinstance(n, int), 'n must be a positive integer!' - self.a *= 2 ** n + a_new = self.a * 2 ** n + field_new = self.field for t in range(n): # Pad if necessary: - pz, py, px = self.dim[0] % 2, self.dim[1] % 2, self.dim[2] % 2 + dim = field_new.shape + pz, py, px = dim[0] % 2, dim[1] % 2, dim[2] % 2 if pz != 0 or py != 0 or px != 0: - self.field = np.pad(self.field, ((0, pz), (0, py), (0, px)), mode='constant') + field_new = np.pad(field_new, ((0, pz), (0, py), (0, px)), mode='constant') # Create coarser grid for the field: - shape_4d = (self.dim[0] // 2, 2, self.dim[1] // 2, 2, self.dim[2] // 2, 2) - self.field = self.field.reshape(shape_4d).mean(axis=(5, 3, 1)) + shape_3d = (dim[0] // 2, 2, dim[1] // 2, 2, dim[2] // 2, 2) + field_new = field_new.reshape(shape_3d).mean(axis=(5, 3, 1)) + return ScalarData(a_new, field_new) def scale_up(self, n=1, order=0): """Scale up the field distribution using spline interpolation of the requested order. @@ -1417,49 +1518,17 @@ class ScalarData(FieldData): assert n > 0 and isinstance(n, int), 'n must be a positive integer!' assert 5 > order >= 0 and isinstance(order, int), \ 'order must be a positive integer between 0 and 5!' - self.a /= 2 ** n - self.field = zoom(self.field, zoom=2 ** n, order=order) + a_new = self.a / 2 ** n + field_new = zoom(self.field, zoom=2 ** n, order=order) + return ScalarData(a_new, field_new) - def get_vector(self, mask): - """Returns the field as a vector, specified by a mask. - - Parameters - ---------- - mask : :class:`~numpy.ndarray` (N=3, boolean) - Masks the pixels from which the components should be taken. - - Returns - ------- - vector : :class:`~numpy.ndarray` (N=1) - The vector containing the field of the specified pixels. - - """ - self._log.debug('Calling get_vector') - if mask is not None: - return np.reshape(self.field[mask], -1) - else: - return self.field_vec - - def set_vector(self, vector, mask=None): - """Set the field components of the masked pixels to the values specified by `vector`. + # TODO: flip! - Parameters - ---------- - mask : :class:`~numpy.ndarray` (N=3, boolean), optional - Masks the pixels from which the components should be taken. - vector : :class:`~numpy.ndarray` (N=1) - The vector containing the field of the specified pixels. - - Returns - ------- - None - - """ - self._log.debug('Calling set_vector') - if mask is not None: - self.field[mask] = vector - else: - self.field_vec = vector + def rotate(self, angle, axis='z', reshape=False, **kwargs): + # TODO: Docstring! + axes = {'x': (0, 1), 'y': (0, 2), 'z': (1, 2)}[axis] # Defines axes of plane of rotation! + field_rot = rotate(self.field[...], angle, axes=axes, reshape=reshape, **kwargs) + return ScalarData(self.a, field_rot) def rot90(self, axis='x'): """Rotate the scalar field 90° around the specified axis (right hand rotation). @@ -1492,6 +1561,24 @@ class ScalarData(FieldData): raise ValueError("Wrong input! 'x', 'y', 'z' allowed!") return ScalarData(self.a, field_rot) + def roll(self, shift, axis='x'): # TODO: Make sure both classes have all manipulator methods! + """Rotate the scalar field 90° around the specified axis (right hand rotation). + + Parameters + ---------- + axis: {'x', 'y', 'z'}, optional + The axis around which the vector field is rotated. + + Returns + ------- + scalardata_rot: :class:`~.ScalarData` + A rotated copy of the :class:`~.ScalarData` object. + + """ + self._log.debug('Calling roll') + ax = {'x': 2, 'y': 1, 'z': 0}[axis] + return ScalarData(self.a, np.roll(self.field, shift, ax)) + def to_signal(self): """Convert :class:`~.ScalarData` data into a HyperSpy signal. diff --git a/pyramid/phasemap.py b/pyramid/phasemap.py index 0beb4f1db98d45ef18c3650159d3065d112a8640..f799f40c029d4ac70d2b8e0bbe6576804a1e31ee 100644 --- a/pyramid/phasemap.py +++ b/pyramid/phasemap.py @@ -262,6 +262,8 @@ class PhaseMap(object): return PhaseMap(self.a, self.phase.copy(), self.mask.copy(), self.confidence.copy()) + # TODO: ALL NOT IN PLACE!!! + def scale_down(self, n=1): """Scale down the phase map by averaging over two pixels along each axis. @@ -671,10 +673,14 @@ class PhaseMap(object): if symmetric: vmin, vmax = np.min([vmin, -0]), np.max([0, vmax]) # Ensure zero is present! limit = np.max(np.abs([vmin, vmax])) - start = (vmin + limit) / (2 * limit) - end = (vmax + limit) / (2 * limit) + # A symmetric colormap only has zero at white (the symmetry point) if the values + # of the corresponding mappable go from -limit to +limit (symmetric bounds)! + # Calculate the colors of this symmetric colormap for the range vmin to vmax: + start = 0.5 + vmin/(2*limit) # 0 for symmetric bounds, >0: unused colors at lower end! + end = 0.5 + vmax/(2*limit) # 1 for symmetric bounds, <1: unused colors at upper end! cmap_colors = cmap(np.linspace(start, end, 256)) - cmap = LinearSegmentedColormap.from_list('Symmetric', cmap_colors) + # Use calculated colors to create custom (asymmetric) colormap with white at zero: + cmap = LinearSegmentedColormap.from_list('custom', cmap_colors) # If no axis is specified, a new figure is created: if axis is None: fig = plt.figure(figsize=figsize) @@ -878,11 +884,14 @@ class PhaseMap(object): # Plot histogram: hist_axis = fig.add_subplot(1, 2, 1) vec = self.phase_vec * self.UNITDICT[unit] # Take units into consideration: + # TODO: This is bad! Discard low confidence values completely instead! Otherwise peak at 0! + # TODO: Set to nan and then discard with np.isnan()? vec *= np.where(self.confidence > 0.5, 1, 0).ravel() # Discard low confidence points! hist_axis.hist(vec, bins=bins, histtype='stepfilled', color='g') # Format histogram: x0, x1 = hist_axis.get_xlim() y0, y1 = hist_axis.get_ylim() + # TODO: Why the next line? Seems bad if you want to change things later´! hist_axis.set(aspect=np.abs(x1 - x0) / np.abs(y1 - y0) * 0.94) # Last value because cbar! fontsize = kwargs.get('fontsize', 16) hist_axis.tick_params(axis='both', which='major', labelsize=fontsize) diff --git a/pyramid/projector.py b/pyramid/projector.py index ae0bd98c072307a3dfec8f7ed3b124869d13c89b..99ff190ce9d2f52f3db64a669a6dae22a18f0ca7 100644 --- a/pyramid/projector.py +++ b/pyramid/projector.py @@ -411,6 +411,12 @@ class RotTiltProjector(Projector): self._log.debug('Calling __init__') self.rotation = rotation self.tilt = tilt + # Create tilt, rotation and combined quaternion, careful: Quaternion(w,x,y,z), not (z,y,x): + quat_z_n = Quaternion.from_axisangle((0, 0, 1), -rotation) # Rotate around z-axis + quat_x = Quaternion.from_axisangle((1, 0, 0), tilt) # Tilt around x-axis + quat_z_p = Quaternion.from_axisangle((0, 0, 1), rotation) # Rotate around z-axis + # Combined quaternion (first rotate around z, then tilt around x, rotate back around z): + quat = quat_z_n * quat_x * quat_z_p # p: positive rotation, m: negative rotation # Determine dimensions: dim_z, dim_y, dim_x = dim center = (dim_z / 2., dim_y / 2., dim_x / 2.) @@ -421,18 +427,17 @@ class RotTiltProjector(Projector): dim_v, dim_u = dim_uv # Creating coordinate list of all voxels: voxels = list(itertools.product(range(dim_z), range(dim_y), range(dim_x))) - # Calculate vectors to voxels relative to rotation center: - voxel_vecs = (np.asarray(voxels) + 0.5 - np.asarray(center)).T - # Create tilt, rotation and combined quaternion, careful: Quaternion(w,x,y,z), not (z,y,x): - quat_z_n = Quaternion.from_axisangle((0, 0, 1), -rotation) # Rotate around z-axis - quat_x = Quaternion.from_axisangle((1, 0, 0), tilt) # Tilt around x-axis - quat_z_p = Quaternion.from_axisangle((0, 0, 1), rotation) # Rotate around z-axis - # Combined quaternion (first rotate around z, then tilt around x, rotate back around z): - quat = quat_z_n * quat_x * quat_z_p # p: positive rotation, m: negative rotation - # Calculate impact positions on the projected pixel coordinate grid (flip because quat.): - impacts = np.flipud(quat.matrix[:2, :].dot(np.flipud(voxel_vecs))) # only care for x/y - impacts[1, :] += dim_u / 2. # Shift back to normal indices - impacts[0, :] += dim_v / 2. # Shift back to normal indices + # Calculate vectors to voxels relative to rotation center (each row contains (z, y, x)): + voxel_vecs = np.asarray(voxels) + 0.5 - np.asarray(center) + # Change to coordinate order of quaternions (x, y, z) instead of (z, y, x): + voxel_vecs = np.fliplr(voxel_vecs) + # Calculate impact positions (x, y) on the projected pixel coordinate grid (z is dropped): + impacts = quat.matrix[:2, :].dot(voxel_vecs.T).T # Only x and y row of matrix is used! + # Reverse transpose and change back coordinate order from (x, y) to (y, x)/(v, u): + impacts = np.fliplr(impacts.T) # Now contains rows with (v, u) entries as desired! + # First index: voxel, second index: 0 -> v, 1 -> u! + impacts[:, 1] += dim_u / 2. # Shift back to normal indices + impacts[:, 0] += dim_v / 2. # Shift back to normal indices # Calculate equivalence radius: R = (3 / (4 * np.pi)) ** (1 / 3.) # Prepare weight matrix calculation: @@ -446,7 +451,7 @@ class RotTiltProjector(Projector): for i, voxel in enumerate(tqdm(voxels, disable=disable, leave=False, desc='Set up projector')): column_index = voxel[0] * dim_y * dim_x + voxel[1] * dim_x + voxel[2] - remainder, impact = np.modf(impacts[:, i]) # split index of impact and remainder! + remainder, impact = np.modf(impacts[i, :]) # split index of impact and remainder! sub_pixel = (remainder * subcount).astype(dtype=np.int) # sub_pixel inside impact px. # Go over all influenced pixels (impact and neighbours, indices are [0, 1, 2]!): for px_ind in list(itertools.product(range(3), range(3))): diff --git a/pyramid/quaternion.py b/pyramid/quaternion.py index 568e2676af8bf7c117729f3e4a801fadec3d55cd..7aeb79574e7c579e7b527470440eff912f00c1ed 100644 --- a/pyramid/quaternion.py +++ b/pyramid/quaternion.py @@ -64,11 +64,18 @@ class Quaternion(object): self._log.debug('Calling __mul__') if isinstance(other, Quaternion): # Quaternion multiplication return self.dot_quat(self, other) - elif len(other) == 3: # vector multiplication + elif len(other) == 3: # vector multiplication (Caution: normalises!) q_vec = Quaternion((0,) + tuple(other)) q = self.dot_quat(self.dot_quat(self, q_vec), self.conj) return q.values[1:] + def _normalize(self): + self._log.debug('Calling _normalize') + mag2 = np.sum(n ** 2 for n in self.values) + if abs(mag2 - 1.0) > self.NORM_TOLERANCE: + mag = np.sqrt(mag2) + self.values = tuple(n / mag for n in self.values) + def dot_quat(self, q1, q2): """Multiply two :class:`~.Quaternion` objects to create a new one (always normalized). @@ -92,13 +99,6 @@ class Quaternion(object): z = w1 * z2 + z1 * w2 + x1 * y2 - y1 * x2 return Quaternion((w, x, y, z)) - def _normalize(self): - self._log.debug('Calling _normalize') - mag2 = np.sum(n ** 2 for n in self.values) - if abs(mag2 - 1.0) > self.NORM_TOLERANCE: - mag = np.sqrt(mag2) - self.values = tuple(n / mag for n in self.values) - @classmethod def from_axisangle(cls, vector, theta): """Create a quaternion from an axis-angle representation diff --git a/pyramid/utils/lorentz.py b/pyramid/utils/lorentz.py new file mode 100644 index 0000000000000000000000000000000000000000..22a2b7e230d70e9fa79df298a5d54f198d8afe57 --- /dev/null +++ b/pyramid/utils/lorentz.py @@ -0,0 +1,45 @@ +from scipy import constants +import numpy as np + +def electron_wavelength(ht): + """ + Returns electron wavelenght in nm. + Parameters + ---------- + ht : float + High tension in kV. + """ + momentum = 2 * constants.m_e * constants.elementary_charge * ht * 1000 * ( + 1 + constants.elementary_charge * ht * 1000 / (2 * constants.m_e * constants.c ** 2)) + wavelength = constants.h / np.sqrt(momentum) * 1e9 # in nm + return wavelength + + +def aberration_function(w, aber_dict, v_acc): + # TODO: Taken from Florian! Use dictionary! + w_cc = np.conjugate(w) + chi_i = {'C1': aber_dict['C1'] * w * w_cc / 2} + chi_sum = np.zeros_like(w) + for key in aber_dict.keys(): + chi_sum += chi_i[key] + return (2 * np.pi / electron_wavelength(v_acc)) * np.real(chi_sum) + + +def apply_aberrations(phasemap, aber_dict, v_acc): + + # Define complex scattering angle w + f_freq_v = np.fft.fftfreq(phasemap.dim_uv[0], phasemap.a) + f_freq_u = np.fft.fftfreq(phasemap.dim_uv[1], phasemap.a) + f_freq_mesh = np.meshgrid(f_freq_u, f_freq_v) + + w = f_freq_mesh[0] + 1j * f_freq_mesh[1] + w *= electron_wavelength(v_acc) + + chi = aberration_function(w, aber_dict, v_acc) + + wave = np.exp(1j * phasemap.phase) + wave_fft = np.fft.fftn(wave) / np.prod(phasemap.dim_uv) + wave_fft *= np.exp(-1j * chi) + wave = np.fft.ifftn(wave_fft) * np.prod(phasemap.dim_uv) + + return wave, chi diff --git a/tests/test_fielddata.py b/tests/test_fielddata.py index 8ca3b77443ec3bcd548b17bd5a036befc229c8f3..c68e614915743f2b950852e0d9404d2b08f4fddd 100644 --- a/tests/test_fielddata.py +++ b/tests/test_fielddata.py @@ -29,20 +29,20 @@ class TestCaseVectorData(unittest.TestCase): assert magdata_copy != self.magdata, 'Unexpected behaviour in copy()!' def test_scale_down(self): - self.magdata.scale_down() + magdata_test = self.magdata.scale_down() reference = 1 / 8. * np.ones((3, 2, 2, 2)) - assert_allclose(self.magdata.field, reference, + assert_allclose(magdata_test.field, reference, err_msg='Unexpected behavior in scale_down()!') - assert_allclose(self.magdata.a, 20, + assert_allclose(magdata_test.a, 20, err_msg='Unexpected behavior in scale_down()!') def test_scale_up(self): - self.magdata.scale_up() + magdata_test = self.magdata.scale_up() reference = np.zeros((3, 8, 8, 8)) reference[:, 2:6, 2:6, 2:6] = 1 - assert_allclose(self.magdata.field, reference, + assert_allclose(magdata_test.field, reference, err_msg='Unexpected behavior in scale_down()!') - assert_allclose(self.magdata.a, 5, + assert_allclose(magdata_test.a, 5, err_msg='Unexpected behavior in scale_down()!') def test_pad(self):