diff --git a/pyramid/__init__.py b/pyramid/__init__.py index e655bb04f334db0f59d8feb6c4fe3df79347e1b1..533f2bc8b163658a85f9bbc1c0175d1f03a74d61 100644 --- a/pyramid/__init__.py +++ b/pyramid/__init__.py @@ -46,7 +46,7 @@ from . import reconstruction from . import fieldconverter from . import magcreator from . import colors -from . import plottools +from . import plottools # TODO: colors and plottools into "plots" package (maybe with examples?) from . import utils from .costfunction import * from .dataset import * diff --git a/pyramid/analytic.py b/pyramid/analytic.py index 6e3131353b3887f8c5390f92e39fe8f1f2fe3eac..28b8abd717ccb43d00cda0b4ffe71391662f4db2 100644 --- a/pyramid/analytic.py +++ b/pyramid/analytic.py @@ -164,6 +164,7 @@ def phase_mag_sphere(dim, a, phi, center, radius, b_0=1): r = np.hypot(x - x0, y - y0) result = coeff * R ** 3 / (r + 1E-30) ** 2 * ( (y - y0) * np.cos(phi) - (x - x0) * np.sin(phi)) + # TODO: During testing: "RuntimeWarning: invalid value encountered in power": result *= np.where(r > R, 1, (1 - (1 - (r / R) ** 2) ** (3. / 2.))) return result diff --git a/pyramid/colors.py b/pyramid/colors.py index 6ff3918c5c70c03d2d14248bb6407a0ab992fa0f..a2fc3821b54bf6bd9d3d5af9e503d074dcc2b186 100644 --- a/pyramid/colors.py +++ b/pyramid/colors.py @@ -2,6 +2,9 @@ # Copyright 2014 by Forschungszentrum Juelich GmbH # Author: J. Caron # + +# TODO: Own small package? Use viscm (with colorspacious)? +# TODO: Also add cmoceaon "phase" colormap? Make optional (try importing, fall back to RdBu!) """This module provides a number of custom colormaps, which also have capabilities for 3D plotting. If this is the case, the :class:`~.Colormap3D` colormap class is a parent class. In `cmaps`, a number of specialised colormaps is available for convenience. If the default for circular colormaps @@ -15,7 +18,7 @@ import logging import matplotlib.pyplot as plt from matplotlib.ticker import FuncFormatter as FuncForm -from matplotlib.ticker import MaxNLocator +from matplotlib.ticker import MaxNLocator, IndexLocator, FixedLocator from mpl_toolkits.mplot3d import Axes3D from mpl_toolkits.axes_grid1 import ImageGrid @@ -32,6 +35,9 @@ import colorsys import abc +from . import plottools + + __all__ = ['Colormap3D', 'ColormapCubehelix', 'ColormapPerception', 'ColormapHLS', 'ColormapClassic', 'ColormapTransparent', 'cmaps', 'CMAP_CIRCULAR_DEFAULT', 'ColorspaceCIELab', 'ColorspaceCIELuv', 'ColorspaceCIExyY', 'ColorspaceYPbPr', @@ -65,7 +71,7 @@ class Colormap3D(colors.Colormap, metaclass=abc.ABCMeta): ---------- vector: tuple (N=3) or :class:`~numpy.ndarray` Vector containing the x, y and z component, or a numpy array encompassing the - components as three lists.z-coordinate of the desired direction to encode. + components as three lists. Returns ------- @@ -83,7 +89,7 @@ class Colormap3D(colors.Colormap, metaclass=abc.ABCMeta): # Calculate color deterministics: hue = phi / (2 * np.pi) lum = 1 - theta / np.pi - sat = r / r.max() + sat = r / (r.max() + 1E-30) # Calculate RGB from hue with colormap: rgba = np.asarray(self(hue)) r, g, b = rgba[..., 0], rgba[..., 1], rgba[..., 2] @@ -97,7 +103,8 @@ class Colormap3D(colors.Colormap, metaclass=abc.ABCMeta): # Return RGB: return np.asarray(255 * np.stack((r, g, b), axis=-1), dtype=np.uint8) - def make_colorwheel(self, size=256, alpha=1): + def make_colorwheel(self, size=256, alpha=1, bgcolor=None): + # TODO: Strange arrows are not straight... self._log.debug('Calling make_colorwheel') # Construct the colorwheel: yy, xx = (np.indices((size, size)) - size/2 + 0.5) @@ -107,11 +114,20 @@ class Colormap3D(colors.Colormap, metaclass=abc.ABCMeta): zz = np.where(rr <= size/2-2, 0, -1) # color inside, black outside aa = np.where(rr >= size/2-2, 255*alpha, 255).astype(dtype=np.uint8) rgba = np.dstack((self.rgb_from_vector(np.asarray((xx, yy, zz))), aa)) + if bgcolor: + if bgcolor == 'w': # TODO: Matplotlib get color tuples from string? + bgcolor = (1, 1, 1) + if len(bgcolor) == 3 and not isinstance(bgcolor, str): # Only you have tuple! + r, g, b = rgba[..., 0], rgba[..., 1], rgba[..., 2] + r = np.where(rr <= size / 2 - 2, r, 255*bgcolor[0]).astype(dtype=np.uint8) + g = np.where(rr <= size / 2 - 2, g, 255*bgcolor[1]).astype(dtype=np.uint8) + b = np.where(rr <= size / 2 - 2, b, 255*bgcolor[2]).astype(dtype=np.uint8) + rgba[..., 0], rgba[..., 1], rgba[..., 2] = r, g, b # Create color wheel: return Image.fromarray(rgba) - def plot_colorwheel(self, axis=None, size=512, alpha=1, arrows=False, figsize=(4, 4), - **kwargs): + def plot_colorwheel(self, axis=None, size=512, alpha=1, arrows=False, greyscale=False, + figsize=(4, 4), bgcolor=None, **kwargs): """Display a color wheel to illustrate the color coding of vector gradient directions. Parameters @@ -126,7 +142,9 @@ class Colormap3D(colors.Colormap, metaclass=abc.ABCMeta): """ self._log.debug('Calling plot_colorwheel') # Construct the colorwheel: - color_wheel = self.make_colorwheel(size=size, alpha=alpha) + color_wheel = self.make_colorwheel(size=size, alpha=alpha, bgcolor=bgcolor) + if greyscale: + color_wheel = color_wheel.convert('L') # Plot the color wheel: if axis is None: fig = plt.figure(figsize=figsize) @@ -148,6 +166,12 @@ class Colormap3D(colors.Colormap, metaclass=abc.ABCMeta): # Return axis: axis.xaxis.set_visible(False) axis.yaxis.set_visible(False) + for tic in axis.xaxis.get_major_ticks(): + tic.tick1On = tic.tick2On = False + tic.label1On = tic.label2On = False + for tic in axis.yaxis.get_major_ticks(): + tic.tick1On = tic.tick2On = False + tic.label1On = tic.label2On = False return axis @@ -281,7 +305,7 @@ class ColormapCubehelix(colors.LinearSegmentedColormap, Colormap3D): super().__init__('cubehelix', cdict, N=256) self._log.debug('Created ' + str(self)) - def plot_helix(self, figsize=(8, 8)): + def plot_helix(self, figsize=None, **kwargs): """Display the RGB and luminance plots for the chosen cubehelix. Parameters @@ -295,6 +319,8 @@ class ColormapCubehelix(colors.LinearSegmentedColormap, Colormap3D): """ self._log.debug('Calling plot_helix') + if figsize is None: + figsize = plottools.FIGSIZE_DEFAULT plt.figure(figsize=figsize) gs = gridspec.GridSpec(2, 1, height_ratios=[8, 1]) # Main plot: @@ -306,8 +332,10 @@ class ColormapCubehelix(colors.LinearSegmentedColormap, Colormap3D): axis.set_xlim(0, self.nlev) axis.set_ylim(0, 1) axis.set_title('Cubehelix', fontsize=18) - axis.set_xlabel('color index', fontsize=15) - axis.set_ylabel('lightness / rgb', fontsize=15) + axis.set_xlabel('Color index', fontsize=15) + axis.set_ylabel('Brightness / RGB', fontsize=15) + axis.xaxis.set_major_locator(FixedLocator(locs=np.linspace(0, self.nlev, 5))) + axis.yaxis.set_major_locator(FixedLocator(locs=[0, 0.5, 1])) # Colorbar horizontal: caxis = plt.subplot(gs[1], sharex=axis) rgb = self(np.linspace(0, 1, 256))[None, ...] @@ -317,6 +345,7 @@ class ColormapCubehelix(colors.LinearSegmentedColormap, Colormap3D): caxis.imshow(im) plt.tick_params(axis='both', which='both', labelleft='off', labelbottom='off', left='off', right='off', top='on', bottom='on') + return plottools.format_axis(axis, scalebar=False, keep_labels=True, **kwargs) class ColormapPerception(colors.LinearSegmentedColormap, Colormap3D): @@ -478,8 +507,10 @@ class ColorspaceCIELab(object): # TODO: Superclass? self.clip = clip self._log.debug('Created ' + str(self)) - def plot(self, L=53.4, axis=None, figsize=(8, 8)): + def plot(self, L=53.4, axis=None, figsize=None, **kwargs): self._log.debug('Calling plot') + if figsize is None: + figsize = plottools.FIGSIZE_DEFAULT dim, ext = self.dim, self.extent # Create Lab colorspace: a = np.linspace(ext[0], ext[1], dim[1]) @@ -515,12 +546,13 @@ class ColorspaceCIELab(object): # TODO: Superclass? axis.set_xlabel('a', fontsize=15) axis.set_ylabel('b', fontsize=15) axis.set_title('CIELab (L = {:g})'.format(L), fontsize=18) + axis.xaxis.set_major_locator(FixedLocator(np.linspace(0, dim[1], 5))) + axis.yaxis.set_major_locator(FixedLocator(np.linspace(0, dim[0], 5))) fx = FuncForm(lambda x, pos: '{:.3g}'.format(x / dim[1] * (ext[1] - ext[0]) + ext[0])) axis.xaxis.set_major_formatter(fx) fy = FuncForm(lambda y, pos: '{:.3g}'.format(y / dim[0] * (ext[3] - ext[2]) + ext[2])) axis.yaxis.set_major_formatter(fy) - axis.xaxis.set_major_locator(MaxNLocator(nbins=12, integer=True)) - axis.yaxis.set_major_locator(MaxNLocator(nbins=12, integer=True)) + plottools.format_axis(axis, scalebar=False, keep_labels=True, **kwargs) def plot_colormap(self, cmap, N=256, L='auto', figsize=None, cbar_lim=None, brightness=True, input_rec=None): @@ -614,8 +646,10 @@ class ColorspaceCIELuv(object): self.clip = clip self._log.debug('Created ' + str(self)) - def plot(self, L=53.4, axis=None, figsize=(8, 8)): + def plot(self, L=53.4, axis=None, figsize=None, **kwargs): self._log.debug('Calling plot') + if figsize is None: + figsize = plottools.FIGSIZE_DEFAULT dim, ext = self.dim, self.extent # Create Lab colorspace: u = np.linspace(ext[0], ext[1], dim[1]) @@ -651,12 +685,13 @@ class ColorspaceCIELuv(object): axis.set_xlabel('u', fontsize=15) axis.set_ylabel('v', fontsize=15) axis.set_title('CIELuv (L = {:g})'.format(L), fontsize=18) + axis.xaxis.set_major_locator(FixedLocator(np.linspace(0, dim[1], 5))) + axis.yaxis.set_major_locator(FixedLocator(np.linspace(0, dim[0], 5))) fx = FuncForm(lambda x, pos: '{:.3g}'.format(x / dim[1] * (ext[1] - ext[0]) + ext[0])) axis.xaxis.set_major_formatter(fx) fy = FuncForm(lambda y, pos: '{:.3g}'.format(y / dim[0] * (ext[3] - ext[2]) + ext[2])) axis.yaxis.set_major_formatter(fy) - axis.xaxis.set_major_locator(MaxNLocator(nbins=12, integer=True)) - axis.yaxis.set_major_locator(MaxNLocator(nbins=12, integer=True)) + plottools.format_axis(axis, scalebar=False, keep_labels=True, **kwargs) def plot_colormap(self, cmap, N=256, L='auto', figsize=None, cbar_lim=None, brightness=True, input_rec=None): @@ -750,8 +785,10 @@ class ColorspaceCIExyY(object): self.clip = clip self._log.debug('Created ' + str(self)) - def plot(self, Y=0.214, axis=None, figsize=(8, 8)): + def plot(self, Y=0.214, axis=None, figsize=None, **kwargs): self._log.debug('Calling plot') + if figsize is None: + figsize = plottools.FIGSIZE_DEFAULT dim, ext = self.dim, self.extent # Create Lab colorspace: x = np.linspace(ext[0], ext[1], dim[1]) @@ -788,12 +825,13 @@ class ColorspaceCIExyY(object): axis.set_xlabel('x', fontsize=15) axis.set_ylabel('y', fontsize=15) axis.set_title('CIExyY (Y = {:g})'.format(Y), fontsize=18) + axis.xaxis.set_major_locator(FixedLocator(np.linspace(0, dim[1], 5))) + axis.yaxis.set_major_locator(FixedLocator(np.linspace(0, dim[0], 5))) fx = FuncForm(lambda x, pos: '{:.3g}'.format(x / dim[1] * (ext[1] - ext[0]) + ext[0])) axis.xaxis.set_major_formatter(fx) fy = FuncForm(lambda y, pos: '{:.3g}'.format(y / dim[0] * (ext[3] - ext[2]) + ext[2])) axis.yaxis.set_major_formatter(fy) - axis.xaxis.set_major_locator(MaxNLocator(nbins=12, integer=True)) - axis.yaxis.set_major_locator(MaxNLocator(nbins=12, integer=True)) + plottools.format_axis(axis, scalebar=False, keep_labels=True, **kwargs) def plot_colormap(self, cmap, N=256, Y='auto', figsize=None, cbar_lim=None, brightness=True, input_rec=None): @@ -891,8 +929,10 @@ class ColorspaceYPbPr(object): self.clip = clip self._log.debug('Created ' + str(self)) - def plot(self, Y=0.5, axis=None, figsize=(8, 8)): + def plot(self, Y=0.5, axis=None, figsize=None, **kwargs): self._log.debug('Calling plot') + if figsize is None: + figsize = plottools.FIGSIZE_DEFAULT dim, ext = self.dim, self.extent # Create YPbPr colorspace: pb = np.linspace(ext[0], ext[1], dim[1]) @@ -925,14 +965,13 @@ class ColorspaceYPbPr(object): axis.set_xlabel('Pb', fontsize=15) axis.set_ylabel('Pr', fontsize=15) axis.set_title("Y'PbPr (Y' = {:g})".format(Y), fontsize=18) - fx = FuncForm( - lambda x, pos: '{:.3g}'.format(x / dim[1] * (ext[1] - ext[0]) + ext[0])) + axis.xaxis.set_major_locator(FixedLocator(np.linspace(0, dim[1], 5))) + axis.yaxis.set_major_locator(FixedLocator(np.linspace(0, dim[0], 5))) + fx = FuncForm(lambda x, pos: '{:.3g}'.format(x / dim[1] * (ext[1] - ext[0]) + ext[0])) axis.xaxis.set_major_formatter(fx) - fy = FuncForm( - lambda y, pos: '{:.3g}'.format(y / dim[0] * (ext[3] - ext[2]) + ext[2])) + fy = FuncForm(lambda y, pos: '{:.3g}'.format(y / dim[0] * (ext[3] - ext[2]) + ext[2])) axis.yaxis.set_major_formatter(fy) - axis.xaxis.set_major_locator(MaxNLocator(nbins=12, integer=True)) - axis.yaxis.set_major_locator(MaxNLocator(nbins=12, integer=True)) + plottools.format_axis(axis, scalebar=False, keep_labels=True, **kwargs) def plot_colormap(self, cmap, N=256, Y='auto', figsize=None, cbar_lim=None, brightness=True, input_rec=None): diff --git a/pyramid/dataset.py b/pyramid/dataset.py index 613817e4eeeb493bf9f3e316012d7a595fd6f633..408ac18973bf848dd2f754ec9254be4b1c0b3565 100644 --- a/pyramid/dataset.py +++ b/pyramid/dataset.py @@ -124,6 +124,7 @@ class DataSet(object): @property def phasemappers(self): + # TODO: get rid, only use phasemapper_dict!! """List of all PhaseMappers in the DataSet.""" return self._phasemappers @@ -166,6 +167,8 @@ class DataSet(object): assert phasemap.dim_uv == dim_uv, 'Projection dimensions (dim_uv) must match!' assert phasemap.a == self.a, 'Grid spacing must match!' # Create lookup key: + # TODO: Think again if phasemappers should be given as attribute (seems to be faulty + # TODO: currently... Also not very expensive, so keep outside? if phasemapper is not None: key = dim_uv # Create standard phasemapper, dim_uv is enough for identification! else: @@ -177,7 +180,7 @@ class DataSet(object): pass else: # Create new standard (RDFC) phasemapper: phasemapper = PhaseMapperRDFC(Kernel(self.a, dim_uv, self.b_0)) - self._phasemapper_dict[key] = phasemapper + self._phasemapper_dict[key] = phasemapper # Append everything to the lists (just contain pointers to objects!): self._phasemaps.append(phasemap) self._projectors.append(projector) @@ -215,7 +218,7 @@ class DataSet(object): # Reset the Se_inv matrix from phasemaps confidence matrices: self.set_Se_inv_diag_with_conf() - def create_phasemaps(self, magdata, difference=True, ramp=None): + def create_phasemaps(self, magdata, difference=False, ramp=None): """Create a list of phasemaps with the projectors in the dataset for a given :class:`~.VectorData` object. @@ -223,6 +226,13 @@ class DataSet(object): ---------- magdata : :class:`~.VectorData` Magnetic distribution to which the projectors of the dataset should be applied. + difference : bool, optional + If `True`, the phasemaps of the dataset are subtracted from the created ones to view + difference images. Default is False. + ramp : :class:`~.Ramp` + A ramp object, which can be specified to add a ramp to the generated phasemaps. + If `difference` is `True`, this can be interpreted as ramp correcting the phasemaps + saved in the dataset. Returns ------- @@ -238,8 +248,8 @@ class DataSet(object): if difference: phasemap -= self.phasemaps[i] if ramp is not None: - assert type(ramp) == Ramp, 'correct_ramp has to be a Ramp object!' - phasemap += ramp(index=i) + assert type(ramp) == Ramp, 'ramp has to be a Ramp object!' + phasemap += ramp(index=i) # Full formula: phasemap -= phasemap_dataset - ramp phasemap.mask = mag_proj.get_mask()[0, ...] phasemaps.append(phasemap) return phasemaps @@ -282,6 +292,7 @@ class DataSet(object): self.set_Se_inv_block_diag(cov_list) def set_3d_mask(self, mask_list=None, threshold=0.9): + # TODO: This function should be in a separate module and not here (maybe?)! """Set the 3D mask from a list of 2D masks. Parameters @@ -328,7 +339,7 @@ class DataSet(object): from .file_io.io_dataset import save_dataset save_dataset(self, filename, overwrite) - def plot_mask(self, **kwargs): + def plot_mask(self): """If it exists, display the 3D mask of the magnetization distribution. Returns @@ -338,7 +349,7 @@ class DataSet(object): """ self._log.debug('Calling plot_mask') if self.mask is not None: - return ScalarData(self.a, self.mask).plot_mask(**kwargs) + return ScalarData(self.a, self.mask).plot_mask() def plot_phasemaps(self, magdata=None, title='Phase Map', difference=False, ramp=None, **kwargs): @@ -352,6 +363,13 @@ class DataSet(object): title : string, optional The main part of the title of the plots. The default is 'Phase Map'. Additional projector info is appended to this. + difference : bool, optional + If `True`, the phasemaps of the dataset are subtracted from the created ones to view + difference images. Default is False. + ramp : :class:`~.Ramp` + A ramp object, which can be specified to add a ramp to the generated phasemaps. + If `magdata` is not given, this will instead just ramp correct the phasemaps saved + in the dataset. Returns ------- @@ -359,12 +377,17 @@ class DataSet(object): """ self._log.debug('Calling plot_phasemaps') - if magdata is not None: + if magdata is not None: # Plot phasemaps of the given magnetisation distribution: phasemaps = self.create_phasemaps(magdata, difference=difference, ramp=ramp) - else: + else: # Plot phasemaps saved in the DataSet (default): phasemaps = self.phasemaps - [phasemap.plot_phase('{} ({})'.format(title, self.projectors[i].get_info()), **kwargs) - for (i, phasemap) in enumerate(phasemaps)] + if ramp is not None: + for i, phasemap in enumerate(phasemaps): + assert type(ramp) == Ramp, 'ramp has to be a Ramp object!' + phasemap -= ramp(index=i) # Ramp correction + for (i, phasemap) in enumerate(phasemaps): + phasemap.plot_phase(note='{} ({})'.format(title, self.projectors[i].get_info()), + **kwargs) def plot_phasemaps_combined(self, magdata=None, title='Combined Plot', difference=False, ramp=None, **kwargs): @@ -377,9 +400,13 @@ class DataSet(object): given, the phasemaps in the dataset are used. title : string, optional The title of the plot. The default is 'Combined Plot'. - cmap : string, optional - The :class:`~matplotlib.colors.Colormap` which is used for the plot as a string. - The default is 'RdBu'. + difference : bool, optional + If `True`, the phasemaps of the dataset are subtracted from the created ones to view + difference images. Default is False. + ramp : :class:`~.Ramp` + A ramp object, which can be specified to add a ramp to the generated phasemaps. + If `magdata` is not given, this will instead just ramp correct the phasemaps saved + in the dataset. Returns ------- @@ -391,7 +418,11 @@ class DataSet(object): phasemaps = self.create_phasemaps(magdata, difference=difference, ramp=ramp) else: phasemaps = self.phasemaps + if ramp is not None: + for i, phasemap in enumerate(phasemaps): + assert type(ramp) == Ramp, 'ramp has to be a Ramp object!' + phasemap -= ramp(index=i) # Ramp correction for (i, phasemap) in enumerate(phasemaps): - phasemap.plot_combined('{} ({})'.format(title, self.projectors[i].get_info()), + phasemap.plot_combined(note='{} ({})'.format(title, self.projectors[i].get_info()), **kwargs) plt.show() diff --git a/pyramid/diagnostics.py b/pyramid/diagnostics.py index c714042635905a4380202548865382242e3c9bfa..047b45aff795205733770f68c59f246dfae28d04 100644 --- a/pyramid/diagnostics.py +++ b/pyramid/diagnostics.py @@ -247,6 +247,8 @@ class Diagnostics(object): lx, rx = _calc_lr(0) ly, ry = _calc_lr(1) lz, rz = _calc_lr(2) + + # TODO: Test if FWHM is really calculated with a in mind... didn't seem so... fwhm_x = (rx - lx) * a fwhm_y = (ry - ly) * a fwhm_z = (rz - lz) * a @@ -381,6 +383,87 @@ class Diagnostics(object): artist.set_path_effects([patheffects.withStroke(linewidth=4, foreground='k', alpha=0.5)]) + def plot_avrg_kern_field3d(self, pos=None, mask=True, ellipsoid=True, **kwargs): + avrg_kern_field = self.get_avrg_kern_field(pos) + avrg_kern_field.plot_mask(color=(1, 1, 1), opacity=0.15, labels=False, grid=False, + orientation=False) + avrg_kern_field.plot_quiver3d(**kwargs, new_fig=False) + fwhm = self.calculate_fwhm()[0] + from mayavi.sources.api import ParametricSurface + from mayavi.modules.api import Surface + from mayavi import mlab + engine = mlab.get_engine() + scene = engine.scenes[0] + scene.scene.disable_render = True # for speed # TODO: EVERYWHERE WITH MAYAVI! + # TODO: from enthought.mayavi import mlab + # TODO: f = mlab.figure() # returns the current scene + # TODO: engine = mlab.get_engine() # returns the running mayavi engine + source = ParametricSurface() + source.function = 'ellipsoid' + engine.add_source(source) + surface = Surface() + source.add_module(surface) + + actor = surface.actor # mayavi actor, actor.actor is tvtk actor + # actor.property.ambient = 1 # defaults to 0 for some reason, ah don't need it, turn off scalar visibility instead + actor.property.opacity = 0.5 + actor.property.color = (0, 0, 0) + actor.mapper.scalar_visibility = False # don't colour ellipses by their scalar indices into colour map + actor.property.backface_culling = True # gets rid of rendering artifact when opacity is < 1 + # actor.property.frontface_culling = True + actor.actor.orientation = [0, 0, 0] # in degrees + actor.actor.origin = (0, 0, 0) + actor.actor.position = (self.pos[1]+0.5, self.pos[2]+0.5, self.pos[3]+0.5) + a = self.magdata.a + actor.actor.scale = [0.5*fwhm[0]/a, 0.5*fwhm[1]/a, 0.5*fwhm[2]/a] + + #surface.append(surface) + + + scene.scene.disable_render = False # now turn it on # TODO: EVERYWHERE WITH MAYAVI! + + + def plot_avrg_kern_field_3d_to_2d(self, dim_uv=None, axis=None, figsize=None, high_res=False, + **kwargs): + # TODO: 3d_to_2d into plottools and make available for all 3D plots if possible! + import tempfile + from PIL import Image + import os + from . import plottools + from mayavi import mlab + if figsize is None: + figsize = plottools.FIGSIZE_DEFAULT + if axis is None: + self._log.debug('axis is None') + fig = plt.figure(figsize=figsize) + axis = fig.add_subplot(1, 1, 1) + axis.set_axis_bgcolor('gray') + kwargs.setdefault('labels', 'False') + #avrg_kern_field = self.get_avrg_kern_field() + #avrg_kern_field.plot_quiver3d(**kwargs) + self.plot_avrg_kern_field3d(**kwargs) + if high_res: # Use temp files: + tmpdir = tempfile.mkdtemp() + temp_path = os.path.join(tmpdir, 'temp.png') + try: + mlab.savefig(temp_path, size=(2000, 2000)) + imgmap = np.asarray(Image.open(temp_path)) + except Exception as e: + raise e + finally: + os.remove(temp_path) + os.rmdir(tmpdir) + else: # Use screenshot (returns array WITH alpha!): + imgmap = mlab.screenshot(mode='rgba', antialiased=True) + mlab.close(mlab.gcf()) + if dim_uv is None: + dim_uv = self.dim[1:] + axis.imshow(imgmap, extent=[0, dim_uv[0], 0, dim_uv[1]], origin='upper') + kwargs.setdefault('scalebar', False) + kwargs.setdefault('hideaxes', True) + return plottools.format_axis(axis, hideaxes=True, scalebar=False) + + def get_vector_field_errors(vector_data, vector_data_ref): """After Kemp et. al.: Analysis of noise-induced errors in vector-field electron tomography""" v, vr = vector_data.field, vector_data_ref.field diff --git a/pyramid/fielddata.py b/pyramid/fielddata.py index d4a57ce1765b7df1adf82f461856d476fb0bf971..f295b4840326367ff9e2c89aafda8dac337c1c33 100644 --- a/pyramid/fielddata.py +++ b/pyramid/fielddata.py @@ -228,7 +228,8 @@ class FieldData(object, metaclass=abc.ABCMeta): self._log.debug('Calling get_mask') return np.where(self.field_amp > threshold, True, False) - def plot_mask(self, title='Mask', threshold=0, **kwargs): + def plot_mask(self, title='Mask', threshold=0, grid=True, labels=True, + orientation=True, figsize=None, new_fig=True, **kwargs): """Plot the mask as a 3D-contour plot. Parameters @@ -246,20 +247,31 @@ class FieldData(object, metaclass=abc.ABCMeta): """ self._log.debug('Calling plot_mask') from mayavi import mlab - mlab.figure(size=(750, 700)) + if figsize is None: + figsize = (750, 700) + if new_fig: + mlab.figure(size=figsize, bgcolor=(0.5, 0.5, 0.5), fgcolor=(0., 0., 0.)) zzz, yyy, xxx = (np.indices(self.dim) + self.a / 2) zzz, yyy, xxx = zzz.T, yyy.T, xxx.T mask = self.get_mask(threshold=threshold).astype(int).T # Transpose because of VTK order! extent = np.ravel(list(zip((0, 0, 0), mask.shape))) cont = mlab.contour3d(xxx, yyy, zzz, mask, contours=[1], **kwargs) - mlab.outline(cont, extent=extent) - mlab.axes(cont, extent=extent) - mlab.title(title, height=0.95, size=0.35) - mlab.orientation_axes() - cont.scene.isometric_view() + if grid: + mlab.outline(cont, extent=extent) + if labels: + mlab.axes(cont, extent=extent) + mlab.title(title, height=0.95, size=0.35) + if orientation: + oa = mlab.orientation_axes() + oa.marker.set_viewport(0, 0, 0.4, 0.4) + mlab.draw() + engine = mlab.get_engine() + scene = engine.scenes[0] + scene.scene.isometric_view() return cont - def plot_contour3d(self, title='Field Distribution', contours=10, opacity=0.25, **kwargs): + def plot_contour3d(self, title='Field Distribution', contours=10, opacity=0.25, + new_fig=True, **kwargs): # TODO: new_fig or hold in every mayavi plot! """Plot the field as a 3D-contour plot. Parameters @@ -279,7 +291,8 @@ class FieldData(object, metaclass=abc.ABCMeta): """ self._log.debug('Calling plot_contour3d') from mayavi import mlab - mlab.figure(size=(750, 700)) + if new_fig: + mlab.figure(size=(750, 700)) zzz, yyy, xxx = (np.indices(self.dim) + self.a / 2) zzz, yyy, xxx = zzz.T, yyy.T, xxx.T field_amp = self.field_amp.T # Transpose because of VTK order! @@ -455,6 +468,9 @@ class VectorData(FieldData): """ _log = logging.getLogger(__name__ + '.VectorData') + def __getitem__(self, item): + return self.__class__(self.a, self.field[item]) + def scale_down(self, n=1): """Scale down the field distribution by averaging over two pixels along each axis. @@ -685,6 +701,7 @@ class VectorData(FieldData): return VectorData(self.a, field_rot) def get_slice(self, ax_slice=None, proj_axis='z'): + # TODO: return x y z instead of u v w (to color fields consistent with xyz!) """Extract a slice from the :class:`~.VectorData` object. Parameters @@ -719,9 +736,14 @@ class VectorData(FieldData): w_mag = np.copy(self.field[1][:, ax_slice, :]) # y-component elif proj_axis == 'x': # Slice of the zy-plane with x = ax_slice self._log.debug('proj_axis == x') - u_mag = np.swapaxes(np.copy(self.field[2][..., ax_slice]), 0, 1) # z-component - v_mag = np.swapaxes(np.copy(self.field[1][..., ax_slice]), 0, 1) # y-component - w_mag = np.swapaxes(np.copy(self.field[0][..., ax_slice]), 0, 1) # x-component + # TODO: Strange swapaxes, really necessary? Get rid EVERYWHERE if possible! + #u_mag = np.swapaxes(np.copy(self.field[2][..., ax_slice]), 0, 1) # z-component + #v_mag = np.swapaxes(np.copy(self.field[1][..., ax_slice]), 0, 1) # y-component + #w_mag = np.swapaxes(np.copy(self.field[0][..., ax_slice]), 0, 1) # x-component + # TODO: z should be special and always along y in 2D if possible!! + u_mag = np.copy(self.field[1][..., ax_slice]) # y-component + v_mag = np.copy(self.field[2][..., ax_slice]) # z-component + w_mag = np.copy(self.field[0][..., ax_slice]) # x-component else: raise ValueError('{} is not a valid argument (use x, y or z)'.format(proj_axis)) return u_mag, v_mag, w_mag @@ -776,7 +798,7 @@ class VectorData(FieldData): from .file_io.io_vectordata import save_vectordata save_vectordata(self, filename, **kwargs) - def plot_quiver(self, ar_dens=1, log=False, scaled=True, scale=1., b_0=None, # Only used here! + def plot_quiver(self, ar_dens=1, log=False, scaled=True, scale=1., b_0=None, qkey_unit='T', coloring='angle', cmap=None, # Used here and plot_streamlines! proj_axis='z', ax_slice=None, show_mask=True, bgcolor=None, axis=None, figsize=None, **kwargs): @@ -848,6 +870,7 @@ class VectorData(FieldData): u_mag = u_mag[::ar_dens, ::ar_dens] v_mag = v_mag[::ar_dens, ::ar_dens] amplitudes = np.hypot(u_mag, v_mag) + # TODO: Delete if only used in log: angles = np.angle(u_mag + 1j * v_mag, deg=True).tolist() # Calculate the arrow colors: if bgcolor is None: @@ -888,6 +911,7 @@ class VectorData(FieldData): tight = False axis.set_aspect('equal') # Take the logarithm of the arrows to clearly show directions (if specified): + # TODO: get rid of log!! (only problems...) if log and np.any(amplitudes): # If the slice is empty, skip! cutoff = 10 amp = np.round(amplitudes, decimals=cutoff) @@ -903,7 +927,7 @@ class VectorData(FieldData): v_mag /= amplitudes.max() + 1E-30 # Plot quiver: # TODO: quiver does not work with matplotlib 2.0! FIX! - quiv = axis.quiver(uu, vv, u_mag, v_mag, hue, cmap=cmap, clim=(0, 1), angles=angles, + quiv = axis.quiver(uu, vv, u_mag, v_mag, hue, cmap=cmap, clim=(0, 1), #angles=angles, pivot='middle', units='xy', scale_units='xy', scale=scale / ar_dens, minlength=0.05, width=1*ar_dens, headlength=2, headaxislength=2, headwidth=2, minshaft=2) @@ -915,7 +939,7 @@ class VectorData(FieldData): else: cbar_mappable, cbar_label = None, None # Change background color: - axis.set_axis_bgcolor(bgcolor) + axis.set_facecolor(bgcolor) # Show mask: if show_mask and not np.all(submask): # Plot mask if desired and not trivial! vv, uu = np.indices(dim_uv) + 0.5 # shift to center of pixel @@ -924,7 +948,7 @@ class VectorData(FieldData): linestyles='dotted', linewidths=2) # Plot quiverkey if B_0 is specified): if b_0 and not log: # The angles needed for log would break the quiverkey! - label = '{:.3g} T'.format(amplitudes.max() * b_0) + label = '{:.3g} {}'.format(amplitudes.max() * b_0, qkey_unit) quiv.angles = 'uv' # With a list of angles, the quiverkey would break! stroke = plottools.STROKE_DEFAULT txtcolor = 'w' if stroke == 'k' else 'k' @@ -985,6 +1009,17 @@ class VectorData(FieldData): # Extract slice and mask: u_mag, v_mag, w_mag = self.get_slice(ax_slice, proj_axis) submask = np.where(np.hypot(u_mag, v_mag) > 0, True, False) + # TODO: Before you fix this, fix get_slice!!! (should be easy...) + # TODO: return x y z instead of u v w (to color fields consistent with xyz!) + # TODO: maybe best to get colors of slice separately from u and v components!!! + # TODO: vector_to_rgb does already exist!! + if proj_axis == 'z': # Slice of the xy-plane with z = ax_slice + u_mag, v_mag, w_mag = u_mag, v_mag, w_mag + elif proj_axis == 'y': # Slice of the xz-plane with y = ax_slice + u_mag, v_mag, w_mag = u_mag, w_mag, v_mag + elif proj_axis == 'x': # Slice of the zy-plane with x = ax_slice + u_mag, v_mag, w_mag = w_mag, u_mag, v_mag + # TODO: END! # If no axis is specified, a new figure is created: if axis is None: self._log.debug('axis is None') @@ -1007,7 +1042,7 @@ class VectorData(FieldData): extent=(0, dim_uv[1], 0, dim_uv[0])) # Change background color: if bgcolor is not None: - axis.set_axis_bgcolor(bgcolor) + axis.set_facecolor(bgcolor) # Show mask: if show_mask and not np.all(submask): # Plot mask if desired and not trivial! vv, uu = np.indices(dim_uv) + 0.5 # shift to center of pixel @@ -1164,7 +1199,8 @@ class VectorData(FieldData): def plot_quiver3d(self, title='Vector Field', limit=None, cmap='jet', mode='2darrow', coloring='angle', ar_dens=1, opacity=1.0, grid=True, labels=True, - orientation=True, figsize=None): + orientation=True, figsize=None, new_fig=True, view='isometric', + position=None, bgcolor=(0.5, 0.5, 0.5)): """Plot the vector field as 3D-vectors in a quiverplot. Parameters @@ -1198,7 +1234,7 @@ class VectorData(FieldData): limit = np.max(np.nan_to_num(self.field_amp)) ad = ar_dens # Create points and vector components as lists: - zzz, yyy, xxx = (np.indices(self.dim) + self.a / 2) + zzz, yyy, xxx = (np.indices(self.dim) + 1 / 2) zzz = zzz[::ad, ::ad, ::ad].ravel() yyy = yyy[::ad, ::ad, ::ad].ravel() xxx = xxx[::ad, ::ad, ::ad].ravel() @@ -1208,7 +1244,8 @@ class VectorData(FieldData): # Plot them as vectors: if figsize is None: figsize = (750, 700) - mlab.figure(size=figsize, bgcolor=(0.5, 0.5, 0.5), fgcolor=(0., 0., 0.)) + if new_fig: + mlab.figure(size=figsize, bgcolor=bgcolor, fgcolor=(0., 0., 0.)) extent = np.ravel(list(zip((0, 0, 0), (self.dim[2], self.dim[1], self.dim[0])))) if coloring == 'angle': # Encodes the full angle via colorwheel and saturation: self._log.debug('Encoding full 3D angles') @@ -1241,10 +1278,21 @@ class VectorData(FieldData): mlab.draw() engine = mlab.get_engine() scene = engine.scenes[0] - scene.scene.isometric_view() + if view == 'isometric': + scene.scene.isometric_view() + elif view == 'x_plus_view': + scene.scene.x_plus_view() + elif view == 'y_plus_view': + scene.scene.y_plus_view() + if position: + scene.scene.camera.position = position return vecs def plot_quiver3d_to_2d(self, dim_uv=None, axis=None, figsize=None, high_res=False, **kwargs): + # TODO: into plottools and make available for all 3D plots if possible! + kwargs.setdefault('labels', False) + kwargs.setdefault('orientation', False) + kwargs.setdefault('bgcolor', (0.7, 0.7, 0.7)) from mayavi import mlab if figsize is None: figsize = plottools.FIGSIZE_DEFAULT @@ -1252,8 +1300,7 @@ class VectorData(FieldData): self._log.debug('axis is None') fig = plt.figure(figsize=figsize) axis = fig.add_subplot(1, 1, 1) - axis.set_axis_bgcolor('gray') - kwargs.setdefault('labels', 'False') + axis.set_axis_bgcolor(kwargs['bgcolor']) self.plot_quiver3d(figsize=(800, 800), **kwargs) if high_res: # Use temp files: tmpdir = tempfile.mkdtemp() @@ -1471,3 +1518,5 @@ class ScalarData(FieldData): """ from .file_io.io_scalardata import save_scalardata save_scalardata(self, filename, **kwargs) + +# TODO: Histogram plots for magnetisation (see thesis!) diff --git a/pyramid/file_io/io_dataset.py b/pyramid/file_io/io_dataset.py index 3c4be310df07f02b8df06a0c7672c2a74a427447..041ed38c5b435382ce63d788984030a3086045aa 100644 --- a/pyramid/file_io/io_dataset.py +++ b/pyramid/file_io/io_dataset.py @@ -12,6 +12,8 @@ import h5py import numpy as np +import scipy as sp + from ..dataset import DataSet from ..file_io.io_projector import load_projector from ..file_io.io_phasemap import load_phasemap @@ -38,7 +40,7 @@ def save_dataset(dataset, filename, overwrite=True): if dataset.mask is not None: f.create_dataset('mask', data=dataset.mask) if dataset.Se_inv is not None: - f.create_dataset('Se_inv', data=dataset.Se_inv) + f.create_dataset('Se_inv', data=dataset.Se_inv.diagonal()) # Save only diagonal! # PhaseMaps and Projectors: for i, projector in enumerate(dataset.projectors): projector_name = 'projector_{}_{}_{}{}'.format(name, i, projector.get_info(), extension) @@ -82,7 +84,8 @@ def load_dataset(filename): dim = f.attrs.get('dim') b_0 = f.attrs.get('b_0') mask = np.copy(f.get('mask', None)) - Se_inv = np.copy(f.get('Se_inv', None)) + Se_inv_diag = np.copy(f.get('Se_inv', None)) + Se_inv = sp.sparse.diags(Se_inv_diag).tocsr() dataset = DataSet(a, dim, b_0, mask, Se_inv) # Projectors: projectors = [] diff --git a/pyramid/file_io/io_phasemap.py b/pyramid/file_io/io_phasemap.py index 9a98ae5845740f4c980ed43192e67201576d16a8..bb5af68f66499063f3b8165967d64d74f237d3f0 100644 --- a/pyramid/file_io/io_phasemap.py +++ b/pyramid/file_io/io_phasemap.py @@ -18,7 +18,8 @@ __all__ = ['load_phasemap'] _log = logging.getLogger(__name__) -def load_phasemap(filename, mask=None, confidence=None, a=None, **kwargs): +def load_phasemap(filename, mask=None, confidence=None, a=None, threshold=0, + print_mask_limits=False, **kwargs): """Load supported file into a :class:`~pyramid.phasemap.PhaseMap` instance. The function loads the file according to the extension: @@ -59,7 +60,11 @@ def load_phasemap(filename, mask=None, confidence=None, a=None, **kwargs): phasemap = _load(filename, as_phasemap=True, a=a, **kwargs) if mask is not None: filemask, kwargs_mask = _parse_add_param(mask) - phasemap.mask = _load(filemask, **kwargs_mask) + mask_raw = _load(filemask, **kwargs_mask) + if print_mask_limits: + print('[Mask] min:', mask_raw.min(), 'max:', mask_raw.max(), 'threshold:', threshold) + mask = np.where(mask_raw > threshold, True, False) + phasemap.mask = mask if confidence is not None: fileconf, kwargs_conf = _parse_add_param(confidence) phasemap.confidence = _load(fileconf, **kwargs_conf) diff --git a/pyramid/file_io/io_vectordata.py b/pyramid/file_io/io_vectordata.py index ca02f828116c823cee34f83ce2448776b13fc9d5..1c2daa11f87354f6a25d347027c4488af7c31cc0 100644 --- a/pyramid/file_io/io_vectordata.py +++ b/pyramid/file_io/io_vectordata.py @@ -194,8 +194,8 @@ def _save_to_llg(vectordata, filename): data = np.array([xx, yy, zz, x_vec, y_vec, z_vec]).T # Save data to file: with open(filename, 'w') as mag_file: - mag_file.write('LLGFileCreator: %s\n'.format(filename)) - mag_file.write(' %d %d %d\n'.format(np.asarray(dim)[::-1])) + mag_file.write('LLGFileCreator: {:s}\n'.format(filename)) + mag_file.write(' {:d} {:d} {:d}\n'.format(*dim)) mag_file.writelines('\n'.join(' '.join('{:7.6e}'.format(cell) for cell in row) for row in data)) diff --git a/pyramid/magcreator/examples.py b/pyramid/magcreator/examples.py index 68cf9517775049f3342985b5dea85759a1c050c7..987a7cba00bdfa92f5253c85479aef7912831ec3 100644 --- a/pyramid/magcreator/examples.py +++ b/pyramid/magcreator/examples.py @@ -73,7 +73,7 @@ def homog_slab(a=1., dim=(32, 32, 32), center=None, width=None, phi=np.pi/4, the def homog_disc(a=1., dim=(32, 32, 32), center=None, radius=None, height=None, - phi=np.pi / 4, theta=np.pi / 4): + phi=np.pi / 4, theta=np.pi / 4, axis='z'): """Create homogeneous disc magnetisation distribution.""" _log.debug('Calling homog_disc') if center is None: @@ -82,7 +82,7 @@ def homog_disc(a=1., dim=(32, 32, 32), center=None, radius=None, height=None, radius = dim[2] // 4 if height is None: height = np.max((dim[0] // 2, 1)) - mag_shape = shapes.disc(dim, center, radius, height) + mag_shape = shapes.disc(dim, center, radius, height, axis) return VectorData(a, mc.create_mag_dist_homog(mag_shape, phi, theta)) @@ -183,7 +183,7 @@ def vortex_disc(a=1., dim=(32, 32, 32), center=None, radius=None, height=None, a radius = dim[2] // 4 if height is None: height = np.max((dim[0] // 2, 1)) - mag_shape = shapes.disc(dim, center, radius, height, axis) + mag_shape = shapes.disc(dim, center, radius, height, axis.replace('-', '')) magnitude = mc.create_mag_dist_vortex(mag_shape, center, axis) return VectorData(a, magnitude) @@ -236,7 +236,7 @@ def vortex_horseshoe(a=1., dim=(16, 64, 64), center=None, radius_core=None, def smooth_vortex_disc(a=1., dim=(32, 32, 32), center=None, radius=None, height=None, axis='z', - vortex_radius=None): + core_r=0, vortex_radius=None): """Create smooth vortex disc magnetisation distribution.""" _log.debug('Calling vortex_disc') if center is None: @@ -247,8 +247,8 @@ def smooth_vortex_disc(a=1., dim=(32, 32, 32), center=None, radius=None, height= height = np.max((dim[0] // 2, 1)) if vortex_radius is None: vortex_radius = radius // 2 - mag_shape = shapes.disc(dim, center, radius, height, axis) - magnitude = mc.create_mag_dist_smooth_vortex(mag_shape, center, vortex_radius, axis) + mag_shape = shapes.disc(dim, center, radius, height, axis.replace('-', '')) # same for +/- + magnitude = mc.create_mag_dist_smooth_vortex(mag_shape, center, vortex_radius, core_r, axis) return VectorData(a, magnitude) diff --git a/pyramid/magcreator/magcreator.py b/pyramid/magcreator/magcreator.py index 1168b34c23fdc02cdf64ea5dcb864f02d1594d5a..d74da1aa9c93ab68223d84811e5d9ee5e402d2c5 100644 --- a/pyramid/magcreator/magcreator.py +++ b/pyramid/magcreator/magcreator.py @@ -129,7 +129,7 @@ def create_mag_dist_vortex(mag_shape, center=None, axis='z'): return np.array([x_mag, y_mag, z_mag]) -def create_mag_dist_smooth_vortex(mag_shape, center=None, vort_r=None, axis='z'): +def create_mag_dist_smooth_vortex(mag_shape, center=None, vort_r=None, core_r=0, axis='z'): """Create a 3-dimensional magnetic distribution of a homogeneous magnetized object. Parameters @@ -160,7 +160,8 @@ def create_mag_dist_smooth_vortex(mag_shape, center=None, vort_r=None, axis='z') def core(r): """Function describing the smooth vortex core.""" - return 1 - 2/np.pi * np.arcsin(np.tanh(np.pi*r/vort_r)) + r_clip = np.clip(r - core_r, a_min=0, a_max=None) + return 1 - 2/np.pi * np.arcsin(np.tanh(np.pi*r_clip/vort_r)) _log.debug('Calling create_mag_dist_vortex') dim = mag_shape.shape diff --git a/pyramid/phasemap.py b/pyramid/phasemap.py index c443c782c7c12caebf2178749062218adbc968e9..222cc96ef8d1dc52b834faf3efe3f60e0dc18b1b 100644 --- a/pyramid/phasemap.py +++ b/pyramid/phasemap.py @@ -19,6 +19,8 @@ from matplotlib.ticker import MaxNLocator from mpl_toolkits.mplot3d import Axes3D +import cmocean + from scipy import ndimage import warnings @@ -134,7 +136,7 @@ class PhaseMap(object): assert confidence.shape == self.phase.shape, \ 'Confidence and phase dimensions must match!' confidence = confidence.astype(dtype=np.float32) - confidence /= confidence.max() # Normalise! + confidence /= confidence.max() + 1E-30 # Normalise! else: confidence = np.ones_like(self.phase, dtype=np.float32) self._confidence = confidence @@ -170,7 +172,7 @@ class PhaseMap(object): assert other.phase.shape == self.dim_uv, \ 'Added field has to have the same dimensions!' mask_comb = np.logical_or(self.mask, other.mask) # masks combine - conf_comb = (self.confidence + other.confidence) / 2 # confidence averaged! + conf_comb = np.minimum(self.confidence, other.confidence) # use minimum confidence! return PhaseMap(self.a, self.phase + other.phase, mask_comb, conf_comb) else: # other is a Number self._log.debug('Adding an offset') @@ -660,7 +662,9 @@ class PhaseMap(object): # Configure colormap, to fix white to zero if colormap is symmetric: if symmetric: if cmap is None: - cmap = plt.get_cmap('RdBu') + cmap = plt.get_cmap('RdBu') # TODO: use cmocean.cm.balance (flipped colours!) + # TODO: get default from "colors" or "plots" package + # TODO: make flexible, cmocean and matplotlib... elif isinstance(cmap, str): # Get colormap if given as string: cmap = plt.get_cmap(cmap) vmin, vmax = np.min([vmin, -0]), np.max([0, vmax]) # Ensure zero is present! diff --git a/pyramid/phasemapper.py b/pyramid/phasemapper.py index ed3f9071ecd8da65f97651795d6f66b860ed8f58..9393b3f95aaa977b1a25e41d6caf70b2e484d18f 100644 --- a/pyramid/phasemapper.py +++ b/pyramid/phasemapper.py @@ -411,7 +411,7 @@ class PhaseMapperMIP(PhaseMapper): class PhaseMapperCharge(PhaseMapper): - """""" + """""" # TODO: Write Docstring! def __init__(self, a, dim_uv, electrode_vec, v_acc=300000): self._log.debug('Calling __init__') diff --git a/pyramid/plottools.py b/pyramid/plottools.py index e656283df493f867f7a9d70b431ab551a4a49949..896feb64d4e21ea8e43ae7649b55fefdf06003ed 100644 --- a/pyramid/plottools.py +++ b/pyramid/plottools.py @@ -14,6 +14,7 @@ from matplotlib.patches import Rectangle from matplotlib import patheffects from matplotlib.ticker import MaxNLocator, FuncFormatter +from mpl_toolkits.axes_grid.inset_locator import inset_axes from mpl_toolkits.axes_grid1 import make_axes_locatable import warnings @@ -23,7 +24,7 @@ from . import colors __all__ = ['format_axis', 'pretty_plots', 'add_scalebar', 'add_annotation', 'add_colorwheel', 'add_cbar'] -FIGSIZE_DEFAULT = (6.7, 5) +FIGSIZE_DEFAULT = (6.7, 5) # TODO: Apparently does not fit as well as before... FONTSIZE_DEFAULT = 20 STROKE_DEFAULT = None @@ -70,6 +71,7 @@ def add_scalebar(axis, sampling=1, fontsize=None, stroke=None): The box containing the scalebar. """ + # TODO: Background (black outline) not visible... if fontsize is None: fontsize = FONTSIZE_DEFAULT if stroke is None: @@ -166,7 +168,8 @@ def add_colorwheel(axis): inset_axes = inset_axes(axis, width=0.75, height=0.75, loc=1) inset_axes.axis('off') cmap = colors.CMAP_CIRCULAR_DEFAULT - return cmap.plot_colorwheel(size=100, axis=inset_axes, alpha=0, arrows=True) + bgcolor = axis.get_facecolor() + return cmap.plot_colorwheel(size=100, axis=inset_axes, alpha=0, bgcolor=bgcolor, arrows=True) def add_cbar(axis, mappable, label='', fontsize=None): @@ -219,10 +222,37 @@ def add_cbar(axis, mappable, label='', fontsize=None): plt.sca(axis) return cbar +def add_coords(axis, coords=('x', 'y')): + ins_ax = inset_axes(axis, width="5%", height="5%", loc=3, borderpad=2.2) + if coords == 3: + coords = ('x', 'y', 'z') + elif coords == 2: + coords = ('x', 'y') + if len(coords) == 3: + ins_ax.arrow(0.5, 0.45, -1.05, -0.75, fc="k", ec="k", + head_width=0.2, head_length=0.3, linewidth=3, clip_on=False) + ins_ax.arrow(0.5, 0.45, 0.96, -0.75, fc="k", ec="k", + head_width=0.2, head_length=0.3, linewidth=3, clip_on=False) + ins_ax.arrow(0.5, 0.45, 0, 1.35, fc="k", ec="k", + head_width=0.2, head_length=0.3, linewidth=3, clip_on=False) + ins_ax.annotate(coords[0], xy=(0, 0), xytext=(-0.9, 0), fontsize=20, clip_on=False) + ins_ax.annotate(coords[1], xy=(0, 0), xytext=(1.4, 0.1), fontsize=20, clip_on=False) + ins_ax.annotate(coords[2], xy=(0, 0), xytext=(0.7, 1.5), fontsize=20, clip_on=False) + elif len(coords) == 2: + ins_ax.arrow(-0.5, -0.5, 1.5, 0, fc="k", ec="k", + head_width=0.2, head_length=0.3, linewidth=3, clip_on=False) + ins_ax.arrow(-0.5, -0.5, 0, 1.5, fc="k", ec="k", + head_width=0.2, head_length=0.3, linewidth=3, clip_on=False) + ins_ax.annotate(coords[0], xy=(0, 0), xytext=(1.3, -0.05), fontsize=20, clip_on=False) + ins_ax.annotate(coords[1], xy=(0, 0), xytext=(-0.2, 1.3), fontsize=20, clip_on=False) + ins_ax.axis('off') + plt.sca(axis) + +# TODO: These parameters in other plot functions belong in a dedicated dictionary!!! def format_axis(axis, format_axis=True, title='', fontsize=None, stroke=None, scalebar=True, - hideaxes=None, sampling=1, note=None, colorwheel=False, cbar_mappable=None, - cbar_label='', tight_layout=True, **_): + hideaxes=None, sampling=1, note=None, colorwheel=False, cbar_mappable=None, + cbar_label='', tight_layout=True, keep_labels=False, coords=None, **_): """Format an axis and add a lot of nice features. Parameters @@ -269,26 +299,28 @@ def format_axis(axis, format_axis=True, title='', fontsize=None, stroke=None, sc fontsize = FONTSIZE_DEFAULT if stroke is None: stroke = STROKE_DEFAULT - # Get dimensions: - bb0 = axis.transLimits.inverted().transform((0, 0)) - bb1 = axis.transLimits.inverted().transform((1, 1)) - dim_uv = (int(abs(bb1[1] - bb0[1])), int(abs(bb1[0] - bb0[0]))) - # Set the title and the axes labels: - axis.set_xlim(0, dim_uv[1]) - axis.set_ylim(0, dim_uv[0]) - axis.set_title(title, fontsize=fontsize) # Add scalebar: if scalebar: add_scalebar(axis, sampling=sampling, fontsize=fontsize, stroke=stroke) if hideaxes is None: hideaxes = True - # Determine major tick locations (useful for grid, even if ticks will not be used): - if dim_uv[0] >= dim_uv[1]: - u_bin, v_bin = np.max((2, np.floor(9 * dim_uv[1] / dim_uv[0]))), 9 - else: - u_bin, v_bin = 9, np.max((2, np.floor(9 * dim_uv[0] / dim_uv[1]))) - axis.xaxis.set_major_locator(MaxNLocator(nbins=u_bin, integer=True)) - axis.yaxis.set_major_locator(MaxNLocator(nbins=v_bin, integer=True)) + if not keep_labels: + # Set_title + axis.set_title(title, fontsize=fontsize) + # Get dimensions: + bb0 = axis.transLimits.inverted().transform((0, 0)) + bb1 = axis.transLimits.inverted().transform((1, 1)) + dim_uv = (int(abs(bb1[1] - bb0[1])), int(abs(bb1[0] - bb0[0]))) + # Set the title and the axes labels: + axis.set_xlim(0, dim_uv[1]) + axis.set_ylim(0, dim_uv[0]) + # Determine major tick locations (useful for grid, even if ticks will not be used): + if dim_uv[0] >= dim_uv[1]: + u_bin, v_bin = np.max((2, np.floor(9 * dim_uv[1] / dim_uv[0]))), 9 + else: + u_bin, v_bin = 9, np.max((2, np.floor(9 * dim_uv[0] / dim_uv[1]))) + axis.xaxis.set_major_locator(MaxNLocator(nbins=u_bin, integer=True)) + axis.yaxis.set_major_locator(MaxNLocator(nbins=v_bin, integer=True)) # Hide axes label and ticks if wanted: if hideaxes: for tic in axis.xaxis.get_major_ticks(): @@ -298,14 +330,21 @@ def format_axis(axis, format_axis=True, title='', fontsize=None, stroke=None, sc tic.tick1On = tic.tick2On = False tic.label1On = tic.label2On = False else: # Set the axes ticks and labels: - axis.set_xlabel('u-axis [nm]', fontsize=fontsize) - axis.set_ylabel('v-axis [nm]', fontsize=fontsize) - axis.xaxis.set_major_formatter(FuncFormatter(lambda x, pos: '{:.3g}'.format(x * sampling))) - axis.yaxis.set_major_formatter(FuncFormatter(lambda x, pos: '{:.3g}'.format(x * sampling))) + if not keep_labels: + axis.set_xlabel('u-axis [nm]') + axis.set_ylabel('v-axis [nm]') + func_formatter = FuncFormatter(lambda x, pos: '{:.3g}'.format(x * sampling)) + axis.xaxis.set_major_formatter(func_formatter) + axis.yaxis.set_major_formatter(func_formatter) axis.tick_params(axis='both', which='major', labelsize=fontsize) + axis.xaxis.label.set_size(fontsize) + axis.yaxis.label.set_size(fontsize) # Add annotation: if note: add_annotation(axis, label=note, fontsize=fontsize, stroke=stroke) + # Add coords: + if coords: + add_coords(axis, coords=coords) # Add colorhweel: if colorwheel: add_colorwheel(axis) diff --git a/pyramid/projector.py b/pyramid/projector.py index eddcebbc4b3cbcee4b25124a707692469b21ace0..a1c58d223b69e85ef64b2726f8dd13e9c02cd4ad 100644 --- a/pyramid/projector.py +++ b/pyramid/projector.py @@ -660,6 +660,7 @@ class SimpleProjector(Projector): for row in range(size_2d)]).reshape(-1) elif axis == 'x': self._log.debug('Projection along the x-axis') + # TODO: is coordinate switch really necessary? Better other way??? coeff = [[0, 0, 1], [0, 1, 0]] # Caution, coordinate switch: u, v --> z, y (not y, z!) indices = np.array( [np.arange(dim_x) + (row % dim_z) * dim_x * dim_y + row // dim_z * dim_x diff --git a/pyramid/ramp.py b/pyramid/ramp.py index 2d08b8ff9fa81e6ed0ba0135c3ea32cce5d28f8d..5b4f9b54d984f48ce344ddca9098d627697f829d 100644 --- a/pyramid/ramp.py +++ b/pyramid/ramp.py @@ -207,7 +207,7 @@ class Ramp(object): Dimensions of the 2D mesh that should be created. params : list List of ramp parameters. The first entry corresponds to a simple offset, the second - and third correspond to a linear ramp in u- and v-direction, respectively and so on. + and third correspond to a linear ramp in v- and u-direction, respectively and so on. Returns ------- diff --git a/pyramid/tests/test_phasemapper.py b/pyramid/tests/test_phasemapper.py index 6da41ecce874692b7d41353ecffc9abf44b0ee74..14870c5b3b5ae968dce96644946a2dbcae84be3d 100644 --- a/pyramid/tests/test_phasemapper.py +++ b/pyramid/tests/test_phasemapper.py @@ -180,20 +180,21 @@ class TestCasePhaseMapperMIP(unittest.TestCase): class TestCasePhaseMapperCharge(unittest.TestCase): def setUp(self): self.path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'test_phasemapper') - self.Charge_proj = load_scalardata(os.path.join(self.path, 'Charge_proj.hdf5')) - self.mapper = PhaseMapperCharge(self.Charge_proj.a, self.Charge_proj.dim[1:]) + self.charge_proj = load_scalardata(os.path.join(self.path, 'charge_proj.hdf5')) + self.mapper = PhaseMapperCharge(self.charge_proj.a, self.charge_proj.dim[1:], + electrode_vec=(8, 8), v_acc=300000) def tearDown(self): self.path = None - self.Charge_proj = None + self.charge_proj = None self.mapper = None def test_call(self): - Charge_phase_ref = load_phasemap(os.path.join(self.path, 'Charge_phase_ref.hdf5')) - phasemap = self.mapper(self.Charge_proj) - assert_allclose(phasemap.phase, Charge_phase_ref.phase, atol=1E-7, + charge_phase_ref = load_phasemap(os.path.join(self.path, 'charge_phase_ref.hdf5')) + phasemap = self.mapper(self.charge_proj) + assert_allclose(phasemap.phase, charge_phase_ref.phase, atol=1E-7, err_msg='Unexpected behavior in __call__()!') - assert_allclose(phasemap.a, Charge_phase_ref.a, err_msg='Unexpected behavior in __call__()!') + assert_allclose(phasemap.a, charge_phase_ref.a, err_msg='Unexpected behavior in __call__()!') def test_jac_dot(self): self.assertRaises(NotImplementedError, self.mapper.jac_dot, None) diff --git a/pyramid/utils/__init__.py b/pyramid/utils/__init__.py index 1c550a0fc3284040769e6291f7c496e41382355c..54bbf950597e88a8886c11058cb65ed232c94b65 100644 --- a/pyramid/utils/__init__.py +++ b/pyramid/utils/__init__.py @@ -7,8 +7,8 @@ from .pm import pm from .reconstruction_2d_from_phasemap import reconstruction_2d_from_phasemap from .reconstruction_3d_from_magdata import reconstruction_3d_from_magdata -from .phasemap_creator import gui_phasemap_creator -from .mag_slicer import gui_mag_slicer +#from .phasemap_creator import gui_phasemap_creator +#from .mag_slicer import gui_mag_slicer -__all__ = ['pm', 'reconstruction_2d_from_phasemap', 'reconstruction_3d_from_magdata', - 'gui_phasemap_creator', 'gui_mag_slicer'] +__all__ = ['pm', 'reconstruction_2d_from_phasemap', 'reconstruction_3d_from_magdata']#, + #'gui_phasemap_creator', 'gui_mag_slicer'] diff --git a/pyramid/utils/mag_slicer.py b/pyramid/utils/mag_slicer.py index 52d480b9e82a3b8cc536612bd8bf557881f718d0..b44e8dd139bd35d246f0083f67dbc78c997a7012 100644 --- a/pyramid/utils/mag_slicer.py +++ b/pyramid/utils/mag_slicer.py @@ -12,9 +12,6 @@ import logging import os import sys -from PyQt4 import QtGui, QtCore -from PyQt4.uic import loadUiType - from matplotlib.figure import Figure from matplotlib.backends.backend_qt4agg import FigureCanvasQTAgg as FigureCanvas from matplotlib.backends.backend_qt4agg import NavigationToolbar2QT as NavigationToolbar @@ -24,6 +21,13 @@ from ..kernel import Kernel from ..phasemapper import PhaseMapperRDFC from ..file_io.io_vectordata import load_vectordata +try: + from PyQt5 import QtGui, QtCore + from PyQt5.uic import loadUiType +except ImportError: + from PyQt4 import QtGui, QtCore + from PyQt4.uic import loadUiType + __all__ = ['gui_mag_slicer'] _log = logging.getLogger(__name__) diff --git a/pyramid/utils/phasemap_creator.py b/pyramid/utils/phasemap_creator.py index ac49c08698d6b204db656b538f2d516c2349bb9d..9058e2189d8a910f7b8ce6659b655cab02d7417f 100644 --- a/pyramid/utils/phasemap_creator.py +++ b/pyramid/utils/phasemap_creator.py @@ -13,9 +13,6 @@ import logging import os import sys -from PyQt4 import QtGui, QtCore -from PyQt4.uic import loadUiType - from matplotlib.figure import Figure from matplotlib.backends.backend_qt4agg import FigureCanvasQTAgg as FigureCanvas from matplotlib.backends.backend_qt4agg import NavigationToolbar2QT as NavigationToolbar @@ -24,10 +21,17 @@ from PIL import Image import numpy as np -import hyperspy.api as hs +#import hyperspy.api as hs # TODO: Necessary? import pyramid as pr +try: + from PyQt5 import QtGui, QtCore + from PyQt5.uic import loadUiType +except ImportError: + from PyQt4 import QtGui, QtCore + from PyQt4.uic import loadUiType + __all__ = ['gui_phasemap_creator'] _log = logging.getLogger(__name__) diff --git a/pyramid/utils/pm.py b/pyramid/utils/pm.py index 5c4232b25e6030953bfdd541ea07964ced64fbfb..e438125cfcb6e669bde3d8c10d073497f321aa5b 100644 --- a/pyramid/utils/pm.py +++ b/pyramid/utils/pm.py @@ -37,6 +37,8 @@ def pm(magdata, mode='z', b_0=1, mapper='RDFC', **kwargs): """ _log.debug('Calling pm') + # In case of FDFC: + padding = kwargs.pop('padding', 0) # Determine projection mode: if mode == 'rot-tilt': projector = RotTiltProjector(magdata.dim, **kwargs) @@ -54,7 +56,6 @@ def pm(magdata, mode='z', b_0=1, mapper='RDFC', **kwargs): if mapper == 'RDFC': phasemapper = PhaseMapperRDFC(Kernel(magdata.a, projector.dim_uv, b_0=b_0)) elif mapper == 'FDFC': - padding = kwargs.get('padding', 0) phasemapper = PhaseMapperFDFC(magdata.a, projector.dim_uv, b_0=b_0, padding=padding) else: raise ValueError("Invalid mapper (use 'RDFC' or 'FDFC'") diff --git a/pyramid/utils/reconstruction_2d_from_phasemap.py b/pyramid/utils/reconstruction_2d_from_phasemap.py index 0272303b0a4e84a77eb62f2db2e3bc22f8ff5800..14ad6868a1c174b337cef6d4b90558a18bfcaa9e 100644 --- a/pyramid/utils/reconstruction_2d_from_phasemap.py +++ b/pyramid/utils/reconstruction_2d_from_phasemap.py @@ -92,10 +92,10 @@ def reconstruction_2d_from_phasemap(phasemap, b_0=1, lam=1E-3, max_iter=100, ram if ramp_order is not None: if ramp_order >= 0: print('offset:', offset) - title += ', fitted Offset: {:.2g} [rad]'.format(offset) + # title += ', fitted Offset: {:.2g} [rad]'.format(offset) if ramp_order >= 1: print('ramp:', ramp) - title += ', (Fitted Ramp: (u:{:.2g}, v:{:.2g}) [rad/nm]'.format(*ramp) + # title += ', (Fitted Ramp: (u:{:.2g}, v:{:.2g}) [rad/nm]'.format(*ramp) phasemap_rec.plot_combined(note=title, gain=gain, vmin=vmin, vmax=vmax) diff = (phasemap_rec - phasemap) diff_name = 'Difference (RMS: {:.2g} rad)'.format(np.sqrt(np.mean(diff.phase) ** 2)) diff --git a/pyramid/utils/reconstruction_3d_from_magdata.py b/pyramid/utils/reconstruction_3d_from_magdata.py index db10f9253e55a9cdb93b7907fbc3b2da2394f52c..fcb649d2da5f7b2c0909d3957ba978b41195a79e 100644 --- a/pyramid/utils/reconstruction_3d_from_magdata.py +++ b/pyramid/utils/reconstruction_3d_from_magdata.py @@ -57,7 +57,7 @@ def reconstruction_3d_from_magdata(magdata, b_0=1, lam=1E-3, max_iter=100, ramp_ default. noise: float, optional If this is not zero, random gaussian noise with this as a maximum value will be applied - to all calculated phasemaps. The default is 0. + to all calculated phasemaps. The default is 0. The unit is radians. offset_max: float, optional if this is not zero, a random offset with this as a maximum value will be applied to all calculated phasemaps. The default is 0. @@ -115,6 +115,7 @@ def reconstruction_3d_from_magdata(magdata, b_0=1, lam=1E-3, max_iter=100, ramp_ ramp_u = np.random.uniform(-ramp_max, ramp_max) ramp_v = np.random.uniform(-ramp_max, ramp_max) phasemap += Ramp.create_ramp(phasemap.a, phasemap.dim_uv, (offset, ramp_u, ramp_v)) + data.phasemaps[i] = phasemap # Add noise if necessary: if noise != 0: for i, phasemap in enumerate(data.phasemaps): @@ -142,7 +143,7 @@ def reconstruction_3d_from_magdata(magdata, b_0=1, lam=1E-3, max_iter=100, ramp_ data.plot_phasemaps() # Plot results: if plot_results: - data.plot_mask(ar_dens=ar_dens) + data.plot_mask() magdata.plot_quiver3d('Original Distribution', ar_dens=ar_dens) magdata_rec.plot_quiver3d('Reconstructed Distribution (angle)', ar_dens=ar_dens) magdata_rec.plot_quiver3d('Reconstructed Distribution (amplitude)',