diff --git a/pyramid/colors.py b/pyramid/colors.py index a2fc3821b54bf6bd9d3d5af9e503d074dcc2b186..35a6ad75a9b30d7c3cc6f368fbd7c4fe8f18326c 100644 --- a/pyramid/colors.py +++ b/pyramid/colors.py @@ -37,6 +37,7 @@ import abc from . import plottools +# TODO: categorize colormaps as sequential, divergent, or cyclic! __all__ = ['Colormap3D', 'ColormapCubehelix', 'ColormapPerception', 'ColormapHLS', 'ColormapClassic', 'ColormapTransparent', 'cmaps', 'CMAP_CIRCULAR_DEFAULT', diff --git a/pyramid/costfunction.py b/pyramid/costfunction.py index c74c0c5d42eac51cf7483b696d9bc2d1e4bbca9e..e92d3c3d28937502082c901e3cca5f75da502b2f 100644 --- a/pyramid/costfunction.py +++ b/pyramid/costfunction.py @@ -47,7 +47,7 @@ class Costfunction(object): _log = logging.getLogger(__name__ + '.Costfunction') - def __init__(self, fwd_model, regularisator=None): + def __init__(self, fwd_model, regularisator=None, track_cost_iterations=10): self._log.debug('Calling __init__') self.fwd_model = fwd_model if regularisator is None: @@ -59,6 +59,10 @@ class Costfunction(object): self.n = self.fwd_model.n self.m = self.fwd_model.m self.Se_inv = self.fwd_model.Se_inv + self.chisq_m = [] + self.chisq_a = [] + self.track_cost_iterations = track_cost_iterations + self.cnt_hess_dot = 0 self._log.debug('Created ' + str(self)) def __repr__(self): @@ -72,13 +76,19 @@ class Costfunction(object): (self.fwd_model, self.fwd_model, self.regularisator) def __call__(self, x): - delta_y = self.fwd_model(x) - self.y - self.chisq_m = delta_y.dot(self.Se_inv.dot(delta_y)) - self.chisq_a = self.regularisator(x) - self.chisq = self.chisq_m + self.chisq_a + self.calculate_costs(x) + self.chisq = self.chisq_m[-1] + self.chisq_a[-1] return self.chisq + def calculate_costs(self, x): + # TODO: Docstring! + delta_y = self.fwd_model(x) - self.y + self.chisq_m.append(delta_y.dot(self.Se_inv.dot(delta_y))) + self.chisq_a.append(self.regularisator(x)) + + def init(self, x): + # TODO: Ask Jörn, why this exists! """Initialise the costfunction by calculating the different cost terms. Parameters @@ -131,10 +141,16 @@ class Costfunction(object): Product of the input `vector` with the Hessian matrix of the costfunction. """ + # TODO: Tracking better as decorator function? Useful for other things? + self.cnt_hess_dot += 1 # TODO: Ask Jörn if this belongs here or in CountingCostFunction! + if self.track_cost_iterations > 0 and self.cnt_hess_dot % self.track_cost_iterations == 0: + self.calculate_costs(vector) + #print(self.cnt_hess_dot, len(self.chisq_a)) # TODO:!!! return (2 * self.fwd_model.jac_T_dot(x, self.Se_inv.dot(self.fwd_model.jac_dot(x, vector))) + self.regularisator.hess_dot(x, vector)) def hess_diag(self, _): + # TODO: needed for preconditioner? """ Return the diagonal of the Hessian. Parameters diff --git a/pyramid/diagnostics.py b/pyramid/diagnostics.py index 047b45aff795205733770f68c59f246dfae28d04..25531dbda7b8acc86ae489f14f01c00613e7cb53 100644 --- a/pyramid/diagnostics.py +++ b/pyramid/diagnostics.py @@ -5,21 +5,30 @@ """This module provides the :class:`~.Diagnostics` class for the calculation of diagnostics of a specified costfunction for a fixed magnetization distribution.""" +import os + import logging +from pyramid.forwardmodel import ForwardModel +from pyramid.costfunction import Costfunction +from pyramid.regularisator import FirstOrderRegularisator from pyramid.fielddata import VectorData from pyramid.phasemap import PhaseMap +from pyramid import reconstruction +from pyramid import plottools import matplotlib.pyplot as plt from matplotlib import patches from matplotlib import patheffects from matplotlib.ticker import FuncFormatter +from matplotlib.colors import LogNorm import numpy as np import jutil -__all__ = ['Diagnostics', 'get_vector_field_errors'] +__all__ = ['Diagnostics', 'LCurve','get_vector_field_errors'] +# TODO: should be subpackage, distribute methods and classes to separate modules! class Diagnostics(object): """Class for calculating diagnostic properties of a specified costfunction. @@ -464,6 +473,120 @@ class Diagnostics(object): return plottools.format_axis(axis, hideaxes=True, scalebar=False) +class LCurve(object): + + # TODO: Docstring! + + # TODO: save magdata_rec! + + _log = logging.getLogger(__name__ + '.FieldData') + + def __init__(self, fwd_model, max_iter=0, verbose=False, save_dir=None): + self._log.debug('Calling __init__') + assert isinstance(fwd_model, ForwardModel), 'Input has to be a costfunction' + self.fwd_model = fwd_model + self.max_iter = max_iter + self.verbose = verbose + self.lams = [] + self.chisq_a = [] + self.chisq_m = [] + if save_dir is not None: + assert os.path.isdir(save_dir), 'save_dir has to be None or a valid directory!' + self.save_dir = save_dir + self._log.debug('Created ' + str(self)) + + def calculate(self, lam): + # TODO: Docstring! + if lam not in self.lams: + # Create new regularisator: # TODO: Not hardcoding FirstOrderRegularisator! + reg = FirstOrderRegularisator(self.fwd_model.data_set.mask, lam, + add_params=self.fwd_model.ramp.n) + cost = Costfunction(fwd_model=self.fwd_model, regularisator=reg) + # Reconstruct: + magdata_rec = reconstruction.optimize_linear(cost, max_iter=self.max_iter, + verbose=self.verbose) + # Save magdata_rec if necessary: + if self.save_dir is not None: + filename = 'magdata_rec_lam{:.0e}.hdf5'.format(lam) + magdata_rec.save(os.path.join(self.save_dir, filename), overwrite=True) + # Append new values: + self.lams.append(lam) + chisq_m, chisq_a = cost.chisq_m[-1], cost.chisq_a[-1] # TODO: is chisq_m list or not? + self.chisq_m.append(chisq_m) + self.chisq_a.append(chisq_a / lam) # TODO: lambda out of regularisator? + # Sort lists according to lambdas: + tuples = zip(*sorted(zip(self.lams, self.chisq_m, self.chisq_a))) + self.lams, self.chisq_m, self.chisq_a = (list(l) for l in tuples) + self._log.info(lam, ' --> m:', chisq_m, ' a:', chisq_a / lam) + return chisq_m, chisq_a / lam + # TODO: Implement else, return saved values! + # TODO: Make this work with batch, sort lambdas at the end! + # TODO: After sorting, calculate the CURVATURE for each lambda! (needs 3 points?) + # TODO: Use finite difference methods (forward/backward/central, depending on location)! + # TODO: Investigate points around highest curvature further. + # TODO: Make sure to update ALL curvatures and search for new best EVERYWHERE! + # TODO: Distinguish regions of the L-Curve. + + def calculate_auto(self, lam_start=1E-18, lam_end=1E5, online_axis=False): + # TODO: Docstring! + # Calculate new cost terms: + log_m_s, log_a_s = np.log10(self.calculate(lam_start)) + log_m_e, log_a_e = np.log10(self.calculate(lam_end)) + # Calculate new lambda: + log_lam_s, log_lam_e = np.log10(lam_start), np.log10(lam_end) + log_lam_new = np.mean((log_lam_s, log_lam_e)) # logarithmic mean to find middle on L! + sign_exp = np.floor(log_lam_new) + last_sign_digit = np.round(10 ** (log_lam_new - sign_exp)) + lam_new = last_sign_digit * 10 ** sign_exp + # Calculate cost terms for new lambda: + log_m_new, log_a_new = np.log10(self.calculate(lam_new)) + if online_axis: # Update plot if necessary: + self.plot(axis=online_axis) + from IPython import display + display.clear_output(wait=True) + display.display(plt.gcf()) + + # TODO: slope has to be normalised, scale of axes is not equal!!! + # TODO: get rid of right flank (do Not use right points with slope steeper than -45° + # Calculate distances from origin and find new interval: + dist_s, dist_e = np.hypot(log_m_s, log_a_s), np.hypot(log_m_e, log_a_e) + dist_new = np.hypot(log_m_new, log_a_new) + print(lam_start, lam_end, lam_new) + print(dist_s, dist_e, dist_new) + #if dist_new + + + + def plot(self, axis=None, figsize=None): + # TODO: Docstring! + if figsize is None: + figsize = plottools.FIGSIZE_DEFAULT + if axis is None: + self._log.debug('axis is None') + fig = plt.figure(figsize=figsize) + axis = fig.add_subplot(1, 1, 1) + axis.set_yscale("log", nonposx='clip') + axis.set_xscale("log", nonposx='clip') + axis.plot(self.chisq_m, self.chisq_a, 'k-', linewidth=3, zorder=1) + sc = axis.scatter(x=self.chisq_m, y=self.chisq_a, c=self.lams, marker='o', s=100, + zorder=2, cmap='nipy_spectral', norm=LogNorm()) + plt.colorbar(mappable=sc, label='regularisation parameter $\lambda$') + #plottools.add_cbar(axis, mappable=sc, label='regularisation parameter $\lambda$') + #axis.get_xaxis().get_major_formatter().labelOnlyBase = False + axis.set_xlabel( + r'$\Vert\mathbf{F}(\mathbf{x})-\mathbf{y}\Vert_{\mathbf{S}_{\epsilon}^{-1}}^{2}$', + fontsize=22, labelpad=-5) + axis.set_ylabel(r'$\frac{1}{\lambda}\Vert\mathbf{x}\Vert_{\mathbf{S}_{a}^{-1}}^{2}$', + fontsize=22) + #axis.set_xlim(3E3, 2E5) + #axis.set_ylim(1E2, 1E9) + axis.xaxis.label.set_color('firebrick') + axis.yaxis.label.set_color('seagreen') + axis.tick_params(axis='both', which='major') + axis.grid() + # TODO: Don't plot the steep part on the right... + + def get_vector_field_errors(vector_data, vector_data_ref): """After Kemp et. al.: Analysis of noise-induced errors in vector-field electron tomography""" v, vr = vector_data.field, vector_data_ref.field @@ -479,3 +602,55 @@ def get_vector_field_errors(vector_data, vector_data_ref): rms_mag = np.sqrt(np.nansum((va - vra)**2) / np.nansum(vra**2)) # Return results as tuple: return rms_tot, rms_dir, rms_mag + + +# TODO: SVD as function for magnetic distributions! +# TODO: Plot only singular vectors, nullspace, or both! +# TODO: Jörn fragen, warum der Nullraum nur mit Maske eingeht!! +# from matplotlib.ticker import MultipleLocator +# n = 32 +# dim_uv = (n, n) +# mapper = pr.PhaseMapperRDFC(kernel=pr.Kernel(a=1, dim_uv=dim_uv)) +# mat = np.asarray([mapper.jac_dot(np.eye(1, 2*n**2, k=k).T) for k in range(2*n**2)]).T +# u, s, vh = sp.linalg.svd(mat, full_matrices=True) +# +# mag_hal = pr.magcreator.examples.smooth_vortex_disc(dim=(1,n,n)) +# phasemap = pr.utils.pm(mag_hal) +# #phasemap.mask = np.ones_like(phasemap.phase, dtype=bool) +# mag_hal_null, cost = pr.utils.reconstruction_2d_from_phasemap(phasemap, max_iter=5000, lam=1E-30) +# +# mag_hal_null.plot_quiver_field(scalebar=False, hideaxes=True, b_0=1) +# (mag_hal_null-mag_hal).plot_quiver_field(scalebar=False, hideaxes=True, b_0=1) +# pr.utils.pm(mag_hal_null).plot_phase() +# mag_hal_vec = mag_hal_null.field_vec[:2*n**2] # Discard z +# coeffs = vh.dot(mag_hal_vec) +# +# fig, axis = plt.subplots(1, 1) +# axis.plot(range(1, len(coeffs)+1), coeffs, 'bo', markersize=4) +# axis.axvline(x=n**2, color='k', linestyle='--') +# axis.set_xlim(0, 2*n**2) +# axis.set_ylim(-1.5, 2.6) +# axis.xaxis.set_major_locator(MultipleLocator(base=512)) +# axis.set_ylabel('Coefficient') +# axis.set_xlabel('Index of column vector') +# plt.text(200, 2.35, 'singular vectors', fontdict={'fontsize':18}) +# plt.text(200+n**2, 2.35, 'null space basis', fontdict={'fontsize':18}) +# +# coeffs_new = np.copy(coeffs) +# coeffs_new[:n**2] = 0 +# mag_hal_new = vh.T.dot(coeffs_new) +# mag_hal_range = mag_hal.copy() +# mag_hal_range.set_vector(np.concatenate((mag_hal_new, np.zeros(n**2)))) +# mag_hal_range.plot_quiver_field(b_0=1) +# pr.utils.pm(mag_hal_range).plot_phase() +# +# fig, axis = plt.subplots(1, 1) +# axis.plot(range(1, len(coeffs)+1), coeffs_new, 'bo', markersize=4) +# axis.axvline(x=n**2, color='k', linestyle='--') +# axis.set_xlim(0, 2*n**2) +# axis.set_ylim(-1.5, 2.6) +# axis.xaxis.set_major_locator(MultipleLocator(base=512)) +# axis.set_ylabel('Coefficient') +# axis.set_xlabel('Index of column vector') +# plt.text(200, 2.35, 'singular vectors', fontdict={'fontsize':18}) +# plt.text(200+n**2, 2.35, 'null space basis', fontdict={'fontsize':18}) diff --git a/pyramid/fielddata.py b/pyramid/fielddata.py index 22fc4be5298dbc53482a25874f6b27322644d8f1..d9db4e9e116deed9e48c3b2e895b7628ab18d936 100644 --- a/pyramid/fielddata.py +++ b/pyramid/fielddata.py @@ -14,9 +14,11 @@ import numpy as np from PIL import Image from matplotlib import patheffects from matplotlib import pyplot as plt -from matplotlib.colors import ListedColormap +from matplotlib.colors import ListedColormap, LinearSegmentedColormap from scipy.ndimage.interpolation import zoom +import cmocean + from . import colors from . import plottools @@ -86,7 +88,18 @@ class FieldData(object, metaclass=abc.ABCMeta): if len(self.shape) == 4: return np.sqrt(np.sum(self.field ** 2, axis=0)) else: - return self.field + return np.abs(self.field) + + @property + def field_vec(self): + """Vector containing the vector field distribution.""" + return np.reshape(self.field, -1) + + @field_vec.setter + def field_vec(self, mag_vec): + assert np.size(mag_vec) == np.prod(self.shape), \ + 'Vector has to match field shape! {} {}'.format(mag_vec.shape, np.prod(self.shape)) + self.field = mag_vec.reshape((3,) + self.dim) def __init__(self, a, field): self._log.debug('Calling __init__') @@ -839,7 +852,7 @@ class VectorData(FieldData): if figsize is None: figsize = plottools.FIGSIZE_DEFAULT assert proj_axis == 'z' or proj_axis == 'y' or proj_axis == 'x', \ - 'Axis has to be x, y or z (as string).' + "Axis has to be 'x', 'y' or 'z'." if ax_slice is None: ax_slice = self.dim[{'z': 0, 'y': 1, 'x': 2}[proj_axis]] // 2 # Extract slice and mask: @@ -986,7 +999,7 @@ class VectorData(FieldData): if figsize is None: figsize = plottools.FIGSIZE_DEFAULT assert proj_axis == 'z' or proj_axis == 'y' or proj_axis == 'x', \ - 'Axis has to be x, y or z (as string).' + "Axis has to be 'x', 'y' or 'z'." if ax_slice is None: ax_slice = self.dim[{'z': 0, 'y': 1, 'x': 2}[proj_axis]] // 2 # Extract slice and mask: @@ -1306,17 +1319,6 @@ class VectorData(FieldData): kwargs.setdefault('hideaxes', True) return plottools.format_axis(axis, hideaxes=True, scalebar=False) - @property - def field_vec(self): - """Vector containing the vector field distribution.""" - return np.reshape(self.field, -1) - - @field_vec.setter - def field_vec(self, mag_vec): - assert np.size(mag_vec) == np.prod(self.shape), \ - 'Vector has to match field shape! {} {}'.format(mag_vec.shape, np.prod(self.shape)) - self.field = mag_vec.reshape((3,) + self.dim) - class ScalarData(FieldData): """Class for storing scalar field data. @@ -1339,6 +1341,17 @@ class ScalarData(FieldData): """ _log = logging.getLogger(__name__ + '.ScalarData') + @property + def field_vec(self): + """Vector containing the scalar field distribution.""" + return np.reshape(self.field, -1) + + @field_vec.setter + def field_vec(self, c_vec): + assert np.size(c_vec) == np.prod(self.shape), \ + 'Vector has to match field shape! {} {}'.format(c_vec.shape, np.prod(self.shape)) + self.field = c_vec.reshape(self.dim) + def scale_down(self, n=1): """Scale down the field distribution by averaging over two pixels along each axis. @@ -1513,15 +1526,124 @@ class ScalarData(FieldData): from .file_io.io_scalardata import save_scalardata save_scalardata(self, filename, **kwargs) - @property - def field_vec(self): - """Vector containing the scalar field distribution.""" - return np.reshape(self.field, -1) + def get_slice(self, ax_slice=None, proj_axis='z'): + # TODO: Docstring! + """Extract a slice from the :class:`~.VectorData` object. - @field_vec.setter - def field_vec(self, c_vec): - assert np.size(c_vec) == np.prod(self.shape), \ - 'Vector has to match field shape! {} {}'.format(c_vec.shape, np.prod(self.shape)) - self.field = c_vec.reshape(self.dim) + Parameters + ---------- + proj_axis : {'z', 'y', 'x'}, optional + The axis, from which the slice is taken. The default is 'z'. + ax_slice : None or int, optional + The slice-index of the axis specified in `proj_axis`. Defaults to the center slice. + + Returns + ------- + u_mag, v_mag, w_mag, submask : :class:`~numpy.ndarray` (N=2) + The extracted vector field components in plane perpendicular to the `proj_axis` and + the perpendicular component. + + """ + self._log.debug('Calling get_slice') + # Find slice: + assert proj_axis == 'z' or proj_axis == 'y' or proj_axis == 'x', \ + 'Axis has to be x, y or z (as string).' + if ax_slice is None: + ax_slice = self.dim[{'z': 0, 'y': 1, 'x': 2}[proj_axis]] // 2 + if proj_axis == 'z': # Slice of the xy-plane with z = ax_slice + self._log.debug('proj_axis == z') + scalar_slice = np.copy(self.field[ax_slice, ...]) + elif proj_axis == 'y': # Slice of the xz-plane with y = ax_slice + self._log.debug('proj_axis == y') + scalar_slice = np.copy(self.field[:, ax_slice, :]) + elif proj_axis == 'x': # Slice of the zy-plane with x = ax_slice + self._log.debug('proj_axis == x') + scalar_slice = np.copy(self.field[..., ax_slice]) + else: + raise ValueError('{} is not a valid argument (use x, y or z)'.format(proj_axis)) + return scalar_slice + + + def plot_field(self, proj_axis='z', ax_slice=None, show_mask=True, bgcolor=None, axis=None, + figsize=None, vmin=None, vmax=None, symmetric=False, cmap=None, cbar=True, + **kwargs): + # TODO: Docstring! + """Plot a slice of the scalar field as an imshow plot. + + Parameters + ---------- + proj_axis : {'z', 'y', 'x'}, optional + The axis, from which a slice is plotted. The default is 'z'. + ax_slice : int, optional + The slice-index of the axis specified in `proj_axis`. Is set to the center of + `proj_axis` if not specified. + show_mask: boolean + Default is True. Shows the outlines of the mask slice if available. + bgcolor: {'white', 'black'}, optional + Determines the background color of the plot. + axis : :class:`~matplotlib.axes.AxesSubplot`, optional + Axis on which the graph is plotted. Creates a new figure if none is specified. + figsize : tuple of floats (N=2) + Size of the plot figure. + + Returns + ------- + axis: :class:`~matplotlib.axes.AxesSubplot` + The axis on which the graph is plotted. + + Notes + ----- + Uses :func:`~.plottools.format_axis` at the end. According keywords can also be given here. + + """ + self._log.debug('Calling plot_field') + a = self.a + if figsize is None: + figsize = plottools.FIGSIZE_DEFAULT + assert proj_axis == 'z' or proj_axis == 'y' or proj_axis == 'x', \ + "Axis has to be 'x', 'y' or 'z'." + if ax_slice is None: + ax_slice = self.dim[{'z': 0, 'y': 1, 'x': 2}[proj_axis]] // 2 + # Extract slice and mask: + field_slice = self.get_slice(ax_slice, proj_axis) + submask = np.where(np.abs(field_slice) > 0, True, False) + dim_uv = field_slice.shape + # If no axis is specified, a new figure is created: + if axis is None: + self._log.debug('axis is None') + fig = plt.figure(figsize=figsize) + axis = fig.add_subplot(1, 1, 1) + tight = True + else: + tight = False + axis.set_aspect('equal') + # Configure colormap and fix center to zero if colormap is symmetric: + if cmap is None: + cmap = cmocean.cm.thermal + elif isinstance(cmap, str): # Get colormap if given as string: + cmap = plt.get_cmap(cmap) + if symmetric: # TODO: symmetric should be called divergent (see cmocean)! + vmin, vmax = np.min([vmin, -0]), np.max([0, vmax]) # Ensure zero is present! + limit = np.max(np.abs([vmin, vmax])) + start = (vmin + limit) / (2 * limit) + end = (vmax + limit) / (2 * limit) + cmap_colors = cmap(np.linspace(start, end, 256)) + cmap = LinearSegmentedColormap.from_list('Symmetric', cmap_colors) + # Plot the field: + im = axis.imshow(field_slice, cmap=cmap, vmin=vmin, vmax=vmax, origin='lower', + interpolation='none', extent=(0, dim_uv[1], 0, dim_uv[0])) + if show_mask and not np.all(submask): # Plot mask if desired and not trivial! + vv, uu = np.indices(dim_uv) + 0.5 + axis.contour(uu, vv, submask, levels=[0.5], colors='k', + linestyles='dotted', linewidths=2) + if bgcolor is not None: + pass#axis.set_facecolor(bgcolor) # TODO: Activate for matplotlib 2.0! + # Determine colorbar title: + cbar_mappable = None + if cbar: + cbar_mappable = im + # Return formatted axis: + return plottools.format_axis(axis, sampling=a, cbar_mappable=cbar_mappable, + tight_layout=tight, **kwargs) # TODO: Histogram plots for magnetisation (see thesis!) diff --git a/pyramid/forwardmodel.py b/pyramid/forwardmodel.py index 529160ce3c854dd98e55b777306424f4c54eef34..bbb2e2579e4d13826863ea2cd7361f647dfd9f8a 100644 --- a/pyramid/forwardmodel.py +++ b/pyramid/forwardmodel.py @@ -21,6 +21,19 @@ __all__ = ['ForwardModel', 'ForwardModelCharge', 'DistributedForwardModel'] # TODO: Ramp should be a forward model itself! Instead of hookpoints, each ForwardModel should # TODO: keep track of start and end in vector x (defaults: 0 and -1)! # TODO: Write CombinedForwardModel class! +# TODO: DataSet should be an argument, but should ONLY contain phasemaps! +# TODO: Maybe a list of PhaseMaps is even better and no DataSet class is needed? +# TODO: But what about the convenience functions? +# TODO: PhaseMaps should contain info about their projection direction! +# TODO: The ForwardModel should then setup the projectors accordingly from this info! +# TODO: Se_inv should be a class of its own and should be constructed by the ForwardModel init! +# TODO: Same goes for the mask! Se_inv and mask should contain all constructers and functions! +# TODO: This way, everything is set up and given AFTER all data are collected, because +# TODO: Se_inv and mask can't be set up before... +# TODO: Hook points belong to the forward models (or better the CombinedForwardModel) +# TODO: Maybe have one ForwardModel per image? (maybe not a good idea...?) +# TODO: Build factory convenience functions for constructing CombinedForwardModels! +# TODO: DistributedForwardModel and CombinedForwardModel could be the same thing?! class ForwardModel(object): """Class for mapping 3D magnetic distributions to 2D phase maps. @@ -82,6 +95,8 @@ class ForwardModel(object): return 'ForwardModel(data_set=%s)' % self.data_set def __call__(self, x): + # TODO: Have an extra forward model without the projector part? + # TODO: Which also corrects for the thickness? Would be nice! # Extract ramp parameters if necessary (x will be shortened!): x = self.ramp.extract_ramp_params(x) # Reset magdata and fill with vector: diff --git a/pyramid/kernel.py b/pyramid/kernel.py index dd6ef5a0d759b2cee8631d8f1bb22f1282d0cea9..d0a04aed3ce86bc32412f08584bf14d3ae61a6a7 100644 --- a/pyramid/kernel.py +++ b/pyramid/kernel.py @@ -11,7 +11,7 @@ import numpy as np from jutil import fft -__all__ = ['Kernel', 'PHI_0'] +__all__ = ['Kernel', 'PHI_0', 'KernelCharge'] PHI_0 = 2067.83 # magnetic flux in T*nm² H_BAR = 6.626E-34 # Planck constant in J*s diff --git a/pyramid/magcreator/magcreator.py b/pyramid/magcreator/magcreator.py index d74da1aa9c93ab68223d84811e5d9ee5e402d2c5..cd165040d34c2f7f2ba8d0940a0605f6d02f89ad 100644 --- a/pyramid/magcreator/magcreator.py +++ b/pyramid/magcreator/magcreator.py @@ -25,6 +25,7 @@ __all__ = ['create_mag_dist_homog', 'create_mag_dist_vortex', 'create_mag_dist_s 'create_mag_dist_smooth_vortex'] _log = logging.getLogger(__name__) +# TODO: generalise for scalar data? rename to fieldcreator? have subclasses vector, scalar? def create_mag_dist_homog(mag_shape, phi, theta=pi / 2): """Create a 3-dimensional magnetic distribution of a homogeneously magnetized object. diff --git a/pyramid/phasemap.py b/pyramid/phasemap.py index c666f4850584b9424241cdaaefa884a689b55960..72665d7ebb7c590ae772e27b00cee38821154b8a 100644 --- a/pyramid/phasemap.py +++ b/pyramid/phasemap.py @@ -659,14 +659,14 @@ class PhaseMap(object): vmin = np.min(phase_lim) if vmax is None: vmax = np.max(phase_lim) - # Configure colormap, to fix white to zero if colormap is symmetric: + # Configure colormap and fix white to zero if colormap is symmetric: + if cmap is None: + cmap = plt.get_cmap('RdBu') # TODO: use cmocean.cm.balance (flipped colours!) + # TODO: get default from "colors" or "plots" package + # TODO: make flexible, cmocean and matplotlib... + elif isinstance(cmap, str): # Get colormap if given as string: + cmap = plt.get_cmap(cmap) if symmetric: - if cmap is None: - cmap = plt.get_cmap('RdBu') # TODO: use cmocean.cm.balance (flipped colours!) - # TODO: get default from "colors" or "plots" package - # TODO: make flexible, cmocean and matplotlib... - elif isinstance(cmap, str): # Get colormap if given as string: - cmap = plt.get_cmap(cmap) vmin, vmax = np.min([vmin, -0]), np.max([0, vmax]) # Ensure zero is present! limit = np.max(np.abs([vmin, vmax])) start = (vmin + limit) / (2 * limit) diff --git a/pyramid/utils/pm.py b/pyramid/utils/pm.py index e438125cfcb6e669bde3d8c10d073497f321aa5b..c29379da6424b59ecd72b0b781db7a01f505ab18 100644 --- a/pyramid/utils/pm.py +++ b/pyramid/utils/pm.py @@ -13,6 +13,7 @@ from ..projector import RotTiltProjector, XTiltProjector, YTiltProjector, Simple __all__ = ['pm'] _log = logging.getLogger(__name__) +# TODO: rename magdata to vecdata everywhere! def pm(magdata, mode='z', b_0=1, mapper='RDFC', **kwargs): """Convenience function for fast magnetic phase mapping.